diff --git a/internal/api/resolver_mutation_studio.go b/internal/api/resolver_mutation_studio.go index c809e3e0f..f78e866b1 100644 --- a/internal/api/resolver_mutation_studio.go +++ b/internal/api/resolver_mutation_studio.go @@ -15,12 +15,39 @@ import ( // used to refetch studio after hooks run -func setChildStudios(ctx context.Context, qb models.StudioReaderWriter, parentStudioID int, childStudioIDs []int) error { +func clearRemovedChildStudios(ctx context.Context, qb models.StudioReaderWriter, parentStudioID int, childStudioIDs []int) error { currentChildren, err := qb.FindChildren(ctx, parentStudioID) if err != nil { return err } + newChildStudioIDs := make(map[int]struct{}, len(childStudioIDs)) + for _, childStudioID := range childStudioIDs { + newChildStudioIDs[childStudioID] = struct{}{} + } + + for _, currentChild := range currentChildren { + if _, keep := newChildStudioIDs[currentChild.ID]; keep { + continue + } + + clearParentPartial := models.NewStudioPartial() + clearParentPartial.ID = currentChild.ID + clearParentPartial.ParentID = models.NewOptionalIntPtr(nil) + + if _, err := qb.UpdatePartial(ctx, clearParentPartial); err != nil { + return err + } + } + + return nil +} + +func setChildStudios(ctx context.Context, qb models.StudioReaderWriter, parentStudioID int, childStudioIDs []int) error { + if err := clearRemovedChildStudios(ctx, qb, parentStudioID, childStudioIDs); err != nil { + return err + } + newChildStudioIDs := make(map[int]struct{}, len(childStudioIDs)) for _, childStudioID := range childStudioIDs { if _, found := newChildStudioIDs[childStudioID]; found { @@ -41,20 +68,6 @@ func setChildStudios(ctx context.Context, qb models.StudioReaderWriter, parentSt } } - for _, currentChild := range currentChildren { - if _, keep := newChildStudioIDs[currentChild.ID]; keep { - continue - } - - clearParentPartial := models.NewStudioPartial() - clearParentPartial.ID = currentChild.ID - clearParentPartial.ParentID = models.NewOptionalIntPtr(nil) - - if _, err := qb.UpdatePartial(ctx, clearParentPartial); err != nil { - return err - } - } - return nil } @@ -257,6 +270,12 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio } } + if translator.hasField("child_ids") { + if err := clearRemovedChildStudios(ctx, qb, studioID, childStudioIDs); err != nil { + return err + } + } + if err := studio.ValidateModify(ctx, updatedStudio, qb); err != nil { return err }