diff --git a/graphql/schema/types/filters.graphql b/graphql/schema/types/filters.graphql index 52eec6785..f89cee3e2 100644 --- a/graphql/schema/types/filters.graphql +++ b/graphql/schema/types/filters.graphql @@ -650,6 +650,8 @@ input TagFilterType { "Filter by last update time" updated_at: TimestampCriterionInput + + custom_fields: [CustomFieldCriterionInput!] } input ImageFilterType { diff --git a/graphql/schema/types/tag.graphql b/graphql/schema/types/tag.graphql index a69b83548..2210c900e 100644 --- a/graphql/schema/types/tag.graphql +++ b/graphql/schema/types/tag.graphql @@ -24,6 +24,7 @@ type Tag { parent_count: Int! # Resolver child_count: Int! # Resolver + custom_fields: Map! } input TagCreateInput { @@ -41,6 +42,8 @@ input TagCreateInput { parent_ids: [ID!] child_ids: [ID!] + + custom_fields: Map } input TagUpdateInput { @@ -59,6 +62,8 @@ input TagUpdateInput { parent_ids: [ID!] child_ids: [ID!] + + custom_fields: CustomFieldsInput } input TagDestroyInput { diff --git a/internal/api/loaders/dataloaders.go b/internal/api/loaders/dataloaders.go index 4676966c9..ecb0bbac2 100644 --- a/internal/api/loaders/dataloaders.go +++ b/internal/api/loaders/dataloaders.go @@ -62,10 +62,11 @@ type Loaders struct { StudioByID *StudioLoader StudioCustomFields *CustomFieldsLoader - TagByID *TagLoader - GroupByID *GroupLoader - FileByID *FileLoader - FolderByID *FolderLoader + TagByID *TagLoader + TagCustomFields *CustomFieldsLoader + GroupByID *GroupLoader + FileByID *FileLoader + FolderByID *FolderLoader } type Middleware struct { @@ -116,6 +117,11 @@ func (m Middleware) Middleware(next http.Handler) http.Handler { maxBatch: maxBatch, fetch: m.fetchTags(ctx), }, + TagCustomFields: &CustomFieldsLoader{ + wait: wait, + maxBatch: maxBatch, + fetch: m.fetchTagCustomFields(ctx), + }, GroupByID: &GroupLoader{ wait: wait, maxBatch: maxBatch, @@ -283,6 +289,18 @@ func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.T } } +func (m Middleware) fetchTagCustomFields(ctx context.Context) func(keys []int) ([]models.CustomFieldMap, []error) { + return func(keys []int) (ret []models.CustomFieldMap, errs []error) { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { + var err error + ret, err = m.Repository.Tag.GetCustomFieldsBulk(ctx, keys) + return err + }) + + return ret, toErrorSlice(err) + } +} + func (m Middleware) fetchGroups(ctx context.Context) func(keys []int) ([]*models.Group, []error) { return func(keys []int) (ret []*models.Group, errs []error) { err := m.Repository.WithDB(ctx, func(ctx context.Context) error { diff --git a/internal/api/resolver_model_tag.go b/internal/api/resolver_model_tag.go index deae41f21..7518036b0 100644 --- a/internal/api/resolver_model_tag.go +++ b/internal/api/resolver_model_tag.go @@ -181,3 +181,16 @@ func (r *tagResolver) ChildCount(ctx context.Context, obj *models.Tag) (ret int, return ret, nil } + +func (r *tagResolver) CustomFields(ctx context.Context, obj *models.Tag) (map[string]interface{}, error) { + m, err := loaders.From(ctx).TagCustomFields.Load(obj.ID) + if err != nil { + return nil, err + } + + if m == nil { + return make(map[string]interface{}), nil + } + + return m, nil +} diff --git a/internal/api/resolver_mutation_tag.go b/internal/api/resolver_mutation_tag.go index 8fb295d40..31c7980f6 100644 --- a/internal/api/resolver_mutation_tag.go +++ b/internal/api/resolver_mutation_tag.go @@ -31,7 +31,10 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput) } // Populate a new tag from the input - newTag := models.NewTag() + newTag := models.CreateTagInput{ + Tag: &models.Tag{}, + } + *newTag.Tag = models.NewTag() newTag.Name = strings.TrimSpace(input.Name) newTag.SortName = translator.string(input.SortName) @@ -60,6 +63,8 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput) return nil, fmt.Errorf("converting child tag ids: %w", err) } + newTag.CustomFields = convertMapJSONNumbers(input.CustomFields) + // Process the base 64 encoded image string var imageData []byte if input.Image != nil { @@ -73,7 +78,7 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput) if err := r.withTxn(ctx, func(ctx context.Context) error { qb := r.repository.Tag - if err := tag.ValidateCreate(ctx, newTag, qb); err != nil { + if err := tag.ValidateCreate(ctx, *newTag.Tag, qb); err != nil { return err } @@ -137,6 +142,13 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput) return nil, fmt.Errorf("converting child tag ids: %w", err) } + if input.CustomFields != nil { + updatedTag.CustomFields = *input.CustomFields + // convert json.Numbers to int/float + updatedTag.CustomFields.Full = convertMapJSONNumbers(updatedTag.CustomFields.Full) + updatedTag.CustomFields.Partial = convertMapJSONNumbers(updatedTag.CustomFields.Partial) + } + var imageData []byte imageIncluded := translator.hasField("image") if input.Image != nil { diff --git a/internal/autotag/integration_test.go b/internal/autotag/integration_test.go index 605082b98..27cce014e 100644 --- a/internal/autotag/integration_test.go +++ b/internal/autotag/integration_test.go @@ -118,7 +118,7 @@ func createTag(ctx context.Context, qb models.TagWriter) error { Name: testName, } - err := qb.Create(ctx, &tag) + err := qb.Create(ctx, &models.CreateTagInput{Tag: &tag}) if err != nil { return err } diff --git a/internal/identify/scene.go b/internal/identify/scene.go index 789674693..b82a04301 100644 --- a/internal/identify/scene.go +++ b/internal/identify/scene.go @@ -167,7 +167,9 @@ func (g sceneRelationships) tags(ctx context.Context) ([]int, error) { } else if createMissing { newTag := t.ToTag(endpoint, nil) - err := g.tagCreator.Create(ctx, newTag) + err := g.tagCreator.Create(ctx, &models.CreateTagInput{ + Tag: newTag, + }) if err != nil { return nil, fmt.Errorf("error creating tag: %w", err) } diff --git a/internal/identify/scene_test.go b/internal/identify/scene_test.go index 0eec61c4e..862bbbff8 100644 --- a/internal/identify/scene_test.go +++ b/internal/identify/scene_test.go @@ -368,14 +368,14 @@ func Test_sceneRelationships_tags(t *testing.T) { db := mocks.NewDatabase() - db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool { - return p.Name == validName + db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateTagInput) bool { + return p.Tag.Name == validName })).Run(func(args mock.Arguments) { - t := args.Get(1).(*models.Tag) - t.ID = validStoredIDInt + t := args.Get(1).(*models.CreateTagInput) + t.Tag.ID = validStoredIDInt }).Return(nil) - db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool { - return p.Name == invalidName + db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateTagInput) bool { + return p.Tag.Name == invalidName })).Return(errors.New("error creating tag")) tr := sceneRelationships{ diff --git a/pkg/gallery/import.go b/pkg/gallery/import.go index 543d4cf48..22f3e6c44 100644 --- a/pkg/gallery/import.go +++ b/pkg/gallery/import.go @@ -249,7 +249,9 @@ func (i *Importer) createTags(ctx context.Context, names []string) ([]*models.Ta newTag := models.NewTag() newTag.Name = name - err := i.TagWriter.Create(ctx, &newTag) + err := i.TagWriter.Create(ctx, &models.CreateTagInput{ + Tag: &newTag, + }) if err != nil { return nil, err } diff --git a/pkg/gallery/import_test.go b/pkg/gallery/import_test.go index 4248f51bc..932f84d48 100644 --- a/pkg/gallery/import_test.go +++ b/pkg/gallery/import_test.go @@ -289,9 +289,9 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { } db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) - db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { - t := args.Get(1).(*models.Tag) - t.ID = existingTagID + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Run(func(args mock.Arguments) { + t := args.Get(1).(*models.CreateTagInput) + t.Tag.ID = existingTagID }).Return(nil) err := i.PreImport(testCtx) @@ -323,7 +323,7 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { } db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() - db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/group/import.go b/pkg/group/import.go index a73c3998e..d7acad47c 100644 --- a/pkg/group/import.go +++ b/pkg/group/import.go @@ -126,7 +126,9 @@ func createTags(ctx context.Context, tagWriter models.TagFinderCreator, names [] newTag := models.NewTag() newTag.Name = name - err := tagWriter.Create(ctx, &newTag) + err := tagWriter.Create(ctx, &models.CreateTagInput{ + Tag: &newTag, + }) if err != nil { return nil, err } diff --git a/pkg/group/import_test.go b/pkg/group/import_test.go index 50b8b2dd1..387ceb87e 100644 --- a/pkg/group/import_test.go +++ b/pkg/group/import_test.go @@ -212,9 +212,9 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { } db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) - db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { - t := args.Get(1).(*models.Tag) - t.ID = existingTagID + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Run(func(args mock.Arguments) { + t := args.Get(1).(*models.CreateTagInput) + t.Tag.ID = existingTagID }).Return(nil) err := i.PreImport(testCtx) @@ -247,7 +247,7 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { } db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() - db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/image/import.go b/pkg/image/import.go index 77b6d7477..c7ef7f00c 100644 --- a/pkg/image/import.go +++ b/pkg/image/import.go @@ -407,7 +407,9 @@ func createTags(ctx context.Context, tagWriter models.TagCreator, names []string newTag := models.NewTag() newTag.Name = name - err := tagWriter.Create(ctx, &newTag) + err := tagWriter.Create(ctx, &models.CreateTagInput{ + Tag: &newTag, + }) if err != nil { return nil, err } diff --git a/pkg/image/import_test.go b/pkg/image/import_test.go index 98b3972b9..5d01d4b97 100644 --- a/pkg/image/import_test.go +++ b/pkg/image/import_test.go @@ -251,9 +251,9 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { } db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) - db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { - t := args.Get(1).(*models.Tag) - t.ID = existingTagID + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Run(func(args mock.Arguments) { + t := args.Get(1).(*models.CreateTagInput) + t.Tag.ID = existingTagID }).Return(nil) err := i.PreImport(testCtx) @@ -285,7 +285,7 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { } db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() - db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/models/custom_fields.go b/pkg/models/custom_fields.go index 5c3acd18b..3212d676f 100644 --- a/pkg/models/custom_fields.go +++ b/pkg/models/custom_fields.go @@ -17,3 +17,7 @@ type CustomFieldsReader interface { GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]CustomFieldMap, error) } + +type CustomFieldsWriter interface { + SetCustomFields(ctx context.Context, id int, fields CustomFieldsInput) error +} diff --git a/pkg/models/jsonschema/studio.go b/pkg/models/jsonschema/studio.go index a3706df66..7684b4317 100644 --- a/pkg/models/jsonschema/studio.go +++ b/pkg/models/jsonschema/studio.go @@ -25,6 +25,8 @@ type Studio struct { Tags []string `json:"tags,omitempty"` IgnoreAutoTag bool `json:"ignore_auto_tag,omitempty"` + CustomFields map[string]interface{} `json:"custom_fields,omitempty"` + // deprecated - for import only URL string `json:"url,omitempty"` } diff --git a/pkg/models/jsonschema/tag.go b/pkg/models/jsonschema/tag.go index faab1bfb2..e7b16b13f 100644 --- a/pkg/models/jsonschema/tag.go +++ b/pkg/models/jsonschema/tag.go @@ -11,17 +11,18 @@ import ( ) type Tag struct { - Name string `json:"name,omitempty"` - SortName string `json:"sort_name,omitempty"` - Description string `json:"description,omitempty"` - Favorite bool `json:"favorite,omitempty"` - Aliases []string `json:"aliases,omitempty"` - Image string `json:"image,omitempty"` - Parents []string `json:"parents,omitempty"` - IgnoreAutoTag bool `json:"ignore_auto_tag,omitempty"` - StashIDs []models.StashID `json:"stash_ids,omitempty"` - CreatedAt json.JSONTime `json:"created_at,omitempty"` - UpdatedAt json.JSONTime `json:"updated_at,omitempty"` + Name string `json:"name,omitempty"` + SortName string `json:"sort_name,omitempty"` + Description string `json:"description,omitempty"` + Favorite bool `json:"favorite,omitempty"` + Aliases []string `json:"aliases,omitempty"` + Image string `json:"image,omitempty"` + Parents []string `json:"parents,omitempty"` + IgnoreAutoTag bool `json:"ignore_auto_tag,omitempty"` + StashIDs []models.StashID `json:"stash_ids,omitempty"` + CreatedAt json.JSONTime `json:"created_at,omitempty"` + UpdatedAt json.JSONTime `json:"updated_at,omitempty"` + CustomFields map[string]interface{} `json:"custom_fields,omitempty"` } func (s Tag) Filename() string { diff --git a/pkg/models/mocks/TagReaderWriter.go b/pkg/models/mocks/TagReaderWriter.go index ac6b10584..95a3b7a87 100644 --- a/pkg/models/mocks/TagReaderWriter.go +++ b/pkg/models/mocks/TagReaderWriter.go @@ -101,11 +101,11 @@ func (_m *TagReaderWriter) CountByParentTagID(ctx context.Context, parentID int) } // Create provides a mock function with given fields: ctx, newTag -func (_m *TagReaderWriter) Create(ctx context.Context, newTag *models.Tag) error { +func (_m *TagReaderWriter) Create(ctx context.Context, newTag *models.CreateTagInput) error { ret := _m.Called(ctx, newTag) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *models.Tag) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *models.CreateTagInput) error); ok { r0 = rf(ctx, newTag) } else { r0 = ret.Error(0) @@ -542,6 +542,52 @@ func (_m *TagReaderWriter) GetChildIDs(ctx context.Context, relatedID int) ([]in return r0, r1 } +// GetCustomFields provides a mock function with given fields: ctx, id +func (_m *TagReaderWriter) GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) { + ret := _m.Called(ctx, id) + + var r0 map[string]interface{} + if rf, ok := ret.Get(0).(func(context.Context, int) map[string]interface{}); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interface{}) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetCustomFieldsBulk provides a mock function with given fields: ctx, ids +func (_m *TagReaderWriter) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]models.CustomFieldMap, error) { + ret := _m.Called(ctx, ids) + + var r0 []models.CustomFieldMap + if rf, ok := ret.Get(0).(func(context.Context, []int) []models.CustomFieldMap); ok { + r0 = rf(ctx, ids) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.CustomFieldMap) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok { + r1 = rf(ctx, ids) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetImage provides a mock function with given fields: ctx, tagID func (_m *TagReaderWriter) GetImage(ctx context.Context, tagID int) ([]byte, error) { ret := _m.Called(ctx, tagID) @@ -699,12 +745,26 @@ func (_m *TagReaderWriter) QueryForAutoTag(ctx context.Context, words []string) return r0, r1 } +// SetCustomFields provides a mock function with given fields: ctx, id, fields +func (_m *TagReaderWriter) SetCustomFields(ctx context.Context, id int, fields models.CustomFieldsInput) error { + ret := _m.Called(ctx, id, fields) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int, models.CustomFieldsInput) error); ok { + r0 = rf(ctx, id, fields) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Update provides a mock function with given fields: ctx, updatedTag -func (_m *TagReaderWriter) Update(ctx context.Context, updatedTag *models.Tag) error { +func (_m *TagReaderWriter) Update(ctx context.Context, updatedTag *models.UpdateTagInput) error { ret := _m.Called(ctx, updatedTag) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *models.Tag) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *models.UpdateTagInput) error); ok { r0 = rf(ctx, updatedTag) } else { r0 = ret.Error(0) diff --git a/pkg/models/model_tag.go b/pkg/models/model_tag.go index 4cd038f7e..aee468639 100644 --- a/pkg/models/model_tag.go +++ b/pkg/models/model_tag.go @@ -29,6 +29,18 @@ func NewTag() Tag { } } +type CreateTagInput struct { + *Tag + + CustomFields map[string]interface{} `json:"custom_fields"` +} + +type UpdateTagInput struct { + *Tag + + CustomFields CustomFieldsInput `json:"custom_fields"` +} + func (s *Tag) LoadAliases(ctx context.Context, l AliasLoader) error { return s.Aliases.load(func() ([]string, error) { return l.GetAliases(ctx, s.ID) @@ -66,6 +78,8 @@ type TagPartial struct { ParentIDs *UpdateIDs ChildIDs *UpdateIDs StashIDs *UpdateStashIDs + + CustomFields CustomFieldsInput } func NewTagPartial() TagPartial { diff --git a/pkg/models/repository_tag.go b/pkg/models/repository_tag.go index a7f828f0b..ba403cf2d 100644 --- a/pkg/models/repository_tag.go +++ b/pkg/models/repository_tag.go @@ -51,12 +51,12 @@ type TagCounter interface { // TagCreator provides methods to create tags. type TagCreator interface { - Create(ctx context.Context, newTag *Tag) error + Create(ctx context.Context, newTag *CreateTagInput) error } // TagUpdater provides methods to update tags. type TagUpdater interface { - Update(ctx context.Context, updatedTag *Tag) error + Update(ctx context.Context, updatedTag *UpdateTagInput) error UpdatePartial(ctx context.Context, id int, updateTag TagPartial) (*Tag, error) UpdateAliases(ctx context.Context, tagID int, aliases []string) error UpdateImage(ctx context.Context, tagID int, image []byte) error @@ -77,6 +77,7 @@ type TagFinderCreator interface { type TagCreatorUpdater interface { TagCreator TagUpdater + CustomFieldsWriter } // TagReader provides all methods to read tags. @@ -89,6 +90,7 @@ type TagReader interface { AliasLoader TagRelationLoader StashIDLoader + CustomFieldsReader All(ctx context.Context) ([]*Tag, error) GetImage(ctx context.Context, tagID int) ([]byte, error) @@ -100,6 +102,7 @@ type TagWriter interface { TagCreator TagUpdater TagDestroyer + CustomFieldsWriter Merge(ctx context.Context, source []int, destination int) error } diff --git a/pkg/models/tag.go b/pkg/models/tag.go index 5ff2df6ad..0f39d8861 100644 --- a/pkg/models/tag.go +++ b/pkg/models/tag.go @@ -56,4 +56,7 @@ type TagFilterType struct { CreatedAt *TimestampCriterionInput `json:"created_at"` // Filter by updated at UpdatedAt *TimestampCriterionInput `json:"updated_at"` + + // Filter by custom fields + CustomFields []CustomFieldCriterionInput `json:"custom_fields"` } diff --git a/pkg/performer/import.go b/pkg/performer/import.go index 622af2b1a..a8e3f7a7a 100644 --- a/pkg/performer/import.go +++ b/pkg/performer/import.go @@ -107,7 +107,9 @@ func createTags(ctx context.Context, tagWriter models.TagFinderCreator, names [] newTag := models.NewTag() newTag.Name = name - err := tagWriter.Create(ctx, &newTag) + err := tagWriter.Create(ctx, &models.CreateTagInput{ + Tag: &newTag, + }) if err != nil { return nil, err } diff --git a/pkg/performer/import_test.go b/pkg/performer/import_test.go index 0a3f86291..455a6e7a3 100644 --- a/pkg/performer/import_test.go +++ b/pkg/performer/import_test.go @@ -111,9 +111,9 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { } db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) - db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { - t := args.Get(1).(*models.Tag) - t.ID = existingTagID + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Run(func(args mock.Arguments) { + t := args.Get(1).(*models.CreateTagInput) + t.Tag.ID = existingTagID }).Return(nil) err := i.PreImport(testCtx) @@ -146,7 +146,7 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { } db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() - db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/scene/import.go b/pkg/scene/import.go index b3f0f1ff1..58604e1a5 100644 --- a/pkg/scene/import.go +++ b/pkg/scene/import.go @@ -549,7 +549,9 @@ func createTags(ctx context.Context, tagWriter models.TagCreator, names []string newTag := models.NewTag() newTag.Name = name - err := tagWriter.Create(ctx, &newTag) + err := tagWriter.Create(ctx, &models.CreateTagInput{ + Tag: &newTag, + }) if err != nil { return nil, err } diff --git a/pkg/scene/import_test.go b/pkg/scene/import_test.go index 558b72ba2..4936ec2bb 100644 --- a/pkg/scene/import_test.go +++ b/pkg/scene/import_test.go @@ -508,9 +508,9 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { } db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) - db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { - t := args.Get(1).(*models.Tag) - t.ID = existingTagID + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Run(func(args mock.Arguments) { + t := args.Get(1).(*models.CreateTagInput) + t.Tag.ID = existingTagID }).Return(nil) err := i.PreImport(testCtx) @@ -542,7 +542,7 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { } db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() - db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/sqlite/anonymise.go b/pkg/sqlite/anonymise.go index 764f569c0..e3b7492cc 100644 --- a/pkg/sqlite/anonymise.go +++ b/pkg/sqlite/anonymise.go @@ -678,6 +678,10 @@ func (db *Anonymiser) anonymiseStudios(ctx context.Context) error { return err } + if err := db.anonymiseCustomFields(ctx, goqu.T(studiosCustomFieldsTable.GetTable()), "studio_id"); err != nil { + return err + } + return nil } @@ -873,6 +877,10 @@ func (db *Anonymiser) anonymiseTags(ctx context.Context) error { return err } + if err := db.anonymiseCustomFields(ctx, goqu.T(tagsCustomFieldsTable.GetTable()), "tag_id"); err != nil { + return err + } + return nil } diff --git a/pkg/sqlite/database.go b/pkg/sqlite/database.go index a87f6706f..51889ff20 100644 --- a/pkg/sqlite/database.go +++ b/pkg/sqlite/database.go @@ -34,7 +34,7 @@ const ( cacheSizeEnv = "STASH_SQLITE_CACHE_SIZE" ) -var appSchemaVersion uint = 76 +var appSchemaVersion uint = 77 //go:embed migrations/*.sql var migrationsBox embed.FS diff --git a/pkg/sqlite/migrations/77_tag_custom_fields.up.sql b/pkg/sqlite/migrations/77_tag_custom_fields.up.sql new file mode 100644 index 000000000..b34a5f794 --- /dev/null +++ b/pkg/sqlite/migrations/77_tag_custom_fields.up.sql @@ -0,0 +1,9 @@ +CREATE TABLE `tag_custom_fields` ( + `tag_id` integer NOT NULL, + `field` varchar(64) NOT NULL, + `value` BLOB NOT NULL, + PRIMARY KEY (`tag_id`, `field`), + foreign key(`tag_id`) references `tags`(`id`) on delete CASCADE +); + +CREATE INDEX `index_tag_custom_fields_field_value` ON `tag_custom_fields` (`field`, `value`); \ No newline at end of file diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index 361b5cb79..bdb83b1df 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -1709,6 +1709,18 @@ func tagStashID(i int) models.StashID { } } +func getTagCustomFields(index int) map[string]interface{} { + if index%5 == 0 { + return nil + } + + return map[string]interface{}{ + "string": getTagStringValue(index, "custom"), + "int": int64(index % 5), + "real": float64(index) / 10, + } +} + // createTags creates n tags with plain Name and o tags with camel cased NaMe included func createTags(ctx context.Context, tqb models.TagReaderWriter, n int, o int) error { const namePlain = "Name" @@ -1736,7 +1748,10 @@ func createTags(ctx context.Context, tqb models.TagReaderWriter, n int, o int) e }) } - err := tqb.Create(ctx, &tag) + err := tqb.Create(ctx, &models.CreateTagInput{ + Tag: &tag, + CustomFields: getTagCustomFields(i), + }) if err != nil { return fmt.Errorf("Error creating tag %v+: %s", tag, err.Error()) diff --git a/pkg/sqlite/studio_test.go b/pkg/sqlite/studio_test.go index 074c77d6f..968f43413 100644 --- a/pkg/sqlite/studio_test.go +++ b/pkg/sqlite/studio_test.go @@ -1694,6 +1694,251 @@ func TestStudioQueryFast(t *testing.T) { }) } +func studiesToIDs(i []*models.Studio) []int { + ret := make([]int, len(i)) + for i, v := range i { + ret[i] = v.ID + } + + return ret +} + +func TestStudioQueryCustomFields(t *testing.T) { + tests := []struct { + name string + filter *models.StudioFilterType + includeIdxs []int + excludeIdxs []int + wantErr bool + }{ + { + "equals", + &models.StudioFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierEquals, + Value: []any{getStudioStringValue(studioIdxWithTwoScenes, "custom")}, + }, + }, + }, + []int{studioIdxWithTwoScenes}, + nil, + false, + }, + { + "not equals", + &models.StudioFilterType{ + Name: &models.StringCriterionInput{ + Value: getStudioStringValue(studioIdxWithTwoScenes, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotEquals, + Value: []any{getStudioStringValue(studioIdxWithTwoScenes, "custom")}, + }, + }, + }, + nil, + []int{studioIdxWithTwoScenes}, + false, + }, + { + "includes", + &models.StudioFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierIncludes, + Value: []any{getStudioStringValue(studioIdxWithTwoScenes, "custom")[9:]}, + }, + }, + }, + []int{studioIdxWithTwoScenes}, + nil, + false, + }, + { + "excludes", + &models.StudioFilterType{ + Name: &models.StringCriterionInput{ + Value: getStudioStringValue(studioIdxWithTwoScenes, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierExcludes, + Value: []any{getStudioStringValue(studioIdxWithTwoScenes, "custom")[9:]}, + }, + }, + }, + nil, + []int{studioIdxWithTwoScenes}, + false, + }, + { + "regex", + &models.StudioFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierMatchesRegex, + Value: []any{".*1_custom"}, + }, + }, + }, + []int{studioIdxWithTwoScenes}, + nil, + false, + }, + { + "invalid regex", + &models.StudioFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierMatchesRegex, + Value: []any{"["}, + }, + }, + }, + nil, + nil, + true, + }, + { + "not matches regex", + &models.StudioFilterType{ + Name: &models.StringCriterionInput{ + Value: getStudioStringValue(studioIdxWithTwoScenes, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotMatchesRegex, + Value: []any{".*1_custom"}, + }, + }, + }, + nil, + []int{studioIdxWithTwoScenes}, + false, + }, + { + "invalid not matches regex", + &models.StudioFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotMatchesRegex, + Value: []any{"["}, + }, + }, + }, + nil, + nil, + true, + }, + { + "null", + &models.StudioFilterType{ + Name: &models.StringCriterionInput{ + Value: getStudioStringValue(studioIdxWithTwoScenes, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "not existing", + Modifier: models.CriterionModifierIsNull, + }, + }, + }, + []int{studioIdxWithTwoScenes}, + nil, + false, + }, + { + "not null", + &models.StudioFilterType{ + Name: &models.StringCriterionInput{ + Value: getStudioStringValue(studioIdxWithTwoScenes, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotNull, + }, + }, + }, + []int{studioIdxWithTwoScenes}, + nil, + false, + }, + { + "between", + &models.StudioFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "real", + Modifier: models.CriterionModifierBetween, + Value: []any{0.15, 0.25}, + }, + }, + }, + []int{studioIdxWithGroup}, + nil, + false, + }, + { + "not between", + &models.StudioFilterType{ + Name: &models.StringCriterionInput{ + Value: getStudioStringValue(studioIdxWithGroup, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "real", + Modifier: models.CriterionModifierNotBetween, + Value: []any{0.15, 0.25}, + }, + }, + }, + nil, + []int{studioIdxWithGroup}, + false, + }, + } + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + studios, _, err := db.Studio.Query(ctx, tt.filter, nil) + if (err != nil) != tt.wantErr { + t.Errorf("StudioStore.Query() error = %v, wantErr %v", err, tt.wantErr) + return + } + + ids := studiesToIDs(studios) + include := indexesToIDs(studioIDs, tt.includeIdxs) + exclude := indexesToIDs(studioIDs, tt.excludeIdxs) + + for _, i := range include { + assert.Contains(ids, i) + } + for _, e := range exclude { + assert.NotContains(ids, e) + } + }) + } +} + // TODO Create // TODO Update // TODO Destroy diff --git a/pkg/sqlite/tables.go b/pkg/sqlite/tables.go index bfc5199fe..f46190a30 100644 --- a/pkg/sqlite/tables.go +++ b/pkg/sqlite/tables.go @@ -49,6 +49,7 @@ var ( tagsAliasesJoinTable = goqu.T(tagAliasesTable) tagRelationsJoinTable = goqu.T(tagRelationsTable) tagsStashIDsJoinTable = goqu.T("tag_stash_ids") + tagsCustomFieldsTable = goqu.T("tag_custom_fields") ) var ( diff --git a/pkg/sqlite/tag.go b/pkg/sqlite/tag.go index b1d773290..ea18664d9 100644 --- a/pkg/sqlite/tag.go +++ b/pkg/sqlite/tag.go @@ -166,6 +166,7 @@ var ( type TagStore struct { blobJoinQueryBuilder + customFieldsStore tableMgr *table } @@ -176,6 +177,10 @@ func NewTagStore(blobStore *BlobStore) *TagStore { blobStore: blobStore, joinTable: tagTable, }, + customFieldsStore: customFieldsStore{ + table: tagsCustomFieldsTable, + fk: tagsCustomFieldsTable.Col(tagIDColumn), + }, tableMgr: tagTableMgr, } } @@ -188,9 +193,9 @@ func (qb *TagStore) selectDataset() *goqu.SelectDataset { return dialect.From(qb.table()).Select(qb.table().All()) } -func (qb *TagStore) Create(ctx context.Context, newObject *models.Tag) error { +func (qb *TagStore) Create(ctx context.Context, newObject *models.CreateTagInput) error { var r tagRow - r.fromTag(*newObject) + r.fromTag(*newObject.Tag) id, err := qb.tableMgr.insertID(ctx, r) if err != nil { @@ -221,12 +226,17 @@ func (qb *TagStore) Create(ctx context.Context, newObject *models.Tag) error { } } + const partial = false + if err := qb.setCustomFields(ctx, id, newObject.CustomFields, partial); err != nil { + return err + } + updated, err := qb.find(ctx, id) if err != nil { return fmt.Errorf("finding after create: %w", err) } - *newObject = *updated + *newObject.Tag = *updated return nil } @@ -270,12 +280,16 @@ func (qb *TagStore) UpdatePartial(ctx context.Context, id int, partial models.Ta } } + if err := qb.SetCustomFields(ctx, id, partial.CustomFields); err != nil { + return nil, err + } + return qb.find(ctx, id) } -func (qb *TagStore) Update(ctx context.Context, updatedObject *models.Tag) error { +func (qb *TagStore) Update(ctx context.Context, updatedObject *models.UpdateTagInput) error { var r tagRow - r.fromTag(*updatedObject) + r.fromTag(*updatedObject.Tag) if err := qb.tableMgr.updateByID(ctx, updatedObject.ID, r); err != nil { return err @@ -305,6 +319,10 @@ func (qb *TagStore) Update(ctx context.Context, updatedObject *models.Tag) error } } + if err := qb.SetCustomFields(ctx, updatedObject.ID, updatedObject.CustomFields); err != nil { + return err + } + return nil } diff --git a/pkg/sqlite/tag_filter.go b/pkg/sqlite/tag_filter.go index dadc351ee..2f4e79149 100644 --- a/pkg/sqlite/tag_filter.go +++ b/pkg/sqlite/tag_filter.go @@ -101,6 +101,13 @@ func (qb *tagFilterHandler) criterionHandler() criterionHandler { ×tampCriterionHandler{tagFilter.CreatedAt, "tags.created_at", nil}, ×tampCriterionHandler{tagFilter.UpdatedAt, "tags.updated_at", nil}, + &customFieldsFilterHandler{ + table: tagsCustomFieldsTable.GetTable(), + fkCol: tagIDColumn, + c: tagFilter.CustomFields, + idCol: "tags.id", + }, + &relatedFilterHandler{ relatedIDCol: "scenes_tags.scene_id", relatedRepo: sceneRepository.repository, diff --git a/pkg/sqlite/tag_test.go b/pkg/sqlite/tag_test.go index f1bac19b2..b673de3f9 100644 --- a/pkg/sqlite/tag_test.go +++ b/pkg/sqlite/tag_test.go @@ -1012,8 +1012,10 @@ func TestTagUpdateTagImage(t *testing.T) { // create tag to test against const name = "TestTagUpdateTagImage" - tag := models.Tag{ - Name: name, + tag := models.CreateTagInput{ + Tag: &models.Tag{ + Name: name, + }, } err := qb.Create(ctx, &tag) if err != nil { @@ -1032,15 +1034,17 @@ func TestTagUpdateAlias(t *testing.T) { // create tag to test against const name = "TestTagUpdateAlias" - tag := models.Tag{ - Name: name, + tag := models.CreateTagInput{ + Tag: &models.Tag{ + Name: name, + }, } err := qb.Create(ctx, &tag) if err != nil { return fmt.Errorf("Error creating tag: %s", err.Error()) } - aliases := []string{"alias1", "alias2"} + aliases := []string{"updatedAlias1", "updatedAlias2"} err = qb.UpdateAliases(ctx, tag.ID, aliases) if err != nil { return fmt.Errorf("Error updating tag aliases: %s", err.Error()) @@ -1065,8 +1069,10 @@ func TestTagStashIDs(t *testing.T) { // create tag to test against const name = "TestTagStashIDs" - tag := models.Tag{ - Name: name, + tag := models.CreateTagInput{ + Tag: &models.Tag{ + Name: name, + }, } err := qb.Create(ctx, &tag) if err != nil { @@ -1089,9 +1095,11 @@ func TestTagFindByStashID(t *testing.T) { const name = "TestTagFindByStashID" const stashID = "stashid" const endpoint = "endpoint" - tag := models.Tag{ - Name: name, - StashIDs: models.NewRelatedStashIDs([]models.StashID{{StashID: stashID, Endpoint: endpoint}}), + tag := models.CreateTagInput{ + Tag: &models.Tag{ + Name: name, + StashIDs: models.NewRelatedStashIDs([]models.StashID{{StashID: stashID, Endpoint: endpoint}}), + }, } err := qb.Create(ctx, &tag) if err != nil { @@ -1263,8 +1271,626 @@ func TestTagMerge(t *testing.T) { } } -// TODO Create -// TODO Update +func loadTagRelationships(ctx context.Context, expected models.Tag, actual *models.Tag) error { + if expected.Aliases.Loaded() { + if err := actual.LoadAliases(ctx, db.Tag); err != nil { + return err + } + } + if expected.ParentIDs.Loaded() { + if err := actual.LoadParentIDs(ctx, db.Tag); err != nil { + return err + } + } + if expected.ChildIDs.Loaded() { + if err := actual.LoadChildIDs(ctx, db.Tag); err != nil { + return err + } + } + if expected.StashIDs.Loaded() { + if err := actual.LoadStashIDs(ctx, db.Tag); err != nil { + return err + } + } + + return nil +} + +func Test_TagStore_Create(t *testing.T) { + var ( + name = "name" + sortName = "sortName" + description = "description" + favorite = true + ignoreAutoTag = true + aliases = []string{"alias1", "alias2"} + endpoint1 = "endpoint1" + endpoint2 = "endpoint2" + stashID1 = "stashid1" + stashID2 = "stashid2" + createdAt = epochTime + updatedAt = epochTime + ) + + tests := []struct { + name string + newObject models.CreateTagInput + wantErr bool + }{ + { + "full", + models.CreateTagInput{ + Tag: &models.Tag{ + Name: name, + SortName: sortName, + Description: description, + Favorite: favorite, + IgnoreAutoTag: ignoreAutoTag, + Aliases: models.NewRelatedStrings(aliases), + ParentIDs: models.NewRelatedIDs([]int{tagIDs[tagIdxWithScene]}), + ChildIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithScene]}), + StashIDs: models.NewRelatedStashIDs([]models.StashID{ + { + StashID: stashID1, + Endpoint: endpoint1, + UpdatedAt: epochTime, + }, + { + StashID: stashID2, + Endpoint: endpoint2, + UpdatedAt: epochTime, + }, + }), + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + CustomFields: testCustomFields, + }, + false, + }, + { + "invalid parent id", + models.CreateTagInput{ + Tag: &models.Tag{ + Name: name, + ParentIDs: models.NewRelatedIDs([]int{invalidID}), + }, + }, + true, + }, + { + "invalid child id", + models.CreateTagInput{ + Tag: &models.Tag{ + Name: name, + ChildIDs: models.NewRelatedIDs([]int{invalidID}), + }, + }, + true, + }, + } + + qb := db.Tag + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + p := tt.newObject + if err := qb.Create(ctx, &p); (err != nil) != tt.wantErr { + t.Errorf("TagStore.Create() error = %v, wantErr = %v", err, tt.wantErr) + } + + if tt.wantErr { + assert.Zero(p.ID) + return + } + + assert.NotZero(p.ID) + + copy := *tt.newObject.Tag + copy.ID = p.ID + + // load relationships + if err := loadTagRelationships(ctx, copy, p.Tag); err != nil { + t.Errorf("loadTagRelationships() error = %v", err) + return + } + + assert.Equal(copy, *p.Tag) + + // ensure can find the tag + found, err := qb.Find(ctx, p.ID) + if err != nil { + t.Errorf("TagStore.Find() error = %v", err) + } + + if !assert.NotNil(found) { + return + } + + // load relationships + if err := loadTagRelationships(ctx, copy, found); err != nil { + t.Errorf("loadTagRelationships() error = %v", err) + return + } + assert.Equal(copy, *found) + + // ensure custom fields are set + cf, err := qb.GetCustomFields(ctx, p.ID) + if err != nil { + t.Errorf("TagStore.GetCustomFields() error = %v", err) + return + } + + assert.Equal(tt.newObject.CustomFields, cf) + + return + }) + } +} + +func Test_TagStore_Update(t *testing.T) { + var ( + name = "name" + sortName = "sortName" + description = "description" + favorite = true + ignoreAutoTag = true + aliases = []string{"alias1", "alias2"} + endpoint1 = "endpoint1" + endpoint2 = "endpoint2" + stashID1 = "stashid1" + stashID2 = "stashid2" + createdAt = epochTime + updatedAt = epochTime + ) + + tests := []struct { + name string + updatedObject models.UpdateTagInput + wantErr bool + }{ + { + "full", + models.UpdateTagInput{ + Tag: &models.Tag{ + ID: tagIDs[tagIdxWithGallery], + Name: name, + SortName: sortName, + Description: description, + Favorite: favorite, + IgnoreAutoTag: ignoreAutoTag, + Aliases: models.NewRelatedStrings(aliases), + ParentIDs: models.NewRelatedIDs([]int{tagIDs[tagIdxWithScene]}), + ChildIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithScene]}), + StashIDs: models.NewRelatedStashIDs([]models.StashID{ + { + StashID: stashID1, + Endpoint: endpoint1, + UpdatedAt: epochTime, + }, + { + StashID: stashID2, + Endpoint: endpoint2, + UpdatedAt: epochTime, + }, + }), + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + CustomFields: models.CustomFieldsInput{ + Full: map[string]interface{}{ + "string": "updated", + "int": int64(999), + "real": 9.99, + }, + }, + }, + false, + }, + { + "set custom fields", + models.UpdateTagInput{ + Tag: &models.Tag{ + ID: tagIDs[tagIdxWithGallery], + Name: tagNames[tagIdxWithGallery], + }, + CustomFields: models.CustomFieldsInput{ + Full: testCustomFields, + }, + }, + false, + }, + { + "clear custom fields", + models.UpdateTagInput{ + Tag: &models.Tag{ + ID: tagIDs[tagIdxWithGallery], + Name: tagNames[tagIdxWithGallery], + }, + CustomFields: models.CustomFieldsInput{ + Full: map[string]interface{}{}, + }, + }, + false, + }, + { + "invalid parent id", + models.UpdateTagInput{ + Tag: &models.Tag{ + ID: tagIDs[tagIdxWithGallery], + Name: tagNames[tagIdxWithGallery], + ParentIDs: models.NewRelatedIDs([]int{invalidID}), + }, + }, + true, + }, + { + "invalid child id", + models.UpdateTagInput{ + Tag: &models.Tag{ + ID: tagIDs[tagIdxWithGallery], + Name: tagNames[tagIdxWithGallery], + ChildIDs: models.NewRelatedIDs([]int{invalidID}), + }, + }, + true, + }, + } + + qb := db.Tag + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + p := tt.updatedObject + if err := qb.Update(ctx, &p); (err != nil) != tt.wantErr { + t.Errorf("TagStore.Update() error = %v, wantErr = %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + s, err := qb.Find(ctx, tt.updatedObject.ID) + if err != nil { + t.Errorf("TagStore.Find() error = %v", err) + return + } + + // load relationships + if err := loadTagRelationships(ctx, *tt.updatedObject.Tag, s); err != nil { + t.Errorf("loadTagRelationships() error = %v", err) + return + } + + assert.Equal(*tt.updatedObject.Tag, *s) + + // ensure custom fields are correct + if tt.updatedObject.CustomFields.Full != nil { + cf, err := qb.GetCustomFields(ctx, tt.updatedObject.ID) + if err != nil { + t.Errorf("TagStore.GetCustomFields() error = %v", err) + return + } + + assert.Equal(tt.updatedObject.CustomFields.Full, cf) + } + }) + } +} + +func Test_TagStore_UpdatePartialCustomFields(t *testing.T) { + tests := []struct { + name string + id int + partial models.TagPartial + expected map[string]interface{} // nil to use the partial + }{ + { + "set custom fields", + tagIDs[tagIdxWithGallery], + models.TagPartial{ + CustomFields: models.CustomFieldsInput{ + Full: testCustomFields, + }, + }, + nil, + }, + { + "clear custom fields", + tagIDs[tagIdxWithGallery], + models.TagPartial{ + CustomFields: models.CustomFieldsInput{ + Full: map[string]interface{}{}, + }, + }, + nil, + }, + { + "partial custom fields", + tagIDs[tagIdxWithGallery], + models.TagPartial{ + CustomFields: models.CustomFieldsInput{ + Partial: map[string]interface{}{ + "string": "bbb", + "new_field": "new", + }, + }, + }, + map[string]interface{}{ + "int": int64(2), + "real": float64(1.7), + "string": "bbb", + "new_field": "new", + }, + }, + } + for _, tt := range tests { + qb := db.Tag + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + _, err := qb.UpdatePartial(ctx, tt.id, tt.partial) + if err != nil { + t.Errorf("TagStore.UpdatePartial() error = %v", err) + return + } + + // ensure custom fields are correct + cf, err := qb.GetCustomFields(ctx, tt.id) + if err != nil { + t.Errorf("TagStore.GetCustomFields() error = %v", err) + return + } + if tt.expected == nil { + assert.Equal(tt.partial.CustomFields.Full, cf) + } else { + assert.Equal(tt.expected, cf) + } + }) + } +} + +func TestTagQueryCustomFields(t *testing.T) { + tests := []struct { + name string + filter *models.TagFilterType + includeIdxs []int + excludeIdxs []int + wantErr bool + }{ + { + "equals", + &models.TagFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierEquals, + Value: []any{getTagStringValue(tagIdxWithGallery, "custom")}, + }, + }, + }, + []int{tagIdxWithGallery}, + nil, + false, + }, + { + "not equals", + &models.TagFilterType{ + Name: &models.StringCriterionInput{ + Value: getTagStringValue(tagIdxWithGallery, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotEquals, + Value: []any{getTagStringValue(tagIdxWithGallery, "custom")}, + }, + }, + }, + nil, + []int{tagIdxWithGallery}, + false, + }, + { + "includes", + &models.TagFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierIncludes, + Value: []any{getTagStringValue(tagIdxWithGallery, "custom")[9:]}, + }, + }, + }, + []int{tagIdxWithGallery}, + nil, + false, + }, + { + "excludes", + &models.TagFilterType{ + Name: &models.StringCriterionInput{ + Value: getTagStringValue(tagIdxWithGallery, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierExcludes, + Value: []any{getTagStringValue(tagIdxWithGallery, "custom")[9:]}, + }, + }, + }, + nil, + []int{tagIdxWithGallery}, + false, + }, + { + "regex", + &models.TagFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierMatchesRegex, + Value: []any{".*17_custom"}, + }, + }, + }, + []int{tagIdxWithGallery}, + nil, + false, + }, + { + "invalid regex", + &models.TagFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierMatchesRegex, + Value: []any{"["}, + }, + }, + }, + nil, + nil, + true, + }, + { + "not matches regex", + &models.TagFilterType{ + Name: &models.StringCriterionInput{ + Value: getTagStringValue(tagIdxWithGallery, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotMatchesRegex, + Value: []any{".*17_custom"}, + }, + }, + }, + nil, + []int{tagIdxWithGallery}, + false, + }, + { + "invalid not matches regex", + &models.TagFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotMatchesRegex, + Value: []any{"["}, + }, + }, + }, + nil, + nil, + true, + }, + { + "null", + &models.TagFilterType{ + Name: &models.StringCriterionInput{ + Value: getTagStringValue(tagIdxWithGallery, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "not existing", + Modifier: models.CriterionModifierIsNull, + }, + }, + }, + []int{tagIdxWithGallery}, + nil, + false, + }, + { + "not null", + &models.TagFilterType{ + Name: &models.StringCriterionInput{ + Value: getTagStringValue(tagIdxWithGallery, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotNull, + }, + }, + }, + []int{tagIdxWithGallery}, + nil, + false, + }, + { + "between", + &models.TagFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "real", + Modifier: models.CriterionModifierBetween, + Value: []any{0.15, 0.25}, + }, + }, + }, + []int{tagIdx2WithScene}, + nil, + false, + }, + { + "not between", + &models.TagFilterType{ + Name: &models.StringCriterionInput{ + Value: getTagStringValue(tagIdx2WithScene, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "real", + Modifier: models.CriterionModifierNotBetween, + Value: []any{0.15, 0.25}, + }, + }, + }, + nil, + []int{tagIdx2WithScene}, + false, + }, + } + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + tags, _, err := db.Tag.Query(ctx, tt.filter, nil) + if (err != nil) != tt.wantErr { + t.Errorf("TagStore.Query() error = %v, wantErr %v", err, tt.wantErr) + return + } + + ids := tagsToIDs(tags) + include := indexesToIDs(tagIDs, tt.includeIdxs) + exclude := indexesToIDs(tagIDs, tt.excludeIdxs) + + for _, i := range include { + assert.Contains(ids, i) + } + for _, e := range exclude { + assert.NotContains(ids, e) + } + }) + } +} + // TODO Destroy // TODO Find // TODO FindBySceneID diff --git a/pkg/studio/export.go b/pkg/studio/export.go index 1440c3cdd..c3a50668f 100644 --- a/pkg/studio/export.go +++ b/pkg/studio/export.go @@ -17,6 +17,7 @@ type FinderImageStashIDGetter interface { models.URLLoader models.StashIDLoader GetImage(ctx context.Context, studioID int) ([]byte, error) + models.CustomFieldsReader } // ToJSON converts a Studio object into its JSON equivalent. @@ -60,6 +61,12 @@ func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models } newStudioJSON.StashIDs = studio.StashIDs.List() + var err error + newStudioJSON.CustomFields, err = reader.GetCustomFields(ctx, studio.ID) + if err != nil { + return nil, fmt.Errorf("getting studio custom fields: %v", err) + } + image, err := reader.GetImage(ctx, studio.ID) if err != nil { logger.Errorf("Error getting studio image: %v", err) diff --git a/pkg/studio/export_test.go b/pkg/studio/export_test.go index c333c0ad5..e41e6f36c 100644 --- a/pkg/studio/export_test.go +++ b/pkg/studio/export_test.go @@ -18,18 +18,24 @@ const ( errImageID = 3 missingParentStudioID = 4 errStudioID = 5 + customFieldsID = 6 parentStudioID = 10 missingStudioID = 11 errParentStudioID = 12 + errCustomFieldsID = 13 ) var ( - studioName = "testStudio" - url = "url" - details = "details" - parentStudioName = "parentStudio" - autoTagIgnored = true + studioName = "testStudio" + url = "url" + details = "details" + parentStudioName = "parentStudio" + autoTagIgnored = true + emptyCustomFields = make(map[string]interface{}) + customFields = map[string]interface{}{ + "customField1": "customValue1", + } ) var studioID = 1 @@ -91,7 +97,7 @@ func createEmptyStudio(id int) models.Studio { } } -func createFullJSONStudio(parentStudio, image string, aliases []string) *jsonschema.Studio { +func createFullJSONStudio(parentStudio, image string, aliases []string, customFields map[string]interface{}) *jsonschema.Studio { return &jsonschema.Studio{ Name: studioName, URLs: []string{url}, @@ -109,6 +115,7 @@ func createFullJSONStudio(parentStudio, image string, aliases []string) *jsonsch Aliases: aliases, StashIDs: stashIDs, IgnoreAutoTag: autoTagIgnored, + CustomFields: customFields, } } @@ -120,16 +127,18 @@ func createEmptyJSONStudio() *jsonschema.Studio { UpdatedAt: json.JSONTime{ Time: updateTime, }, - Aliases: []string{}, - URLs: []string{}, - StashIDs: []models.StashID{}, + Aliases: []string{}, + URLs: []string{}, + StashIDs: []models.StashID{}, + CustomFields: emptyCustomFields, } } type testScenario struct { - input models.Studio - expected *jsonschema.Studio - err bool + input models.Studio + customFields map[string]interface{} + expected *jsonschema.Studio + err bool } var scenarios []testScenario @@ -138,30 +147,48 @@ func initTestTable() { scenarios = []testScenario{ { createFullStudio(studioID, parentStudioID), - createFullJSONStudio(parentStudioName, image, []string{"alias"}), + emptyCustomFields, + createFullJSONStudio(parentStudioName, image, []string{"alias"}, emptyCustomFields), + false, + }, + { + createFullStudio(customFieldsID, parentStudioID), + customFields, + createFullJSONStudio(parentStudioName, image, []string{"alias"}, customFields), false, }, { createEmptyStudio(noImageID), + emptyCustomFields, createEmptyJSONStudio(), false, }, { createFullStudio(errImageID, parentStudioID), - createFullJSONStudio(parentStudioName, "", []string{"alias"}), + emptyCustomFields, + createFullJSONStudio(parentStudioName, "", []string{"alias"}, emptyCustomFields), // failure to get image is not an error false, }, { createFullStudio(missingParentStudioID, missingStudioID), - createFullJSONStudio("", image, []string{"alias"}), + emptyCustomFields, + createFullJSONStudio("", image, []string{"alias"}, emptyCustomFields), false, }, { createFullStudio(errStudioID, errParentStudioID), + emptyCustomFields, nil, true, }, + { + createFullStudio(errCustomFieldsID, parentStudioID), + customFields, + nil, + // failure to get custom fields should cause an error + true, + }, } } @@ -177,6 +204,7 @@ func TestToJSON(t *testing.T) { db.Studio.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once() db.Studio.On("GetImage", testCtx, missingParentStudioID).Return(imageBytes, nil).Maybe() db.Studio.On("GetImage", testCtx, errStudioID).Return(imageBytes, nil).Maybe() + db.Studio.On("GetImage", testCtx, customFieldsID).Return(imageBytes, nil).Once() parentStudioErr := errors.New("error getting parent studio") @@ -184,6 +212,15 @@ func TestToJSON(t *testing.T) { db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil) db.Studio.On("Find", testCtx, errParentStudioID).Return(nil, parentStudioErr) + customFieldsErr := errors.New("error getting custom fields") + + db.Studio.On("GetCustomFields", testCtx, studioID).Return(emptyCustomFields, nil).Once() + db.Studio.On("GetCustomFields", testCtx, customFieldsID).Return(customFields, nil).Once() + db.Studio.On("GetCustomFields", testCtx, missingParentStudioID).Return(emptyCustomFields, nil).Once() + db.Studio.On("GetCustomFields", testCtx, noImageID).Return(emptyCustomFields, nil).Once() + db.Studio.On("GetCustomFields", testCtx, errImageID).Return(emptyCustomFields, nil).Once() + db.Studio.On("GetCustomFields", testCtx, errCustomFieldsID).Return(nil, customFieldsErr).Once() + for i, s := range scenarios { studio := s.input json, err := ToJSON(testCtx, db.Studio, &studio) diff --git a/pkg/studio/import.go b/pkg/studio/import.go index d5284ce02..d9e52100c 100644 --- a/pkg/studio/import.go +++ b/pkg/studio/import.go @@ -26,13 +26,15 @@ type Importer struct { Input jsonschema.Studio MissingRefBehaviour models.ImportMissingRefEnum - ID int - studio models.Studio - imageData []byte + ID int + studio models.Studio + customFields models.CustomFieldMap + imageData []byte } func (i *Importer) PreImport(ctx context.Context) error { i.studio = studioJSONtoStudio(i.Input) + i.customFields = i.Input.CustomFields if err := i.populateParentStudio(ctx); err != nil { return err @@ -110,7 +112,9 @@ func createTags(ctx context.Context, tagWriter models.TagFinderCreator, names [] newTag := models.NewTag() newTag.Name = name - err := tagWriter.Create(ctx, &newTag) + err := tagWriter.Create(ctx, &models.CreateTagInput{ + Tag: &newTag, + }) if err != nil { return nil, err } @@ -194,7 +198,10 @@ func (i *Importer) FindExistingID(ctx context.Context) (*int, error) { } func (i *Importer) Create(ctx context.Context) (*int, error) { - err := i.ReaderWriter.Create(ctx, &models.CreateStudioInput{Studio: &i.studio}) + err := i.ReaderWriter.Create(ctx, &models.CreateStudioInput{ + Studio: &i.studio, + CustomFields: i.customFields, + }) if err != nil { return nil, fmt.Errorf("error creating studio: %v", err) } @@ -206,7 +213,12 @@ func (i *Importer) Create(ctx context.Context) (*int, error) { func (i *Importer) Update(ctx context.Context, id int) error { studio := i.studio studio.ID = id - err := i.ReaderWriter.Update(ctx, &models.UpdateStudioInput{Studio: &studio}) + err := i.ReaderWriter.Update(ctx, &models.UpdateStudioInput{ + Studio: &studio, + CustomFields: models.CustomFieldsInput{ + Full: i.customFields, + }, + }) if err != nil { return fmt.Errorf("error updating existing studio: %v", err) } diff --git a/pkg/studio/import_test.go b/pkg/studio/import_test.go index 6648ebe0d..4eb757293 100644 --- a/pkg/studio/import_test.go +++ b/pkg/studio/import_test.go @@ -62,7 +62,7 @@ func TestImporterPreImport(t *testing.T) { assert.Nil(t, err) - i.Input = *createFullJSONStudio(studioName, image, []string{"alias"}) + i.Input = *createFullJSONStudio(studioName, image, []string{"alias"}, customFields) i.Input.ParentStudio = "" err = i.PreImport(testCtx) @@ -71,6 +71,7 @@ func TestImporterPreImport(t *testing.T) { expectedStudio := createFullStudio(0, 0) expectedStudio.ParentID = nil assert.Equal(t, expectedStudio, i.studio) + assert.Equal(t, models.CustomFieldMap(customFields), i.customFields) } func TestImporterPreImportWithTag(t *testing.T) { @@ -121,9 +122,9 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { } db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) - db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { - t := args.Get(1).(*models.Tag) - t.ID = existingTagID + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Run(func(args mock.Arguments) { + t := args.Get(1).(*models.CreateTagInput) + t.Tag.ID = existingTagID }).Return(nil) err := i.PreImport(testCtx) @@ -156,7 +157,7 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { } db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() - db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/tag/export.go b/pkg/tag/export.go index b07418667..fc7115209 100644 --- a/pkg/tag/export.go +++ b/pkg/tag/export.go @@ -16,6 +16,7 @@ type FinderAliasImageGetter interface { GetAliases(ctx context.Context, studioID int) ([]string, error) GetImage(ctx context.Context, tagID int) ([]byte, error) FindByChildTagID(ctx context.Context, childID int) ([]*models.Tag, error) + GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) models.StashIDLoader } @@ -63,6 +64,11 @@ func ToJSON(ctx context.Context, reader FinderAliasImageGetter, tag *models.Tag) newTagJSON.Parents = GetNames(parents) + newTagJSON.CustomFields, err = reader.GetCustomFields(ctx, tag.ID) + if err != nil { + return nil, fmt.Errorf("getting tag custom fields: %v", err) + } + return &newTagJSON, nil } diff --git a/pkg/tag/export_test.go b/pkg/tag/export_test.go index 84e082f30..cba2d4ebf 100644 --- a/pkg/tag/export_test.go +++ b/pkg/tag/export_test.go @@ -14,12 +14,14 @@ import ( ) const ( - tagID = 1 - noImageID = 2 - errImageID = 3 - errAliasID = 4 - withParentsID = 5 - errParentsID = 6 + tagID = iota + 1 + customFieldsID + noImageID + errImageID + errAliasID + withParentsID + errParentsID + errCustomFieldsID ) const ( @@ -32,6 +34,11 @@ var ( autoTagIgnored = true createTime = time.Date(2001, 01, 01, 0, 0, 0, 0, time.UTC) updateTime = time.Date(2002, 01, 01, 0, 0, 0, 0, time.UTC) + + emptyCustomFields = make(map[string]interface{}) + customFields = map[string]interface{}{ + "customField1": "customValue1", + } ) func createTag(id int) models.Tag { @@ -47,8 +54,8 @@ func createTag(id int) models.Tag { } } -func createJSONTag(aliases []string, image string, parents []string) *jsonschema.Tag { - return &jsonschema.Tag{ +func createJSONTag(aliases []string, image string, parents []string, withCustomFields bool) *jsonschema.Tag { + ret := &jsonschema.Tag{ Name: tagName, SortName: sortName, Favorite: true, @@ -61,15 +68,23 @@ func createJSONTag(aliases []string, image string, parents []string) *jsonschema UpdatedAt: json.JSONTime{ Time: updateTime, }, - Image: image, - Parents: parents, + Image: image, + Parents: parents, + CustomFields: emptyCustomFields, } + + if withCustomFields { + ret.CustomFields = customFields + } + + return ret } type testScenario struct { - tag models.Tag - expected *jsonschema.Tag - err bool + tag models.Tag + customFields map[string]interface{} + expected *jsonschema.Tag + err bool } var scenarios []testScenario @@ -78,32 +93,50 @@ func initTestTable() { scenarios = []testScenario{ { createTag(tagID), - createJSONTag([]string{"alias"}, image, nil), + emptyCustomFields, + createJSONTag([]string{"alias"}, image, nil, false), + false, + }, + { + createTag(customFieldsID), + customFields, + createJSONTag([]string{"alias"}, image, nil, true), false, }, { createTag(noImageID), - createJSONTag(nil, "", nil), + emptyCustomFields, + createJSONTag(nil, "", nil, false), false, }, { createTag(errImageID), - createJSONTag(nil, "", nil), + emptyCustomFields, + createJSONTag(nil, "", nil, false), // getting the image should not cause an error false, }, { createTag(errAliasID), + emptyCustomFields, nil, true, }, { createTag(withParentsID), - createJSONTag(nil, image, []string{"parent"}), + emptyCustomFields, + createJSONTag(nil, image, []string{"parent"}, false), false, }, { createTag(errParentsID), + emptyCustomFields, + nil, + true, + }, + { + createTag(errCustomFieldsID), + customFields, nil, true, }, @@ -118,32 +151,48 @@ func TestToJSON(t *testing.T) { imageErr := errors.New("error getting image") aliasErr := errors.New("error getting aliases") parentsErr := errors.New("error getting parents") + customFieldsErr := errors.New("error getting custom fields") db.Tag.On("GetAliases", testCtx, tagID).Return([]string{"alias"}, nil).Once() + db.Tag.On("GetAliases", testCtx, customFieldsID).Return([]string{"alias"}, nil).Once() db.Tag.On("GetAliases", testCtx, noImageID).Return(nil, nil).Once() db.Tag.On("GetAliases", testCtx, errImageID).Return(nil, nil).Once() db.Tag.On("GetAliases", testCtx, errAliasID).Return(nil, aliasErr).Once() db.Tag.On("GetAliases", testCtx, withParentsID).Return(nil, nil).Once() db.Tag.On("GetAliases", testCtx, errParentsID).Return(nil, nil).Once() + db.Tag.On("GetAliases", testCtx, errCustomFieldsID).Return(nil, nil).Once() db.Tag.On("GetStashIDs", testCtx, tagID).Return(nil, nil).Once() + db.Tag.On("GetStashIDs", testCtx, customFieldsID).Return(nil, nil).Once() db.Tag.On("GetStashIDs", testCtx, noImageID).Return(nil, nil).Once() db.Tag.On("GetStashIDs", testCtx, errImageID).Return(nil, nil).Once() // errAliasID test fails before GetStashIDs is called, so no mock needed db.Tag.On("GetStashIDs", testCtx, withParentsID).Return(nil, nil).Once() db.Tag.On("GetStashIDs", testCtx, errParentsID).Return(nil, nil).Once() + db.Tag.On("GetStashIDs", testCtx, errCustomFieldsID).Return(nil, nil).Once() db.Tag.On("GetImage", testCtx, tagID).Return(imageBytes, nil).Once() + db.Tag.On("GetImage", testCtx, customFieldsID).Return(imageBytes, nil).Once() db.Tag.On("GetImage", testCtx, noImageID).Return(nil, nil).Once() db.Tag.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once() db.Tag.On("GetImage", testCtx, withParentsID).Return(imageBytes, nil).Once() db.Tag.On("GetImage", testCtx, errParentsID).Return(nil, nil).Once() + db.Tag.On("GetImage", testCtx, errCustomFieldsID).Return(nil, nil).Once() db.Tag.On("FindByChildTagID", testCtx, tagID).Return(nil, nil).Once() + db.Tag.On("FindByChildTagID", testCtx, customFieldsID).Return(nil, nil).Once() db.Tag.On("FindByChildTagID", testCtx, noImageID).Return(nil, nil).Once() db.Tag.On("FindByChildTagID", testCtx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once() db.Tag.On("FindByChildTagID", testCtx, errParentsID).Return(nil, parentsErr).Once() db.Tag.On("FindByChildTagID", testCtx, errImageID).Return(nil, nil).Once() + db.Tag.On("FindByChildTagID", testCtx, errCustomFieldsID).Return(nil, nil).Once() + + db.Tag.On("GetCustomFields", testCtx, tagID).Return(emptyCustomFields, nil).Once() + db.Tag.On("GetCustomFields", testCtx, customFieldsID).Return(customFields, nil).Once() + db.Tag.On("GetCustomFields", testCtx, noImageID).Return(emptyCustomFields, nil).Once() + db.Tag.On("GetCustomFields", testCtx, errImageID).Return(emptyCustomFields, nil).Once() + db.Tag.On("GetCustomFields", testCtx, withParentsID).Return(emptyCustomFields, nil).Once() + db.Tag.On("GetCustomFields", testCtx, errCustomFieldsID).Return(nil, customFieldsErr).Once() for i, s := range scenarios { tag := s.tag diff --git a/pkg/tag/import.go b/pkg/tag/import.go index 53b741886..501dc6795 100644 --- a/pkg/tag/import.go +++ b/pkg/tag/import.go @@ -31,8 +31,9 @@ type Importer struct { Input jsonschema.Tag MissingRefBehaviour models.ImportMissingRefEnum - tag models.Tag - imageData []byte + tag models.Tag + imageData []byte + customFields map[string]interface{} } func (i *Importer) PreImport(ctx context.Context) error { @@ -55,6 +56,8 @@ func (i *Importer) PreImport(ctx context.Context) error { } } + i.customFields = i.Input.CustomFields + return nil } @@ -78,6 +81,14 @@ func (i *Importer) PostImport(ctx context.Context, id int) error { return fmt.Errorf("error setting parents: %v", err) } + if len(i.customFields) > 0 { + if err := i.ReaderWriter.SetCustomFields(ctx, id, models.CustomFieldsInput{ + Full: i.customFields, + }); err != nil { + return fmt.Errorf("error setting tag custom fields: %v", err) + } + } + return nil } @@ -101,7 +112,10 @@ func (i *Importer) FindExistingID(ctx context.Context) (*int, error) { } func (i *Importer) Create(ctx context.Context) (*int, error) { - err := i.ReaderWriter.Create(ctx, &i.tag) + err := i.ReaderWriter.Create(ctx, &models.CreateTagInput{ + Tag: &i.tag, + CustomFields: i.customFields, + }) if err != nil { return nil, fmt.Errorf("error creating tag: %v", err) } @@ -113,7 +127,12 @@ func (i *Importer) Create(ctx context.Context) (*int, error) { func (i *Importer) Update(ctx context.Context, id int) error { tag := i.tag tag.ID = id - err := i.ReaderWriter.Update(ctx, &tag) + err := i.ReaderWriter.Update(ctx, &models.UpdateTagInput{ + Tag: &tag, + CustomFields: models.CustomFieldsInput{ + Full: i.customFields, + }, + }) if err != nil { return fmt.Errorf("error updating existing tag: %v", err) } @@ -157,7 +176,9 @@ func (i *Importer) createParent(ctx context.Context, name string) (int, error) { newTag := models.NewTag() newTag.Name = name - err := i.ReaderWriter.Create(ctx, &newTag) + err := i.ReaderWriter.Create(ctx, &models.CreateTagInput{ + Tag: &newTag, + }) if err != nil { return 0, err } diff --git a/pkg/tag/import_test.go b/pkg/tag/import_test.go index b706c4937..f6eaec88a 100644 --- a/pkg/tag/import_test.go +++ b/pkg/tag/import_test.go @@ -154,14 +154,14 @@ func TestImporterPostImportParentMissing(t *testing.T) { db.Tag.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once() db.Tag.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once() - db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool { - return t.Name == "Create" + db.Tag.On("Create", testCtx, mock.MatchedBy(func(input *models.CreateTagInput) bool { + return input.Tag.Name == "Create" })).Run(func(args mock.Arguments) { - t := args.Get(1).(*models.Tag) - t.ID = 100 + input := args.Get(1).(*models.CreateTagInput) + input.Tag.ID = 100 }).Return(nil).Once() - db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool { - return t.Name == "CreateError" + db.Tag.On("Create", testCtx, mock.MatchedBy(func(input *models.CreateTagInput) bool { + return input.Tag.Name == "CreateError" })).Return(errors.New("failed creating parent")).Once() i.MissingRefBehaviour = models.ImportMissingRefEnumCreate @@ -261,11 +261,15 @@ func TestCreate(t *testing.T) { } errCreate := errors.New("Create error") - db.Tag.On("Create", testCtx, &tag).Run(func(args mock.Arguments) { - t := args.Get(1).(*models.Tag) - t.ID = tagID + db.Tag.On("Create", testCtx, mock.MatchedBy(func(input *models.CreateTagInput) bool { + return input.Tag.Name == tag.Name + })).Run(func(args mock.Arguments) { + input := args.Get(1).(*models.CreateTagInput) + input.Tag.ID = tagID }).Return(nil).Once() - db.Tag.On("Create", testCtx, &tagErr).Return(errCreate).Once() + db.Tag.On("Create", testCtx, mock.MatchedBy(func(input *models.CreateTagInput) bool { + return input.Tag.Name == tagErr.Name + })).Return(errCreate).Once() id, err := i.Create(testCtx) assert.Equal(t, tagID, *id) @@ -299,7 +303,10 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input tag.ID = tagID - db.Tag.On("Update", testCtx, &tag).Return(nil).Once() + tagInput := models.UpdateTagInput{ + Tag: &tag, + } + db.Tag.On("Update", testCtx, &tagInput).Return(nil).Once() err := i.Update(testCtx, tagID) assert.Nil(t, err) @@ -308,7 +315,10 @@ func TestUpdate(t *testing.T) { // need to set id separately tagErr.ID = errImageID - db.Tag.On("Update", testCtx, &tagErr).Return(errUpdate).Once() + errInput := models.UpdateTagInput{ + Tag: &tagErr, + } + db.Tag.On("Update", testCtx, &errInput).Return(errUpdate).Once() err = i.Update(testCtx, errImageID) assert.NotNil(t, err)