From 1b60982946881264248c29131ef8d5ebc7ff9c2e Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Thu, 16 Oct 2025 18:35:32 +1100 Subject: [PATCH] Update unit tests --- internal/autotag/integration_test.go | 7 +- internal/identify/scene_test.go | 2 +- internal/identify/studio_test.go | 6 +- pkg/gallery/import_test.go | 8 +- pkg/group/import_test.go | 8 +- pkg/image/import_test.go | 8 +- pkg/models/model_scraped_item_test.go | 2 +- pkg/scene/import_test.go | 8 +- pkg/sqlite/setup_test.go | 25 +- pkg/sqlite/studio_test.go | 559 +++++++++++++++++++++++++- pkg/studio/import_test.go | 18 +- 11 files changed, 605 insertions(+), 46 deletions(-) diff --git a/internal/autotag/integration_test.go b/internal/autotag/integration_test.go index fc83df848..605082b98 100644 --- a/internal/autotag/integration_test.go +++ b/internal/autotag/integration_test.go @@ -101,16 +101,15 @@ func createPerformer(ctx context.Context, pqb models.PerformerWriter) error { func createStudio(ctx context.Context, qb models.StudioWriter, name string) (*models.Studio, error) { // create the studio - studio := models.Studio{ - Name: name, - } + studio := models.NewCreateStudioInput() + studio.Name = name err := qb.Create(ctx, &studio) if err != nil { return nil, err } - return &studio, nil + return studio.Studio, nil } func createTag(ctx context.Context, qb models.TagWriter) error { diff --git a/internal/identify/scene_test.go b/internal/identify/scene_test.go index a76aef516..0eec61c4e 100644 --- a/internal/identify/scene_test.go +++ b/internal/identify/scene_test.go @@ -27,7 +27,7 @@ func Test_sceneRelationships_studio(t *testing.T) { db := mocks.NewDatabase() db.Studio.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) + s := args.Get(1).(*models.CreateStudioInput) s.ID = validStoredIDInt }).Return(nil) diff --git a/internal/identify/studio_test.go b/internal/identify/studio_test.go index 5424a6a93..083675650 100644 --- a/internal/identify/studio_test.go +++ b/internal/identify/studio_test.go @@ -21,13 +21,13 @@ func Test_createMissingStudio(t *testing.T) { db := mocks.NewDatabase() - db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool { + db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateStudioInput) bool { return p.Name == validName })).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) + s := args.Get(1).(*models.CreateStudioInput) s.ID = createdID }).Return(nil) - db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool { + db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateStudioInput) bool { return p.Name == invalidName })).Return(errors.New("error creating studio")) diff --git a/pkg/gallery/import_test.go b/pkg/gallery/import_test.go index b64f80d8f..4248f51bc 100644 --- a/pkg/gallery/import_test.go +++ b/pkg/gallery/import_test.go @@ -115,9 +115,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) - s.ID = existingStudioID + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) { + s := args.Get(1).(*models.CreateStudioInput) + s.Studio.ID = existingStudioID }).Return(nil) err := i.PreImport(testCtx) @@ -147,7 +147,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/group/import_test.go b/pkg/group/import_test.go index c4ca47442..50b8b2dd1 100644 --- a/pkg/group/import_test.go +++ b/pkg/group/import_test.go @@ -121,9 +121,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) - s.ID = existingStudioID + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) { + s := args.Get(1).(*models.CreateStudioInput) + s.Studio.ID = existingStudioID }).Return(nil) err := i.PreImport(testCtx) @@ -156,7 +156,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/image/import_test.go b/pkg/image/import_test.go index 286e51fe3..98b3972b9 100644 --- a/pkg/image/import_test.go +++ b/pkg/image/import_test.go @@ -77,9 +77,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) - s.ID = existingStudioID + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) { + s := args.Get(1).(*models.CreateStudioInput) + s.Studio.ID = existingStudioID }).Return(nil) err := i.PreImport(testCtx) @@ -109,7 +109,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/models/model_scraped_item_test.go b/pkg/models/model_scraped_item_test.go index b6b44025f..545543652 100644 --- a/pkg/models/model_scraped_item_test.go +++ b/pkg/models/model_scraped_item_test.go @@ -113,7 +113,7 @@ func Test_scrapedToStudioInput(t *testing.T) { got.StashIDs.List()[stid].UpdatedAt = time.Time{} } } - assert.Equal(t, tt.want, got) + assert.Equal(t, tt.want, got.Studio) }) } } diff --git a/pkg/scene/import_test.go b/pkg/scene/import_test.go index a6e3edcdf..558b72ba2 100644 --- a/pkg/scene/import_test.go +++ b/pkg/scene/import_test.go @@ -241,9 +241,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) - s.ID = existingStudioID + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) { + s := args.Get(1).(*models.CreateStudioInput) + s.Studio.ID = existingStudioID }).Return(nil) err := i.PreImport(testCtx) @@ -273,7 +273,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index 843b8b4c2..361b5cb79 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -1765,7 +1765,19 @@ func getStudioNullStringValue(index int, field string) string { return ret.String } -func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int) (*models.Studio, error) { +func getStudioCustomFields(index int) map[string]interface{} { + if index%5 == 0 { + return nil + } + + return map[string]interface{}{ + "string": getStudioStringValue(index, "custom"), + "int": int64(index % 5), + "real": float64(index) / 10, + } +} + +func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int, customFields map[string]interface{}) (*models.Studio, error) { studio := models.Studio{ Name: name, } @@ -1774,7 +1786,7 @@ func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, par studio.ParentID = parentID } - err := createStudioFromModel(ctx, sqb, &studio) + err := createStudioFromModel(ctx, sqb, &studio, customFields) if err != nil { return nil, err } @@ -1782,8 +1794,11 @@ func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, par return &studio, nil } -func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio) error { - err := sqb.Create(ctx, studio) +func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio, customFields map[string]interface{}) error { + err := sqb.Create(ctx, &models.CreateStudioInput{ + Studio: studio, + CustomFields: customFields, + }) if err != nil { return fmt.Errorf("Error creating studio %v+: %s", studio, err.Error()) @@ -1845,7 +1860,7 @@ func createStudios(ctx context.Context, n int, o int) error { alias := getStudioStringValue(i, "Alias") studio.Aliases = models.NewRelatedStrings([]string{alias}) } - err := createStudioFromModel(ctx, sqb, &studio) + err := createStudioFromModel(ctx, sqb, &studio, getStudioCustomFields(i)) if err != nil { return err diff --git a/pkg/sqlite/studio_test.go b/pkg/sqlite/studio_test.go index 003877c77..954d0f5aa 100644 --- a/pkg/sqlite/studio_test.go +++ b/pkg/sqlite/studio_test.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/stashapp/stash/pkg/models" "github.com/stretchr/testify/assert" @@ -47,6 +48,550 @@ func TestStudioFindByName(t *testing.T) { }) } +func loadStudioRelationships(ctx context.Context, expected models.Studio, actual *models.Studio) error { + if expected.Aliases.Loaded() { + if err := actual.LoadAliases(ctx, db.Studio); err != nil { + return err + } + } + if expected.TagIDs.Loaded() { + if err := actual.LoadTagIDs(ctx, db.Studio); err != nil { + return err + } + } + if expected.StashIDs.Loaded() { + if err := actual.LoadStashIDs(ctx, db.Studio); err != nil { + return err + } + } + + return nil +} + +func Test_StudioStore_Create(t *testing.T) { + var ( + name = "name" + details = "details" + url = "url" + rating = 3 + aliases = []string{"alias1", "alias2"} + ignoreAutoTag = true + favorite = true + endpoint1 = "endpoint1" + endpoint2 = "endpoint2" + stashID1 = "stashid1" + stashID2 = "stashid2" + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + ) + + tests := []struct { + name string + newObject models.CreateStudioInput + wantErr bool + }{ + { + "full", + models.CreateStudioInput{ + Studio: &models.Studio{ + Name: name, + URL: url, + Favorite: favorite, + Rating: &rating, + Details: details, + IgnoreAutoTag: ignoreAutoTag, + TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithStudio], tagIDs[tagIdx1WithDupName]}), + Aliases: models.NewRelatedStrings(aliases), + StashIDs: models.NewRelatedStashIDs([]models.StashID{ + { + StashID: stashID1, + Endpoint: endpoint1, + UpdatedAt: epochTime, + }, + { + StashID: stashID2, + Endpoint: endpoint2, + UpdatedAt: epochTime, + }, + }), + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + CustomFields: testCustomFields, + }, + false, + }, + { + "invalid tag id", + models.CreateStudioInput{ + Studio: &models.Studio{ + Name: name, + TagIDs: models.NewRelatedIDs([]int{invalidID}), + }, + }, + true, + }, + } + + qb := db.Studio + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + p := tt.newObject + if err := qb.Create(ctx, &p); (err != nil) != tt.wantErr { + t.Errorf("StudioStore.Create() error = %v, wantErr = %v", err, tt.wantErr) + } + + if tt.wantErr { + assert.Zero(p.ID) + return + } + + assert.NotZero(p.ID) + + copy := *tt.newObject.Studio + copy.ID = p.ID + + // load relationships + if err := loadStudioRelationships(ctx, copy, p.Studio); err != nil { + t.Errorf("loadStudioRelationships() error = %v", err) + return + } + + assert.Equal(copy, *p.Studio) + + // ensure can find the Studio + found, err := qb.Find(ctx, p.ID) + if err != nil { + t.Errorf("StudioStore.Find() error = %v", err) + } + + if !assert.NotNil(found) { + return + } + + // load relationships + if err := loadStudioRelationships(ctx, copy, found); err != nil { + t.Errorf("loadStudioRelationships() error = %v", err) + return + } + assert.Equal(copy, *found) + + // ensure custom fields are set + cf, err := qb.GetCustomFields(ctx, p.ID) + if err != nil { + t.Errorf("StudioStore.GetCustomFields() error = %v", err) + return + } + + assert.Equal(tt.newObject.CustomFields, cf) + + return + }) + } +} + +func Test_StudioStore_Update(t *testing.T) { + var ( + name = "name" + details = "details" + url = "url" + rating = 3 + aliases = []string{"aliasX", "aliasY"} + ignoreAutoTag = true + favorite = true + endpoint1 = "endpoint1" + endpoint2 = "endpoint2" + stashID1 = "stashid1" + stashID2 = "stashid2" + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + ) + + tests := []struct { + name string + updatedObject models.UpdateStudioInput + wantErr bool + }{ + { + "full", + models.UpdateStudioInput{ + Studio: &models.Studio{ + ID: studioIDs[studioIdxWithGallery], + Name: name, + URL: url, + Favorite: favorite, + Rating: &rating, + Details: details, + IgnoreAutoTag: ignoreAutoTag, + Aliases: models.NewRelatedStrings(aliases), + TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithStudio]}), + StashIDs: models.NewRelatedStashIDs([]models.StashID{ + { + StashID: stashID1, + Endpoint: endpoint1, + UpdatedAt: epochTime, + }, + { + StashID: stashID2, + Endpoint: endpoint2, + UpdatedAt: epochTime, + }, + }), + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + }, + false, + }, + { + "clear nullables", + models.UpdateStudioInput{ + Studio: &models.Studio{ + ID: studioIDs[studioIdxWithGallery], + Name: name, // name is mandatory + Aliases: models.NewRelatedStrings([]string{}), + TagIDs: models.NewRelatedIDs([]int{}), + StashIDs: models.NewRelatedStashIDs([]models.StashID{}), + }, + }, + false, + }, + { + "clear tag ids", + models.UpdateStudioInput{ + Studio: &models.Studio{ + ID: studioIDs[sceneIdxWithTag], + Name: name, // name is mandatory + TagIDs: models.NewRelatedIDs([]int{}), + }, + }, + false, + }, + { + "set custom fields", + models.UpdateStudioInput{ + Studio: &models.Studio{ + ID: studioIDs[studioIdxWithGallery], + Name: name, // name is mandatory + }, + CustomFields: models.CustomFieldsInput{ + Full: testCustomFields, + }, + }, + false, + }, + { + "clear custom fields", + models.UpdateStudioInput{ + Studio: &models.Studio{ + ID: studioIDs[studioIdxWithGallery], + Name: name, // name is mandatory + }, + CustomFields: models.CustomFieldsInput{ + Full: map[string]interface{}{}, + }, + }, + false, + }, + { + "invalid tag id", + models.UpdateStudioInput{ + Studio: &models.Studio{ + ID: studioIDs[sceneIdxWithGallery], + Name: name, // name is mandatory + TagIDs: models.NewRelatedIDs([]int{invalidID}), + }, + }, + true, + }, + } + + qb := db.Studio + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + copy := *tt.updatedObject.Studio + + if err := qb.Update(ctx, &tt.updatedObject); (err != nil) != tt.wantErr { + t.Errorf("StudioStore.Update() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + s, err := qb.Find(ctx, tt.updatedObject.ID) + if err != nil { + t.Errorf("StudioStore.Find() error = %v", err) + } + + // load relationships + if err := loadStudioRelationships(ctx, copy, s); err != nil { + t.Errorf("loadStudioRelationships() error = %v", err) + return + } + + assert.Equal(copy, *s) + + // ensure custom fields are correct + if tt.updatedObject.CustomFields.Full != nil { + cf, err := qb.GetCustomFields(ctx, tt.updatedObject.ID) + if err != nil { + t.Errorf("StudioStore.GetCustomFields() error = %v", err) + return + } + + assert.Equal(tt.updatedObject.CustomFields.Full, cf) + } + }) + } +} + +func clearStudioPartial() models.StudioPartial { + nullString := models.OptionalString{Set: true, Null: true} + nullInt := models.OptionalInt{Set: true, Null: true} + + // leave mandatory fields + return models.StudioPartial{ + URL: nullString, + Aliases: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet}, + Rating: nullInt, + Details: nullString, + TagIDs: &models.UpdateIDs{Mode: models.RelationshipUpdateModeSet}, + StashIDs: &models.UpdateStashIDs{Mode: models.RelationshipUpdateModeSet}, + } +} + +func Test_StudioStore_UpdatePartial(t *testing.T) { + var ( + name = "name" + details = "details" + url = "url" + aliases = []string{"aliasX", "aliasY"} + rating = 3 + ignoreAutoTag = true + favorite = true + endpoint1 = "endpoint1" + endpoint2 = "endpoint2" + stashID1 = "stashid1" + stashID2 = "stashid2" + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + ) + + tests := []struct { + name string + id int + partial models.StudioPartial + want models.Studio + wantErr bool + }{ + { + "full", + studioIDs[studioIdxWithDupName], + models.StudioPartial{ + Name: models.NewOptionalString(name), + URL: models.NewOptionalString(url), + Aliases: &models.UpdateStrings{ + Values: aliases, + Mode: models.RelationshipUpdateModeSet, + }, + Favorite: models.NewOptionalBool(favorite), + Rating: models.NewOptionalInt(rating), + Details: models.NewOptionalString(details), + IgnoreAutoTag: models.NewOptionalBool(ignoreAutoTag), + TagIDs: &models.UpdateIDs{ + IDs: []int{tagIDs[tagIdx1WithStudio], tagIDs[tagIdx1WithDupName]}, + Mode: models.RelationshipUpdateModeSet, + }, + StashIDs: &models.UpdateStashIDs{ + StashIDs: []models.StashID{ + { + StashID: stashID1, + Endpoint: endpoint1, + UpdatedAt: epochTime, + }, + { + StashID: stashID2, + Endpoint: endpoint2, + UpdatedAt: epochTime, + }, + }, + Mode: models.RelationshipUpdateModeSet, + }, + CreatedAt: models.NewOptionalTime(createdAt), + UpdatedAt: models.NewOptionalTime(updatedAt), + }, + models.Studio{ + ID: studioIDs[studioIdxWithDupName], + Name: name, + URL: url, + Aliases: models.NewRelatedStrings(aliases), + Favorite: favorite, + Rating: &rating, + Details: details, + IgnoreAutoTag: ignoreAutoTag, + TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithStudio]}), + StashIDs: models.NewRelatedStashIDs([]models.StashID{ + { + StashID: stashID1, + Endpoint: endpoint1, + UpdatedAt: epochTime, + }, + { + StashID: stashID2, + Endpoint: endpoint2, + UpdatedAt: epochTime, + }, + }), + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + false, + }, + { + "clear all", + studioIDs[studioIdxWithTwoTags], + clearStudioPartial(), + models.Studio{ + ID: studioIDs[studioIdxWithTwoTags], + Name: getStudioStringValue(studioIdxWithTwoTags, "Name"), + Favorite: getStudioBoolValue(studioIdxWithTwoTags), + Aliases: models.NewRelatedStrings([]string{}), + TagIDs: models.NewRelatedIDs([]int{}), + StashIDs: models.NewRelatedStashIDs([]models.StashID{}), + IgnoreAutoTag: getIgnoreAutoTag(studioIdxWithTwoTags), + }, + false, + }, + { + "invalid id", + invalidID, + models.StudioPartial{Name: models.NewOptionalString(name)}, + models.Studio{}, + true, + }, + } + for _, tt := range tests { + qb := db.Studio + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + tt.partial.ID = tt.id + + got, err := qb.UpdatePartial(ctx, tt.partial) + if (err != nil) != tt.wantErr { + t.Errorf("StudioStore.UpdatePartial() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return + } + + if err := loadStudioRelationships(ctx, tt.want, got); err != nil { + t.Errorf("loadStudioRelationships() error = %v", err) + return + } + + assert.Equal(tt.want, *got) + + s, err := qb.Find(ctx, tt.id) + if err != nil { + t.Errorf("StudioStore.Find() error = %v", err) + } + + // load relationships + if err := loadStudioRelationships(ctx, tt.want, s); err != nil { + t.Errorf("loadStudioRelationships() error = %v", err) + return + } + + assert.Equal(tt.want, *s) + }) + } +} + +func Test_StudioStore_UpdatePartialCustomFields(t *testing.T) { + tests := []struct { + name string + id int + partial models.StudioPartial + expected map[string]interface{} // nil to use the partial + }{ + { + "set custom fields", + studioIDs[studioIdxWithGallery], + models.StudioPartial{ + CustomFields: models.CustomFieldsInput{ + Full: testCustomFields, + }, + }, + nil, + }, + { + "clear custom fields", + studioIDs[studioIdxWithGallery], + models.StudioPartial{ + CustomFields: models.CustomFieldsInput{ + Full: map[string]interface{}{}, + }, + }, + nil, + }, + { + "partial custom fields", + studioIDs[studioIdxWithGallery], + models.StudioPartial{ + CustomFields: models.CustomFieldsInput{ + Partial: map[string]interface{}{ + "string": "bbb", + "new_field": "new", + }, + }, + }, + map[string]interface{}{ + "int": int64(2), + "real": 0.7, + "string": "bbb", + "new_field": "new", + }, + }, + } + for _, tt := range tests { + qb := db.Studio + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + tt.partial.ID = tt.id + + _, err := qb.UpdatePartial(ctx, tt.partial) + if err != nil { + t.Errorf("StudioStore.UpdatePartial() error = %v", err) + return + } + + // ensure custom fields are correct + cf, err := qb.GetCustomFields(ctx, tt.id) + if err != nil { + t.Errorf("StudioStore.GetCustomFields() error = %v", err) + return + } + if tt.expected == nil { + assert.Equal(tt.partial.CustomFields.Full, cf) + } else { + assert.Equal(tt.expected, cf) + } + }) + } +} + func TestStudioQueryNameOr(t *testing.T) { const studio1Idx = 1 const studio2Idx = 2 @@ -311,13 +856,13 @@ func TestStudioDestroyParent(t *testing.T) { // create parent and child studios if err := withTxn(func(ctx context.Context) error { - createdParent, err := createStudio(ctx, db.Studio, parentName, nil) + createdParent, err := createStudio(ctx, db.Studio, parentName, nil, nil) if err != nil { return fmt.Errorf("Error creating parent studio: %s", err.Error()) } parentID := createdParent.ID - createdChild, err := createStudio(ctx, db.Studio, childName, &parentID) + createdChild, err := createStudio(ctx, db.Studio, childName, &parentID, nil) if err != nil { return fmt.Errorf("Error creating child studio: %s", err.Error()) } @@ -373,13 +918,13 @@ func TestStudioUpdateClearParent(t *testing.T) { // create parent and child studios if err := withTxn(func(ctx context.Context) error { - createdParent, err := createStudio(ctx, db.Studio, parentName, nil) + createdParent, err := createStudio(ctx, db.Studio, parentName, nil, nil) if err != nil { return fmt.Errorf("Error creating parent studio: %s", err.Error()) } parentID := createdParent.ID - createdChild, err := createStudio(ctx, db.Studio, childName, &parentID) + createdChild, err := createStudio(ctx, db.Studio, childName, &parentID, nil) if err != nil { return fmt.Errorf("Error creating child studio: %s", err.Error()) } @@ -414,7 +959,7 @@ func TestStudioUpdateStudioImage(t *testing.T) { // create studio to test against const name = "TestStudioUpdateStudioImage" - created, err := createStudio(ctx, db.Studio, name, nil) + created, err := createStudio(ctx, db.Studio, name, nil, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } @@ -578,7 +1123,7 @@ func TestStudioStashIDs(t *testing.T) { // create studio to test against const name = "TestStudioStashIDs" - created, err := createStudio(ctx, db.Studio, name, nil) + created, err := createStudio(ctx, db.Studio, name, nil, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } @@ -990,7 +1535,7 @@ func TestStudioAlias(t *testing.T) { // create studio to test against const name = "TestStudioAlias" - created, err := createStudio(ctx, db.Studio, name, nil) + created, err := createStudio(ctx, db.Studio, name, nil, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } diff --git a/pkg/studio/import_test.go b/pkg/studio/import_test.go index 882b8ca56..6648ebe0d 100644 --- a/pkg/studio/import_test.go +++ b/pkg/studio/import_test.go @@ -206,9 +206,9 @@ func TestImporterPreImportWithMissingParent(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Times(3) - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) - s.ID = existingStudioID + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) { + s := args.Get(1).(*models.CreateStudioInput) + s.Studio.ID = existingStudioID }).Return(nil) err := i.PreImport(testCtx) @@ -240,7 +240,7 @@ func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) { } db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Once() - db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) @@ -327,11 +327,11 @@ func TestCreate(t *testing.T) { } errCreate := errors.New("Create error") - db.Studio.On("Create", testCtx, &studio).Run(func(args mock.Arguments) { - s := args.Get(1).(*models.Studio) + db.Studio.On("Create", testCtx, &models.CreateStudioInput{Studio: &studio}).Run(func(args mock.Arguments) { + s := args.Get(1).(*models.CreateStudioInput) s.ID = studioID }).Return(nil).Once() - db.Studio.On("Create", testCtx, &studioErr).Return(errCreate).Once() + db.Studio.On("Create", testCtx, &models.CreateStudioInput{Studio: &studioErr}).Return(errCreate).Once() id, err := i.Create(testCtx) assert.Equal(t, studioID, *id) @@ -366,7 +366,7 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input studio.ID = studioID - db.Studio.On("Update", testCtx, &studio).Return(nil).Once() + db.Studio.On("Update", testCtx, &models.UpdateStudioInput{Studio: &studio}).Return(nil).Once() err := i.Update(testCtx, studioID) assert.Nil(t, err) @@ -375,7 +375,7 @@ func TestUpdate(t *testing.T) { // need to set id separately studioErr.ID = errImageID - db.Studio.On("Update", testCtx, &studioErr).Return(errUpdate).Once() + db.Studio.On("Update", testCtx, &models.UpdateStudioInput{Studio: &studioErr}).Return(errUpdate).Once() err = i.Update(testCtx, errImageID) assert.NotNil(t, err)