Use post commit hook for post-create plugin hooks (#2920)

This commit is contained in:
WithoutPants 2022-09-19 14:53:06 +10:00 committed by GitHub
parent 0359ce2ed8
commit 2564351265
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 109 additions and 88 deletions

View file

@ -72,17 +72,17 @@ func TestTagCreate(t *testing.T) {
}
}
tagRW.On("Query", testCtx, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
tagRW.On("Query", mock.Anything, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, 1, nil).Once()
tagRW.On("Query", testCtx, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", testCtx, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", mock.Anything, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", mock.Anything, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
expectedErr := errors.New("TagCreate error")
tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, expectedErr)
tagRW.On("Create", mock.Anything, mock.AnythingOfType("models.Tag")).Return(nil, expectedErr)
_, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: existingTagName,
@ -100,14 +100,14 @@ func TestTagCreate(t *testing.T) {
r = newResolver()
tagRW = r.repository.Tag.(*mocks.TagReaderWriter)
tagRW.On("Query", testCtx, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", testCtx, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", mock.Anything, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", mock.Anything, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
newTag := &models.Tag{
ID: newTagID,
Name: tagName,
}
tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(newTag, nil)
tagRW.On("Find", testCtx, newTagID).Return(newTag, nil)
tagRW.On("Create", mock.Anything, mock.AnythingOfType("models.Tag")).Return(newTag, nil)
tagRW.On("Find", mock.Anything, newTagID).Return(newTag, nil)
tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: tagName,

View file

@ -74,10 +74,10 @@ func TestSceneIdentifier_Identify(t *testing.T) {
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
mockSceneReaderWriter.On("UpdatePartial", testCtx, mock.MatchedBy(func(id int) bool {
mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
return id == errUpdateID
}), mock.Anything).Return(nil, errors.New("update error"))
mockSceneReaderWriter.On("UpdatePartial", testCtx, mock.MatchedBy(func(id int) bool {
mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
return id != errUpdateID
}), mock.Anything).Return(nil, nil)

View file

@ -221,7 +221,7 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil
checksum := scene.Checksum
oshash := scene.OSHash
mgr.PluginCache.RegisterPostHooks(ctx, mgr.Database, scene.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{
mgr.PluginCache.RegisterPostHooks(ctx, scene.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{
Checksum: checksum,
OSHash: oshash,
Path: scene.Path,
@ -267,7 +267,7 @@ func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID file.I
return err
}
mgr.PluginCache.RegisterPostHooks(ctx, mgr.Database, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{
mgr.PluginCache.RegisterPostHooks(ctx, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{
Checksum: g.Checksum(),
Path: g.Path,
}, nil)
@ -306,7 +306,7 @@ func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderI
return err
}
mgr.PluginCache.RegisterPostHooks(ctx, mgr.Database, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{
mgr.PluginCache.RegisterPostHooks(ctx, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{
// No checksum for folders
// Checksum: g.Checksum(),
Path: g.Path,
@ -340,7 +340,7 @@ func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *fil
return err
}
mgr.PluginCache.RegisterPostHooks(ctx, mgr.Database, i.ID, plugin.ImageDestroyPost, plugin.ImageDestroyInput{
mgr.PluginCache.RegisterPostHooks(ctx, i.ID, plugin.ImageDestroyPost, plugin.ImageDestroyInput{
Checksum: i.Checksum,
Path: i.Path,
}, nil)

View file

@ -70,12 +70,12 @@ func NewDeleter() *Deleter {
// RegisterHooks registers post-commit and post-rollback hooks.
func (d *Deleter) RegisterHooks(ctx context.Context, mgr txn.Manager) {
mgr.AddPostCommitHook(ctx, func(ctx context.Context) error {
txn.AddPostCommitHook(ctx, func(ctx context.Context) error {
d.Commit()
return nil
})
mgr.AddPostRollbackHook(ctx, func(ctx context.Context) error {
txn.AddPostRollbackHook(ctx, func(ctx context.Context) error {
d.Rollback()
return nil
})

View file

@ -68,7 +68,7 @@ func (h *ScanHandler) Handle(ctx context.Context, f file.File) error {
return fmt.Errorf("creating new gallery: %w", err)
}
h.PluginCache.ExecutePostHooks(ctx, newGallery.ID, plugin.GalleryCreatePost, nil, nil)
h.PluginCache.RegisterPostHooks(ctx, newGallery.ID, plugin.GalleryCreatePost, nil, nil)
existing = []*models.Gallery{newGallery}
}

View file

@ -125,7 +125,7 @@ func (h *ScanHandler) Handle(ctx context.Context, f file.File) error {
return fmt.Errorf("creating new image: %w", err)
}
h.PluginCache.ExecutePostHooks(ctx, newImage.ID, plugin.ImageCreatePost, nil, nil)
h.PluginCache.RegisterPostHooks(ctx, newImage.ID, plugin.ImageCreatePost, nil, nil)
existing = []*models.Image{newImage}
}

View file

@ -200,8 +200,8 @@ func (c Cache) ExecutePostHooks(ctx context.Context, id int, hookType HookTrigge
}
}
func (c Cache) RegisterPostHooks(ctx context.Context, txnMgr txn.Manager, id int, hookType HookTriggerEnum, input interface{}, inputFields []string) {
txnMgr.AddPostCommitHook(ctx, func(ctx context.Context) error {
func (c Cache) RegisterPostHooks(ctx context.Context, id int, hookType HookTriggerEnum, input interface{}, inputFields []string) {
txn.AddPostCommitHook(ctx, func(ctx context.Context) error {
c.ExecutePostHooks(ctx, id, hookType, input, inputFields)
return nil
})

View file

@ -93,7 +93,7 @@ func (h *ScanHandler) Handle(ctx context.Context, f file.File) error {
return fmt.Errorf("creating new scene: %w", err)
}
h.PluginCache.ExecutePostHooks(ctx, newScene.ID, plugin.SceneCreatePost, nil, nil)
h.PluginCache.RegisterPostHooks(ctx, newScene.ID, plugin.SceneCreatePost, nil, nil)
existing = []*models.Scene{newScene}
}

View file

@ -1,50 +0,0 @@
package sqlite
import (
"context"
"github.com/stashapp/stash/pkg/txn"
)
type hookManager struct {
postCommitHooks []txn.TxnFunc
postRollbackHooks []txn.TxnFunc
}
func (m *hookManager) register(ctx context.Context) context.Context {
return context.WithValue(ctx, hookManagerKey, m)
}
func (db *Database) hookManager(ctx context.Context) *hookManager {
m, ok := ctx.Value(hookManagerKey).(*hookManager)
if !ok {
return nil
}
return m
}
func (db *Database) executePostCommitHooks(ctx context.Context) {
m := db.hookManager(ctx)
for _, h := range m.postCommitHooks {
// ignore errors
_ = h(ctx)
}
}
func (db *Database) executePostRollbackHooks(ctx context.Context) {
m := db.hookManager(ctx)
for _, h := range m.postRollbackHooks {
// ignore errors
_ = h(ctx)
}
}
func (db *Database) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) {
m := db.hookManager(ctx)
m.postCommitHooks = append(m.postCommitHooks, hook)
}
func (db *Database) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) {
m := db.hookManager(ctx)
m.postRollbackHooks = append(m.postRollbackHooks, hook)
}

View file

@ -17,7 +17,6 @@ type key int
const (
txnKey key = iota + 1
dbKey
hookManagerKey
)
func (db *Database) WithDatabase(ctx context.Context) (context.Context, error) {
@ -42,9 +41,6 @@ func (db *Database) Begin(ctx context.Context) (context.Context, error) {
return nil, fmt.Errorf("beginning transaction: %w", err)
}
hookMgr := &hookManager{}
ctx = hookMgr.register(ctx)
return context.WithValue(ctx, txnKey, tx), nil
}
@ -58,9 +54,6 @@ func (db *Database) Commit(ctx context.Context) error {
return err
}
// execute post-commit hooks
db.executePostCommitHooks(ctx)
return nil
}
@ -74,9 +67,6 @@ func (db *Database) Rollback(ctx context.Context) error {
return err
}
// execute post-rollback hooks
db.executePostRollbackHooks(ctx)
return nil
}

54
pkg/txn/hooks.go Normal file
View file

@ -0,0 +1,54 @@
package txn
import (
"context"
)
type key int
const (
hookManagerKey key = iota + 1
)
type hookManager struct {
postCommitHooks []TxnFunc
postRollbackHooks []TxnFunc
}
func (m *hookManager) register(ctx context.Context) context.Context {
return context.WithValue(ctx, hookManagerKey, m)
}
func hookManagerCtx(ctx context.Context) *hookManager {
m, ok := ctx.Value(hookManagerKey).(*hookManager)
if !ok {
return nil
}
return m
}
func executePostCommitHooks(ctx context.Context) {
m := hookManagerCtx(ctx)
for _, h := range m.postCommitHooks {
// ignore errors
_ = h(ctx)
}
}
func executePostRollbackHooks(ctx context.Context) {
m := hookManagerCtx(ctx)
for _, h := range m.postRollbackHooks {
// ignore errors
_ = h(ctx)
}
}
func AddPostCommitHook(ctx context.Context, hook TxnFunc) {
m := hookManagerCtx(ctx)
m.postCommitHooks = append(m.postCommitHooks, hook)
}
func AddPostRollbackHook(ctx context.Context, hook TxnFunc) {
m := hookManagerCtx(ctx)
m.postRollbackHooks = append(m.postRollbackHooks, hook)
}

View file

@ -11,9 +11,6 @@ type Manager interface {
Rollback(ctx context.Context) error
IsLocked(err error) bool
AddPostCommitHook(ctx context.Context, hook TxnFunc)
AddPostRollbackHook(ctx context.Context, hook TxnFunc)
}
type DatabaseProvider interface {
@ -26,7 +23,7 @@ type TxnFunc func(ctx context.Context) error
// the transaction is rolled back. Otherwise it is committed.
func WithTxn(ctx context.Context, m Manager, fn TxnFunc) error {
var err error
ctx, err = m.Begin(ctx)
ctx, err = begin(ctx, m)
if err != nil {
return err
}
@ -34,16 +31,16 @@ func WithTxn(ctx context.Context, m Manager, fn TxnFunc) error {
defer func() {
if p := recover(); p != nil {
// a panic occurred, rollback and repanic
_ = m.Rollback(ctx)
rollback(ctx, m)
panic(p)
}
if err != nil {
// something went wrong, rollback
_ = m.Rollback(ctx)
rollback(ctx, m)
} else {
// all good, commit
err = m.Commit(ctx)
err = commit(ctx, m)
}
}()
@ -51,6 +48,36 @@ func WithTxn(ctx context.Context, m Manager, fn TxnFunc) error {
return err
}
func begin(ctx context.Context, m Manager) (context.Context, error) {
var err error
ctx, err = m.Begin(ctx)
if err != nil {
return nil, err
}
hm := hookManager{}
ctx = hm.register(ctx)
return ctx, nil
}
func commit(ctx context.Context, m Manager) error {
if err := m.Commit(ctx); err != nil {
return err
}
executePostCommitHooks(ctx)
return nil
}
func rollback(ctx context.Context, m Manager) {
if err := m.Rollback(ctx); err != nil {
return
}
executePostRollbackHooks(ctx)
}
// WithDatabase executes fn with the context provided by p.WithDatabase.
// It does not run inside a transaction, so all database operations will be
// executed in their own transaction.