Update unit tests

This commit is contained in:
WithoutPants 2025-10-16 18:35:32 +11:00
parent ff4a102a86
commit 1b60982946
11 changed files with 605 additions and 46 deletions

View file

@ -101,16 +101,15 @@ func createPerformer(ctx context.Context, pqb models.PerformerWriter) error {
func createStudio(ctx context.Context, qb models.StudioWriter, name string) (*models.Studio, error) {
// create the studio
studio := models.Studio{
Name: name,
}
studio := models.NewCreateStudioInput()
studio.Name = name
err := qb.Create(ctx, &studio)
if err != nil {
return nil, err
}
return &studio, nil
return studio.Studio, nil
}
func createTag(ctx context.Context, qb models.TagWriter) error {

View file

@ -27,7 +27,7 @@ func Test_sceneRelationships_studio(t *testing.T) {
db := mocks.NewDatabase()
db.Studio.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s := args.Get(1).(*models.CreateStudioInput)
s.ID = validStoredIDInt
}).Return(nil)

View file

@ -21,13 +21,13 @@ func Test_createMissingStudio(t *testing.T) {
db := mocks.NewDatabase()
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateStudioInput) bool {
return p.Name == validName
})).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s := args.Get(1).(*models.CreateStudioInput)
s.ID = createdID
}).Return(nil)
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateStudioInput) bool {
return p.Name == invalidName
})).Return(errors.New("error creating studio"))

View file

@ -115,9 +115,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.CreateStudioInput)
s.Studio.ID = existingStudioID
}).Return(nil)
err := i.PreImport(testCtx)
@ -147,7 +147,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)

View file

@ -121,9 +121,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.CreateStudioInput)
s.Studio.ID = existingStudioID
}).Return(nil)
err := i.PreImport(testCtx)
@ -156,7 +156,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)

View file

@ -77,9 +77,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.CreateStudioInput)
s.Studio.ID = existingStudioID
}).Return(nil)
err := i.PreImport(testCtx)
@ -109,7 +109,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)

View file

@ -113,7 +113,7 @@ func Test_scrapedToStudioInput(t *testing.T) {
got.StashIDs.List()[stid].UpdatedAt = time.Time{}
}
}
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.want, got.Studio)
})
}
}

View file

@ -241,9 +241,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.CreateStudioInput)
s.Studio.ID = existingStudioID
}).Return(nil)
err := i.PreImport(testCtx)
@ -273,7 +273,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)

View file

@ -1765,7 +1765,19 @@ func getStudioNullStringValue(index int, field string) string {
return ret.String
}
func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int) (*models.Studio, error) {
func getStudioCustomFields(index int) map[string]interface{} {
if index%5 == 0 {
return nil
}
return map[string]interface{}{
"string": getStudioStringValue(index, "custom"),
"int": int64(index % 5),
"real": float64(index) / 10,
}
}
func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int, customFields map[string]interface{}) (*models.Studio, error) {
studio := models.Studio{
Name: name,
}
@ -1774,7 +1786,7 @@ func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, par
studio.ParentID = parentID
}
err := createStudioFromModel(ctx, sqb, &studio)
err := createStudioFromModel(ctx, sqb, &studio, customFields)
if err != nil {
return nil, err
}
@ -1782,8 +1794,11 @@ func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, par
return &studio, nil
}
func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio) error {
err := sqb.Create(ctx, studio)
func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio, customFields map[string]interface{}) error {
err := sqb.Create(ctx, &models.CreateStudioInput{
Studio: studio,
CustomFields: customFields,
})
if err != nil {
return fmt.Errorf("Error creating studio %v+: %s", studio, err.Error())
@ -1845,7 +1860,7 @@ func createStudios(ctx context.Context, n int, o int) error {
alias := getStudioStringValue(i, "Alias")
studio.Aliases = models.NewRelatedStrings([]string{alias})
}
err := createStudioFromModel(ctx, sqb, &studio)
err := createStudioFromModel(ctx, sqb, &studio, getStudioCustomFields(i))
if err != nil {
return err

View file

@ -11,6 +11,7 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/stashapp/stash/pkg/models"
"github.com/stretchr/testify/assert"
@ -47,6 +48,550 @@ func TestStudioFindByName(t *testing.T) {
})
}
func loadStudioRelationships(ctx context.Context, expected models.Studio, actual *models.Studio) error {
if expected.Aliases.Loaded() {
if err := actual.LoadAliases(ctx, db.Studio); err != nil {
return err
}
}
if expected.TagIDs.Loaded() {
if err := actual.LoadTagIDs(ctx, db.Studio); err != nil {
return err
}
}
if expected.StashIDs.Loaded() {
if err := actual.LoadStashIDs(ctx, db.Studio); err != nil {
return err
}
}
return nil
}
func Test_StudioStore_Create(t *testing.T) {
var (
name = "name"
details = "details"
url = "url"
rating = 3
aliases = []string{"alias1", "alias2"}
ignoreAutoTag = true
favorite = true
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
)
tests := []struct {
name string
newObject models.CreateStudioInput
wantErr bool
}{
{
"full",
models.CreateStudioInput{
Studio: &models.Studio{
Name: name,
URL: url,
Favorite: favorite,
Rating: &rating,
Details: details,
IgnoreAutoTag: ignoreAutoTag,
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithStudio], tagIDs[tagIdx1WithDupName]}),
Aliases: models.NewRelatedStrings(aliases),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
UpdatedAt: epochTime,
},
{
StashID: stashID2,
Endpoint: endpoint2,
UpdatedAt: epochTime,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
},
CustomFields: testCustomFields,
},
false,
},
{
"invalid tag id",
models.CreateStudioInput{
Studio: &models.Studio{
Name: name,
TagIDs: models.NewRelatedIDs([]int{invalidID}),
},
},
true,
},
}
qb := db.Studio
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
p := tt.newObject
if err := qb.Create(ctx, &p); (err != nil) != tt.wantErr {
t.Errorf("StudioStore.Create() error = %v, wantErr = %v", err, tt.wantErr)
}
if tt.wantErr {
assert.Zero(p.ID)
return
}
assert.NotZero(p.ID)
copy := *tt.newObject.Studio
copy.ID = p.ID
// load relationships
if err := loadStudioRelationships(ctx, copy, p.Studio); err != nil {
t.Errorf("loadStudioRelationships() error = %v", err)
return
}
assert.Equal(copy, *p.Studio)
// ensure can find the Studio
found, err := qb.Find(ctx, p.ID)
if err != nil {
t.Errorf("StudioStore.Find() error = %v", err)
}
if !assert.NotNil(found) {
return
}
// load relationships
if err := loadStudioRelationships(ctx, copy, found); err != nil {
t.Errorf("loadStudioRelationships() error = %v", err)
return
}
assert.Equal(copy, *found)
// ensure custom fields are set
cf, err := qb.GetCustomFields(ctx, p.ID)
if err != nil {
t.Errorf("StudioStore.GetCustomFields() error = %v", err)
return
}
assert.Equal(tt.newObject.CustomFields, cf)
return
})
}
}
func Test_StudioStore_Update(t *testing.T) {
var (
name = "name"
details = "details"
url = "url"
rating = 3
aliases = []string{"aliasX", "aliasY"}
ignoreAutoTag = true
favorite = true
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
)
tests := []struct {
name string
updatedObject models.UpdateStudioInput
wantErr bool
}{
{
"full",
models.UpdateStudioInput{
Studio: &models.Studio{
ID: studioIDs[studioIdxWithGallery],
Name: name,
URL: url,
Favorite: favorite,
Rating: &rating,
Details: details,
IgnoreAutoTag: ignoreAutoTag,
Aliases: models.NewRelatedStrings(aliases),
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithStudio]}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
UpdatedAt: epochTime,
},
{
StashID: stashID2,
Endpoint: endpoint2,
UpdatedAt: epochTime,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
},
},
false,
},
{
"clear nullables",
models.UpdateStudioInput{
Studio: &models.Studio{
ID: studioIDs[studioIdxWithGallery],
Name: name, // name is mandatory
Aliases: models.NewRelatedStrings([]string{}),
TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
},
},
false,
},
{
"clear tag ids",
models.UpdateStudioInput{
Studio: &models.Studio{
ID: studioIDs[sceneIdxWithTag],
Name: name, // name is mandatory
TagIDs: models.NewRelatedIDs([]int{}),
},
},
false,
},
{
"set custom fields",
models.UpdateStudioInput{
Studio: &models.Studio{
ID: studioIDs[studioIdxWithGallery],
Name: name, // name is mandatory
},
CustomFields: models.CustomFieldsInput{
Full: testCustomFields,
},
},
false,
},
{
"clear custom fields",
models.UpdateStudioInput{
Studio: &models.Studio{
ID: studioIDs[studioIdxWithGallery],
Name: name, // name is mandatory
},
CustomFields: models.CustomFieldsInput{
Full: map[string]interface{}{},
},
},
false,
},
{
"invalid tag id",
models.UpdateStudioInput{
Studio: &models.Studio{
ID: studioIDs[sceneIdxWithGallery],
Name: name, // name is mandatory
TagIDs: models.NewRelatedIDs([]int{invalidID}),
},
},
true,
},
}
qb := db.Studio
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
copy := *tt.updatedObject.Studio
if err := qb.Update(ctx, &tt.updatedObject); (err != nil) != tt.wantErr {
t.Errorf("StudioStore.Update() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
return
}
s, err := qb.Find(ctx, tt.updatedObject.ID)
if err != nil {
t.Errorf("StudioStore.Find() error = %v", err)
}
// load relationships
if err := loadStudioRelationships(ctx, copy, s); err != nil {
t.Errorf("loadStudioRelationships() error = %v", err)
return
}
assert.Equal(copy, *s)
// ensure custom fields are correct
if tt.updatedObject.CustomFields.Full != nil {
cf, err := qb.GetCustomFields(ctx, tt.updatedObject.ID)
if err != nil {
t.Errorf("StudioStore.GetCustomFields() error = %v", err)
return
}
assert.Equal(tt.updatedObject.CustomFields.Full, cf)
}
})
}
}
func clearStudioPartial() models.StudioPartial {
nullString := models.OptionalString{Set: true, Null: true}
nullInt := models.OptionalInt{Set: true, Null: true}
// leave mandatory fields
return models.StudioPartial{
URL: nullString,
Aliases: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet},
Rating: nullInt,
Details: nullString,
TagIDs: &models.UpdateIDs{Mode: models.RelationshipUpdateModeSet},
StashIDs: &models.UpdateStashIDs{Mode: models.RelationshipUpdateModeSet},
}
}
func Test_StudioStore_UpdatePartial(t *testing.T) {
var (
name = "name"
details = "details"
url = "url"
aliases = []string{"aliasX", "aliasY"}
rating = 3
ignoreAutoTag = true
favorite = true
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
)
tests := []struct {
name string
id int
partial models.StudioPartial
want models.Studio
wantErr bool
}{
{
"full",
studioIDs[studioIdxWithDupName],
models.StudioPartial{
Name: models.NewOptionalString(name),
URL: models.NewOptionalString(url),
Aliases: &models.UpdateStrings{
Values: aliases,
Mode: models.RelationshipUpdateModeSet,
},
Favorite: models.NewOptionalBool(favorite),
Rating: models.NewOptionalInt(rating),
Details: models.NewOptionalString(details),
IgnoreAutoTag: models.NewOptionalBool(ignoreAutoTag),
TagIDs: &models.UpdateIDs{
IDs: []int{tagIDs[tagIdx1WithStudio], tagIDs[tagIdx1WithDupName]},
Mode: models.RelationshipUpdateModeSet,
},
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
UpdatedAt: epochTime,
},
{
StashID: stashID2,
Endpoint: endpoint2,
UpdatedAt: epochTime,
},
},
Mode: models.RelationshipUpdateModeSet,
},
CreatedAt: models.NewOptionalTime(createdAt),
UpdatedAt: models.NewOptionalTime(updatedAt),
},
models.Studio{
ID: studioIDs[studioIdxWithDupName],
Name: name,
URL: url,
Aliases: models.NewRelatedStrings(aliases),
Favorite: favorite,
Rating: &rating,
Details: details,
IgnoreAutoTag: ignoreAutoTag,
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithStudio]}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
UpdatedAt: epochTime,
},
{
StashID: stashID2,
Endpoint: endpoint2,
UpdatedAt: epochTime,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
},
false,
},
{
"clear all",
studioIDs[studioIdxWithTwoTags],
clearStudioPartial(),
models.Studio{
ID: studioIDs[studioIdxWithTwoTags],
Name: getStudioStringValue(studioIdxWithTwoTags, "Name"),
Favorite: getStudioBoolValue(studioIdxWithTwoTags),
Aliases: models.NewRelatedStrings([]string{}),
TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
IgnoreAutoTag: getIgnoreAutoTag(studioIdxWithTwoTags),
},
false,
},
{
"invalid id",
invalidID,
models.StudioPartial{Name: models.NewOptionalString(name)},
models.Studio{},
true,
},
}
for _, tt := range tests {
qb := db.Studio
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
tt.partial.ID = tt.id
got, err := qb.UpdatePartial(ctx, tt.partial)
if (err != nil) != tt.wantErr {
t.Errorf("StudioStore.UpdatePartial() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if err := loadStudioRelationships(ctx, tt.want, got); err != nil {
t.Errorf("loadStudioRelationships() error = %v", err)
return
}
assert.Equal(tt.want, *got)
s, err := qb.Find(ctx, tt.id)
if err != nil {
t.Errorf("StudioStore.Find() error = %v", err)
}
// load relationships
if err := loadStudioRelationships(ctx, tt.want, s); err != nil {
t.Errorf("loadStudioRelationships() error = %v", err)
return
}
assert.Equal(tt.want, *s)
})
}
}
func Test_StudioStore_UpdatePartialCustomFields(t *testing.T) {
tests := []struct {
name string
id int
partial models.StudioPartial
expected map[string]interface{} // nil to use the partial
}{
{
"set custom fields",
studioIDs[studioIdxWithGallery],
models.StudioPartial{
CustomFields: models.CustomFieldsInput{
Full: testCustomFields,
},
},
nil,
},
{
"clear custom fields",
studioIDs[studioIdxWithGallery],
models.StudioPartial{
CustomFields: models.CustomFieldsInput{
Full: map[string]interface{}{},
},
},
nil,
},
{
"partial custom fields",
studioIDs[studioIdxWithGallery],
models.StudioPartial{
CustomFields: models.CustomFieldsInput{
Partial: map[string]interface{}{
"string": "bbb",
"new_field": "new",
},
},
},
map[string]interface{}{
"int": int64(2),
"real": 0.7,
"string": "bbb",
"new_field": "new",
},
},
}
for _, tt := range tests {
qb := db.Studio
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
tt.partial.ID = tt.id
_, err := qb.UpdatePartial(ctx, tt.partial)
if err != nil {
t.Errorf("StudioStore.UpdatePartial() error = %v", err)
return
}
// ensure custom fields are correct
cf, err := qb.GetCustomFields(ctx, tt.id)
if err != nil {
t.Errorf("StudioStore.GetCustomFields() error = %v", err)
return
}
if tt.expected == nil {
assert.Equal(tt.partial.CustomFields.Full, cf)
} else {
assert.Equal(tt.expected, cf)
}
})
}
}
func TestStudioQueryNameOr(t *testing.T) {
const studio1Idx = 1
const studio2Idx = 2
@ -311,13 +856,13 @@ func TestStudioDestroyParent(t *testing.T) {
// create parent and child studios
if err := withTxn(func(ctx context.Context) error {
createdParent, err := createStudio(ctx, db.Studio, parentName, nil)
createdParent, err := createStudio(ctx, db.Studio, parentName, nil, nil)
if err != nil {
return fmt.Errorf("Error creating parent studio: %s", err.Error())
}
parentID := createdParent.ID
createdChild, err := createStudio(ctx, db.Studio, childName, &parentID)
createdChild, err := createStudio(ctx, db.Studio, childName, &parentID, nil)
if err != nil {
return fmt.Errorf("Error creating child studio: %s", err.Error())
}
@ -373,13 +918,13 @@ func TestStudioUpdateClearParent(t *testing.T) {
// create parent and child studios
if err := withTxn(func(ctx context.Context) error {
createdParent, err := createStudio(ctx, db.Studio, parentName, nil)
createdParent, err := createStudio(ctx, db.Studio, parentName, nil, nil)
if err != nil {
return fmt.Errorf("Error creating parent studio: %s", err.Error())
}
parentID := createdParent.ID
createdChild, err := createStudio(ctx, db.Studio, childName, &parentID)
createdChild, err := createStudio(ctx, db.Studio, childName, &parentID, nil)
if err != nil {
return fmt.Errorf("Error creating child studio: %s", err.Error())
}
@ -414,7 +959,7 @@ func TestStudioUpdateStudioImage(t *testing.T) {
// create studio to test against
const name = "TestStudioUpdateStudioImage"
created, err := createStudio(ctx, db.Studio, name, nil)
created, err := createStudio(ctx, db.Studio, name, nil, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
@ -578,7 +1123,7 @@ func TestStudioStashIDs(t *testing.T) {
// create studio to test against
const name = "TestStudioStashIDs"
created, err := createStudio(ctx, db.Studio, name, nil)
created, err := createStudio(ctx, db.Studio, name, nil, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
@ -990,7 +1535,7 @@ func TestStudioAlias(t *testing.T) {
// create studio to test against
const name = "TestStudioAlias"
created, err := createStudio(ctx, db.Studio, name, nil)
created, err := createStudio(ctx, db.Studio, name, nil, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}

View file

@ -206,9 +206,9 @@ func TestImporterPreImportWithMissingParent(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.CreateStudioInput)
s.Studio.ID = existingStudioID
}).Return(nil)
err := i.PreImport(testCtx)
@ -240,7 +240,7 @@ func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) {
}
db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
@ -327,11 +327,11 @@ func TestCreate(t *testing.T) {
}
errCreate := errors.New("Create error")
db.Studio.On("Create", testCtx, &studio).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
db.Studio.On("Create", testCtx, &models.CreateStudioInput{Studio: &studio}).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.CreateStudioInput)
s.ID = studioID
}).Return(nil).Once()
db.Studio.On("Create", testCtx, &studioErr).Return(errCreate).Once()
db.Studio.On("Create", testCtx, &models.CreateStudioInput{Studio: &studioErr}).Return(errCreate).Once()
id, err := i.Create(testCtx)
assert.Equal(t, studioID, *id)
@ -366,7 +366,7 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
studio.ID = studioID
db.Studio.On("Update", testCtx, &studio).Return(nil).Once()
db.Studio.On("Update", testCtx, &models.UpdateStudioInput{Studio: &studio}).Return(nil).Once()
err := i.Update(testCtx, studioID)
assert.Nil(t, err)
@ -375,7 +375,7 @@ func TestUpdate(t *testing.T) {
// need to set id separately
studioErr.ID = errImageID
db.Studio.On("Update", testCtx, &studioErr).Return(errUpdate).Once()
db.Studio.On("Update", testCtx, &models.UpdateStudioInput{Studio: &studioErr}).Return(errUpdate).Once()
err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)