使用适用于 Go 的 AWS SDK v2 进行单元测试 - 适用于 Go 的 AWS SDK v2

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

使用适用于 Go 的 AWS SDK v2 进行单元测试

在应用程序中使用 SDK 时,您需要为应用程序的单元测试模拟 SDK。模拟 SDK 可以让您的测试专注于想要测试的内容,而不是 SDK 的内部细节。

要支持模拟,请使用 Go 接口代替具体的服务客户端、分页器和 Waiter 类型,例如 s3.Client。这样,您的应用程序便可使用依赖注入等模式来测试您的应用程序逻辑。

模拟客户端操作

在本示例中,S3GetObjectAPI 是一个定义 GetObjectFromS3 函数所需 Amazon S3 API 操作集的接口。Amazon S3 客户端的 GetObject 方法能够满足 S3GetObjectAPI 的要求。

import "context" import "github.com/aws/aws-sdk-go-v2/service/s3" // ... type S3GetObjectAPI interface { GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) } func GetObjectFromS3(ctx context.Context, api S3GetObjectAPI, bucket, key string) ([]byte, error) { object, err := api.GetObject(ctx, &s3.GetObjectInput{ Bucket: &bucket, Key: &key, }) if err != nil { return nil, err } defer object.Body.Close() return ioutil.ReadAll(object.Body) }

要测试 GetObjectFromS3 函数,请使用 mockGetObjectAPI 来满足 S3GetObjectAPI 接口定义。然后使用 mockGetObjectAPI 类型来模拟服务客户端返回的输出和错误响应。

import "testing" import "github.com/aws/aws-sdk-go-v2/service/s3" // ... type mockGetObjectAPI func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) func (m mockGetObjectAPI) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { return m(ctx, params, optFns...) } func TestGetObjectFromS3(t *testing.T) { cases := []struct { client func(t *testing.T) S3GetObjectAPI bucket string key string expect []byte }{ { client: func(t *testing.T) S3GetObjectAPI { return mockGetObjectAPI(func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { t.Helper() if params.Bucket == nil { t.Fatal("expect bucket to not be nil") } if e, a := "fooBucket", *params.Bucket; e != a { t.Errorf("expect %v, got %v", e, a) } if params.Key == nil { t.Fatal("expect key to not be nil") } if e, a := "barKey", *params.Key; e != a { t.Errorf("expect %v, got %v", e, a) } return &s3.GetObjectOutput{ Body: ioutil.NopCloser(bytes.NewReader([]byte("this is the body foo bar baz"))), }, nil }) }, bucket: "amzn-s3-demo-bucket>", key: "barKey", expect: []byte("this is the body foo bar baz"), }, } for i, tt := range cases { t.Run(strconv.Itoa(i), func(t *testing.T) { ctx := context.TODO() content, err := GetObjectFromS3(ctx, tt.client(t), tt.bucket, tt.key) if err != nil { t.Fatalf("expect no error, got %v", err) } if e, a := tt.expect, content; bytes.Compare(e, a) != 0 { t.Errorf("expect %v, got %v", e, a) } }) } }

模拟分页器

与服务客户端类似,可以通过为分页器定义 Go 接口来模拟分页器。您的应用程序代码将使用该接口。这样便可在应用程序运行时使用 SDK 的实现,并允许将模拟实现用于测试。

在以下示例中,ListObjectsV2Pager 是一个定义 CountObjects 函数所需 Amazon S3 ListObjectsV2Paginator 行为的接口。

import "context" import "github.com/aws/aws-sdk-go-v2/service/s3" // ... type ListObjectsV2Pager interface { HasMorePages() bool NextPage(context.Context, ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) } func CountObjects(ctx context.Context, pager ListObjectsV2Pager) (count int, err error) { for pager.HasMorePages() { var output *s3.ListObjectsV2Output output, err = pager.NextPage(ctx) if err != nil { return count, err } count += int(output.KeyCount) } return count, nil }

要测试 CountObjects,请创建满足 ListObjectsV2Pager 接口定义的 mockListObjectsV2Pager 类型。然后使用 mockListObjectsV2Pager 复制来自服务操作分页器的输出和错误响应的分页行为。

import "context" import "fmt" import "testing" import "github.com/aws/aws-sdk-go-v2/service/s3" // ... type mockListObjectsV2Pager struct { PageNum int Pages []*s3.ListObjectsV2Output } func (m *mockListObjectsV2Pager) HasMorePages() bool { return m.PageNum < len(m.Pages) } func (m *mockListObjectsV2Pager) NextPage(ctx context.Context, f ...func(*s3.Options)) (output *s3.ListObjectsV2Output, err error) { if m.PageNum >= len(m.Pages) { return nil, fmt.Errorf("no more pages") } output = m.Pages[m.PageNum] m.PageNum++ return output, nil } func TestCountObjects(t *testing.T) { pager := &mockListObjectsV2Pager{ Pages: []*s3.ListObjectsV2Output{ { KeyCount: 5, }, { KeyCount: 10, }, { KeyCount: 15, }, }, } objects, err := CountObjects(context.TODO(), pager) if err != nil { t.Fatalf("expect no error, got %v", err) } if expect, actual := 30, objects; expect != actual { t.Errorf("expect %v, got %v", expect, actual) } }