From ca5178f05ebd5f702e5c20e660a3c4c062ed0335 Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Mon, 23 Feb 2026 11:53:12 +1100 Subject: [PATCH] Backend support for Group custom fields (#6596) --- graphql/schema/types/filters.graphql | 3 + graphql/schema/types/group.graphql | 7 + internal/api/loaders/dataloaders.go | 28 +- internal/api/resolver_model_movie.go | 13 + internal/api/resolver_mutation_group.go | 57 ++-- internal/manager/repository.go | 2 +- pkg/group/create.go | 24 +- pkg/group/export.go | 60 ++-- pkg/group/export_test.go | 62 +++- pkg/group/import.go | 9 + pkg/group/import_test.go | 21 +- pkg/group/service.go | 1 + pkg/models/group.go | 2 + pkg/models/jsonschema/group.go | 2 + pkg/models/mocks/GroupReaderWriter.go | 60 ++++ pkg/models/model_group.go | 10 + pkg/models/repository_group.go | 2 + pkg/sqlite/anonymise.go | 4 + pkg/sqlite/custom_fields_test.go | 8 +- pkg/sqlite/database.go | 2 +- pkg/sqlite/gallery_test.go | 73 ++++ pkg/sqlite/group.go | 9 + pkg/sqlite/group_filter.go | 7 + pkg/sqlite/group_test.go | 312 ++++++++++++++++++ .../migrations/82_group_custom_fields.up.sql | 9 + pkg/sqlite/setup_test.go | 19 ++ pkg/sqlite/tables.go | 1 + 27 files changed, 725 insertions(+), 82 deletions(-) create mode 100644 pkg/sqlite/migrations/82_group_custom_fields.up.sql diff --git a/graphql/schema/types/filters.graphql b/graphql/schema/types/filters.graphql index d683329b6..4162f0af3 100644 --- a/graphql/schema/types/filters.graphql +++ b/graphql/schema/types/filters.graphql @@ -462,6 +462,9 @@ input GroupFilterType { scenes_filter: SceneFilterType "Filter by related studios that meet this criteria" studios_filter: StudioFilterType + + "Filter by custom fields" + custom_fields: [CustomFieldCriterionInput!] } input StudioFilterType { diff --git a/graphql/schema/types/group.graphql b/graphql/schema/types/group.graphql index a46932054..a1c878923 100644 --- a/graphql/schema/types/group.graphql +++ b/graphql/schema/types/group.graphql @@ -31,6 +31,7 @@ type Group { sub_group_count(depth: Int): Int! # Resolver scenes: [Scene!]! o_counter: Int # Resolver + custom_fields: Map! } input GroupDescriptionInput { @@ -59,6 +60,8 @@ input GroupCreateInput { front_image: String "This should be a URL or a base64 encoded data URL" back_image: String + + custom_fields: Map } input GroupUpdateInput { @@ -82,6 +85,8 @@ input GroupUpdateInput { front_image: String "This should be a URL or a base64 encoded data URL" back_image: String + + custom_fields: CustomFieldsInput } input BulkUpdateGroupDescriptionsInput { @@ -101,6 +106,8 @@ input BulkGroupUpdateInput { containing_groups: BulkUpdateGroupDescriptionsInput sub_groups: BulkUpdateGroupDescriptionsInput + + custom_fields: CustomFieldsInput } input GroupDestroyInput { diff --git a/internal/api/loaders/dataloaders.go b/internal/api/loaders/dataloaders.go index e7293ad1c..ff8a87ab0 100644 --- a/internal/api/loaders/dataloaders.go +++ b/internal/api/loaders/dataloaders.go @@ -64,11 +64,12 @@ type Loaders struct { StudioByID *StudioLoader StudioCustomFields *CustomFieldsLoader - TagByID *TagLoader - TagCustomFields *CustomFieldsLoader - GroupByID *GroupLoader - FileByID *FileLoader - FolderByID *FolderLoader + TagByID *TagLoader + TagCustomFields *CustomFieldsLoader + GroupByID *GroupLoader + GroupCustomFields *CustomFieldsLoader + FileByID *FileLoader + FolderByID *FolderLoader } type Middleware struct { @@ -139,6 +140,11 @@ func (m Middleware) Middleware(next http.Handler) http.Handler { maxBatch: maxBatch, fetch: m.fetchGroups(ctx), }, + GroupCustomFields: &CustomFieldsLoader{ + wait: wait, + maxBatch: maxBatch, + fetch: m.fetchGroupCustomFields(ctx), + }, FileByID: &FileLoader{ wait: wait, maxBatch: maxBatch, @@ -325,6 +331,18 @@ func (m Middleware) fetchTagCustomFields(ctx context.Context) func(keys []int) ( } } +func (m Middleware) fetchGroupCustomFields(ctx context.Context) func(keys []int) ([]models.CustomFieldMap, []error) { + return func(keys []int) (ret []models.CustomFieldMap, errs []error) { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { + var err error + ret, err = m.Repository.Group.GetCustomFieldsBulk(ctx, keys) + return err + }) + + return ret, toErrorSlice(err) + } +} + func (m Middleware) fetchGalleryCustomFields(ctx context.Context) func(keys []int) ([]models.CustomFieldMap, []error) { return func(keys []int) (ret []models.CustomFieldMap, errs []error) { err := m.Repository.WithDB(ctx, func(ctx context.Context) error { diff --git a/internal/api/resolver_model_movie.go b/internal/api/resolver_model_movie.go index 317123c6e..287d5d51a 100644 --- a/internal/api/resolver_model_movie.go +++ b/internal/api/resolver_model_movie.go @@ -215,3 +215,16 @@ func (r *groupResolver) OCounter(ctx context.Context, obj *models.Group) (ret *i } return &count, nil } + +func (r *groupResolver) CustomFields(ctx context.Context, obj *models.Group) (map[string]interface{}, error) { + m, err := loaders.From(ctx).GroupCustomFields.Load(obj.ID) + if err != nil { + return nil, err + } + + if m == nil { + return make(map[string]interface{}), nil + } + + return m, nil +} diff --git a/internal/api/resolver_mutation_group.go b/internal/api/resolver_mutation_group.go index 14dc817b9..dff5a6c1e 100644 --- a/internal/api/resolver_mutation_group.go +++ b/internal/api/resolver_mutation_group.go @@ -14,13 +14,17 @@ import ( "github.com/stashapp/stash/pkg/utils" ) -func groupFromGroupCreateInput(ctx context.Context, input GroupCreateInput) (*models.Group, error) { +func groupFromGroupCreateInput(ctx context.Context, input GroupCreateInput) (*models.CreateGroupInput, error) { translator := changesetTranslator{ inputMap: getUpdateInputMap(ctx), } // Populate a new group from the input - newGroup := models.NewGroup() + newGroupInput := &models.CreateGroupInput{ + Group: &models.Group{}, + } + *newGroupInput.Group = models.NewGroup() + newGroup := newGroupInput.Group newGroup.Name = strings.TrimSpace(input.Name) newGroup.Aliases = translator.string(input.Aliases) @@ -59,28 +63,19 @@ func groupFromGroupCreateInput(ctx context.Context, input GroupCreateInput) (*mo newGroup.URLs = models.NewRelatedStrings(stringslice.TrimSpace(input.Urls)) } - return &newGroup, nil -} - -func (r *mutationResolver) GroupCreate(ctx context.Context, input GroupCreateInput) (*models.Group, error) { - newGroup, err := groupFromGroupCreateInput(ctx, input) - if err != nil { - return nil, err - } + newGroupInput.CustomFields = convertMapJSONNumbers(input.CustomFields) // Process the base 64 encoded image string - var frontimageData []byte if input.FrontImage != nil { - frontimageData, err = utils.ProcessImageInput(ctx, *input.FrontImage) + newGroupInput.FrontImageData, err = utils.ProcessImageInput(ctx, *input.FrontImage) if err != nil { return nil, fmt.Errorf("processing front image: %w", err) } } // Process the base 64 encoded image string - var backimageData []byte if input.BackImage != nil { - backimageData, err = utils.ProcessImageInput(ctx, *input.BackImage) + newGroupInput.BackImageData, err = utils.ProcessImageInput(ctx, *input.BackImage) if err != nil { return nil, fmt.Errorf("processing back image: %w", err) } @@ -88,13 +83,22 @@ func (r *mutationResolver) GroupCreate(ctx context.Context, input GroupCreateInp // HACK: if back image is being set, set the front image to the default. // This is because we can't have a null front image with a non-null back image. - if len(frontimageData) == 0 && len(backimageData) != 0 { - frontimageData = static.ReadAll(static.DefaultGroupImage) + if len(newGroupInput.FrontImageData) == 0 && len(newGroupInput.BackImageData) != 0 { + newGroupInput.FrontImageData = static.ReadAll(static.DefaultGroupImage) + } + + return newGroupInput, nil +} + +func (r *mutationResolver) GroupCreate(ctx context.Context, input GroupCreateInput) (*models.Group, error) { + createGroupInput, err := groupFromGroupCreateInput(ctx, input) + if err != nil { + return nil, err } // Start the transaction and save the group if err := r.withTxn(ctx, func(ctx context.Context) error { - if err = r.groupService.Create(ctx, newGroup, frontimageData, backimageData); err != nil { + if err = r.groupService.Create(ctx, createGroupInput); err != nil { return err } @@ -104,9 +108,9 @@ func (r *mutationResolver) GroupCreate(ctx context.Context, input GroupCreateInp } // for backwards compatibility - run both movie and group hooks - r.hookExecutor.ExecutePostHooks(ctx, newGroup.ID, hook.GroupCreatePost, input, nil) - r.hookExecutor.ExecutePostHooks(ctx, newGroup.ID, hook.MovieCreatePost, input, nil) - return r.getGroup(ctx, newGroup.ID) + r.hookExecutor.ExecutePostHooks(ctx, createGroupInput.Group.ID, hook.GroupCreatePost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, createGroupInput.Group.ID, hook.MovieCreatePost, input, nil) + return r.getGroup(ctx, createGroupInput.Group.ID) } func groupPartialFromGroupUpdateInput(translator changesetTranslator, input GroupUpdateInput) (ret models.GroupPartial, err error) { @@ -150,6 +154,12 @@ func groupPartialFromGroupUpdateInput(translator changesetTranslator, input Grou } updatedGroup.URLs = translator.updateStrings(input.Urls, "urls") + if input.CustomFields != nil { + updatedGroup.CustomFields = *input.CustomFields + // convert json.Numbers to int/float + updatedGroup.CustomFields.Full = convertMapJSONNumbers(updatedGroup.CustomFields.Full) + updatedGroup.CustomFields.Partial = convertMapJSONNumbers(updatedGroup.CustomFields.Partial) + } return updatedGroup, nil } @@ -246,6 +256,13 @@ func groupPartialFromBulkGroupUpdateInput(translator changesetTranslator, input updatedGroup.URLs = translator.optionalURLsBulk(input.Urls, nil) + if input.CustomFields != nil { + updatedGroup.CustomFields = *input.CustomFields + // convert json.Numbers to int/float + updatedGroup.CustomFields.Full = convertMapJSONNumbers(updatedGroup.CustomFields.Full) + updatedGroup.CustomFields.Partial = convertMapJSONNumbers(updatedGroup.CustomFields.Partial) + } + return updatedGroup, nil } diff --git a/internal/manager/repository.go b/internal/manager/repository.go index afbf0b963..65514ed1d 100644 --- a/internal/manager/repository.go +++ b/internal/manager/repository.go @@ -39,7 +39,7 @@ type GalleryService interface { } type GroupService interface { - Create(ctx context.Context, group *models.Group, frontimageData []byte, backimageData []byte) error + Create(ctx context.Context, input *models.CreateGroupInput) error UpdatePartial(ctx context.Context, id int, updatedGroup models.GroupPartial, frontImage group.ImageInput, backImage group.ImageInput) (*models.Group, error) AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error diff --git a/pkg/group/create.go b/pkg/group/create.go index 56d6b7a4e..9cc578b23 100644 --- a/pkg/group/create.go +++ b/pkg/group/create.go @@ -12,27 +12,37 @@ var ( ErrHierarchyLoop = errors.New("a group cannot be contained by one of its subgroups") ) -func (s *Service) Create(ctx context.Context, group *models.Group, frontimageData []byte, backimageData []byte) error { +func (s *Service) Create(ctx context.Context, input *models.CreateGroupInput) error { r := s.Repository + group := input.Group if err := s.validateCreate(ctx, group); err != nil { return err } - err := r.Create(ctx, group) + err := r.Create(ctx, input.Group) if err != nil { return err } - // update image table - if len(frontimageData) > 0 { - if err := r.UpdateFrontImage(ctx, group.ID, frontimageData); err != nil { + // set custom fields + if len(input.CustomFields) > 0 { + if err := r.SetCustomFields(ctx, group.ID, models.CustomFieldsInput{ + Full: input.CustomFields, + }); err != nil { return err } } - if len(backimageData) > 0 { - if err := r.UpdateBackImage(ctx, group.ID, backimageData); err != nil { + // update image table + if len(input.FrontImageData) > 0 { + if err := r.UpdateFrontImage(ctx, group.ID, input.FrontImageData); err != nil { + return err + } + } + + if len(input.BackImageData) > 0 { + if err := r.UpdateBackImage(ctx, group.ID, input.BackImageData); err != nil { return err } } diff --git a/pkg/group/export.go b/pkg/group/export.go index 418ce7bed..0a56fbdbb 100644 --- a/pkg/group/export.go +++ b/pkg/group/export.go @@ -11,61 +11,67 @@ import ( "github.com/stashapp/stash/pkg/utils" ) -type ImageGetter interface { - GetFrontImage(ctx context.Context, movieID int) ([]byte, error) - GetBackImage(ctx context.Context, movieID int) ([]byte, error) +type GroupExportReader interface { + GetFrontImage(ctx context.Context, groupID int) ([]byte, error) + GetBackImage(ctx context.Context, groupID int) ([]byte, error) + GetCustomFields(ctx context.Context, groupID int) (map[string]interface{}, error) } -// ToJSON converts a Movie into its JSON equivalent. -func ToJSON(ctx context.Context, reader ImageGetter, studioReader models.StudioGetter, movie *models.Group) (*jsonschema.Group, error) { - newMovieJSON := jsonschema.Group{ - Name: movie.Name, - Aliases: movie.Aliases, - Director: movie.Director, - Synopsis: movie.Synopsis, - URLs: movie.URLs.List(), - CreatedAt: json.JSONTime{Time: movie.CreatedAt}, - UpdatedAt: json.JSONTime{Time: movie.UpdatedAt}, +// ToJSON converts a Group into its JSON equivalent. +func ToJSON(ctx context.Context, reader GroupExportReader, studioReader models.StudioGetter, group *models.Group) (*jsonschema.Group, error) { + newGroupJSON := jsonschema.Group{ + Name: group.Name, + Aliases: group.Aliases, + Director: group.Director, + Synopsis: group.Synopsis, + URLs: group.URLs.List(), + CreatedAt: json.JSONTime{Time: group.CreatedAt}, + UpdatedAt: json.JSONTime{Time: group.UpdatedAt}, } - if movie.Date != nil { - newMovieJSON.Date = movie.Date.String() + if group.Date != nil { + newGroupJSON.Date = group.Date.String() } - if movie.Rating != nil { - newMovieJSON.Rating = *movie.Rating + if group.Rating != nil { + newGroupJSON.Rating = *group.Rating } - if movie.Duration != nil { - newMovieJSON.Duration = *movie.Duration + if group.Duration != nil { + newGroupJSON.Duration = *group.Duration } - if movie.StudioID != nil { - studio, err := studioReader.Find(ctx, *movie.StudioID) + if group.StudioID != nil { + studio, err := studioReader.Find(ctx, *group.StudioID) if err != nil { return nil, fmt.Errorf("error getting movie studio: %v", err) } if studio != nil { - newMovieJSON.Studio = studio.Name + newGroupJSON.Studio = studio.Name } } - frontImage, err := reader.GetFrontImage(ctx, movie.ID) + frontImage, err := reader.GetFrontImage(ctx, group.ID) if err != nil { logger.Errorf("Error getting movie front image: %v", err) } if len(frontImage) > 0 { - newMovieJSON.FrontImage = utils.GetBase64StringFromData(frontImage) + newGroupJSON.FrontImage = utils.GetBase64StringFromData(frontImage) } - backImage, err := reader.GetBackImage(ctx, movie.ID) + backImage, err := reader.GetBackImage(ctx, group.ID) if err != nil { logger.Errorf("Error getting movie back image: %v", err) } if len(backImage) > 0 { - newMovieJSON.BackImage = utils.GetBase64StringFromData(backImage) + newGroupJSON.BackImage = utils.GetBase64StringFromData(backImage) } - return &newMovieJSON, nil + newGroupJSON.CustomFields, err = reader.GetCustomFields(ctx, group.ID) + if err != nil { + return nil, fmt.Errorf("getting group custom fields: %v", err) + } + + return &newGroupJSON, nil } diff --git a/pkg/group/export_test.go b/pkg/group/export_test.go index 5f8d9f7dc..bff50de5e 100644 --- a/pkg/group/export_test.go +++ b/pkg/group/export_test.go @@ -8,24 +8,26 @@ import ( "github.com/stashapp/stash/pkg/models/jsonschema" "github.com/stashapp/stash/pkg/models/mocks" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "testing" "time" ) const ( - movieID = 1 - emptyID = 2 - errFrontImageID = 3 - errBackImageID = 4 - errStudioMovieID = 5 - missingStudioMovieID = 6 + movieID = iota + 1 + emptyID + errFrontImageID + errBackImageID + errStudioMovieID + missingStudioMovieID + errCustomFieldsID ) const ( - studioID = 1 - missingStudioID = 2 - errStudioID = 3 + studioID = iota + 1 + missingStudioID + errStudioID ) const movieName = "testMovie" @@ -51,6 +53,11 @@ const ( var ( frontImageBytes = []byte("frontImageBytes") backImageBytes = []byte("backImageBytes") + + emptyCustomFields = make(map[string]interface{}) + customFields = map[string]interface{}{ + "customField1": "customValue1", + } ) var movieStudio models.Studio = models.Studio{ @@ -88,7 +95,7 @@ func createEmptyMovie(id int) models.Group { } } -func createFullJSONMovie(studio, frontImage, backImage string) *jsonschema.Group { +func createFullJSONMovie(studio, frontImage, backImage string, customFields map[string]interface{}) *jsonschema.Group { return &jsonschema.Group{ Name: movieName, Aliases: movieAliases, @@ -107,6 +114,7 @@ func createFullJSONMovie(studio, frontImage, backImage string) *jsonschema.Group UpdatedAt: json.JSONTime{ Time: updateTime, }, + CustomFields: customFields, } } @@ -119,13 +127,15 @@ func createEmptyJSONMovie() *jsonschema.Group { UpdatedAt: json.JSONTime{ Time: updateTime, }, + CustomFields: emptyCustomFields, } } type testScenario struct { - movie models.Group - expected *jsonschema.Group - err bool + movie models.Group + customFields map[string]interface{} + expected *jsonschema.Group + err bool } var scenarios []testScenario @@ -134,36 +144,48 @@ func initTestTable() { scenarios = []testScenario{ { createFullMovie(movieID, studioID), - createFullJSONMovie(studioName, frontImage, backImage), + customFields, + createFullJSONMovie(studioName, frontImage, backImage, customFields), false, }, { createEmptyMovie(emptyID), + emptyCustomFields, createEmptyJSONMovie(), false, }, { createFullMovie(errFrontImageID, studioID), - createFullJSONMovie(studioName, "", backImage), + emptyCustomFields, + createFullJSONMovie(studioName, "", backImage, emptyCustomFields), // failure to get front image should not cause error false, }, { createFullMovie(errBackImageID, studioID), - createFullJSONMovie(studioName, frontImage, ""), + emptyCustomFields, + createFullJSONMovie(studioName, frontImage, "", emptyCustomFields), // failure to get back image should not cause error false, }, { createFullMovie(errStudioMovieID, errStudioID), + emptyCustomFields, nil, true, }, { createFullMovie(missingStudioMovieID, missingStudioID), - createFullJSONMovie("", frontImage, backImage), + emptyCustomFields, + createFullJSONMovie("", frontImage, backImage, emptyCustomFields), false, }, + { + createFullMovie(errCustomFieldsID, studioID), + customFields, + nil, + true, + }, } } @@ -179,6 +201,7 @@ func TestToJSON(t *testing.T) { db.Group.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe() db.Group.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once() db.Group.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once() + db.Group.On("GetFrontImage", testCtx, errCustomFieldsID).Return(nil, nil).Once() db.Group.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once() db.Group.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once() @@ -186,6 +209,11 @@ func TestToJSON(t *testing.T) { db.Group.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once() db.Group.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe() db.Group.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe() + db.Group.On("GetBackImage", testCtx, errCustomFieldsID).Return(nil, nil).Once() + + db.Group.On("GetCustomFields", testCtx, movieID).Return(customFields, nil).Once() + db.Group.On("GetCustomFields", testCtx, errCustomFieldsID).Return(nil, errors.New("error getting custom fields")).Once() + db.Group.On("GetCustomFields", testCtx, mock.Anything).Return(emptyCustomFields, nil).Times(4) studioErr := errors.New("error getting studio") diff --git a/pkg/group/import.go b/pkg/group/import.go index d7acad47c..1a332bac2 100644 --- a/pkg/group/import.go +++ b/pkg/group/import.go @@ -14,6 +14,7 @@ import ( type ImporterReaderWriter interface { models.GroupCreatorUpdater + models.CustomFieldsWriter FindByName(ctx context.Context, name string, nocase bool) (*models.Group, error) } @@ -233,6 +234,14 @@ func (i *Importer) PostImport(ctx context.Context, id int) error { } } + if len(i.Input.CustomFields) > 0 { + if err := i.ReaderWriter.SetCustomFields(ctx, id, models.CustomFieldsInput{ + Full: i.Input.CustomFields, + }); err != nil { + return fmt.Errorf("error setting custom fields: %v", err) + } + } + if len(i.frontImageData) > 0 { if err := i.ReaderWriter.UpdateFrontImage(ctx, id, i.frontImageData); err != nil { return fmt.Errorf("error setting group front image: %v", err) diff --git a/pkg/group/import_test.go b/pkg/group/import_test.go index 387ceb87e..006c91327 100644 --- a/pkg/group/import_test.go +++ b/pkg/group/import_test.go @@ -259,17 +259,29 @@ func TestImporterPostImport(t *testing.T) { db := mocks.NewDatabase() i := Importer{ - ReaderWriter: db.Group, - StudioWriter: db.Studio, + ReaderWriter: db.Group, + StudioWriter: db.Studio, + Input: jsonschema.Group{ + CustomFields: customFields, + }, frontImageData: frontImageBytes, backImageData: backImageBytes, } updateMovieImageErr := errors.New("UpdateImages error") + customFieldsErr := errors.New("SetCustomFields error") + + customFieldsInput := models.CustomFieldsInput{ + Full: customFields, + } db.Group.On("UpdateFrontImage", testCtx, movieID, frontImageBytes).Return(nil).Once() - db.Group.On("UpdateBackImage", testCtx, movieID, backImageBytes).Return(nil).Once() db.Group.On("UpdateFrontImage", testCtx, errImageID, frontImageBytes).Return(updateMovieImageErr).Once() + db.Group.On("UpdateBackImage", testCtx, movieID, backImageBytes).Return(nil).Once() + + db.Group.On("SetCustomFields", testCtx, movieID, customFieldsInput).Return(nil).Once() + db.Group.On("SetCustomFields", testCtx, errImageID, customFieldsInput).Return(nil).Once() + db.Group.On("SetCustomFields", testCtx, errCustomFieldsID, customFieldsInput).Return(customFieldsErr).Once() err := i.PostImport(testCtx, movieID) assert.Nil(t, err) @@ -277,6 +289,9 @@ func TestImporterPostImport(t *testing.T) { err = i.PostImport(testCtx, errImageID) assert.NotNil(t, err) + err = i.PostImport(testCtx, errCustomFieldsID) + assert.NotNil(t, err) + db.AssertExpectations(t) } diff --git a/pkg/group/service.go b/pkg/group/service.go index ff6e03541..37094665a 100644 --- a/pkg/group/service.go +++ b/pkg/group/service.go @@ -10,6 +10,7 @@ type CreatorUpdater interface { models.GroupGetter models.GroupCreator models.GroupUpdater + models.CustomFieldsWriter models.ContainingGroupLoader models.SubGroupLoader diff --git a/pkg/models/group.go b/pkg/models/group.go index ec550eea8..396384b51 100644 --- a/pkg/models/group.go +++ b/pkg/models/group.go @@ -43,4 +43,6 @@ type GroupFilterType struct { CreatedAt *TimestampCriterionInput `json:"created_at"` // Filter by updated at UpdatedAt *TimestampCriterionInput `json:"updated_at"` + // Filter by custom fields + CustomFields []CustomFieldCriterionInput `json:"custom_fields"` } diff --git a/pkg/models/jsonschema/group.go b/pkg/models/jsonschema/group.go index b284dab6e..357ac70bc 100644 --- a/pkg/models/jsonschema/group.go +++ b/pkg/models/jsonschema/group.go @@ -33,6 +33,8 @@ type Group struct { CreatedAt json.JSONTime `json:"created_at,omitempty"` UpdatedAt json.JSONTime `json:"updated_at,omitempty"` + CustomFields map[string]interface{} `json:"custom_fields,omitempty"` + // deprecated - for import only URL string `json:"url,omitempty"` } diff --git a/pkg/models/mocks/GroupReaderWriter.go b/pkg/models/mocks/GroupReaderWriter.go index dc745d094..ac9e513f4 100644 --- a/pkg/models/mocks/GroupReaderWriter.go +++ b/pkg/models/mocks/GroupReaderWriter.go @@ -312,6 +312,52 @@ func (_m *GroupReaderWriter) GetContainingGroupDescriptions(ctx context.Context, return r0, r1 } +// GetCustomFields provides a mock function with given fields: ctx, id +func (_m *GroupReaderWriter) GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) { + ret := _m.Called(ctx, id) + + var r0 map[string]interface{} + if rf, ok := ret.Get(0).(func(context.Context, int) map[string]interface{}); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interface{}) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetCustomFieldsBulk provides a mock function with given fields: ctx, ids +func (_m *GroupReaderWriter) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]models.CustomFieldMap, error) { + ret := _m.Called(ctx, ids) + + var r0 []models.CustomFieldMap + if rf, ok := ret.Get(0).(func(context.Context, []int) []models.CustomFieldMap); ok { + r0 = rf(ctx, ids) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.CustomFieldMap) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok { + r1 = rf(ctx, ids) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetFrontImage provides a mock function with given fields: ctx, groupID func (_m *GroupReaderWriter) GetFrontImage(ctx context.Context, groupID int) ([]byte, error) { ret := _m.Called(ctx, groupID) @@ -497,6 +543,20 @@ func (_m *GroupReaderWriter) QueryCount(ctx context.Context, groupFilter *models return r0, r1 } +// SetCustomFields provides a mock function with given fields: ctx, id, fields +func (_m *GroupReaderWriter) SetCustomFields(ctx context.Context, id int, fields models.CustomFieldsInput) error { + ret := _m.Called(ctx, id, fields) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int, models.CustomFieldsInput) error); ok { + r0 = rf(ctx, id, fields) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Update provides a mock function with given fields: ctx, updatedGroup func (_m *GroupReaderWriter) Update(ctx context.Context, updatedGroup *models.Group) error { ret := _m.Called(ctx, updatedGroup) diff --git a/pkg/models/model_group.go b/pkg/models/model_group.go index 82c71996a..5bfb42c44 100644 --- a/pkg/models/model_group.go +++ b/pkg/models/model_group.go @@ -34,6 +34,14 @@ func NewGroup() Group { } } +type CreateGroupInput struct { + *Group + + CustomFields map[string]interface{} `json:"custom_fields"` + FrontImageData []byte + BackImageData []byte +} + func (m *Group) LoadURLs(ctx context.Context, l URLLoader) error { return m.URLs.load(func() ([]string, error) { return l.GetURLs(ctx, m.ID) @@ -74,6 +82,8 @@ type GroupPartial struct { SubGroups *UpdateGroupDescriptions CreatedAt OptionalTime UpdatedAt OptionalTime + + CustomFields CustomFieldsInput } func NewGroupPartial() GroupPartial { diff --git a/pkg/models/repository_group.go b/pkg/models/repository_group.go index 704390d77..d7f74de64 100644 --- a/pkg/models/repository_group.go +++ b/pkg/models/repository_group.go @@ -68,6 +68,7 @@ type GroupReader interface { TagIDLoader ContainingGroupLoader SubGroupLoader + CustomFieldsReader All(ctx context.Context) ([]*Group, error) GetFrontImage(ctx context.Context, groupID int) ([]byte, error) @@ -81,6 +82,7 @@ type GroupWriter interface { GroupCreator GroupUpdater GroupDestroyer + CustomFieldsWriter } // GroupReaderWriter provides all group methods. diff --git a/pkg/sqlite/anonymise.go b/pkg/sqlite/anonymise.go index 6a5cd4da5..ace306169 100644 --- a/pkg/sqlite/anonymise.go +++ b/pkg/sqlite/anonymise.go @@ -964,6 +964,10 @@ func (db *Anonymiser) anonymiseGroups(ctx context.Context) error { return err } + if err := db.anonymiseCustomFields(ctx, goqu.T(groupsCustomFieldsTable.GetTable()), "group_id"); err != nil { + return err + } + return nil } diff --git a/pkg/sqlite/custom_fields_test.go b/pkg/sqlite/custom_fields_test.go index ae4a276f7..2f7ecd7dc 100644 --- a/pkg/sqlite/custom_fields_test.go +++ b/pkg/sqlite/custom_fields_test.go @@ -242,7 +242,13 @@ func TestSceneSetCustomFields(t *testing.T) { } func TestGallerySetCustomFields(t *testing.T) { - galleryIdx := galleryIdxWithScene + galleryIdx := galleryIdxWithChapters testSetCustomFields(t, "Gallery", db.Gallery, galleryIDs[galleryIdx], getGalleryCustomFields(galleryIdx)) } + +func TestGroupSetCustomFields(t *testing.T) { + groupIdx := groupIdxWithScene + + testSetCustomFields(t, "Group", db.Group, groupIDs[groupIdx], getGroupCustomFields(groupIdx)) +} diff --git a/pkg/sqlite/database.go b/pkg/sqlite/database.go index d95832836..000b91c4d 100644 --- a/pkg/sqlite/database.go +++ b/pkg/sqlite/database.go @@ -34,7 +34,7 @@ const ( cacheSizeEnv = "STASH_SQLITE_CACHE_SIZE" ) -var appSchemaVersion uint = 81 +var appSchemaVersion uint = 82 //go:embed migrations/*.sql var migrationsBox embed.FS diff --git a/pkg/sqlite/gallery_test.go b/pkg/sqlite/gallery_test.go index 4156f129c..9bd0da47f 100644 --- a/pkg/sqlite/gallery_test.go +++ b/pkg/sqlite/gallery_test.go @@ -831,6 +831,79 @@ func Test_galleryQueryBuilder_UpdatePartialRelationships(t *testing.T) { } } +func Test_GalleryStore_UpdatePartialCustomFields(t *testing.T) { + tests := []struct { + name string + id int + partial models.GalleryPartial + expected map[string]interface{} // nil to use the partial + }{ + { + "set custom fields", + galleryIDs[galleryIdx1WithImage], + models.GalleryPartial{ + CustomFields: models.CustomFieldsInput{ + Full: testCustomFields, + }, + }, + nil, + }, + { + "clear custom fields", + galleryIDs[galleryIdx1WithImage], + models.GalleryPartial{ + CustomFields: models.CustomFieldsInput{ + Full: map[string]interface{}{}, + }, + }, + nil, + }, + { + "partial custom fields", + galleryIDs[galleryIdxWithTwoTags], + models.GalleryPartial{ + CustomFields: models.CustomFieldsInput{ + Partial: map[string]interface{}{ + "string": "bbb", + "new_field": "new", + }, + }, + }, + map[string]interface{}{ + "int": int64(2), + "real": 1.2, + "string": "bbb", + "new_field": "new", + }, + }, + } + for _, tt := range tests { + qb := db.Gallery + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + _, err := qb.UpdatePartial(ctx, tt.id, tt.partial) + if err != nil { + t.Errorf("GalleryStore.UpdatePartial() error = %v", err) + return + } + + // ensure custom fields are correct + cf, err := qb.GetCustomFields(ctx, tt.id) + if err != nil { + t.Errorf("GalleryStore.GetCustomFields() error = %v", err) + return + } + if tt.expected == nil { + assert.Equal(tt.partial.CustomFields.Full, cf) + } else { + assert.Equal(tt.expected, cf) + } + }) + } +} + func Test_galleryQueryBuilder_Destroy(t *testing.T) { tests := []struct { name string diff --git a/pkg/sqlite/group.go b/pkg/sqlite/group.go index b216335b8..13a6905a5 100644 --- a/pkg/sqlite/group.go +++ b/pkg/sqlite/group.go @@ -131,6 +131,7 @@ var ( type GroupStore struct { blobJoinQueryBuilder + customFieldsStore tagRelationshipStore groupRelationshipStore @@ -143,6 +144,10 @@ func NewGroupStore(blobStore *BlobStore) *GroupStore { blobStore: blobStore, joinTable: groupTable, }, + customFieldsStore: customFieldsStore{ + table: groupsCustomFieldsTable, + fk: groupsCustomFieldsTable.Col(groupIDColumn), + }, tagRelationshipStore: tagRelationshipStore{ idRelationshipStore: idRelationshipStore{ joinTable: groupsTagsTableMgr, @@ -235,6 +240,10 @@ func (qb *GroupStore) UpdatePartial(ctx context.Context, id int, partial models. return nil, err } + if err := qb.SetCustomFields(ctx, id, partial.CustomFields); err != nil { + return nil, err + } + return qb.find(ctx, id) } diff --git a/pkg/sqlite/group_filter.go b/pkg/sqlite/group_filter.go index f81783374..4f3f7b41a 100644 --- a/pkg/sqlite/group_filter.go +++ b/pkg/sqlite/group_filter.go @@ -84,6 +84,13 @@ func (qb *groupFilterHandler) criterionHandler() criterionHandler { ×tampCriterionHandler{groupFilter.CreatedAt, "groups.created_at", nil}, ×tampCriterionHandler{groupFilter.UpdatedAt, "groups.updated_at", nil}, + &customFieldsFilterHandler{ + table: groupsCustomFieldsTable.GetTable(), + fkCol: groupIDColumn, + c: groupFilter.CustomFields, + idCol: "groups.id", + }, + &relatedFilterHandler{ relatedIDCol: "groups_scenes.scene_id", relatedRepo: sceneRepository.repository, diff --git a/pkg/sqlite/group_test.go b/pkg/sqlite/group_test.go index db293dd92..22b551e02 100644 --- a/pkg/sqlite/group_test.go +++ b/pkg/sqlite/group_test.go @@ -566,6 +566,79 @@ func Test_groupQueryBuilder_UpdatePartial(t *testing.T) { } } +func Test_GroupStore_UpdatePartialCustomFields(t *testing.T) { + tests := []struct { + name string + id int + partial models.GroupPartial + expected map[string]interface{} // nil to use the partial + }{ + { + "set custom fields", + groupIDs[groupIdxWithChild], + models.GroupPartial{ + CustomFields: models.CustomFieldsInput{ + Full: testCustomFields, + }, + }, + nil, + }, + { + "clear custom fields", + groupIDs[groupIdxWithChild], + models.GroupPartial{ + CustomFields: models.CustomFieldsInput{ + Full: map[string]interface{}{}, + }, + }, + nil, + }, + { + "partial custom fields", + groupIDs[groupIdxWithTwoTags], + models.GroupPartial{ + CustomFields: models.CustomFieldsInput{ + Partial: map[string]interface{}{ + "string": "bbb", + "new_field": "new", + }, + }, + }, + map[string]interface{}{ + "int": int64(3), + "real": 0.3, + "string": "bbb", + "new_field": "new", + }, + }, + } + for _, tt := range tests { + qb := db.Group + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + _, err := qb.UpdatePartial(ctx, tt.id, tt.partial) + if err != nil { + t.Errorf("GroupStore.UpdatePartial() error = %v", err) + return + } + + // ensure custom fields are correct + cf, err := qb.GetCustomFields(ctx, tt.id) + if err != nil { + t.Errorf("GroupStore.GetCustomFields() error = %v", err) + return + } + if tt.expected == nil { + assert.Equal(tt.partial.CustomFields.Full, cf) + } else { + assert.Equal(tt.expected, cf) + } + }) + } +} + func TestGroupFindByName(t *testing.T) { withTxn(func(ctx context.Context) error { mqb := db.Group @@ -1917,6 +1990,245 @@ func TestGroupFindSubGroupIDs(t *testing.T) { } } +func TestGroupQueryCustomFields(t *testing.T) { + tests := []struct { + name string + filter *models.GroupFilterType + includeIdxs []int + excludeIdxs []int + wantErr bool + }{ + { + "equals", + &models.GroupFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierEquals, + Value: []any{getGroupStringValue(groupIdxWithChild, "custom")}, + }, + }, + }, + []int{groupIdxWithChild}, + nil, + false, + }, + { + "not equals", + &models.GroupFilterType{ + Name: &models.StringCriterionInput{ + Value: getGroupStringValue(groupIdxWithChild, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotEquals, + Value: []any{getGroupStringValue(groupIdxWithChild, "custom")}, + }, + }, + }, + nil, + []int{groupIdxWithChild}, + false, + }, + { + "includes", + &models.GroupFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierIncludes, + Value: []any{getGroupStringValue(groupIdxWithChild, "custom")[9:]}, + }, + }, + }, + []int{groupIdxWithChild}, + nil, + false, + }, + { + "excludes", + &models.GroupFilterType{ + Name: &models.StringCriterionInput{ + Value: getGroupStringValue(groupIdxWithChild, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierExcludes, + Value: []any{getGroupStringValue(groupIdxWithChild, "custom")[9:]}, + }, + }, + }, + nil, + []int{groupIdxWithChild}, + false, + }, + { + "regex", + &models.GroupFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierMatchesRegex, + Value: []any{".*11_custom"}, + }, + }, + }, + []int{groupIdxWithChildWithScene}, + nil, + false, + }, + { + "invalid regex", + &models.GroupFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierMatchesRegex, + Value: []any{"["}, + }, + }, + }, + nil, + nil, + true, + }, + { + "not matches regex", + &models.GroupFilterType{ + Name: &models.StringCriterionInput{ + Value: getGroupStringValue(groupIdxWithChildWithScene, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotMatchesRegex, + Value: []any{".*11_custom"}, + }, + }, + }, + nil, + []int{groupIdxWithChildWithScene}, + false, + }, + { + "invalid not matches regex", + &models.GroupFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotMatchesRegex, + Value: []any{"["}, + }, + }, + }, + nil, + nil, + true, + }, + { + "null", + &models.GroupFilterType{ + Name: &models.StringCriterionInput{ + Value: getGroupStringValue(groupIdxWithGrandParent, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "not existing", + Modifier: models.CriterionModifierIsNull, + }, + }, + }, + []int{groupIdxWithGrandParent}, + nil, + false, + }, + { + "not null", + &models.GroupFilterType{ + Name: &models.StringCriterionInput{ + Value: getGroupStringValue(groupIdxWithGrandParent, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotNull, + }, + }, + }, + []int{groupIdxWithGrandParent}, + nil, + false, + }, + { + "between", + &models.GroupFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "real", + Modifier: models.CriterionModifierBetween, + Value: []any{0.15, 0.25}, + }, + }, + }, + []int{groupIdxWithTag}, + nil, + false, + }, + { + "not between", + &models.GroupFilterType{ + Name: &models.StringCriterionInput{ + Value: getGroupStringValue(groupIdxWithTag, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "real", + Modifier: models.CriterionModifierNotBetween, + Value: []any{0.15, 0.25}, + }, + }, + }, + nil, + []int{groupIdxWithTag}, + false, + }, + } + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + groups, _, err := db.Group.Query(ctx, tt.filter, nil) + if (err != nil) != tt.wantErr { + t.Errorf("GroupStore.Query() error = %v, wantErr %v", err, tt.wantErr) + } + + if err != nil { + return + } + + ids := groupsToIDs(groups) + include := indexesToIDs(groupIDs, tt.includeIdxs) + exclude := indexesToIDs(groupIDs, tt.excludeIdxs) + + for _, i := range include { + assert.Contains(ids, i) + } + for _, e := range exclude { + assert.NotContains(ids, e) + } + }) + } +} + // TODO Update // TODO Destroy - ensure image is destroyed // TODO Find diff --git a/pkg/sqlite/migrations/82_group_custom_fields.up.sql b/pkg/sqlite/migrations/82_group_custom_fields.up.sql new file mode 100644 index 000000000..c1f287fec --- /dev/null +++ b/pkg/sqlite/migrations/82_group_custom_fields.up.sql @@ -0,0 +1,9 @@ +CREATE TABLE `group_custom_fields` ( + `group_id` integer NOT NULL, + `field` varchar(64) NOT NULL, + `value` BLOB NOT NULL, + PRIMARY KEY (`group_id`, `field`), + foreign key(`group_id`) references `groups`(`id`) on delete CASCADE +); + +CREATE INDEX `index_group_custom_fields_field_value` ON `group_custom_fields` (`field`, `value`); diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index 675b3f417..0078f1a67 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -1457,6 +1457,18 @@ func getGroupEmptyString(index int, field string) string { return v.String } +func getGroupCustomFields(index int) map[string]interface{} { + if index%5 == 0 { + return nil + } + + return map[string]interface{}{ + "string": getGroupStringValue(index, "custom"), + "int": int64(index % 5), + "real": float64(index) / 10, + } +} + // createGroups creates n groups with plain Name and o groups with camel cased NaMe included func createGroups(ctx context.Context, mqb models.GroupReaderWriter, n int, o int) error { const namePlain = "Name" @@ -1489,6 +1501,13 @@ func createGroups(ctx context.Context, mqb models.GroupReaderWriter, n int, o in return fmt.Errorf("Error creating group [%d] %v+: %s", i, group, err.Error()) } + customFields := getGroupCustomFields(i) + if customFields != nil { + if err := mqb.SetCustomFields(ctx, group.ID, models.CustomFieldsInput{Full: customFields}); err != nil { + return fmt.Errorf("Error setting custom fields for group %d: %s", group.ID, err.Error()) + } + } + groupIDs = append(groupIDs, group.ID) groupNames = append(groupNames, group.Name) } diff --git a/pkg/sqlite/tables.go b/pkg/sqlite/tables.go index 7867054ba..6c898048d 100644 --- a/pkg/sqlite/tables.go +++ b/pkg/sqlite/tables.go @@ -47,6 +47,7 @@ var ( groupsURLsJoinTable = goqu.T(groupURLsTable) groupsTagsJoinTable = goqu.T(groupsTagsTable) groupRelationsJoinTable = goqu.T(groupRelationsTable) + groupsCustomFieldsTable = goqu.T("group_custom_fields") tagsAliasesJoinTable = goqu.T(tagAliasesTable) tagRelationsJoinTable = goqu.T(tagRelationsTable)