mirror of
https://github.com/stashapp/stash.git
synced 2025-12-06 16:34:02 +01:00
Use post commit hook for post-create plugin hooks (#2920)
This commit is contained in:
parent
0359ce2ed8
commit
2564351265
12 changed files with 109 additions and 88 deletions
|
|
@ -72,17 +72,17 @@ func TestTagCreate(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tagRW.On("Query", testCtx, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
|
tagRW.On("Query", mock.Anything, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
|
||||||
{
|
{
|
||||||
ID: existingTagID,
|
ID: existingTagID,
|
||||||
Name: existingTagName,
|
Name: existingTagName,
|
||||||
},
|
},
|
||||||
}, 1, nil).Once()
|
}, 1, nil).Once()
|
||||||
tagRW.On("Query", testCtx, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
|
tagRW.On("Query", mock.Anything, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
|
||||||
tagRW.On("Query", testCtx, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
|
tagRW.On("Query", mock.Anything, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
|
||||||
|
|
||||||
expectedErr := errors.New("TagCreate error")
|
expectedErr := errors.New("TagCreate error")
|
||||||
tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, expectedErr)
|
tagRW.On("Create", mock.Anything, mock.AnythingOfType("models.Tag")).Return(nil, expectedErr)
|
||||||
|
|
||||||
_, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
|
_, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
|
||||||
Name: existingTagName,
|
Name: existingTagName,
|
||||||
|
|
@ -100,14 +100,14 @@ func TestTagCreate(t *testing.T) {
|
||||||
r = newResolver()
|
r = newResolver()
|
||||||
tagRW = r.repository.Tag.(*mocks.TagReaderWriter)
|
tagRW = r.repository.Tag.(*mocks.TagReaderWriter)
|
||||||
|
|
||||||
tagRW.On("Query", testCtx, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
|
tagRW.On("Query", mock.Anything, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
|
||||||
tagRW.On("Query", testCtx, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
|
tagRW.On("Query", mock.Anything, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
|
||||||
newTag := &models.Tag{
|
newTag := &models.Tag{
|
||||||
ID: newTagID,
|
ID: newTagID,
|
||||||
Name: tagName,
|
Name: tagName,
|
||||||
}
|
}
|
||||||
tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(newTag, nil)
|
tagRW.On("Create", mock.Anything, mock.AnythingOfType("models.Tag")).Return(newTag, nil)
|
||||||
tagRW.On("Find", testCtx, newTagID).Return(newTag, nil)
|
tagRW.On("Find", mock.Anything, newTagID).Return(newTag, nil)
|
||||||
|
|
||||||
tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
|
tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
|
||||||
Name: tagName,
|
Name: tagName,
|
||||||
|
|
|
||||||
|
|
@ -74,10 +74,10 @@ func TestSceneIdentifier_Identify(t *testing.T) {
|
||||||
|
|
||||||
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
|
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
|
||||||
|
|
||||||
mockSceneReaderWriter.On("UpdatePartial", testCtx, mock.MatchedBy(func(id int) bool {
|
mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
|
||||||
return id == errUpdateID
|
return id == errUpdateID
|
||||||
}), mock.Anything).Return(nil, errors.New("update error"))
|
}), mock.Anything).Return(nil, errors.New("update error"))
|
||||||
mockSceneReaderWriter.On("UpdatePartial", testCtx, mock.MatchedBy(func(id int) bool {
|
mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
|
||||||
return id != errUpdateID
|
return id != errUpdateID
|
||||||
}), mock.Anything).Return(nil, nil)
|
}), mock.Anything).Return(nil, nil)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -221,7 +221,7 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil
|
||||||
checksum := scene.Checksum
|
checksum := scene.Checksum
|
||||||
oshash := scene.OSHash
|
oshash := scene.OSHash
|
||||||
|
|
||||||
mgr.PluginCache.RegisterPostHooks(ctx, mgr.Database, scene.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{
|
mgr.PluginCache.RegisterPostHooks(ctx, scene.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{
|
||||||
Checksum: checksum,
|
Checksum: checksum,
|
||||||
OSHash: oshash,
|
OSHash: oshash,
|
||||||
Path: scene.Path,
|
Path: scene.Path,
|
||||||
|
|
@ -267,7 +267,7 @@ func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID file.I
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr.PluginCache.RegisterPostHooks(ctx, mgr.Database, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{
|
mgr.PluginCache.RegisterPostHooks(ctx, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{
|
||||||
Checksum: g.Checksum(),
|
Checksum: g.Checksum(),
|
||||||
Path: g.Path,
|
Path: g.Path,
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
@ -306,7 +306,7 @@ func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderI
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr.PluginCache.RegisterPostHooks(ctx, mgr.Database, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{
|
mgr.PluginCache.RegisterPostHooks(ctx, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{
|
||||||
// No checksum for folders
|
// No checksum for folders
|
||||||
// Checksum: g.Checksum(),
|
// Checksum: g.Checksum(),
|
||||||
Path: g.Path,
|
Path: g.Path,
|
||||||
|
|
@ -340,7 +340,7 @@ func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *fil
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr.PluginCache.RegisterPostHooks(ctx, mgr.Database, i.ID, plugin.ImageDestroyPost, plugin.ImageDestroyInput{
|
mgr.PluginCache.RegisterPostHooks(ctx, i.ID, plugin.ImageDestroyPost, plugin.ImageDestroyInput{
|
||||||
Checksum: i.Checksum,
|
Checksum: i.Checksum,
|
||||||
Path: i.Path,
|
Path: i.Path,
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
|
||||||
|
|
@ -70,12 +70,12 @@ func NewDeleter() *Deleter {
|
||||||
|
|
||||||
// RegisterHooks registers post-commit and post-rollback hooks.
|
// RegisterHooks registers post-commit and post-rollback hooks.
|
||||||
func (d *Deleter) RegisterHooks(ctx context.Context, mgr txn.Manager) {
|
func (d *Deleter) RegisterHooks(ctx context.Context, mgr txn.Manager) {
|
||||||
mgr.AddPostCommitHook(ctx, func(ctx context.Context) error {
|
txn.AddPostCommitHook(ctx, func(ctx context.Context) error {
|
||||||
d.Commit()
|
d.Commit()
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
mgr.AddPostRollbackHook(ctx, func(ctx context.Context) error {
|
txn.AddPostRollbackHook(ctx, func(ctx context.Context) error {
|
||||||
d.Rollback()
|
d.Rollback()
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ func (h *ScanHandler) Handle(ctx context.Context, f file.File) error {
|
||||||
return fmt.Errorf("creating new gallery: %w", err)
|
return fmt.Errorf("creating new gallery: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.PluginCache.ExecutePostHooks(ctx, newGallery.ID, plugin.GalleryCreatePost, nil, nil)
|
h.PluginCache.RegisterPostHooks(ctx, newGallery.ID, plugin.GalleryCreatePost, nil, nil)
|
||||||
|
|
||||||
existing = []*models.Gallery{newGallery}
|
existing = []*models.Gallery{newGallery}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,7 @@ func (h *ScanHandler) Handle(ctx context.Context, f file.File) error {
|
||||||
return fmt.Errorf("creating new image: %w", err)
|
return fmt.Errorf("creating new image: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.PluginCache.ExecutePostHooks(ctx, newImage.ID, plugin.ImageCreatePost, nil, nil)
|
h.PluginCache.RegisterPostHooks(ctx, newImage.ID, plugin.ImageCreatePost, nil, nil)
|
||||||
|
|
||||||
existing = []*models.Image{newImage}
|
existing = []*models.Image{newImage}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -200,8 +200,8 @@ func (c Cache) ExecutePostHooks(ctx context.Context, id int, hookType HookTrigge
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Cache) RegisterPostHooks(ctx context.Context, txnMgr txn.Manager, id int, hookType HookTriggerEnum, input interface{}, inputFields []string) {
|
func (c Cache) RegisterPostHooks(ctx context.Context, id int, hookType HookTriggerEnum, input interface{}, inputFields []string) {
|
||||||
txnMgr.AddPostCommitHook(ctx, func(ctx context.Context) error {
|
txn.AddPostCommitHook(ctx, func(ctx context.Context) error {
|
||||||
c.ExecutePostHooks(ctx, id, hookType, input, inputFields)
|
c.ExecutePostHooks(ctx, id, hookType, input, inputFields)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,7 @@ func (h *ScanHandler) Handle(ctx context.Context, f file.File) error {
|
||||||
return fmt.Errorf("creating new scene: %w", err)
|
return fmt.Errorf("creating new scene: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.PluginCache.ExecutePostHooks(ctx, newScene.ID, plugin.SceneCreatePost, nil, nil)
|
h.PluginCache.RegisterPostHooks(ctx, newScene.ID, plugin.SceneCreatePost, nil, nil)
|
||||||
|
|
||||||
existing = []*models.Scene{newScene}
|
existing = []*models.Scene{newScene}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,50 +0,0 @@
|
||||||
package sqlite
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/stashapp/stash/pkg/txn"
|
|
||||||
)
|
|
||||||
|
|
||||||
type hookManager struct {
|
|
||||||
postCommitHooks []txn.TxnFunc
|
|
||||||
postRollbackHooks []txn.TxnFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *hookManager) register(ctx context.Context) context.Context {
|
|
||||||
return context.WithValue(ctx, hookManagerKey, m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *Database) hookManager(ctx context.Context) *hookManager {
|
|
||||||
m, ok := ctx.Value(hookManagerKey).(*hookManager)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *Database) executePostCommitHooks(ctx context.Context) {
|
|
||||||
m := db.hookManager(ctx)
|
|
||||||
for _, h := range m.postCommitHooks {
|
|
||||||
// ignore errors
|
|
||||||
_ = h(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *Database) executePostRollbackHooks(ctx context.Context) {
|
|
||||||
m := db.hookManager(ctx)
|
|
||||||
for _, h := range m.postRollbackHooks {
|
|
||||||
// ignore errors
|
|
||||||
_ = h(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *Database) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) {
|
|
||||||
m := db.hookManager(ctx)
|
|
||||||
m.postCommitHooks = append(m.postCommitHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *Database) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) {
|
|
||||||
m := db.hookManager(ctx)
|
|
||||||
m.postRollbackHooks = append(m.postRollbackHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
@ -17,7 +17,6 @@ type key int
|
||||||
const (
|
const (
|
||||||
txnKey key = iota + 1
|
txnKey key = iota + 1
|
||||||
dbKey
|
dbKey
|
||||||
hookManagerKey
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (db *Database) WithDatabase(ctx context.Context) (context.Context, error) {
|
func (db *Database) WithDatabase(ctx context.Context) (context.Context, error) {
|
||||||
|
|
@ -42,9 +41,6 @@ func (db *Database) Begin(ctx context.Context) (context.Context, error) {
|
||||||
return nil, fmt.Errorf("beginning transaction: %w", err)
|
return nil, fmt.Errorf("beginning transaction: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
hookMgr := &hookManager{}
|
|
||||||
ctx = hookMgr.register(ctx)
|
|
||||||
|
|
||||||
return context.WithValue(ctx, txnKey, tx), nil
|
return context.WithValue(ctx, txnKey, tx), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -58,9 +54,6 @@ func (db *Database) Commit(ctx context.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute post-commit hooks
|
|
||||||
db.executePostCommitHooks(ctx)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -74,9 +67,6 @@ func (db *Database) Rollback(ctx context.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute post-rollback hooks
|
|
||||||
db.executePostRollbackHooks(ctx)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
54
pkg/txn/hooks.go
Normal file
54
pkg/txn/hooks.go
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
package txn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type key int
|
||||||
|
|
||||||
|
const (
|
||||||
|
hookManagerKey key = iota + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
type hookManager struct {
|
||||||
|
postCommitHooks []TxnFunc
|
||||||
|
postRollbackHooks []TxnFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *hookManager) register(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, hookManagerKey, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hookManagerCtx(ctx context.Context) *hookManager {
|
||||||
|
m, ok := ctx.Value(hookManagerKey).(*hookManager)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func executePostCommitHooks(ctx context.Context) {
|
||||||
|
m := hookManagerCtx(ctx)
|
||||||
|
for _, h := range m.postCommitHooks {
|
||||||
|
// ignore errors
|
||||||
|
_ = h(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func executePostRollbackHooks(ctx context.Context) {
|
||||||
|
m := hookManagerCtx(ctx)
|
||||||
|
for _, h := range m.postRollbackHooks {
|
||||||
|
// ignore errors
|
||||||
|
_ = h(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddPostCommitHook(ctx context.Context, hook TxnFunc) {
|
||||||
|
m := hookManagerCtx(ctx)
|
||||||
|
m.postCommitHooks = append(m.postCommitHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddPostRollbackHook(ctx context.Context, hook TxnFunc) {
|
||||||
|
m := hookManagerCtx(ctx)
|
||||||
|
m.postRollbackHooks = append(m.postRollbackHooks, hook)
|
||||||
|
}
|
||||||
|
|
@ -11,9 +11,6 @@ type Manager interface {
|
||||||
Rollback(ctx context.Context) error
|
Rollback(ctx context.Context) error
|
||||||
|
|
||||||
IsLocked(err error) bool
|
IsLocked(err error) bool
|
||||||
|
|
||||||
AddPostCommitHook(ctx context.Context, hook TxnFunc)
|
|
||||||
AddPostRollbackHook(ctx context.Context, hook TxnFunc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type DatabaseProvider interface {
|
type DatabaseProvider interface {
|
||||||
|
|
@ -26,7 +23,7 @@ type TxnFunc func(ctx context.Context) error
|
||||||
// the transaction is rolled back. Otherwise it is committed.
|
// the transaction is rolled back. Otherwise it is committed.
|
||||||
func WithTxn(ctx context.Context, m Manager, fn TxnFunc) error {
|
func WithTxn(ctx context.Context, m Manager, fn TxnFunc) error {
|
||||||
var err error
|
var err error
|
||||||
ctx, err = m.Begin(ctx)
|
ctx, err = begin(ctx, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -34,16 +31,16 @@ func WithTxn(ctx context.Context, m Manager, fn TxnFunc) error {
|
||||||
defer func() {
|
defer func() {
|
||||||
if p := recover(); p != nil {
|
if p := recover(); p != nil {
|
||||||
// a panic occurred, rollback and repanic
|
// a panic occurred, rollback and repanic
|
||||||
_ = m.Rollback(ctx)
|
rollback(ctx, m)
|
||||||
panic(p)
|
panic(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// something went wrong, rollback
|
// something went wrong, rollback
|
||||||
_ = m.Rollback(ctx)
|
rollback(ctx, m)
|
||||||
} else {
|
} else {
|
||||||
// all good, commit
|
// all good, commit
|
||||||
err = m.Commit(ctx)
|
err = commit(ctx, m)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
@ -51,6 +48,36 @@ func WithTxn(ctx context.Context, m Manager, fn TxnFunc) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func begin(ctx context.Context, m Manager) (context.Context, error) {
|
||||||
|
var err error
|
||||||
|
ctx, err = m.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hm := hookManager{}
|
||||||
|
ctx = hm.register(ctx)
|
||||||
|
|
||||||
|
return ctx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func commit(ctx context.Context, m Manager) error {
|
||||||
|
if err := m.Commit(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
executePostCommitHooks(ctx)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func rollback(ctx context.Context, m Manager) {
|
||||||
|
if err := m.Rollback(ctx); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
executePostRollbackHooks(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// WithDatabase executes fn with the context provided by p.WithDatabase.
|
// WithDatabase executes fn with the context provided by p.WithDatabase.
|
||||||
// It does not run inside a transaction, so all database operations will be
|
// It does not run inside a transaction, so all database operations will be
|
||||||
// executed in their own transaction.
|
// executed in their own transaction.
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue