mirror of
https://github.com/stashapp/stash.git
synced 2026-02-08 16:31:52 +01:00
Update unit tests
This commit is contained in:
parent
ff4a102a86
commit
1b60982946
11 changed files with 605 additions and 46 deletions
|
|
@ -101,16 +101,15 @@ func createPerformer(ctx context.Context, pqb models.PerformerWriter) error {
|
|||
|
||||
func createStudio(ctx context.Context, qb models.StudioWriter, name string) (*models.Studio, error) {
|
||||
// create the studio
|
||||
studio := models.Studio{
|
||||
Name: name,
|
||||
}
|
||||
studio := models.NewCreateStudioInput()
|
||||
studio.Name = name
|
||||
|
||||
err := qb.Create(ctx, &studio)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &studio, nil
|
||||
return studio.Studio, nil
|
||||
}
|
||||
|
||||
func createTag(ctx context.Context, qb models.TagWriter) error {
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ func Test_sceneRelationships_studio(t *testing.T) {
|
|||
db := mocks.NewDatabase()
|
||||
|
||||
db.Studio.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s := args.Get(1).(*models.CreateStudioInput)
|
||||
s.ID = validStoredIDInt
|
||||
}).Return(nil)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,13 +21,13 @@ func Test_createMissingStudio(t *testing.T) {
|
|||
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
|
||||
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateStudioInput) bool {
|
||||
return p.Name == validName
|
||||
})).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s := args.Get(1).(*models.CreateStudioInput)
|
||||
s.ID = createdID
|
||||
}).Return(nil)
|
||||
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
|
||||
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateStudioInput) bool {
|
||||
return p.Name == invalidName
|
||||
})).Return(errors.New("error creating studio"))
|
||||
|
||||
|
|
|
|||
|
|
@ -115,9 +115,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = existingStudioID
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.CreateStudioInput)
|
||||
s.Studio.ID = existingStudioID
|
||||
}).Return(nil)
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
|
|
@ -147,7 +147,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
|
|
@ -121,9 +121,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = existingStudioID
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.CreateStudioInput)
|
||||
s.Studio.ID = existingStudioID
|
||||
}).Return(nil)
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
|
|
@ -156,7 +156,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
|
|
@ -77,9 +77,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = existingStudioID
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.CreateStudioInput)
|
||||
s.Studio.ID = existingStudioID
|
||||
}).Return(nil)
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
|
|
@ -109,7 +109,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ func Test_scrapedToStudioInput(t *testing.T) {
|
|||
got.StashIDs.List()[stid].UpdatedAt = time.Time{}
|
||||
}
|
||||
}
|
||||
assert.Equal(t, tt.want, got)
|
||||
assert.Equal(t, tt.want, got.Studio)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -241,9 +241,9 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = existingStudioID
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.CreateStudioInput)
|
||||
s.Studio.ID = existingStudioID
|
||||
}).Return(nil)
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
|
|
@ -273,7 +273,7 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
|
|
@ -1765,7 +1765,19 @@ func getStudioNullStringValue(index int, field string) string {
|
|||
return ret.String
|
||||
}
|
||||
|
||||
func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int) (*models.Studio, error) {
|
||||
func getStudioCustomFields(index int) map[string]interface{} {
|
||||
if index%5 == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"string": getStudioStringValue(index, "custom"),
|
||||
"int": int64(index % 5),
|
||||
"real": float64(index) / 10,
|
||||
}
|
||||
}
|
||||
|
||||
func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int, customFields map[string]interface{}) (*models.Studio, error) {
|
||||
studio := models.Studio{
|
||||
Name: name,
|
||||
}
|
||||
|
|
@ -1774,7 +1786,7 @@ func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, par
|
|||
studio.ParentID = parentID
|
||||
}
|
||||
|
||||
err := createStudioFromModel(ctx, sqb, &studio)
|
||||
err := createStudioFromModel(ctx, sqb, &studio, customFields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -1782,8 +1794,11 @@ func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, par
|
|||
return &studio, nil
|
||||
}
|
||||
|
||||
func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio) error {
|
||||
err := sqb.Create(ctx, studio)
|
||||
func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio, customFields map[string]interface{}) error {
|
||||
err := sqb.Create(ctx, &models.CreateStudioInput{
|
||||
Studio: studio,
|
||||
CustomFields: customFields,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating studio %v+: %s", studio, err.Error())
|
||||
|
|
@ -1845,7 +1860,7 @@ func createStudios(ctx context.Context, n int, o int) error {
|
|||
alias := getStudioStringValue(i, "Alias")
|
||||
studio.Aliases = models.NewRelatedStrings([]string{alias})
|
||||
}
|
||||
err := createStudioFromModel(ctx, sqb, &studio)
|
||||
err := createStudioFromModel(ctx, sqb, &studio, getStudioCustomFields(i))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
|
@ -47,6 +48,550 @@ func TestStudioFindByName(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func loadStudioRelationships(ctx context.Context, expected models.Studio, actual *models.Studio) error {
|
||||
if expected.Aliases.Loaded() {
|
||||
if err := actual.LoadAliases(ctx, db.Studio); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if expected.TagIDs.Loaded() {
|
||||
if err := actual.LoadTagIDs(ctx, db.Studio); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if expected.StashIDs.Loaded() {
|
||||
if err := actual.LoadStashIDs(ctx, db.Studio); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_StudioStore_Create(t *testing.T) {
|
||||
var (
|
||||
name = "name"
|
||||
details = "details"
|
||||
url = "url"
|
||||
rating = 3
|
||||
aliases = []string{"alias1", "alias2"}
|
||||
ignoreAutoTag = true
|
||||
favorite = true
|
||||
endpoint1 = "endpoint1"
|
||||
endpoint2 = "endpoint2"
|
||||
stashID1 = "stashid1"
|
||||
stashID2 = "stashid2"
|
||||
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
newObject models.CreateStudioInput
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"full",
|
||||
models.CreateStudioInput{
|
||||
Studio: &models.Studio{
|
||||
Name: name,
|
||||
URL: url,
|
||||
Favorite: favorite,
|
||||
Rating: &rating,
|
||||
Details: details,
|
||||
IgnoreAutoTag: ignoreAutoTag,
|
||||
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithStudio], tagIDs[tagIdx1WithDupName]}),
|
||||
Aliases: models.NewRelatedStrings(aliases),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{
|
||||
{
|
||||
StashID: stashID1,
|
||||
Endpoint: endpoint1,
|
||||
UpdatedAt: epochTime,
|
||||
},
|
||||
{
|
||||
StashID: stashID2,
|
||||
Endpoint: endpoint2,
|
||||
UpdatedAt: epochTime,
|
||||
},
|
||||
}),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
},
|
||||
CustomFields: testCustomFields,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid tag id",
|
||||
models.CreateStudioInput{
|
||||
Studio: &models.Studio{
|
||||
Name: name,
|
||||
TagIDs: models.NewRelatedIDs([]int{invalidID}),
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
qb := db.Studio
|
||||
|
||||
for _, tt := range tests {
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
p := tt.newObject
|
||||
if err := qb.Create(ctx, &p); (err != nil) != tt.wantErr {
|
||||
t.Errorf("StudioStore.Create() error = %v, wantErr = %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Zero(p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NotZero(p.ID)
|
||||
|
||||
copy := *tt.newObject.Studio
|
||||
copy.ID = p.ID
|
||||
|
||||
// load relationships
|
||||
if err := loadStudioRelationships(ctx, copy, p.Studio); err != nil {
|
||||
t.Errorf("loadStudioRelationships() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(copy, *p.Studio)
|
||||
|
||||
// ensure can find the Studio
|
||||
found, err := qb.Find(ctx, p.ID)
|
||||
if err != nil {
|
||||
t.Errorf("StudioStore.Find() error = %v", err)
|
||||
}
|
||||
|
||||
if !assert.NotNil(found) {
|
||||
return
|
||||
}
|
||||
|
||||
// load relationships
|
||||
if err := loadStudioRelationships(ctx, copy, found); err != nil {
|
||||
t.Errorf("loadStudioRelationships() error = %v", err)
|
||||
return
|
||||
}
|
||||
assert.Equal(copy, *found)
|
||||
|
||||
// ensure custom fields are set
|
||||
cf, err := qb.GetCustomFields(ctx, p.ID)
|
||||
if err != nil {
|
||||
t.Errorf("StudioStore.GetCustomFields() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(tt.newObject.CustomFields, cf)
|
||||
|
||||
return
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_StudioStore_Update(t *testing.T) {
|
||||
var (
|
||||
name = "name"
|
||||
details = "details"
|
||||
url = "url"
|
||||
rating = 3
|
||||
aliases = []string{"aliasX", "aliasY"}
|
||||
ignoreAutoTag = true
|
||||
favorite = true
|
||||
endpoint1 = "endpoint1"
|
||||
endpoint2 = "endpoint2"
|
||||
stashID1 = "stashid1"
|
||||
stashID2 = "stashid2"
|
||||
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
updatedObject models.UpdateStudioInput
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"full",
|
||||
models.UpdateStudioInput{
|
||||
Studio: &models.Studio{
|
||||
ID: studioIDs[studioIdxWithGallery],
|
||||
Name: name,
|
||||
URL: url,
|
||||
Favorite: favorite,
|
||||
Rating: &rating,
|
||||
Details: details,
|
||||
IgnoreAutoTag: ignoreAutoTag,
|
||||
Aliases: models.NewRelatedStrings(aliases),
|
||||
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithStudio]}),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{
|
||||
{
|
||||
StashID: stashID1,
|
||||
Endpoint: endpoint1,
|
||||
UpdatedAt: epochTime,
|
||||
},
|
||||
{
|
||||
StashID: stashID2,
|
||||
Endpoint: endpoint2,
|
||||
UpdatedAt: epochTime,
|
||||
},
|
||||
}),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"clear nullables",
|
||||
models.UpdateStudioInput{
|
||||
Studio: &models.Studio{
|
||||
ID: studioIDs[studioIdxWithGallery],
|
||||
Name: name, // name is mandatory
|
||||
Aliases: models.NewRelatedStrings([]string{}),
|
||||
TagIDs: models.NewRelatedIDs([]int{}),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"clear tag ids",
|
||||
models.UpdateStudioInput{
|
||||
Studio: &models.Studio{
|
||||
ID: studioIDs[sceneIdxWithTag],
|
||||
Name: name, // name is mandatory
|
||||
TagIDs: models.NewRelatedIDs([]int{}),
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"set custom fields",
|
||||
models.UpdateStudioInput{
|
||||
Studio: &models.Studio{
|
||||
ID: studioIDs[studioIdxWithGallery],
|
||||
Name: name, // name is mandatory
|
||||
},
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: testCustomFields,
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"clear custom fields",
|
||||
models.UpdateStudioInput{
|
||||
Studio: &models.Studio{
|
||||
ID: studioIDs[studioIdxWithGallery],
|
||||
Name: name, // name is mandatory
|
||||
},
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid tag id",
|
||||
models.UpdateStudioInput{
|
||||
Studio: &models.Studio{
|
||||
ID: studioIDs[sceneIdxWithGallery],
|
||||
Name: name, // name is mandatory
|
||||
TagIDs: models.NewRelatedIDs([]int{invalidID}),
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
qb := db.Studio
|
||||
for _, tt := range tests {
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
copy := *tt.updatedObject.Studio
|
||||
|
||||
if err := qb.Update(ctx, &tt.updatedObject); (err != nil) != tt.wantErr {
|
||||
t.Errorf("StudioStore.Update() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
s, err := qb.Find(ctx, tt.updatedObject.ID)
|
||||
if err != nil {
|
||||
t.Errorf("StudioStore.Find() error = %v", err)
|
||||
}
|
||||
|
||||
// load relationships
|
||||
if err := loadStudioRelationships(ctx, copy, s); err != nil {
|
||||
t.Errorf("loadStudioRelationships() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(copy, *s)
|
||||
|
||||
// ensure custom fields are correct
|
||||
if tt.updatedObject.CustomFields.Full != nil {
|
||||
cf, err := qb.GetCustomFields(ctx, tt.updatedObject.ID)
|
||||
if err != nil {
|
||||
t.Errorf("StudioStore.GetCustomFields() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(tt.updatedObject.CustomFields.Full, cf)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func clearStudioPartial() models.StudioPartial {
|
||||
nullString := models.OptionalString{Set: true, Null: true}
|
||||
nullInt := models.OptionalInt{Set: true, Null: true}
|
||||
|
||||
// leave mandatory fields
|
||||
return models.StudioPartial{
|
||||
URL: nullString,
|
||||
Aliases: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet},
|
||||
Rating: nullInt,
|
||||
Details: nullString,
|
||||
TagIDs: &models.UpdateIDs{Mode: models.RelationshipUpdateModeSet},
|
||||
StashIDs: &models.UpdateStashIDs{Mode: models.RelationshipUpdateModeSet},
|
||||
}
|
||||
}
|
||||
|
||||
func Test_StudioStore_UpdatePartial(t *testing.T) {
|
||||
var (
|
||||
name = "name"
|
||||
details = "details"
|
||||
url = "url"
|
||||
aliases = []string{"aliasX", "aliasY"}
|
||||
rating = 3
|
||||
ignoreAutoTag = true
|
||||
favorite = true
|
||||
endpoint1 = "endpoint1"
|
||||
endpoint2 = "endpoint2"
|
||||
stashID1 = "stashid1"
|
||||
stashID2 = "stashid2"
|
||||
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
partial models.StudioPartial
|
||||
want models.Studio
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"full",
|
||||
studioIDs[studioIdxWithDupName],
|
||||
models.StudioPartial{
|
||||
Name: models.NewOptionalString(name),
|
||||
URL: models.NewOptionalString(url),
|
||||
Aliases: &models.UpdateStrings{
|
||||
Values: aliases,
|
||||
Mode: models.RelationshipUpdateModeSet,
|
||||
},
|
||||
Favorite: models.NewOptionalBool(favorite),
|
||||
Rating: models.NewOptionalInt(rating),
|
||||
Details: models.NewOptionalString(details),
|
||||
IgnoreAutoTag: models.NewOptionalBool(ignoreAutoTag),
|
||||
TagIDs: &models.UpdateIDs{
|
||||
IDs: []int{tagIDs[tagIdx1WithStudio], tagIDs[tagIdx1WithDupName]},
|
||||
Mode: models.RelationshipUpdateModeSet,
|
||||
},
|
||||
StashIDs: &models.UpdateStashIDs{
|
||||
StashIDs: []models.StashID{
|
||||
{
|
||||
StashID: stashID1,
|
||||
Endpoint: endpoint1,
|
||||
UpdatedAt: epochTime,
|
||||
},
|
||||
{
|
||||
StashID: stashID2,
|
||||
Endpoint: endpoint2,
|
||||
UpdatedAt: epochTime,
|
||||
},
|
||||
},
|
||||
Mode: models.RelationshipUpdateModeSet,
|
||||
},
|
||||
CreatedAt: models.NewOptionalTime(createdAt),
|
||||
UpdatedAt: models.NewOptionalTime(updatedAt),
|
||||
},
|
||||
models.Studio{
|
||||
ID: studioIDs[studioIdxWithDupName],
|
||||
Name: name,
|
||||
URL: url,
|
||||
Aliases: models.NewRelatedStrings(aliases),
|
||||
Favorite: favorite,
|
||||
Rating: &rating,
|
||||
Details: details,
|
||||
IgnoreAutoTag: ignoreAutoTag,
|
||||
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithStudio]}),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{
|
||||
{
|
||||
StashID: stashID1,
|
||||
Endpoint: endpoint1,
|
||||
UpdatedAt: epochTime,
|
||||
},
|
||||
{
|
||||
StashID: stashID2,
|
||||
Endpoint: endpoint2,
|
||||
UpdatedAt: epochTime,
|
||||
},
|
||||
}),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"clear all",
|
||||
studioIDs[studioIdxWithTwoTags],
|
||||
clearStudioPartial(),
|
||||
models.Studio{
|
||||
ID: studioIDs[studioIdxWithTwoTags],
|
||||
Name: getStudioStringValue(studioIdxWithTwoTags, "Name"),
|
||||
Favorite: getStudioBoolValue(studioIdxWithTwoTags),
|
||||
Aliases: models.NewRelatedStrings([]string{}),
|
||||
TagIDs: models.NewRelatedIDs([]int{}),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
|
||||
IgnoreAutoTag: getIgnoreAutoTag(studioIdxWithTwoTags),
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid id",
|
||||
invalidID,
|
||||
models.StudioPartial{Name: models.NewOptionalString(name)},
|
||||
models.Studio{},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
qb := db.Studio
|
||||
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
tt.partial.ID = tt.id
|
||||
|
||||
got, err := qb.UpdatePartial(ctx, tt.partial)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("StudioStore.UpdatePartial() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
if err := loadStudioRelationships(ctx, tt.want, got); err != nil {
|
||||
t.Errorf("loadStudioRelationships() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(tt.want, *got)
|
||||
|
||||
s, err := qb.Find(ctx, tt.id)
|
||||
if err != nil {
|
||||
t.Errorf("StudioStore.Find() error = %v", err)
|
||||
}
|
||||
|
||||
// load relationships
|
||||
if err := loadStudioRelationships(ctx, tt.want, s); err != nil {
|
||||
t.Errorf("loadStudioRelationships() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(tt.want, *s)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_StudioStore_UpdatePartialCustomFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
partial models.StudioPartial
|
||||
expected map[string]interface{} // nil to use the partial
|
||||
}{
|
||||
{
|
||||
"set custom fields",
|
||||
studioIDs[studioIdxWithGallery],
|
||||
models.StudioPartial{
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: testCustomFields,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"clear custom fields",
|
||||
studioIDs[studioIdxWithGallery],
|
||||
models.StudioPartial{
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"partial custom fields",
|
||||
studioIDs[studioIdxWithGallery],
|
||||
models.StudioPartial{
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Partial: map[string]interface{}{
|
||||
"string": "bbb",
|
||||
"new_field": "new",
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]interface{}{
|
||||
"int": int64(2),
|
||||
"real": 0.7,
|
||||
"string": "bbb",
|
||||
"new_field": "new",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
qb := db.Studio
|
||||
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
tt.partial.ID = tt.id
|
||||
|
||||
_, err := qb.UpdatePartial(ctx, tt.partial)
|
||||
if err != nil {
|
||||
t.Errorf("StudioStore.UpdatePartial() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// ensure custom fields are correct
|
||||
cf, err := qb.GetCustomFields(ctx, tt.id)
|
||||
if err != nil {
|
||||
t.Errorf("StudioStore.GetCustomFields() error = %v", err)
|
||||
return
|
||||
}
|
||||
if tt.expected == nil {
|
||||
assert.Equal(tt.partial.CustomFields.Full, cf)
|
||||
} else {
|
||||
assert.Equal(tt.expected, cf)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStudioQueryNameOr(t *testing.T) {
|
||||
const studio1Idx = 1
|
||||
const studio2Idx = 2
|
||||
|
|
@ -311,13 +856,13 @@ func TestStudioDestroyParent(t *testing.T) {
|
|||
|
||||
// create parent and child studios
|
||||
if err := withTxn(func(ctx context.Context) error {
|
||||
createdParent, err := createStudio(ctx, db.Studio, parentName, nil)
|
||||
createdParent, err := createStudio(ctx, db.Studio, parentName, nil, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating parent studio: %s", err.Error())
|
||||
}
|
||||
|
||||
parentID := createdParent.ID
|
||||
createdChild, err := createStudio(ctx, db.Studio, childName, &parentID)
|
||||
createdChild, err := createStudio(ctx, db.Studio, childName, &parentID, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating child studio: %s", err.Error())
|
||||
}
|
||||
|
|
@ -373,13 +918,13 @@ func TestStudioUpdateClearParent(t *testing.T) {
|
|||
|
||||
// create parent and child studios
|
||||
if err := withTxn(func(ctx context.Context) error {
|
||||
createdParent, err := createStudio(ctx, db.Studio, parentName, nil)
|
||||
createdParent, err := createStudio(ctx, db.Studio, parentName, nil, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating parent studio: %s", err.Error())
|
||||
}
|
||||
|
||||
parentID := createdParent.ID
|
||||
createdChild, err := createStudio(ctx, db.Studio, childName, &parentID)
|
||||
createdChild, err := createStudio(ctx, db.Studio, childName, &parentID, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating child studio: %s", err.Error())
|
||||
}
|
||||
|
|
@ -414,7 +959,7 @@ func TestStudioUpdateStudioImage(t *testing.T) {
|
|||
|
||||
// create studio to test against
|
||||
const name = "TestStudioUpdateStudioImage"
|
||||
created, err := createStudio(ctx, db.Studio, name, nil)
|
||||
created, err := createStudio(ctx, db.Studio, name, nil, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating studio: %s", err.Error())
|
||||
}
|
||||
|
|
@ -578,7 +1123,7 @@ func TestStudioStashIDs(t *testing.T) {
|
|||
|
||||
// create studio to test against
|
||||
const name = "TestStudioStashIDs"
|
||||
created, err := createStudio(ctx, db.Studio, name, nil)
|
||||
created, err := createStudio(ctx, db.Studio, name, nil, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating studio: %s", err.Error())
|
||||
}
|
||||
|
|
@ -990,7 +1535,7 @@ func TestStudioAlias(t *testing.T) {
|
|||
|
||||
// create studio to test against
|
||||
const name = "TestStudioAlias"
|
||||
created, err := createStudio(ctx, db.Studio, name, nil)
|
||||
created, err := createStudio(ctx, db.Studio, name, nil, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating studio: %s", err.Error())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -206,9 +206,9 @@ func TestImporterPreImportWithMissingParent(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Times(3)
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = existingStudioID
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.CreateStudioInput)
|
||||
s.Studio.ID = existingStudioID
|
||||
}).Return(nil)
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
|
|
@ -240,7 +240,7 @@ func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Once()
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.CreateStudioInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
|
@ -327,11 +327,11 @@ func TestCreate(t *testing.T) {
|
|||
}
|
||||
|
||||
errCreate := errors.New("Create error")
|
||||
db.Studio.On("Create", testCtx, &studio).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
db.Studio.On("Create", testCtx, &models.CreateStudioInput{Studio: &studio}).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.CreateStudioInput)
|
||||
s.ID = studioID
|
||||
}).Return(nil).Once()
|
||||
db.Studio.On("Create", testCtx, &studioErr).Return(errCreate).Once()
|
||||
db.Studio.On("Create", testCtx, &models.CreateStudioInput{Studio: &studioErr}).Return(errCreate).Once()
|
||||
|
||||
id, err := i.Create(testCtx)
|
||||
assert.Equal(t, studioID, *id)
|
||||
|
|
@ -366,7 +366,7 @@ func TestUpdate(t *testing.T) {
|
|||
|
||||
// id needs to be set for the mock input
|
||||
studio.ID = studioID
|
||||
db.Studio.On("Update", testCtx, &studio).Return(nil).Once()
|
||||
db.Studio.On("Update", testCtx, &models.UpdateStudioInput{Studio: &studio}).Return(nil).Once()
|
||||
|
||||
err := i.Update(testCtx, studioID)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -375,7 +375,7 @@ func TestUpdate(t *testing.T) {
|
|||
|
||||
// need to set id separately
|
||||
studioErr.ID = errImageID
|
||||
db.Studio.On("Update", testCtx, &studioErr).Return(errUpdate).Once()
|
||||
db.Studio.On("Update", testCtx, &models.UpdateStudioInput{Studio: &studioErr}).Return(errUpdate).Once()
|
||||
|
||||
err = i.Update(testCtx, errImageID)
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
Loading…
Reference in a new issue