diff --git a/internal/api/resolver_mutation_performer.go b/internal/api/resolver_mutation_performer.go index 858739e6b..82a7df118 100644 --- a/internal/api/resolver_mutation_performer.go +++ b/internal/api/resolver_mutation_performer.go @@ -154,7 +154,7 @@ func validateNoLegacyURLs(translator changesetTranslator) error { return nil } -func (r *mutationResolver) handleLegacyURLs(ctx context.Context, performerID int, legacyURLs *LegacyURLs, updatedPerformer *models.PerformerPartial) error { +func (r *mutationResolver) handleLegacyURLs(ctx context.Context, performerID int, legacyURLs legacyPerformerURLs, updatedPerformer *models.PerformerPartial) error { qb := r.repository.Performer // we need to be careful with URL/Twitter/Instagram @@ -229,17 +229,25 @@ func (r *mutationResolver) handleLegacyURLs(ctx context.Context, performerID int return nil } -type LegacyURLs struct { +type legacyPerformerURLs struct { URL models.OptionalString Twitter models.OptionalString Instagram models.OptionalString } -func (u *LegacyURLs) AnySet() bool { +func (u *legacyPerformerURLs) AnySet() bool { return u.URL.Set || u.Twitter.Set || u.Instagram.Set } -func performerPartialFromInput(input models.PerformerUpdateInput, translator changesetTranslator) (*models.PerformerPartial, *LegacyURLs, error) { +func legacyPerformerURLsFromInput(input models.PerformerUpdateInput, translator changesetTranslator) legacyPerformerURLs { + return legacyPerformerURLs{ + URL: translator.optionalString(input.URL, "url"), + Twitter: translator.optionalString(input.Twitter, "twitter"), + Instagram: translator.optionalString(input.Instagram, "instagram"), + } +} + +func performerPartialFromInput(input models.PerformerUpdateInput, translator changesetTranslator) (*models.PerformerPartial, error) { // Populate performer from the input updatedPerformer := models.NewPerformerPartial() @@ -269,25 +277,19 @@ func performerPartialFromInput(input models.PerformerUpdateInput, translator cha if translator.hasField("urls") { // ensure url/twitter/instagram are not included in the input if err := validateNoLegacyURLs(translator); err != nil { - return nil, nil, err + return nil, err } updatedPerformer.URLs = translator.updateStrings(input.Urls, "urls") } - var legacyURLs = LegacyURLs{ - URL: translator.optionalString(input.URL, "url"), - Twitter: translator.optionalString(input.Twitter, "twitter"), - Instagram: translator.optionalString(input.Instagram, "instagram"), - } - updatedPerformer.Birthdate, err = translator.optionalDate(input.Birthdate, "birthdate") if err != nil { - return nil, nil, fmt.Errorf("converting birthdate: %w", err) + return nil, fmt.Errorf("converting birthdate: %w", err) } updatedPerformer.DeathDate, err = translator.optionalDate(input.DeathDate, "death_date") if err != nil { - return nil, nil, fmt.Errorf("converting death date: %w", err) + return nil, fmt.Errorf("converting death date: %w", err) } // prefer height_cm over height @@ -302,7 +304,7 @@ func performerPartialFromInput(input models.PerformerUpdateInput, translator cha updatedPerformer.TagIDs, err = translator.updateIds(input.TagIds, "tag_ids") if err != nil { - return nil, nil, fmt.Errorf("converting tag ids: %w", err) + return nil, fmt.Errorf("converting tag ids: %w", err) } updatedPerformer.CustomFields = input.CustomFields @@ -310,7 +312,7 @@ func performerPartialFromInput(input models.PerformerUpdateInput, translator cha updatedPerformer.CustomFields.Full = convertMapJSONNumbers(updatedPerformer.CustomFields.Full) updatedPerformer.CustomFields.Partial = convertMapJSONNumbers(updatedPerformer.CustomFields.Partial) - return &updatedPerformer, &legacyURLs, nil + return &updatedPerformer, nil } func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.PerformerUpdateInput) (*models.Performer, error) { @@ -323,11 +325,13 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per inputMap: getUpdateInputMap(ctx), } - updatedPerformer, legacyURLs, err := performerPartialFromInput(input, translator) + updatedPerformer, err := performerPartialFromInput(input, translator) if err != nil { return nil, err } + legacyURLs := legacyPerformerURLsFromInput(input, translator) + var imageData []byte imageIncluded := translator.hasField("image") if input.Image != nil { @@ -347,11 +351,11 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per } } - if err := performer.ValidateUpdate(ctx, performerID, updatedPerformer, qb); err != nil { + if err := performer.ValidateUpdate(ctx, performerID, *updatedPerformer, qb); err != nil { return err } - _, err = qb.UpdatePartial(ctx, performerID, updatedPerformer) + _, err = qb.UpdatePartial(ctx, performerID, *updatedPerformer) if err != nil { return err } @@ -415,7 +419,7 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe updatedPerformer.URLs = translator.updateStringsBulk(input.Urls, "urls") } - var legacyURLs = LegacyURLs{ + legacyURLs := legacyPerformerURLs{ URL: translator.optionalString(input.URL, "url"), Twitter: translator.optionalString(input.Twitter, "twitter"), Instagram: translator.optionalString(input.Instagram, "instagram"), @@ -453,16 +457,16 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe for _, performerID := range performerIDs { if legacyURLs.AnySet() { - if err := r.handleLegacyURLs(ctx, performerID, &legacyURLs, &updatedPerformer); err != nil { + if err := r.handleLegacyURLs(ctx, performerID, legacyURLs, &updatedPerformer); err != nil { return err } } - if err := performer.ValidateUpdate(ctx, performerID, &updatedPerformer, qb); err != nil { + if err := performer.ValidateUpdate(ctx, performerID, updatedPerformer, qb); err != nil { return err } - performer, err := qb.UpdatePartial(ctx, performerID, &updatedPerformer) + performer, err := qb.UpdatePartial(ctx, performerID, updatedPerformer) if err != nil { return err } @@ -555,18 +559,18 @@ func (r *mutationResolver) PerformerMerge(ctx context.Context, input PerformerMe var values *models.PerformerPartial var imageData []byte - var legacyURLs *LegacyURLs if input.Values != nil { translator := changesetTranslator{ inputMap: getNamedUpdateInputMap(ctx, "input.values"), } - values, legacyURLs, err = performerPartialFromInput(*input.Values, translator) + values, err = performerPartialFromInput(*input.Values, translator) if err != nil { return nil, err } - if legacyURLs != nil && legacyURLs.AnySet() { + legacyURLs := legacyPerformerURLsFromInput(*input.Values, translator) + if legacyURLs.AnySet() { return nil, errors.New("Merging legacy performer URLs is not supported") } @@ -588,7 +592,7 @@ func (r *mutationResolver) PerformerMerge(ctx context.Context, input PerformerMe dest, err = qb.Find(ctx, destID) if err != nil { - return fmt.Errorf("finding destination scene ID %d: %w", destID, err) + return fmt.Errorf("finding destination performer ID %d: %w", destID, err) } sources, err := qb.FindMany(ctx, srcIDs) @@ -602,7 +606,7 @@ func (r *mutationResolver) PerformerMerge(ctx context.Context, input PerformerMe } } - if _, err := qb.UpdatePartial(ctx, destID, values); err != nil { + if _, err := qb.UpdatePartial(ctx, destID, *values); err != nil { return fmt.Errorf("updating performer: %w", err) } diff --git a/internal/api/resolver_mutation_studio.go b/internal/api/resolver_mutation_studio.go index 4b3316111..da3aa1983 100644 --- a/internal/api/resolver_mutation_studio.go +++ b/internal/api/resolver_mutation_studio.go @@ -134,7 +134,7 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio if translator.hasField("urls") { // ensure url not included in the input - if err := r.validateNoLegacyURLs(translator); err != nil { + if err := validateNoLegacyURLs(translator); err != nil { return nil, err } @@ -211,7 +211,7 @@ func (r *mutationResolver) BulkStudioUpdate(ctx context.Context, input BulkStudi if translator.hasField("urls") { // ensure url/twitter/instagram are not included in the input - if err := r.validateNoLegacyURLs(translator); err != nil { + if err := validateNoLegacyURLs(translator); err != nil { return nil, err } diff --git a/internal/manager/task_stash_box_tag.go b/internal/manager/task_stash_box_tag.go index 28fee1c5b..d7d987a6d 100644 --- a/internal/manager/task_stash_box_tag.go +++ b/internal/manager/task_stash_box_tag.go @@ -186,11 +186,11 @@ func (t *stashBoxBatchPerformerTagTask) processMatchedPerformer(ctx context.Cont } } - if err := performer.ValidateUpdate(ctx, t.performer.ID, &partial, qb); err != nil { + if err := performer.ValidateUpdate(ctx, t.performer.ID, partial, qb); err != nil { return err } - if _, err := qb.UpdatePartial(ctx, t.performer.ID, &partial); err != nil { + if _, err := qb.UpdatePartial(ctx, t.performer.ID, partial); err != nil { return err } diff --git a/pkg/models/mocks/PerformerReaderWriter.go b/pkg/models/mocks/PerformerReaderWriter.go index a73011330..6487bc5a5 100644 --- a/pkg/models/mocks/PerformerReaderWriter.go +++ b/pkg/models/mocks/PerformerReaderWriter.go @@ -590,11 +590,11 @@ func (_m *PerformerReaderWriter) UpdateImage(ctx context.Context, performerID in } // UpdatePartial provides a mock function with given fields: ctx, id, updatedPerformer -func (_m *PerformerReaderWriter) UpdatePartial(ctx context.Context, id int, updatedPerformer *models.PerformerPartial) (*models.Performer, error) { +func (_m *PerformerReaderWriter) UpdatePartial(ctx context.Context, id int, updatedPerformer models.PerformerPartial) (*models.Performer, error) { ret := _m.Called(ctx, id, updatedPerformer) var r0 *models.Performer - if rf, ok := ret.Get(0).(func(context.Context, int, *models.PerformerPartial) *models.Performer); ok { + if rf, ok := ret.Get(0).(func(context.Context, int, models.PerformerPartial) *models.Performer); ok { r0 = rf(ctx, id, updatedPerformer) } else { if ret.Get(0) != nil { @@ -603,7 +603,7 @@ func (_m *PerformerReaderWriter) UpdatePartial(ctx context.Context, id int, upda } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, int, *models.PerformerPartial) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, int, models.PerformerPartial) error); ok { r1 = rf(ctx, id, updatedPerformer) } else { r1 = ret.Error(1) diff --git a/pkg/models/repository_performer.go b/pkg/models/repository_performer.go index cf769f8bc..175208c9d 100644 --- a/pkg/models/repository_performer.go +++ b/pkg/models/repository_performer.go @@ -49,7 +49,7 @@ type PerformerCreator interface { // PerformerUpdater provides methods to update performers. type PerformerUpdater interface { Update(ctx context.Context, updatedPerformer *UpdatePerformerInput) error - UpdatePartial(ctx context.Context, id int, updatedPerformer *PerformerPartial) (*Performer, error) + UpdatePartial(ctx context.Context, id int, updatedPerformer PerformerPartial) (*Performer, error) UpdateImage(ctx context.Context, performerID int, image []byte) error } diff --git a/pkg/performer/validate.go b/pkg/performer/validate.go index 89dd290f9..68f7a8ef5 100644 --- a/pkg/performer/validate.go +++ b/pkg/performer/validate.go @@ -66,7 +66,7 @@ func ValidateCreate(ctx context.Context, performer models.Performer, qb models.P return nil } -func ValidateUpdate(ctx context.Context, id int, partial *models.PerformerPartial, qb models.PerformerReader) error { +func ValidateUpdate(ctx context.Context, id int, partial models.PerformerPartial, qb models.PerformerReader) error { existing, err := qb.Find(ctx, id) if err != nil { return err diff --git a/pkg/sqlite/performer.go b/pkg/sqlite/performer.go index 74808f424..ab46426e4 100644 --- a/pkg/sqlite/performer.go +++ b/pkg/sqlite/performer.go @@ -138,7 +138,7 @@ type performerRowRecord struct { updateRecord } -func (r *performerRowRecord) fromPartial(o *models.PerformerPartial) { +func (r *performerRowRecord) fromPartial(o models.PerformerPartial) { r.setString("name", o.Name) r.setNullString("disambiguation", o.Disambiguation) r.setNullString("gender", o.Gender) @@ -302,7 +302,7 @@ func (qb *PerformerStore) Create(ctx context.Context, newObject *models.CreatePe return nil } -func (qb *PerformerStore) UpdatePartial(ctx context.Context, id int, partial *models.PerformerPartial) (*models.Performer, error) { +func (qb *PerformerStore) UpdatePartial(ctx context.Context, id int, partial models.PerformerPartial) (*models.Performer, error) { r := performerRowRecord{ updateRecord{ Record: make(exp.Record), @@ -916,6 +916,8 @@ func (qb *PerformerStore) Merge(ctx context.Context, source []int, destination i } args = append(args, destination) + + // for each table, update source performer ids to destination performer id, ignoring duplicates for table, idColumn := range performerTables { _, err := dbWrapper.Exec(ctx, `UPDATE OR IGNORE `+table+` SET performer_id = ? diff --git a/pkg/sqlite/performer_test.go b/pkg/sqlite/performer_test.go index 0bf0aaeef..d5d8ce2fa 100644 --- a/pkg/sqlite/performer_test.go +++ b/pkg/sqlite/performer_test.go @@ -611,7 +611,7 @@ func Test_PerformerStore_UpdatePartial(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - got, err := qb.UpdatePartial(ctx, tt.id, &tt.partial) + got, err := qb.UpdatePartial(ctx, tt.id, tt.partial) if (err != nil) != tt.wantErr { t.Errorf("PerformerStore.UpdatePartial() error = %v, wantErr %v", err, tt.wantErr) return @@ -696,7 +696,7 @@ func Test_PerformerStore_UpdatePartialCustomFields(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - _, err := qb.UpdatePartial(ctx, tt.id, &tt.partial) + _, err := qb.UpdatePartial(ctx, tt.id, tt.partial) if err != nil { t.Errorf("PerformerStore.UpdatePartial() error = %v", err) return @@ -2092,7 +2092,7 @@ func testPerformerStashIDs(ctx context.Context, t *testing.T, s *models.Performe // update stash ids and ensure was updated var err error - s, err = qb.UpdatePartial(ctx, s.ID, &models.PerformerPartial{ + s, err = qb.UpdatePartial(ctx, s.ID, models.PerformerPartial{ StashIDs: &models.UpdateStashIDs{ StashIDs: []models.StashID{stashID}, Mode: models.RelationshipUpdateModeSet, @@ -2110,7 +2110,7 @@ func testPerformerStashIDs(ctx context.Context, t *testing.T, s *models.Performe assert.Equal(t, []models.StashID{stashID}, s.StashIDs.List()) // remove stash ids and ensure was updated - s, err = qb.UpdatePartial(ctx, s.ID, &models.PerformerPartial{ + s, err = qb.UpdatePartial(ctx, s.ID, models.PerformerPartial{ StashIDs: &models.UpdateStashIDs{ StashIDs: []models.StashID{stashID}, Mode: models.RelationshipUpdateModeRemove, diff --git a/pkg/sqlite/tag.go b/pkg/sqlite/tag.go index 977ac0433..dd730c62c 100644 --- a/pkg/sqlite/tag.go +++ b/pkg/sqlite/tag.go @@ -859,6 +859,8 @@ func (qb *TagStore) Merge(ctx context.Context, source []int, destination int) er } args = append(args, destination) + + // for each table, update source tag ids to destination tag id, ignoring duplicates for table, idColumn := range tagTables { _, err := dbWrapper.Exec(ctx, `UPDATE OR IGNORE `+table+` SET tag_id = ?