diff --git a/internal/api/resolver_mutation_tag_test.go b/internal/api/resolver_mutation_tag_test.go index bfd2781c3..cc0bd79a7 100644 --- a/internal/api/resolver_mutation_tag_test.go +++ b/internal/api/resolver_mutation_tag_test.go @@ -72,17 +72,17 @@ func TestTagCreate(t *testing.T) { } } - tagRW.On("Query", testCtx, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{ + tagRW.On("Query", mock.Anything, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{ { ID: existingTagID, Name: existingTagName, }, }, 1, nil).Once() - tagRW.On("Query", testCtx, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once() - tagRW.On("Query", testCtx, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once() + tagRW.On("Query", mock.Anything, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once() + tagRW.On("Query", mock.Anything, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once() expectedErr := errors.New("TagCreate error") - tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, expectedErr) + tagRW.On("Create", mock.Anything, mock.AnythingOfType("models.Tag")).Return(nil, expectedErr) _, err := r.Mutation().TagCreate(testCtx, TagCreateInput{ Name: existingTagName, @@ -100,14 +100,14 @@ func TestTagCreate(t *testing.T) { r = newResolver() tagRW = r.repository.Tag.(*mocks.TagReaderWriter) - tagRW.On("Query", testCtx, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once() - tagRW.On("Query", testCtx, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once() + tagRW.On("Query", mock.Anything, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once() + tagRW.On("Query", mock.Anything, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once() newTag := &models.Tag{ ID: newTagID, Name: tagName, } - tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(newTag, nil) - tagRW.On("Find", testCtx, newTagID).Return(newTag, nil) + tagRW.On("Create", mock.Anything, mock.AnythingOfType("models.Tag")).Return(newTag, nil) + tagRW.On("Find", mock.Anything, newTagID).Return(newTag, nil) tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{ Name: tagName, diff --git a/internal/identify/identify_test.go b/internal/identify/identify_test.go index a9395265d..751f9bf4c 100644 --- a/internal/identify/identify_test.go +++ b/internal/identify/identify_test.go @@ -74,10 +74,10 @@ func TestSceneIdentifier_Identify(t *testing.T) { mockSceneReaderWriter := &mocks.SceneReaderWriter{} - mockSceneReaderWriter.On("UpdatePartial", testCtx, mock.MatchedBy(func(id int) bool { + mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool { return id == errUpdateID }), mock.Anything).Return(nil, errors.New("update error")) - mockSceneReaderWriter.On("UpdatePartial", testCtx, mock.MatchedBy(func(id int) bool { + mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool { return id != errUpdateID }), mock.Anything).Return(nil, nil) diff --git a/internal/manager/task_clean.go b/internal/manager/task_clean.go index 61076b15d..541ca3548 100644 --- a/internal/manager/task_clean.go +++ b/internal/manager/task_clean.go @@ -221,7 +221,7 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil checksum := scene.Checksum oshash := scene.OSHash - mgr.PluginCache.RegisterPostHooks(ctx, mgr.Database, scene.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{ + mgr.PluginCache.RegisterPostHooks(ctx, scene.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{ Checksum: checksum, OSHash: oshash, Path: scene.Path, @@ -267,7 +267,7 @@ func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID file.I return err } - mgr.PluginCache.RegisterPostHooks(ctx, mgr.Database, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{ + mgr.PluginCache.RegisterPostHooks(ctx, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{ Checksum: g.Checksum(), Path: g.Path, }, nil) @@ -306,7 +306,7 @@ func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderI return err } - mgr.PluginCache.RegisterPostHooks(ctx, mgr.Database, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{ + mgr.PluginCache.RegisterPostHooks(ctx, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{ // No checksum for folders // Checksum: g.Checksum(), Path: g.Path, @@ -340,7 +340,7 @@ func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *fil return err } - mgr.PluginCache.RegisterPostHooks(ctx, mgr.Database, i.ID, plugin.ImageDestroyPost, plugin.ImageDestroyInput{ + mgr.PluginCache.RegisterPostHooks(ctx, i.ID, plugin.ImageDestroyPost, plugin.ImageDestroyInput{ Checksum: i.Checksum, Path: i.Path, }, nil) diff --git a/pkg/file/delete.go b/pkg/file/delete.go index badbb5096..c71d73428 100644 --- a/pkg/file/delete.go +++ b/pkg/file/delete.go @@ -70,12 +70,12 @@ func NewDeleter() *Deleter { // RegisterHooks registers post-commit and post-rollback hooks. func (d *Deleter) RegisterHooks(ctx context.Context, mgr txn.Manager) { - mgr.AddPostCommitHook(ctx, func(ctx context.Context) error { + txn.AddPostCommitHook(ctx, func(ctx context.Context) error { d.Commit() return nil }) - mgr.AddPostRollbackHook(ctx, func(ctx context.Context) error { + txn.AddPostRollbackHook(ctx, func(ctx context.Context) error { d.Rollback() return nil }) diff --git a/pkg/gallery/scan.go b/pkg/gallery/scan.go index 3908f1cc2..7c31c2ccf 100644 --- a/pkg/gallery/scan.go +++ b/pkg/gallery/scan.go @@ -68,7 +68,7 @@ func (h *ScanHandler) Handle(ctx context.Context, f file.File) error { return fmt.Errorf("creating new gallery: %w", err) } - h.PluginCache.ExecutePostHooks(ctx, newGallery.ID, plugin.GalleryCreatePost, nil, nil) + h.PluginCache.RegisterPostHooks(ctx, newGallery.ID, plugin.GalleryCreatePost, nil, nil) existing = []*models.Gallery{newGallery} } diff --git a/pkg/image/scan.go b/pkg/image/scan.go index 4f313ccc5..61ef9e6e3 100644 --- a/pkg/image/scan.go +++ b/pkg/image/scan.go @@ -125,7 +125,7 @@ func (h *ScanHandler) Handle(ctx context.Context, f file.File) error { return fmt.Errorf("creating new image: %w", err) } - h.PluginCache.ExecutePostHooks(ctx, newImage.ID, plugin.ImageCreatePost, nil, nil) + h.PluginCache.RegisterPostHooks(ctx, newImage.ID, plugin.ImageCreatePost, nil, nil) existing = []*models.Image{newImage} } diff --git a/pkg/plugin/plugins.go b/pkg/plugin/plugins.go index ea66adcc2..5f74b1d8b 100644 --- a/pkg/plugin/plugins.go +++ b/pkg/plugin/plugins.go @@ -200,8 +200,8 @@ func (c Cache) ExecutePostHooks(ctx context.Context, id int, hookType HookTrigge } } -func (c Cache) RegisterPostHooks(ctx context.Context, txnMgr txn.Manager, id int, hookType HookTriggerEnum, input interface{}, inputFields []string) { - txnMgr.AddPostCommitHook(ctx, func(ctx context.Context) error { +func (c Cache) RegisterPostHooks(ctx context.Context, id int, hookType HookTriggerEnum, input interface{}, inputFields []string) { + txn.AddPostCommitHook(ctx, func(ctx context.Context) error { c.ExecutePostHooks(ctx, id, hookType, input, inputFields) return nil }) diff --git a/pkg/scene/scan.go b/pkg/scene/scan.go index cf9b0d6fc..41490b952 100644 --- a/pkg/scene/scan.go +++ b/pkg/scene/scan.go @@ -93,7 +93,7 @@ func (h *ScanHandler) Handle(ctx context.Context, f file.File) error { return fmt.Errorf("creating new scene: %w", err) } - h.PluginCache.ExecutePostHooks(ctx, newScene.ID, plugin.SceneCreatePost, nil, nil) + h.PluginCache.RegisterPostHooks(ctx, newScene.ID, plugin.SceneCreatePost, nil, nil) existing = []*models.Scene{newScene} } diff --git a/pkg/sqlite/hooks.go b/pkg/sqlite/hooks.go deleted file mode 100644 index 468bbbdf9..000000000 --- a/pkg/sqlite/hooks.go +++ /dev/null @@ -1,50 +0,0 @@ -package sqlite - -import ( - "context" - - "github.com/stashapp/stash/pkg/txn" -) - -type hookManager struct { - postCommitHooks []txn.TxnFunc - postRollbackHooks []txn.TxnFunc -} - -func (m *hookManager) register(ctx context.Context) context.Context { - return context.WithValue(ctx, hookManagerKey, m) -} - -func (db *Database) hookManager(ctx context.Context) *hookManager { - m, ok := ctx.Value(hookManagerKey).(*hookManager) - if !ok { - return nil - } - return m -} - -func (db *Database) executePostCommitHooks(ctx context.Context) { - m := db.hookManager(ctx) - for _, h := range m.postCommitHooks { - // ignore errors - _ = h(ctx) - } -} - -func (db *Database) executePostRollbackHooks(ctx context.Context) { - m := db.hookManager(ctx) - for _, h := range m.postRollbackHooks { - // ignore errors - _ = h(ctx) - } -} - -func (db *Database) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) { - m := db.hookManager(ctx) - m.postCommitHooks = append(m.postCommitHooks, hook) -} - -func (db *Database) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) { - m := db.hookManager(ctx) - m.postRollbackHooks = append(m.postRollbackHooks, hook) -} diff --git a/pkg/sqlite/transaction.go b/pkg/sqlite/transaction.go index 0e3234c5c..42b65ad7b 100644 --- a/pkg/sqlite/transaction.go +++ b/pkg/sqlite/transaction.go @@ -17,7 +17,6 @@ type key int const ( txnKey key = iota + 1 dbKey - hookManagerKey ) func (db *Database) WithDatabase(ctx context.Context) (context.Context, error) { @@ -42,9 +41,6 @@ func (db *Database) Begin(ctx context.Context) (context.Context, error) { return nil, fmt.Errorf("beginning transaction: %w", err) } - hookMgr := &hookManager{} - ctx = hookMgr.register(ctx) - return context.WithValue(ctx, txnKey, tx), nil } @@ -58,9 +54,6 @@ func (db *Database) Commit(ctx context.Context) error { return err } - // execute post-commit hooks - db.executePostCommitHooks(ctx) - return nil } @@ -74,9 +67,6 @@ func (db *Database) Rollback(ctx context.Context) error { return err } - // execute post-rollback hooks - db.executePostRollbackHooks(ctx) - return nil } diff --git a/pkg/txn/hooks.go b/pkg/txn/hooks.go new file mode 100644 index 000000000..8ace7c3d5 --- /dev/null +++ b/pkg/txn/hooks.go @@ -0,0 +1,54 @@ +package txn + +import ( + "context" +) + +type key int + +const ( + hookManagerKey key = iota + 1 +) + +type hookManager struct { + postCommitHooks []TxnFunc + postRollbackHooks []TxnFunc +} + +func (m *hookManager) register(ctx context.Context) context.Context { + return context.WithValue(ctx, hookManagerKey, m) +} + +func hookManagerCtx(ctx context.Context) *hookManager { + m, ok := ctx.Value(hookManagerKey).(*hookManager) + if !ok { + return nil + } + return m +} + +func executePostCommitHooks(ctx context.Context) { + m := hookManagerCtx(ctx) + for _, h := range m.postCommitHooks { + // ignore errors + _ = h(ctx) + } +} + +func executePostRollbackHooks(ctx context.Context) { + m := hookManagerCtx(ctx) + for _, h := range m.postRollbackHooks { + // ignore errors + _ = h(ctx) + } +} + +func AddPostCommitHook(ctx context.Context, hook TxnFunc) { + m := hookManagerCtx(ctx) + m.postCommitHooks = append(m.postCommitHooks, hook) +} + +func AddPostRollbackHook(ctx context.Context, hook TxnFunc) { + m := hookManagerCtx(ctx) + m.postRollbackHooks = append(m.postRollbackHooks, hook) +} diff --git a/pkg/txn/transaction.go b/pkg/txn/transaction.go index 117e44eac..401286a47 100644 --- a/pkg/txn/transaction.go +++ b/pkg/txn/transaction.go @@ -11,9 +11,6 @@ type Manager interface { Rollback(ctx context.Context) error IsLocked(err error) bool - - AddPostCommitHook(ctx context.Context, hook TxnFunc) - AddPostRollbackHook(ctx context.Context, hook TxnFunc) } type DatabaseProvider interface { @@ -26,7 +23,7 @@ type TxnFunc func(ctx context.Context) error // the transaction is rolled back. Otherwise it is committed. func WithTxn(ctx context.Context, m Manager, fn TxnFunc) error { var err error - ctx, err = m.Begin(ctx) + ctx, err = begin(ctx, m) if err != nil { return err } @@ -34,16 +31,16 @@ func WithTxn(ctx context.Context, m Manager, fn TxnFunc) error { defer func() { if p := recover(); p != nil { // a panic occurred, rollback and repanic - _ = m.Rollback(ctx) + rollback(ctx, m) panic(p) } if err != nil { // something went wrong, rollback - _ = m.Rollback(ctx) + rollback(ctx, m) } else { // all good, commit - err = m.Commit(ctx) + err = commit(ctx, m) } }() @@ -51,6 +48,36 @@ func WithTxn(ctx context.Context, m Manager, fn TxnFunc) error { return err } +func begin(ctx context.Context, m Manager) (context.Context, error) { + var err error + ctx, err = m.Begin(ctx) + if err != nil { + return nil, err + } + + hm := hookManager{} + ctx = hm.register(ctx) + + return ctx, nil +} + +func commit(ctx context.Context, m Manager) error { + if err := m.Commit(ctx); err != nil { + return err + } + + executePostCommitHooks(ctx) + return nil +} + +func rollback(ctx context.Context, m Manager) { + if err := m.Rollback(ctx); err != nil { + return + } + + executePostRollbackHooks(ctx) +} + // WithDatabase executes fn with the context provided by p.WithDatabase. // It does not run inside a transaction, so all database operations will be // executed in their own transaction.