diff --git a/graphql/documents/queries/plugins.graphql b/graphql/documents/queries/plugins.graphql index 72b8407da..d827542b6 100644 --- a/graphql/documents/queries/plugins.graphql +++ b/graphql/documents/queries/plugins.graphql @@ -10,6 +10,12 @@ query Plugins { name description } + + hooks { + name + description + hooks + } } } diff --git a/graphql/schema/types/plugin.graphql b/graphql/schema/types/plugin.graphql index 30d18c7ce..4828d7aae 100644 --- a/graphql/schema/types/plugin.graphql +++ b/graphql/schema/types/plugin.graphql @@ -7,6 +7,7 @@ type Plugin { version: String tasks: [PluginTask!] + hooks: [PluginHook!] } type PluginTask { @@ -15,6 +16,13 @@ type PluginTask { plugin: Plugin! } +type PluginHook { + name: String! + description: String + hooks: [String!] + plugin: Plugin! +} + type PluginResult { error: String result: String diff --git a/pkg/api/changeset_translator.go b/pkg/api/changeset_translator.go index c473c1c83..3864b1082 100644 --- a/pkg/api/changeset_translator.go +++ b/pkg/api/changeset_translator.go @@ -65,6 +65,15 @@ func (t changesetTranslator) hasField(field string) bool { return found } +func (t changesetTranslator) getFields() []string { + var ret []string + for k := range t.inputMap { + ret = append(ret, k) + } + + return ret +} + func (t changesetTranslator) nullString(value *string, field string) *sql.NullString { if !t.hasField(field) { return nil diff --git a/pkg/api/context_keys.go b/pkg/api/context_keys.go index 95eb0fd6a..839464af9 100644 --- a/pkg/api/context_keys.go +++ b/pkg/api/context_keys.go @@ -5,13 +5,12 @@ package api type key int const ( - galleryKey key = 0 - performerKey key = 1 - sceneKey key = 2 - studioKey key = 3 - movieKey key = 4 - ContextUser key = 5 - tagKey key = 6 - downloadKey key = 7 - imageKey key = 8 + galleryKey key = iota + performerKey + sceneKey + studioKey + movieKey + tagKey + downloadKey + imageKey ) diff --git a/pkg/api/resolver.go b/pkg/api/resolver.go index ddcedcf4a..07534fc1e 100644 --- a/pkg/api/resolver.go +++ b/pkg/api/resolver.go @@ -7,10 +7,16 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/plugin" ) +type hookExecutor interface { + ExecutePostHooks(ctx context.Context, id int, hookType plugin.HookTriggerEnum, input interface{}, inputFields []string) +} + type Resolver struct { - txnManager models.TransactionManager + txnManager models.TransactionManager + hookExecutor hookExecutor } func (r *Resolver) Gallery() models.GalleryResolver { diff --git a/pkg/api/resolver_mutation_gallery.go b/pkg/api/resolver_mutation_gallery.go index 30100fa2e..8b4259782 100644 --- a/pkg/api/resolver_mutation_gallery.go +++ b/pkg/api/resolver_mutation_gallery.go @@ -10,9 +10,21 @@ import ( "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/utils" ) +func (r *mutationResolver) getGallery(ctx context.Context, id int) (ret *models.Gallery, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Gallery().Find(id) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + func (r *mutationResolver) GalleryCreate(ctx context.Context, input models.GalleryCreateInput) (*models.Gallery, error) { // name must be provided if input.Title == "" { @@ -90,7 +102,8 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input models.Galle return nil, err } - return gallery, nil + r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryCreatePost, input, nil) + return r.getGallery(ctx, gallery.ID) } func (r *mutationResolver) updateGalleryPerformers(qb models.GalleryReaderWriter, galleryID int, performerIDs []string) error { @@ -130,7 +143,9 @@ func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.Galle return nil, err } - return ret, nil + // execute post hooks outside txn + r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.GalleryUpdatePost, input, translator.getFields()) + return r.getGallery(ctx, ret.ID) } func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.GalleryUpdateInput) (ret []*models.Gallery, err error) { @@ -156,7 +171,23 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models. return nil, err } - return ret, nil + // execute post hooks outside txn + var newRet []*models.Gallery + for i, gallery := range ret { + translator := changesetTranslator{ + inputMap: inputMaps[i], + } + + r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryUpdatePost, input, translator.getFields()) + gallery, err = r.getGallery(ctx, gallery.ID) + if err != nil { + return nil, err + } + + newRet = append(newRet, gallery) + } + + return newRet, nil } func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Gallery, error) { @@ -314,7 +345,20 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input models.B return nil, err } - return ret, nil + // execute post hooks outside of txn + var newRet []*models.Gallery + for _, gallery := range ret { + r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryUpdatePost, input, translator.getFields()) + + gallery, err := r.getGallery(ctx, gallery.ID) + if err != nil { + return nil, err + } + + newRet = append(newRet, gallery) + } + + return newRet, nil } func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids models.BulkUpdateIds) (ret []int, err error) { @@ -438,6 +482,16 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall } } + // call post hook after performing the other actions + for _, gallery := range galleries { + r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryDestroyPost, input, nil) + } + + // call image destroy post hook as well + for _, img := range imgsToDelete { + r.hookExecutor.ExecutePostHooks(ctx, img.ID, plugin.ImageDestroyPost, nil, nil) + } + return true, nil } diff --git a/pkg/api/resolver_mutation_image.go b/pkg/api/resolver_mutation_image.go index ac5131b1e..b87bbdbdd 100644 --- a/pkg/api/resolver_mutation_image.go +++ b/pkg/api/resolver_mutation_image.go @@ -8,9 +8,21 @@ import ( "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/utils" ) +func (r *mutationResolver) getImage(ctx context.Context, id int) (ret *models.Image, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Image().Find(id) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUpdateInput) (ret *models.Image, err error) { translator := changesetTranslator{ inputMap: getUpdateInputMap(ctx), @@ -24,7 +36,9 @@ func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUp return nil, err } - return ret, nil + // execute post hooks outside txn + r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.ImageUpdatePost, input, translator.getFields()) + return r.getImage(ctx, ret.ID) } func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.ImageUpdateInput) (ret []*models.Image, err error) { @@ -50,7 +64,23 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.Ima return nil, err } - return ret, nil + // execute post hooks outside txn + var newRet []*models.Image + for i, image := range ret { + translator := changesetTranslator{ + inputMap: inputMaps[i], + } + + r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageUpdatePost, input, translator.getFields()) + image, err = r.getImage(ctx, image.ID) + if err != nil { + return nil, err + } + + newRet = append(newRet, image) + } + + return newRet, nil } func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) { @@ -202,7 +232,20 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input models.Bul return nil, err } - return ret, nil + // execute post hooks outside of txn + var newRet []*models.Image + for _, image := range ret { + r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageUpdatePost, input, translator.getFields()) + + image, err = r.getImage(ctx, image.ID) + if err != nil { + return nil, err + } + + newRet = append(newRet, image) + } + + return newRet, nil } func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) { @@ -268,6 +311,9 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD manager.DeleteImageFile(image) } + // call post hook after performing the other actions + r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageDestroyPost, input, nil) + return true, nil } @@ -315,6 +361,9 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image if input.DeleteFile != nil && *input.DeleteFile { manager.DeleteImageFile(image) } + + // call post hook after performing the other actions + r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageDestroyPost, input, nil) } return true, nil diff --git a/pkg/api/resolver_mutation_metadata.go b/pkg/api/resolver_mutation_metadata.go index e24d883cf..b43e1b224 100644 --- a/pkg/api/resolver_mutation_metadata.go +++ b/pkg/api/resolver_mutation_metadata.go @@ -17,7 +17,7 @@ import ( ) func (r *mutationResolver) MetadataScan(ctx context.Context, input models.ScanMetadataInput) (string, error) { - jobID, err := manager.GetInstance().Scan(input) + jobID, err := manager.GetInstance().Scan(ctx, input) if err != nil { return "", err @@ -27,7 +27,7 @@ func (r *mutationResolver) MetadataScan(ctx context.Context, input models.ScanMe } func (r *mutationResolver) MetadataImport(ctx context.Context) (string, error) { - jobID, err := manager.GetInstance().Import() + jobID, err := manager.GetInstance().Import(ctx) if err != nil { return "", err } @@ -41,13 +41,13 @@ func (r *mutationResolver) ImportObjects(ctx context.Context, input models.Impor return "", err } - jobID := manager.GetInstance().RunSingleTask(t) + jobID := manager.GetInstance().RunSingleTask(ctx, t) return strconv.Itoa(jobID), nil } func (r *mutationResolver) MetadataExport(ctx context.Context) (string, error) { - jobID, err := manager.GetInstance().Export() + jobID, err := manager.GetInstance().Export(ctx) if err != nil { return "", err } @@ -75,7 +75,7 @@ func (r *mutationResolver) ExportObjects(ctx context.Context, input models.Expor } func (r *mutationResolver) MetadataGenerate(ctx context.Context, input models.GenerateMetadataInput) (string, error) { - jobID, err := manager.GetInstance().Generate(input) + jobID, err := manager.GetInstance().Generate(ctx, input) if err != nil { return "", err @@ -85,17 +85,17 @@ func (r *mutationResolver) MetadataGenerate(ctx context.Context, input models.Ge } func (r *mutationResolver) MetadataAutoTag(ctx context.Context, input models.AutoTagMetadataInput) (string, error) { - jobID := manager.GetInstance().AutoTag(input) + jobID := manager.GetInstance().AutoTag(ctx, input) return strconv.Itoa(jobID), nil } func (r *mutationResolver) MetadataClean(ctx context.Context, input models.CleanMetadataInput) (string, error) { - jobID := manager.GetInstance().Clean(input) + jobID := manager.GetInstance().Clean(ctx, input) return strconv.Itoa(jobID), nil } func (r *mutationResolver) MigrateHashNaming(ctx context.Context) (string, error) { - jobID := manager.GetInstance().MigrateHash() + jobID := manager.GetInstance().MigrateHash(ctx) return strconv.Itoa(jobID), nil } diff --git a/pkg/api/resolver_mutation_movie.go b/pkg/api/resolver_mutation_movie.go index 3672fd47e..e1c63974f 100644 --- a/pkg/api/resolver_mutation_movie.go +++ b/pkg/api/resolver_mutation_movie.go @@ -7,9 +7,21 @@ import ( "time" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/utils" ) +func (r *mutationResolver) getMovie(ctx context.Context, id int) (ret *models.Movie, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Movie().Find(id) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCreateInput) (*models.Movie, error) { // generate checksum from movie name rather than image checksum := utils.MD5FromString(input.Name) @@ -104,7 +116,8 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCr return nil, err } - return movie, nil + r.hookExecutor.ExecutePostHooks(ctx, movie.ID, plugin.MovieCreatePost, input, nil) + return r.getMovie(ctx, movie.ID) } func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUpdateInput) (*models.Movie, error) { @@ -203,7 +216,8 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp return nil, err } - return movie, nil + r.hookExecutor.ExecutePostHooks(ctx, movie.ID, plugin.MovieUpdatePost, input, translator.getFields()) + return r.getMovie(ctx, movie.ID) } func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieDestroyInput) (bool, error) { @@ -217,6 +231,9 @@ func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieD }); err != nil { return false, err } + + r.hookExecutor.ExecutePostHooks(ctx, id, plugin.MovieDestroyPost, input, nil) + return true, nil } @@ -238,5 +255,10 @@ func (r *mutationResolver) MoviesDestroy(ctx context.Context, movieIDs []string) }); err != nil { return false, err } + + for _, id := range ids { + r.hookExecutor.ExecutePostHooks(ctx, id, plugin.MovieDestroyPost, movieIDs, nil) + } + return true, nil } diff --git a/pkg/api/resolver_mutation_performer.go b/pkg/api/resolver_mutation_performer.go index 60af02780..66c068aa5 100644 --- a/pkg/api/resolver_mutation_performer.go +++ b/pkg/api/resolver_mutation_performer.go @@ -9,9 +9,21 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/performer" + "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/utils" ) +func (r *mutationResolver) getPerformer(ctx context.Context, id int) (ret *models.Performer, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Performer().Find(id) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.PerformerCreateInput) (*models.Performer, error) { // generate checksum from performer name rather than image checksum := utils.MD5FromString(input.Name) @@ -146,7 +158,8 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per return nil, err } - return performer, nil + r.hookExecutor.ExecutePostHooks(ctx, performer.ID, plugin.PerformerCreatePost, input, nil) + return r.getPerformer(ctx, performer.ID) } func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.PerformerUpdateInput) (*models.Performer, error) { @@ -267,7 +280,8 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per return nil, err } - return p, nil + r.hookExecutor.ExecutePostHooks(ctx, p.ID, plugin.PerformerUpdatePost, input, translator.getFields()) + return r.getPerformer(ctx, p.ID) } func (r *mutationResolver) updatePerformerTags(qb models.PerformerReaderWriter, performerID int, tagsIDs []string) error { @@ -372,7 +386,20 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input models return nil, err } - return ret, nil + // execute post hooks outside of txn + var newRet []*models.Performer + for _, performer := range ret { + r.hookExecutor.ExecutePostHooks(ctx, performer.ID, plugin.ImageUpdatePost, input, translator.getFields()) + + performer, err = r.getPerformer(ctx, performer.ID) + if err != nil { + return nil, err + } + + newRet = append(newRet, performer) + } + + return newRet, nil } func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.PerformerDestroyInput) (bool, error) { @@ -386,6 +413,9 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.Pe }); err != nil { return false, err } + + r.hookExecutor.ExecutePostHooks(ctx, id, plugin.PerformerDestroyPost, input, nil) + return true, nil } @@ -407,5 +437,10 @@ func (r *mutationResolver) PerformersDestroy(ctx context.Context, performerIDs [ }); err != nil { return false, err } + + for _, id := range ids { + r.hookExecutor.ExecutePostHooks(ctx, id, plugin.PerformerDestroyPost, performerIDs, nil) + } + return true, nil } diff --git a/pkg/api/resolver_mutation_plugin.go b/pkg/api/resolver_mutation_plugin.go index 4ad20c856..832f21371 100644 --- a/pkg/api/resolver_mutation_plugin.go +++ b/pkg/api/resolver_mutation_plugin.go @@ -2,40 +2,15 @@ package api import ( "context" - "net/http" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/manager" - "github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/plugin/common" ) func (r *mutationResolver) RunPluginTask(ctx context.Context, pluginID string, taskName string, args []*models.PluginArgInput) (string, error) { - currentUser := getCurrentUserID(ctx) - - var cookie *http.Cookie - var err error - if currentUser != nil { - cookie, err = createSessionCookie(*currentUser) - if err != nil { - return "", err - } - } - - config := config.GetInstance() - serverConnection := common.StashServerConnection{ - Scheme: "http", - Port: config.GetPort(), - SessionCookie: cookie, - Dir: config.GetConfigPath(), - } - - if HasTLSConfig() { - serverConnection.Scheme = "https" - } - - manager.GetInstance().RunPluginTask(pluginID, taskName, args, serverConnection) + m := manager.GetInstance() + m.RunPluginTask(ctx, pluginID, taskName, args) return "todo", nil } diff --git a/pkg/api/resolver_mutation_scene.go b/pkg/api/resolver_mutation_scene.go index b73acabed..a0e788454 100644 --- a/pkg/api/resolver_mutation_scene.go +++ b/pkg/api/resolver_mutation_scene.go @@ -10,9 +10,21 @@ import ( "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/utils" ) +func (r *mutationResolver) getScene(ctx context.Context, id int) (ret *models.Scene, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Scene().Find(id) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUpdateInput) (ret *models.Scene, err error) { translator := changesetTranslator{ inputMap: getUpdateInputMap(ctx), @@ -26,7 +38,8 @@ func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUp return nil, err } - return ret, nil + r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneUpdatePost, input, translator.getFields()) + return r.getScene(ctx, ret.ID) } func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.SceneUpdateInput) (ret []*models.Scene, err error) { @@ -52,7 +65,24 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce return nil, err } - return ret, nil + // execute post hooks outside of txn + var newRet []*models.Scene + for i, scene := range ret { + translator := changesetTranslator{ + inputMap: inputMaps[i], + } + + r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneUpdatePost, input, translator.getFields()) + + scene, err = r.getScene(ctx, scene.ID) + if err != nil { + return nil, err + } + + newRet = append(newRet, scene) + } + + return newRet, nil } func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) { @@ -281,7 +311,20 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.Bul return nil, err } - return ret, nil + // execute post hooks outside of txn + var newRet []*models.Scene + for _, scene := range ret { + r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneUpdatePost, input, translator.getFields()) + + scene, err = r.getScene(ctx, scene.ID) + if err != nil { + return nil, err + } + + newRet = append(newRet, scene) + } + + return newRet, nil } func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int { @@ -393,6 +436,9 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD manager.DeleteSceneFile(scene) } + // call post hook after performing the other actions + r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneDestroyPost, input, nil) + return true, nil } @@ -442,11 +488,25 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene if input.DeleteFile != nil && *input.DeleteFile { manager.DeleteSceneFile(scene) } + + // call post hook after performing the other actions + r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneDestroyPost, input, nil) } return true, nil } +func (r *mutationResolver) getSceneMarker(ctx context.Context, id int) (ret *models.SceneMarker, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.SceneMarker().Find(id) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.SceneMarkerCreateInput) (*models.SceneMarker, error) { primaryTagID, err := strconv.Atoi(input.PrimaryTagID) if err != nil { @@ -473,7 +533,13 @@ func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.S return nil, err } - return r.changeMarker(ctx, create, newSceneMarker, tagIDs) + ret, err := r.changeMarker(ctx, create, newSceneMarker, tagIDs) + if err != nil { + return nil, err + } + + r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneMarkerCreatePost, input, nil) + return r.getSceneMarker(ctx, ret.ID) } func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.SceneMarkerUpdateInput) (*models.SceneMarker, error) { @@ -507,7 +573,16 @@ func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.S return nil, err } - return r.changeMarker(ctx, update, updatedSceneMarker, tagIDs) + ret, err := r.changeMarker(ctx, update, updatedSceneMarker, tagIDs) + if err != nil { + return nil, err + } + + translator := changesetTranslator{ + inputMap: getUpdateInputMap(ctx), + } + r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneMarkerUpdatePost, input, translator.getFields()) + return r.getSceneMarker(ctx, ret.ID) } func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (bool, error) { @@ -544,6 +619,8 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b postCommitFunc() + r.hookExecutor.ExecutePostHooks(ctx, markerID, plugin.SceneMarkerDestroyPost, id, nil) + return true, nil } @@ -651,9 +728,9 @@ func (r *mutationResolver) SceneResetO(ctx context.Context, id string) (ret int, func (r *mutationResolver) SceneGenerateScreenshot(ctx context.Context, id string, at *float64) (string, error) { if at != nil { - manager.GetInstance().GenerateScreenshot(id, *at) + manager.GetInstance().GenerateScreenshot(ctx, id, *at) } else { - manager.GetInstance().GenerateDefaultScreenshot(id) + manager.GetInstance().GenerateDefaultScreenshot(ctx, id) } return "todo", nil diff --git a/pkg/api/resolver_mutation_stash_box.go b/pkg/api/resolver_mutation_stash_box.go index 6e9983523..d05212deb 100644 --- a/pkg/api/resolver_mutation_stash_box.go +++ b/pkg/api/resolver_mutation_stash_box.go @@ -24,6 +24,6 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input } func (r *mutationResolver) StashBoxBatchPerformerTag(ctx context.Context, input models.StashBoxBatchPerformerTagInput) (string, error) { - jobID := manager.GetInstance().StashBoxBatchPerformerTag(input) + jobID := manager.GetInstance().StashBoxBatchPerformerTag(ctx, input) return strconv.Itoa(jobID), nil } diff --git a/pkg/api/resolver_mutation_studio.go b/pkg/api/resolver_mutation_studio.go index 7b06485b4..108d952bc 100644 --- a/pkg/api/resolver_mutation_studio.go +++ b/pkg/api/resolver_mutation_studio.go @@ -8,9 +8,21 @@ import ( "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/utils" ) +func (r *mutationResolver) getStudio(ctx context.Context, id int) (ret *models.Studio, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Studio().Find(id) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + func (r *mutationResolver) StudioCreate(ctx context.Context, input models.StudioCreateInput) (*models.Studio, error) { // generate checksum from studio name rather than image checksum := utils.MD5FromString(input.Name) @@ -82,7 +94,8 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio return nil, err } - return studio, nil + r.hookExecutor.ExecutePostHooks(ctx, studio.ID, plugin.StudioCreatePost, input, nil) + return r.getStudio(ctx, studio.ID) } func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) { @@ -162,7 +175,8 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio return nil, err } - return studio, nil + r.hookExecutor.ExecutePostHooks(ctx, studio.ID, plugin.StudioUpdatePost, input, translator.getFields()) + return r.getStudio(ctx, studio.ID) } func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.StudioDestroyInput) (bool, error) { @@ -176,6 +190,9 @@ func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.Studi }); err != nil { return false, err } + + r.hookExecutor.ExecutePostHooks(ctx, id, plugin.StudioDestroyPost, input, nil) + return true, nil } @@ -197,5 +214,10 @@ func (r *mutationResolver) StudiosDestroy(ctx context.Context, studioIDs []strin }); err != nil { return false, err } + + for _, id := range ids { + r.hookExecutor.ExecutePostHooks(ctx, id, plugin.StudioDestroyPost, studioIDs, nil) + } + return true, nil } diff --git a/pkg/api/resolver_mutation_tag.go b/pkg/api/resolver_mutation_tag.go index 953f1ce5c..8b8682683 100644 --- a/pkg/api/resolver_mutation_tag.go +++ b/pkg/api/resolver_mutation_tag.go @@ -7,10 +7,22 @@ import ( "time" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/tag" "github.com/stashapp/stash/pkg/utils" ) +func (r *mutationResolver) getTag(ctx context.Context, id int) (ret *models.Tag, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Tag().Find(id) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreateInput) (*models.Tag, error) { // Populate a new tag from the input currentTime := time.Now() @@ -68,7 +80,8 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate return nil, err } - return t, nil + r.hookExecutor.ExecutePostHooks(ctx, t.ID, plugin.TagCreatePost, input, nil) + return r.getTag(ctx, t.ID) } func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdateInput) (*models.Tag, error) { @@ -153,7 +166,8 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate return nil, err } - return t, nil + r.hookExecutor.ExecutePostHooks(ctx, t.ID, plugin.TagUpdatePost, input, translator.getFields()) + return r.getTag(ctx, t.ID) } func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestroyInput) (bool, error) { @@ -167,6 +181,9 @@ func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestr }); err != nil { return false, err } + + r.hookExecutor.ExecutePostHooks(ctx, tagID, plugin.TagDestroyPost, input, nil) + return true, nil } @@ -188,5 +205,10 @@ func (r *mutationResolver) TagsDestroy(ctx context.Context, tagIDs []string) (bo }); err != nil { return false, err } + + for _, id := range ids { + r.hookExecutor.ExecutePostHooks(ctx, id, plugin.TagDestroyPost, tagIDs, nil) + } + return true, nil } diff --git a/pkg/api/resolver_mutation_tag_test.go b/pkg/api/resolver_mutation_tag_test.go index 371b88d33..9329f6b7d 100644 --- a/pkg/api/resolver_mutation_tag_test.go +++ b/pkg/api/resolver_mutation_tag_test.go @@ -7,6 +7,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" + "github.com/stashapp/stash/pkg/plugin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -15,7 +16,8 @@ import ( // TODO - move this into a common area func newResolver() *Resolver { return &Resolver{ - txnManager: mocks.NewTransactionManager(), + txnManager: mocks.NewTransactionManager(), + hookExecutor: &mockHookExecutor{}, } } @@ -26,6 +28,11 @@ const existingTagID = 1 const existingTagName = "existingTagName" const newTagID = 2 +type mockHookExecutor struct{} + +func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType plugin.HookTriggerEnum, input interface{}, inputFields []string) { +} + func TestTagCreate(t *testing.T) { r := newResolver() @@ -84,10 +91,12 @@ func TestTagCreate(t *testing.T) { tagRW.On("Query", tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once() tagRW.On("Query", tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once() - tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{ + newTag := &models.Tag{ ID: newTagID, Name: tagName, - }, nil) + } + tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(newTag, nil) + tagRW.On("Find", newTagID).Return(newTag, nil) tag, err := r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{ Name: tagName, diff --git a/pkg/api/server.go b/pkg/api/server.go index 762cb1fc0..ee9063287 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -29,6 +29,7 @@ import ( "github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/manager/paths" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/session" "github.com/stashapp/stash/pkg/utils" ) @@ -41,11 +42,6 @@ var uiBox *packr.Box //var legacyUiBox *packr.Box var loginUIBox *packr.Box -const ( - ApiKeyHeader = "ApiKey" - ApiKeyParameter = "apikey" -) - func allowUnauthenticated(r *http.Request) bool { return strings.HasPrefix(r.URL.Path, "/login") || r.URL.Path == "/css" } @@ -53,44 +49,26 @@ func allowUnauthenticated(r *http.Request) bool { func authenticateHandler() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c := config.GetInstance() - ctx := r.Context() - - // translate api key into current user, if present - userID := "" - apiKey := r.Header.Get(ApiKeyHeader) - var err error - - // try getting the api key as a query parameter - if apiKey == "" { - apiKey = r.URL.Query().Get(ApiKeyParameter) - } - - if apiKey != "" { - // match against configured API and set userID to the - // configured username. In future, we'll want to - // get the username from the key. - if c.GetAPIKey() != apiKey { - w.Header().Add("WWW-Authenticate", `FormBased`) - w.WriteHeader(http.StatusUnauthorized) + userID, err := manager.GetInstance().SessionStore.Authenticate(w, r) + if err != nil { + if err != session.ErrUnauthorized { + w.WriteHeader(http.StatusInternalServerError) + _, err = w.Write([]byte(err.Error())) + if err != nil { + logger.Error(err) + } return } - userID = c.GetUsername() - } else { - // handle session - userID, err = getSessionUserID(w, r) - } - - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - _, err = w.Write([]byte(err.Error())) - if err != nil { - logger.Error(err) - } + // unauthorized error + w.Header().Add("WWW-Authenticate", `FormBased`) + w.WriteHeader(http.StatusUnauthorized) return } + c := config.GetInstance() + ctx := r.Context() + // handle redirect if no user and user is required if userID == "" && c.HasCredentials() && !allowUnauthenticated(r) { // if we don't have a userID, then redirect @@ -112,7 +90,7 @@ func authenticateHandler() func(http.Handler) http.Handler { return } - ctx = context.WithValue(ctx, ContextUser, userID) + ctx = session.SetCurrentUserID(ctx, userID) r = r.WithContext(ctx) @@ -121,6 +99,16 @@ func authenticateHandler() func(http.Handler) http.Handler { } } +func visitedPluginHandler() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // get the visited plugins and set them in the context + + next.ServeHTTP(w, r) + }) + } +} + const loginEndPoint = "/login" func Start() { @@ -128,13 +116,15 @@ func Start() { //legacyUiBox = packr.New("UI Box", "../../ui/v1/dist/stash-frontend") loginUIBox = packr.New("Login UI Box", "../../ui/login") - initSessionStore() initialiseImages() r := chi.NewRouter() r.Use(middleware.Heartbeat("/healthz")) r.Use(authenticateHandler()) + visitedPluginHandler := manager.GetInstance().SessionStore.VisitedPluginHandler() + r.Use(visitedPluginHandler) + r.Use(middleware.Recoverer) c := config.GetInstance() @@ -155,8 +145,10 @@ func Start() { } txnManager := manager.GetInstance().TxnManager + pluginCache := manager.GetInstance().PluginCache resolver := &Resolver{ - txnManager: txnManager, + txnManager: txnManager, + hookExecutor: pluginCache, } gqlSrv := gqlHandler.New(models.NewExecutableSchema(models.Config{Resolvers: resolver})) @@ -184,7 +176,8 @@ func Start() { } // register GQL handler with plugin cache - manager.GetInstance().PluginCache.RegisterGQLHandler(gqlHandlerFunc) + // chain the visited plugin handler + manager.GetInstance().PluginCache.RegisterGQLHandler(visitedPluginHandler(http.HandlerFunc(gqlHandlerFunc))) r.HandleFunc("/graphql", gqlHandlerFunc) r.HandleFunc("/playground", gqlPlayground.Handler("GraphQL playground", "/graphql")) @@ -358,15 +351,6 @@ func makeTLSConfig() *tls.Config { return tlsConfig } -func HasTLSConfig() bool { - ret, _ := utils.FileExists(paths.GetSSLCert()) - if ret { - ret, _ = utils.FileExists(paths.GetSSLKey()) - } - - return ret -} - type contextKey struct { name string } diff --git a/pkg/api/session.go b/pkg/api/session.go index a81d37c9e..739df916f 100644 --- a/pkg/api/session.go +++ b/pkg/api/session.go @@ -1,15 +1,13 @@ package api import ( - "context" "fmt" "html/template" "net/http" + "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager/config" - - "github.com/gorilla/securecookie" - "github.com/gorilla/sessions" + "github.com/stashapp/stash/pkg/session" ) const cookieName = "session" @@ -19,17 +17,11 @@ const userIDKey = "userID" const returnURLParam = "returnURL" -var sessionStore = sessions.NewCookieStore(config.GetInstance().GetSessionStoreKey()) - type loginTemplateData struct { URL string Error string } -func initSessionStore() { - sessionStore.MaxAge(config.GetInstance().GetMaxSessionAge()) -} - func redirectToLogin(w http.ResponseWriter, returnURL string, loginError string) { data, _ := loginUIBox.Find("login.html") templ, err := template.New("Login").Parse(string(data)) @@ -59,22 +51,13 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { url = "/" } - // ignore error - we want a new session regardless - newSession, _ := sessionStore.Get(r, cookieName) - - username := r.FormValue("username") - password := r.FormValue("password") - - // authenticate the user - if !config.GetInstance().ValidateCredentials(username, password) { + err := manager.GetInstance().SessionStore.Login(w, r) + if err == session.ErrInvalidCredentials { // redirect back to the login page with an error redirectToLogin(w, url, "Username or password is invalid") return } - newSession.Values[userIDKey] = username - - err := newSession.Save(r, w) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -84,17 +67,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { } func handleLogout(w http.ResponseWriter, r *http.Request) { - session, err := sessionStore.Get(r, cookieName) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - delete(session.Values, userIDKey) - session.Options.MaxAge = -1 - - err = session.Save(r, w) - if err != nil { + if err := manager.GetInstance().SessionStore.Logout(w, r); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -102,51 +75,3 @@ func handleLogout(w http.ResponseWriter, r *http.Request) { // redirect to the login page if credentials are required getLoginHandler(w, r) } - -func getSessionUserID(w http.ResponseWriter, r *http.Request) (string, error) { - session, err := sessionStore.Get(r, cookieName) - // ignore errors and treat as an empty user id, so that we handle expired - // cookie - if err != nil { - return "", nil - } - - if !session.IsNew { - val := session.Values[userIDKey] - - // refresh the cookie - err = session.Save(r, w) - if err != nil { - return "", err - } - - ret, _ := val.(string) - - return ret, nil - } - - return "", nil -} - -func getCurrentUserID(ctx context.Context) *string { - userCtxVal := ctx.Value(ContextUser) - if userCtxVal != nil { - currentUser := userCtxVal.(string) - return ¤tUser - } - - return nil -} - -func createSessionCookie(username string) (*http.Cookie, error) { - session := sessions.NewSession(sessionStore, cookieName) - session.Values[userIDKey] = username - - encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, - sessionStore.Codecs...) - if err != nil { - return nil, err - } - - return sessions.NewCookie(session.Name(), encoded, session.Options), nil -} diff --git a/pkg/job/job.go b/pkg/job/job.go index 735f2d298..0ace099e1 100644 --- a/pkg/job/job.go +++ b/pkg/job/job.go @@ -55,6 +55,7 @@ type Job struct { EndTime *time.Time AddTime time.Time + outerCtx context.Context exec JobExec cancelFunc context.CancelFunc } diff --git a/pkg/job/manager.go b/pkg/job/manager.go index 9750edc50..233818483 100644 --- a/pkg/job/manager.go +++ b/pkg/job/manager.go @@ -4,6 +4,8 @@ import ( "context" "sync" "time" + + "github.com/stashapp/stash/pkg/utils" ) const maxGraveyardSize = 10 @@ -46,7 +48,7 @@ func (m *Manager) Stop() { } // Add queues a job. -func (m *Manager) Add(description string, e JobExec) int { +func (m *Manager) Add(ctx context.Context, description string, e JobExec) int { m.mutex.Lock() defer m.mutex.Unlock() @@ -58,6 +60,7 @@ func (m *Manager) Add(description string, e JobExec) int { Description: description, AddTime: t, exec: e, + outerCtx: ctx, } m.queue = append(m.queue, &j) @@ -74,7 +77,7 @@ func (m *Manager) Add(description string, e JobExec) int { // Start adds a job and starts it immediately, concurrently with any other // jobs. -func (m *Manager) Start(description string, e JobExec) int { +func (m *Manager) Start(ctx context.Context, description string, e JobExec) int { m.mutex.Lock() defer m.mutex.Unlock() @@ -86,6 +89,7 @@ func (m *Manager) Start(description string, e JobExec) int { Description: description, AddTime: t, exec: e, + outerCtx: ctx, } m.queue = append(m.queue, &j) @@ -173,7 +177,7 @@ func (m *Manager) dispatch(j *Job) (done chan struct{}) { j.StartTime = &t j.Status = StatusRunning - ctx, cancelFunc := context.WithCancel(context.Background()) + ctx, cancelFunc := context.WithCancel(utils.ValueOnlyContext(j.outerCtx)) j.cancelFunc = cancelFunc done = make(chan struct{}) diff --git a/pkg/job/manager_test.go b/pkg/job/manager_test.go index b76d424a5..51bb6a1f1 100644 --- a/pkg/job/manager_test.go +++ b/pkg/job/manager_test.go @@ -45,7 +45,7 @@ func TestAdd(t *testing.T) { const jobName = "test job" exec1 := newTestExec(make(chan struct{})) - jobID := m.Add(jobName, exec1) + jobID := m.Add(context.Background(), jobName, exec1) // expect jobID to be the first ID assert := assert.New(t) @@ -80,7 +80,7 @@ func TestAdd(t *testing.T) { // add another job to the queue const otherJobName = "other job name" exec2 := newTestExec(make(chan struct{})) - job2ID := m.Add(otherJobName, exec2) + job2ID := m.Add(context.Background(), otherJobName, exec2) // expect status to be ready j2 := m.GetJob(job2ID) @@ -130,11 +130,11 @@ func TestCancel(t *testing.T) { // add two jobs const jobName = "test job" exec1 := newTestExec(make(chan struct{})) - jobID := m.Add(jobName, exec1) + jobID := m.Add(context.Background(), jobName, exec1) const otherJobName = "other job" exec2 := newTestExec(make(chan struct{})) - job2ID := m.Add(otherJobName, exec2) + job2ID := m.Add(context.Background(), otherJobName, exec2) // wait a tiny bit time.Sleep(sleepTime) @@ -198,11 +198,11 @@ func TestCancelAll(t *testing.T) { // add two jobs const jobName = "test job" exec1 := newTestExec(make(chan struct{})) - jobID := m.Add(jobName, exec1) + jobID := m.Add(context.Background(), jobName, exec1) const otherJobName = "other job" exec2 := newTestExec(make(chan struct{})) - job2ID := m.Add(otherJobName, exec2) + job2ID := m.Add(context.Background(), otherJobName, exec2) // wait a tiny bit time.Sleep(sleepTime) @@ -246,7 +246,7 @@ func TestSubscribe(t *testing.T) { // add a job const jobName = "test job" exec1 := newTestExec(make(chan struct{})) - jobID := m.Add(jobName, exec1) + jobID := m.Add(context.Background(), jobName, exec1) assert := assert.New(t) @@ -326,7 +326,7 @@ func TestSubscribe(t *testing.T) { // add another job and cancel it exec2 := newTestExec(make(chan struct{})) - jobID = m.Add(jobName, exec2) + jobID = m.Add(context.Background(), jobName, exec2) m.CancelJob(jobID) diff --git a/pkg/manager/config/config.go b/pkg/manager/config/config.go index 0cf5d259b..21e934bca 100644 --- a/pkg/manager/config/config.go +++ b/pkg/manager/config/config.go @@ -13,6 +13,7 @@ import ( "github.com/spf13/viper" + "github.com/stashapp/stash/pkg/manager/paths" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" ) @@ -150,6 +151,15 @@ func (e MissingConfigError) Error() string { return fmt.Sprintf("missing the following mandatory settings: %s", strings.Join(e.missingFields, ", ")) } +func HasTLSConfig() bool { + ret, _ := utils.FileExists(paths.GetSSLCert()) + if ret { + ret, _ = utils.FileExists(paths.GetSSLKey()) + } + + return ret +} + type Instance struct { cpuProfilePath string isNewSystem bool diff --git a/pkg/manager/manager.go b/pkg/manager/manager.go index 82fefc6f7..349851cc2 100644 --- a/pkg/manager/manager.go +++ b/pkg/manager/manager.go @@ -19,6 +19,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/scraper" + "github.com/stashapp/stash/pkg/session" "github.com/stashapp/stash/pkg/sqlite" "github.com/stashapp/stash/pkg/utils" ) @@ -31,6 +32,8 @@ type singleton struct { FFMPEGPath string FFProbePath string + SessionStore *session.Store + JobManager *job.Manager PluginCache *plugin.Cache @@ -100,6 +103,11 @@ func Initialize() *singleton { if cfgFile != "" { cfgFile = cfgFile + " " } + + // create temporary session store - this will be re-initialised + // after config is complete + instance.SessionStore = session.NewStore(cfg) + logger.Warnf("config file %snot found. Assuming new system...", cfgFile) } @@ -179,6 +187,8 @@ func (s *singleton) PostInit() error { s.Paths = paths.NewPaths(s.Config.GetGeneratedPath()) s.RefreshConfig() + s.SessionStore = session.NewStore(s.Config) + s.PluginCache.RegisterSessionStore(s.SessionStore) if err := s.PluginCache.LoadPlugins(); err != nil { logger.Errorf("Error reading plugin configs: %s", err.Error()) diff --git a/pkg/manager/manager_tasks.go b/pkg/manager/manager_tasks.go index 74d8643bf..abe8f7560 100644 --- a/pkg/manager/manager_tasks.go +++ b/pkg/manager/manager_tasks.go @@ -60,7 +60,7 @@ func (s *singleton) ScanSubscribe(ctx context.Context) <-chan bool { return s.scanSubs.subscribe(ctx) } -func (s *singleton) Scan(input models.ScanMetadataInput) (int, error) { +func (s *singleton) Scan(ctx context.Context, input models.ScanMetadataInput) (int, error) { if err := s.validateFFMPEG(); err != nil { return 0, err } @@ -71,10 +71,10 @@ func (s *singleton) Scan(input models.ScanMetadataInput) (int, error) { subscriptions: s.scanSubs, } - return s.JobManager.Add("Scanning...", &scanJob), nil + return s.JobManager.Add(ctx, "Scanning...", &scanJob), nil } -func (s *singleton) Import() (int, error) { +func (s *singleton) Import(ctx context.Context) (int, error) { config := config.GetInstance() metadataPath := config.GetMetadataPath() if metadataPath == "" { @@ -96,10 +96,10 @@ func (s *singleton) Import() (int, error) { task.Start(&wg) }) - return s.JobManager.Add("Importing...", j), nil + return s.JobManager.Add(ctx, "Importing...", j), nil } -func (s *singleton) Export() (int, error) { +func (s *singleton) Export(ctx context.Context) (int, error) { config := config.GetInstance() metadataPath := config.GetMetadataPath() if metadataPath == "" { @@ -117,10 +117,10 @@ func (s *singleton) Export() (int, error) { task.Start(&wg) }) - return s.JobManager.Add("Exporting...", j), nil + return s.JobManager.Add(ctx, "Exporting...", j), nil } -func (s *singleton) RunSingleTask(t Task) int { +func (s *singleton) RunSingleTask(ctx context.Context, t Task) int { var wg sync.WaitGroup wg.Add(1) @@ -128,7 +128,7 @@ func (s *singleton) RunSingleTask(t Task) int { t.Start(&wg) }) - return s.JobManager.Add(t.GetDescription(), j) + return s.JobManager.Add(ctx, t.GetDescription(), j) } func setGeneratePreviewOptionsInput(optionsInput *models.GeneratePreviewOptionsInput) { @@ -159,7 +159,7 @@ func setGeneratePreviewOptionsInput(optionsInput *models.GeneratePreviewOptionsI } } -func (s *singleton) Generate(input models.GenerateMetadataInput) (int, error) { +func (s *singleton) Generate(ctx context.Context, input models.GenerateMetadataInput) (int, error) { if err := s.validateFFMPEG(); err != nil { return 0, err } @@ -367,19 +367,19 @@ func (s *singleton) Generate(input models.GenerateMetadataInput) (int, error) { logger.Info(fmt.Sprintf("Generate finished (%s)", elapsed)) }) - return s.JobManager.Add("Generating...", j), nil + return s.JobManager.Add(ctx, "Generating...", j), nil } -func (s *singleton) GenerateDefaultScreenshot(sceneId string) int { - return s.generateScreenshot(sceneId, nil) +func (s *singleton) GenerateDefaultScreenshot(ctx context.Context, sceneId string) int { + return s.generateScreenshot(ctx, sceneId, nil) } -func (s *singleton) GenerateScreenshot(sceneId string, at float64) int { - return s.generateScreenshot(sceneId, &at) +func (s *singleton) GenerateScreenshot(ctx context.Context, sceneId string, at float64) int { + return s.generateScreenshot(ctx, sceneId, &at) } // generate default screenshot if at is nil -func (s *singleton) generateScreenshot(sceneId string, at *float64) int { +func (s *singleton) generateScreenshot(ctx context.Context, sceneId string, at *float64) int { instance.Paths.Generated.EnsureTmpDir() j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { @@ -413,19 +413,19 @@ func (s *singleton) generateScreenshot(sceneId string, at *float64) int { logger.Infof("Generate screenshot finished") }) - return s.JobManager.Add(fmt.Sprintf("Generating screenshot for scene id %s", sceneId), j) + return s.JobManager.Add(ctx, fmt.Sprintf("Generating screenshot for scene id %s", sceneId), j) } -func (s *singleton) AutoTag(input models.AutoTagMetadataInput) int { +func (s *singleton) AutoTag(ctx context.Context, input models.AutoTagMetadataInput) int { j := autoTagJob{ txnManager: s.TxnManager, input: input, } - return s.JobManager.Add("Auto-tagging...", &j) + return s.JobManager.Add(ctx, "Auto-tagging...", &j) } -func (s *singleton) Clean(input models.CleanMetadataInput) int { +func (s *singleton) Clean(ctx context.Context, input models.CleanMetadataInput) int { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { var scenes []*models.Scene var images []*models.Image @@ -488,6 +488,7 @@ func (s *singleton) Clean(input models.CleanMetadataInput) int { wg.Add(1) task := CleanTask{ + ctx: ctx, TxnManager: s.TxnManager, Scene: scene, fileNamingAlgorithm: fileNamingAlgo, @@ -514,6 +515,7 @@ func (s *singleton) Clean(input models.CleanMetadataInput) int { wg.Add(1) task := CleanTask{ + ctx: ctx, TxnManager: s.TxnManager, Image: img, } @@ -538,6 +540,7 @@ func (s *singleton) Clean(input models.CleanMetadataInput) int { wg.Add(1) task := CleanTask{ + ctx: ctx, TxnManager: s.TxnManager, Gallery: gallery, } @@ -552,10 +555,10 @@ func (s *singleton) Clean(input models.CleanMetadataInput) int { s.scanSubs.notify() }) - return s.JobManager.Add("Cleaning...", j) + return s.JobManager.Add(ctx, "Cleaning...", j) } -func (s *singleton) MigrateHash() int { +func (s *singleton) MigrateHash(ctx context.Context) int { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { fileNamingAlgo := config.GetInstance().GetVideoFileNamingAlgorithm() logger.Infof("Migrating generated files for %s naming hash", fileNamingAlgo.String()) @@ -596,7 +599,7 @@ func (s *singleton) MigrateHash() int { logger.Info("Finished migrating") }) - return s.JobManager.Add("Migrating scene hashes...", j) + return s.JobManager.Add(ctx, "Migrating scene hashes...", j) } type totalsGenerate struct { @@ -702,7 +705,7 @@ func (s *singleton) neededGenerate(scenes []*models.Scene, input models.Generate return &totals } -func (s *singleton) StashBoxBatchPerformerTag(input models.StashBoxBatchPerformerTagInput) int { +func (s *singleton) StashBoxBatchPerformerTag(ctx context.Context, input models.StashBoxBatchPerformerTagInput) int { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { logger.Infof("Initiating stash-box batch performer tag") @@ -800,5 +803,5 @@ func (s *singleton) StashBoxBatchPerformerTag(input models.StashBoxBatchPerforme } }) - return s.JobManager.Add("Batch stash-box performer tag...", j) + return s.JobManager.Add(ctx, "Batch stash-box performer tag...", j) } diff --git a/pkg/manager/task_clean.go b/pkg/manager/task_clean.go index a36fcebb8..804e039b2 100644 --- a/pkg/manager/task_clean.go +++ b/pkg/manager/task_clean.go @@ -10,10 +10,12 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/utils" ) type CleanTask struct { + ctx context.Context TxnManager models.TransactionManager Scene *models.Scene Gallery *models.Gallery @@ -158,6 +160,8 @@ func (t *CleanTask) deleteScene(sceneID int) { postCommitFunc() DeleteGeneratedSceneFiles(scene, t.fileNamingAlgorithm) + + GetInstance().PluginCache.ExecutePostHooks(t.ctx, sceneID, plugin.SceneDestroyPost, nil, nil) } func (t *CleanTask) deleteGallery(galleryID int) { @@ -168,6 +172,8 @@ func (t *CleanTask) deleteGallery(galleryID int) { logger.Errorf("Error deleting gallery from database: %s", err.Error()) return } + + GetInstance().PluginCache.ExecutePostHooks(t.ctx, galleryID, plugin.GalleryDestroyPost, nil, nil) } func (t *CleanTask) deleteImage(imageID int) { @@ -185,20 +191,8 @@ func (t *CleanTask) deleteImage(imageID int) { if pathErr != nil { logger.Errorf("Error deleting thumbnail image from cache: %s", pathErr) } -} -func (t *CleanTask) fileExists(filename string) (bool, error) { - info, err := os.Stat(filename) - if os.IsNotExist(err) { - return false, nil - } - - // handle if error is something else - if err != nil { - return false, err - } - - return !info.IsDir(), nil + GetInstance().PluginCache.ExecutePostHooks(t.ctx, imageID, plugin.ImageDestroyPost, nil, nil) } func getStashFromPath(pathToCheck string) *models.StashConfig { diff --git a/pkg/manager/task_plugin.go b/pkg/manager/task_plugin.go index f2a408b53..600584c49 100644 --- a/pkg/manager/task_plugin.go +++ b/pkg/manager/task_plugin.go @@ -7,13 +7,12 @@ import ( "github.com/stashapp/stash/pkg/job" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/plugin/common" ) -func (s *singleton) RunPluginTask(pluginID string, taskName string, args []*models.PluginArgInput, serverConnection common.StashServerConnection) int { - j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { +func (s *singleton) RunPluginTask(ctx context.Context, pluginID string, taskName string, args []*models.PluginArgInput) int { + j := job.MakeJobExec(func(jobCtx context.Context, progress *job.Progress) { pluginProgress := make(chan float64) - task, err := s.PluginCache.CreateTask(pluginID, taskName, serverConnection, args, pluginProgress) + task, err := s.PluginCache.CreateTask(ctx, pluginID, taskName, args, pluginProgress) if err != nil { logger.Errorf("Error creating plugin task: %s", err.Error()) return @@ -48,7 +47,7 @@ func (s *singleton) RunPluginTask(pluginID string, taskName string, args []*mode return case p := <-pluginProgress: progress.SetPercent(p) - case <-ctx.Done(): + case <-jobCtx.Done(): if err := task.Stop(); err != nil { logger.Errorf("Error stopping plugin operation: %s", err.Error()) } @@ -57,5 +56,5 @@ func (s *singleton) RunPluginTask(pluginID string, taskName string, args []*mode } }) - return s.JobManager.Add(fmt.Sprintf("Running plugin task: %s", taskName), j) + return s.JobManager.Add(ctx, fmt.Sprintf("Running plugin task: %s", taskName), j) } diff --git a/pkg/manager/task_scan.go b/pkg/manager/task_scan.go index 53dda3486..c90157611 100644 --- a/pkg/manager/task_scan.go +++ b/pkg/manager/task_scan.go @@ -21,6 +21,7 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/utils" ) @@ -102,6 +103,7 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) { GeneratePhash: utils.IsTrue(input.ScanGeneratePhashes), progress: progress, CaseSensitiveFs: csFs, + ctx: ctx, } go func() { @@ -201,6 +203,7 @@ func (j *ScanJob) neededScan(ctx context.Context, paths []*models.StashConfig) ( } type ScanTask struct { + ctx context.Context TxnManager models.TransactionManager FilePath string UseFileMetadata bool @@ -424,6 +427,8 @@ func (t *ScanTask) scanGallery() { if err != nil { return err } + + GetInstance().PluginCache.ExecutePostHooks(t.ctx, g.ID, plugin.GalleryUpdatePost, nil, nil) } } else { currentTime := time.Now() @@ -461,6 +466,8 @@ func (t *ScanTask) scanGallery() { return err } scanImages = true + + GetInstance().PluginCache.ExecutePostHooks(t.ctx, g.ID, plugin.GalleryCreatePost, nil, nil) } } @@ -787,6 +794,8 @@ func (t *ScanTask) scanScene() *models.Scene { }); err != nil { return logError(err) } + + GetInstance().PluginCache.ExecutePostHooks(t.ctx, s.ID, plugin.SceneUpdatePost, nil, nil) } } else { logger.Infof("%s doesn't exist. Creating new item...", t.FilePath) @@ -826,6 +835,8 @@ func (t *ScanTask) scanScene() *models.Scene { }); err != nil { return logError(err) } + + GetInstance().PluginCache.ExecutePostHooks(t.ctx, retScene.ID, plugin.SceneCreatePost, nil, nil) } return retScene @@ -895,6 +906,8 @@ func (t *ScanTask) rescanScene(s *models.Scene, fileModTime time.Time) (*models. return nil, err } + GetInstance().PluginCache.ExecutePostHooks(t.ctx, ret.ID, plugin.SceneUpdatePost, nil, nil) + // leave the generated files as is - the scene file may have been moved // elsewhere @@ -1081,6 +1094,8 @@ func (t *ScanTask) scanImage() { logger.Error(err.Error()) return } + + GetInstance().PluginCache.ExecutePostHooks(t.ctx, i.ID, plugin.ImageUpdatePost, nil, nil) } } else { logger.Infof("%s doesn't exist. Creating new item...", image.PathDisplayName(t.FilePath)) @@ -1111,6 +1126,8 @@ func (t *ScanTask) scanImage() { logger.Error(err.Error()) return } + + GetInstance().PluginCache.ExecutePostHooks(t.ctx, i.ID, plugin.ImageCreatePost, nil, nil) } if t.zipGallery != nil { @@ -1186,6 +1203,8 @@ func (t *ScanTask) rescanImage(i *models.Image, fileModTime time.Time) (*models. } } + GetInstance().PluginCache.ExecutePostHooks(t.ctx, ret.ID, plugin.ImageUpdatePost, nil, nil) + return ret, nil } diff --git a/pkg/plugin/common/msg.go b/pkg/plugin/common/msg.go index 39eea860e..d7e93c6ea 100644 --- a/pkg/plugin/common/msg.go +++ b/pkg/plugin/common/msg.go @@ -2,6 +2,10 @@ package common import "net/http" +const ( + HookContextKey = "hookContext" +) + // StashServerConnection represents the connection details needed for a // plugin instance to connect to its parent stash server. type StashServerConnection struct { @@ -97,3 +101,12 @@ func (o *PluginOutput) SetError(err error) { errStr := err.Error() o.Error = &errStr } + +// HookContext is passed as a PluginArgValue and indicates what hook triggered +// this plugin task. +type HookContext struct { + ID int `json:"id,omitempty"` + Type string `json:"type"` + Input interface{} `json:"input"` + InputFields []string `json:"inputFields,omitempty"` +} diff --git a/pkg/plugin/config.go b/pkg/plugin/config.go index 3f8ffad41..a56c5520d 100644 --- a/pkg/plugin/config.go +++ b/pkg/plugin/config.go @@ -54,6 +54,9 @@ type Config struct { // The task configurations for tasks provided by this plugin. Tasks []*OperationConfig `yaml:"tasks"` + + // The hooks configurations for hooks registered by this plugin. + Hooks []*HookConfig `yaml:"hooks"` } func (c Config) getPluginTasks(includePlugin bool) []*models.PluginTask { @@ -74,6 +77,34 @@ func (c Config) getPluginTasks(includePlugin bool) []*models.PluginTask { return ret } +func (c Config) getPluginHooks(includePlugin bool) []*models.PluginHook { + var ret []*models.PluginHook + + for _, o := range c.Hooks { + hook := &models.PluginHook{ + Name: o.Name, + Description: &o.Description, + Hooks: convertHooks(o.TriggeredBy), + } + + if includePlugin { + hook.Plugin = c.toPlugin() + } + ret = append(ret, hook) + } + + return ret +} + +func convertHooks(hooks []HookTriggerEnum) []string { + var ret []string + for _, h := range hooks { + ret = append(ret, h.String()) + } + + return ret +} + func (c Config) getName() string { if c.Name != "" { return c.Name @@ -90,6 +121,7 @@ func (c Config) toPlugin() *models.Plugin { URL: c.URL, Version: c.Version, Tasks: c.getPluginTasks(false), + Hooks: c.getPluginHooks(false), } } @@ -103,6 +135,19 @@ func (c Config) getTask(name string) *OperationConfig { return nil } +func (c Config) getHooks(hookType HookTriggerEnum) []*HookConfig { + var ret []*HookConfig + for _, h := range c.Hooks { + for _, t := range h.TriggeredBy { + if hookType == t { + ret = append(ret, h) + } + } + } + + return ret +} + func (c Config) getConfigPath() string { return filepath.Dir(c.path) } @@ -194,6 +239,13 @@ type OperationConfig struct { DefaultArgs map[string]string `yaml:"defaultArgs"` } +type HookConfig struct { + OperationConfig `yaml:",inline"` + + // A list of stash operations that will be used to trigger this hook operation. + TriggeredBy []HookTriggerEnum `yaml:"triggeredBy"` +} + func loadPluginFromYAML(reader io.Reader) (*Config, error) { ret := &Config{} diff --git a/pkg/plugin/examples/js/js.js b/pkg/plugin/examples/js/js.js index 84acfe076..39ba1f5c6 100644 --- a/pkg/plugin/examples/js/js.js +++ b/pkg/plugin/examples/js/js.js @@ -11,6 +11,8 @@ function main() { doLongTask(); } else if (modeArg == "indef") { doIndefiniteTask(); + } else if (modeArg == "hook") { + doHookTask(); } } catch (err) { return { @@ -207,4 +209,9 @@ function doIndefiniteTask() { } } +function doHookTask() { + log.Info("JS Hook called!"); + log.Info(input.Args); +} + main(); \ No newline at end of file diff --git a/pkg/plugin/examples/js/js.yml b/pkg/plugin/examples/js/js.yml index 1c8dc4571..25f0d56e3 100644 --- a/pkg/plugin/examples/js/js.yml +++ b/pkg/plugin/examples/js/js.yml @@ -24,4 +24,35 @@ tasks: description: Sleeps for 100 seconds - interruptable defaultArgs: mode: long - +hooks: + - name: Log scene marker create/update + description: Logs some stuff when creating/updating scene marker. + triggeredBy: + - SceneMarker.Create.Post + - SceneMarker.Update.Post + - SceneMarker.Delete.Post + - Scene.Create.Post + - Scene.Update.Post + - Scene.Destroy.Post + - Image.Create.Post + - Image.Update.Post + - Image.Destroy.Post + - Gallery.Create.Post + - Gallery.Update.Post + - Gallery.Destroy.Post + - Movie.Create.Post + - Movie.Update.Post + - Movie.Destroy.Post + - Performer.Create.Post + - Performer.Update.Post + - Performer.Destroy.Post + - Studio.Create.Post + - Studio.Update.Post + - Studio.Destroy.Post + - Tag.Create.Post + - Tag.Update.Post + - Tag.Destroy.Post + defaultArgs: + mode: hook + + diff --git a/pkg/plugin/examples/python/stash_interface.py b/pkg/plugin/examples/python/stash_interface.py index e05d27ddc..f86149899 100644 --- a/pkg/plugin/examples/python/stash_interface.py +++ b/pkg/plugin/examples/python/stash_interface.py @@ -17,7 +17,10 @@ class StashInterface: self.url = scheme + "://localhost:" + str(self.port) + "/graphql" - # TODO - cookies + # Session cookie for authentication + self.cookies = { + 'session': conn.get('SessionCookie').get('Value') + } def __callGraphQL(self, query, variables = None): json = {} @@ -26,7 +29,7 @@ class StashInterface: json['variables'] = variables # handle cookies - response = requests.post(self.url, json=json, headers=self.headers) + response = requests.post(self.url, json=json, headers=self.headers, cookies=self.cookies) if response.status_code == 200: result = response.json() diff --git a/pkg/plugin/hooks.go b/pkg/plugin/hooks.go new file mode 100644 index 000000000..9ff83f06f --- /dev/null +++ b/pkg/plugin/hooks.go @@ -0,0 +1,125 @@ +package plugin + +import ( + "github.com/stashapp/stash/pkg/plugin/common" +) + +type HookTriggerEnum string + +// Scan-related hooks are current disabled until post-hook execution is +// integrated. + +const ( + SceneMarkerCreatePost HookTriggerEnum = "SceneMarker.Create.Post" + SceneMarkerUpdatePost HookTriggerEnum = "SceneMarker.Update.Post" + SceneMarkerDestroyPost HookTriggerEnum = "SceneMarker.Destroy.Post" + + SceneCreatePost HookTriggerEnum = "Scene.Create.Post" + SceneUpdatePost HookTriggerEnum = "Scene.Update.Post" + SceneDestroyPost HookTriggerEnum = "Scene.Destroy.Post" + + ImageCreatePost HookTriggerEnum = "Image.Create.Post" + ImageUpdatePost HookTriggerEnum = "Image.Update.Post" + ImageDestroyPost HookTriggerEnum = "Image.Destroy.Post" + + GalleryCreatePost HookTriggerEnum = "Gallery.Create.Post" + GalleryUpdatePost HookTriggerEnum = "Gallery.Update.Post" + GalleryDestroyPost HookTriggerEnum = "Gallery.Destroy.Post" + + MovieCreatePost HookTriggerEnum = "Movie.Create.Post" + MovieUpdatePost HookTriggerEnum = "Movie.Update.Post" + MovieDestroyPost HookTriggerEnum = "Movie.Destroy.Post" + + PerformerCreatePost HookTriggerEnum = "Performer.Create.Post" + PerformerUpdatePost HookTriggerEnum = "Performer.Update.Post" + PerformerDestroyPost HookTriggerEnum = "Performer.Destroy.Post" + + StudioCreatePost HookTriggerEnum = "Studio.Create.Post" + StudioUpdatePost HookTriggerEnum = "Studio.Update.Post" + StudioDestroyPost HookTriggerEnum = "Studio.Destroy.Post" + + TagCreatePost HookTriggerEnum = "Tag.Create.Post" + TagUpdatePost HookTriggerEnum = "Tag.Update.Post" + TagDestroyPost HookTriggerEnum = "Tag.Destroy.Post" +) + +var AllHookTriggerEnum = []HookTriggerEnum{ + SceneMarkerCreatePost, + SceneMarkerUpdatePost, + SceneMarkerDestroyPost, + + SceneCreatePost, + SceneUpdatePost, + SceneDestroyPost, + + ImageCreatePost, + ImageUpdatePost, + ImageDestroyPost, + + GalleryCreatePost, + GalleryUpdatePost, + GalleryDestroyPost, + + MovieCreatePost, + MovieUpdatePost, + MovieDestroyPost, + + PerformerCreatePost, + PerformerUpdatePost, + PerformerDestroyPost, + + StudioCreatePost, + StudioUpdatePost, + StudioDestroyPost, + + TagCreatePost, + TagUpdatePost, + TagDestroyPost, +} + +func (e HookTriggerEnum) IsValid() bool { + + switch e { + case SceneMarkerCreatePost, + SceneMarkerUpdatePost, + SceneMarkerDestroyPost, + + SceneCreatePost, + SceneUpdatePost, + SceneDestroyPost, + + ImageCreatePost, + ImageUpdatePost, + ImageDestroyPost, + + GalleryCreatePost, + GalleryUpdatePost, + GalleryDestroyPost, + + MovieCreatePost, + MovieUpdatePost, + MovieDestroyPost, + + PerformerCreatePost, + PerformerUpdatePost, + PerformerDestroyPost, + + StudioCreatePost, + StudioUpdatePost, + StudioDestroyPost, + + TagCreatePost, + TagUpdatePost, + TagDestroyPost: + return true + } + return false +} + +func (e HookTriggerEnum) String() string { + return string(e) +} + +func addHookContext(argsMap common.ArgsMap, hookContext common.HookContext) { + argsMap[common.HookContextKey] = hookContext +} diff --git a/pkg/plugin/js.go b/pkg/plugin/js.go index c295998bd..6b070ce46 100644 --- a/pkg/plugin/js.go +++ b/pkg/plugin/js.go @@ -28,11 +28,6 @@ type jsPluginTask struct { vm *otto.Otto } -func throw(vm *otto.Otto, str string) { - value, _ := vm.Call("new Error", nil, str) - panic(value) -} - func (t *jsPluginTask) onError(err error) { errString := err.Error() t.result = &common.PluginOutput{ @@ -76,12 +71,10 @@ func (t *jsPluginTask) Start() error { return err } - input := t.buildPluginInput() - - t.vm.Set("input", input) + t.vm.Set("input", t.input) js.AddLogAPI(t.vm, t.progress) js.AddUtilAPI(t.vm) - js.AddGQLAPI(t.vm, t.gqlHandler) + js.AddGQLAPI(t.vm, t.input.ServerConnection.SessionCookie, t.gqlHandler) t.vm.Interrupt = make(chan func(), 1) diff --git a/pkg/plugin/js/gql.go b/pkg/plugin/js/gql.go index 0258c12e9..13d5fe003 100644 --- a/pkg/plugin/js/gql.go +++ b/pkg/plugin/js/gql.go @@ -33,7 +33,7 @@ func throw(vm *otto.Otto, str string) { panic(value) } -func gqlRequestFunc(vm *otto.Otto, gqlHandler http.HandlerFunc) func(call otto.FunctionCall) otto.Value { +func gqlRequestFunc(vm *otto.Otto, cookie *http.Cookie, gqlHandler http.Handler) func(call otto.FunctionCall) otto.Value { return func(call otto.FunctionCall) otto.Value { if len(call.ArgumentList) == 0 { throw(vm, "missing argument") @@ -67,11 +67,15 @@ func gqlRequestFunc(vm *otto.Otto, gqlHandler http.HandlerFunc) func(call otto.F } r.Header.Set("Content-Type", "application/json") + if cookie != nil { + r.AddCookie(cookie) + } + w := &responseWriter{ header: make(http.Header), } - gqlHandler(w, r) + gqlHandler.ServeHTTP(w, r) if w.statusCode != http.StatusOK && w.statusCode != 0 { throw(vm, fmt.Sprintf("graphQL query failed: %d - %s. Query: %s. Variables: %v", w.statusCode, w.r.String(), in.Query, in.Variables)) @@ -99,9 +103,9 @@ func gqlRequestFunc(vm *otto.Otto, gqlHandler http.HandlerFunc) func(call otto.F } } -func AddGQLAPI(vm *otto.Otto, gqlHandler http.HandlerFunc) { +func AddGQLAPI(vm *otto.Otto, cookie *http.Cookie, gqlHandler http.Handler) { gql, _ := vm.Object("({})") - gql.Set("Do", gqlRequestFunc(vm, gqlHandler)) + gql.Set("Do", gqlRequestFunc(vm, cookie, gqlHandler)) vm.Set("gql", gql) } diff --git a/pkg/plugin/js/log.go b/pkg/plugin/js/log.go index 30fe76825..35d23a537 100644 --- a/pkg/plugin/js/log.go +++ b/pkg/plugin/js/log.go @@ -8,6 +8,8 @@ import ( "github.com/stashapp/stash/pkg/logger" ) +const pluginPrefix = "[Plugin] " + func argToString(call otto.FunctionCall) string { arg := call.Argument(0) if arg.IsObject() { @@ -20,27 +22,27 @@ func argToString(call otto.FunctionCall) string { } func logTrace(call otto.FunctionCall) otto.Value { - logger.Trace(argToString(call)) + logger.Trace(pluginPrefix + argToString(call)) return otto.UndefinedValue() } func logDebug(call otto.FunctionCall) otto.Value { - logger.Debug(argToString(call)) + logger.Debug(pluginPrefix + argToString(call)) return otto.UndefinedValue() } func logInfo(call otto.FunctionCall) otto.Value { - logger.Info(argToString(call)) + logger.Info(pluginPrefix + argToString(call)) return otto.UndefinedValue() } func logWarn(call otto.FunctionCall) otto.Value { - logger.Warn(argToString(call)) + logger.Warn(pluginPrefix + argToString(call)) return otto.UndefinedValue() } func logError(call otto.FunctionCall) otto.Value { - logger.Error(argToString(call)) + logger.Error(pluginPrefix + argToString(call)) return otto.UndefinedValue() } diff --git a/pkg/plugin/plugins.go b/pkg/plugin/plugins.go index 1eb218f1c..c38ebcc5f 100644 --- a/pkg/plugin/plugins.go +++ b/pkg/plugin/plugins.go @@ -8,6 +8,7 @@ package plugin import ( + "context" "fmt" "net/http" "os" @@ -17,13 +18,16 @@ import ( "github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin/common" + "github.com/stashapp/stash/pkg/session" + "github.com/stashapp/stash/pkg/utils" ) // Cache stores plugin details. type Cache struct { - config *config.Instance - plugins []Config - gqlHandler http.HandlerFunc + config *config.Instance + plugins []Config + sessionStore *session.Store + gqlHandler http.Handler } // NewCache returns a new Cache. @@ -39,10 +43,14 @@ func NewCache(config *config.Instance) *Cache { } } -func (c *Cache) RegisterGQLHandler(handler http.HandlerFunc) { +func (c *Cache) RegisterGQLHandler(handler http.Handler) { c.gqlHandler = handler } +func (c *Cache) RegisterSessionStore(sessionStore *session.Store) { + c.sessionStore = sessionStore +} + // LoadPlugins clears the plugin cache and loads from the plugin path. // In the event of an error during loading, the cache will be left empty. func (c *Cache) LoadPlugins() error { @@ -105,10 +113,38 @@ func (c Cache) ListPluginTasks() []*models.PluginTask { return ret } +func buildPluginInput(plugin *Config, operation *OperationConfig, serverConnection common.StashServerConnection, args []*models.PluginArgInput) common.PluginInput { + args = applyDefaultArgs(args, operation.DefaultArgs) + serverConnection.PluginDir = plugin.getConfigPath() + return common.PluginInput{ + ServerConnection: serverConnection, + Args: toPluginArgs(args), + } +} + +func (c Cache) makeServerConnection(ctx context.Context) common.StashServerConnection { + cookie := c.sessionStore.MakePluginCookie(ctx) + + serverConnection := common.StashServerConnection{ + Scheme: "http", + Port: c.config.GetPort(), + SessionCookie: cookie, + Dir: c.config.GetConfigPath(), + } + + if config.HasTLSConfig() { + serverConnection.Scheme = "https" + } + + return serverConnection +} + // CreateTask runs the plugin operation for the pluginID and operation // name provided. Returns an error if the plugin or the operation could not be // resolved. -func (c Cache) CreateTask(pluginID string, operationName string, serverConnection common.StashServerConnection, args []*models.PluginArgInput, progress chan float64) (Task, error) { +func (c Cache) CreateTask(ctx context.Context, pluginID string, operationName string, args []*models.PluginArgInput, progress chan float64) (Task, error) { + serverConnection := c.makeServerConnection(ctx) + // find the plugin and operation plugin := c.getPlugin(pluginID) @@ -122,16 +158,88 @@ func (c Cache) CreateTask(pluginID string, operationName string, serverConnectio } task := pluginTask{ - plugin: plugin, - operation: operation, - serverConnection: serverConnection, - args: args, - progress: progress, - gqlHandler: c.gqlHandler, + plugin: plugin, + operation: operation, + input: buildPluginInput(plugin, operation, serverConnection, args), + progress: progress, + gqlHandler: c.gqlHandler, } return task.createTask(), nil } +func (c Cache) ExecutePostHooks(ctx context.Context, id int, hookType HookTriggerEnum, input interface{}, inputFields []string) { + if err := c.executePostHooks(ctx, hookType, common.HookContext{ + ID: id, + Type: hookType.String(), + Input: input, + InputFields: inputFields, + }); err != nil { + logger.Errorf("error executing post hooks: %s", err.Error()) + } +} + +func (c Cache) executePostHooks(ctx context.Context, hookType HookTriggerEnum, hookContext common.HookContext) error { + visitedPlugins := session.GetVisitedPlugins(ctx) + + for _, p := range c.plugins { + hooks := p.getHooks(hookType) + // don't revisit a plugin we've already visited + // only log if there's hooks that we're skipping + if len(hooks) > 0 && utils.StrInclude(visitedPlugins, p.id) { + logger.Debugf("plugin ID '%s' already triggered, not re-triggering", p.id) + continue + } + + for _, h := range hooks { + newCtx := session.AddVisitedPlugin(ctx, p.id) + serverConnection := c.makeServerConnection(newCtx) + + pluginInput := buildPluginInput(&p, &h.OperationConfig, serverConnection, nil) + addHookContext(pluginInput.Args, hookContext) + + pt := pluginTask{ + plugin: &p, + operation: &h.OperationConfig, + input: pluginInput, + gqlHandler: c.gqlHandler, + } + + task := pt.createTask() + if err := task.Start(); err != nil { + return err + } + + // handle cancel from context + c := make(chan struct{}) + go func() { + task.Wait() + close(c) + }() + + select { + case <-ctx.Done(): + task.Stop() + return fmt.Errorf("operation cancelled") + case <-c: + // task finished normally + } + + output := task.GetResult() + if output == nil { + logger.Debugf("%s [%s]: returned no result", hookType.String(), p.Name) + } else { + if output.Error != nil { + logger.Errorf("%s [%s]: returned error: %s", hookType.String(), p.Name, *output.Error) + } else if output.Output != nil { + logger.Debugf("%s [%s]: returned: %v", hookType.String(), p.Name, output.Output) + } + } + } + } + + return nil +} + func (c Cache) getPlugin(pluginID string) *Config { for _, s := range c.plugins { if s.id == pluginID { diff --git a/pkg/plugin/raw.go b/pkg/plugin/raw.go index 0169b9bb6..fe44c2f6d 100644 --- a/pkg/plugin/raw.go +++ b/pkg/plugin/raw.go @@ -50,8 +50,7 @@ func (t *rawPluginTask) Start() error { go func() { defer stdin.Close() - input := t.buildPluginInput() - inBytes, _ := json.Marshal(input) + inBytes, _ := json.Marshal(t.input) io.WriteString(stdin, string(inBytes)) }() diff --git a/pkg/plugin/rpc.go b/pkg/plugin/rpc.go index f12b17cd8..dff9774c0 100644 --- a/pkg/plugin/rpc.go +++ b/pkg/plugin/rpc.go @@ -70,12 +70,10 @@ func (t *rpcPluginTask) Start() error { Client: t.client, } - input := t.buildPluginInput() - t.done = make(chan *rpc.Call, 1) result := common.PluginOutput{} t.waitGroup.Add(1) - iface.RunAsync(input, &result, t.done) + iface.RunAsync(t.input, &result, t.done) go t.waitToFinish(&result) t.started = true diff --git a/pkg/plugin/task.go b/pkg/plugin/task.go index c3df2e2ba..4b4a9d870 100644 --- a/pkg/plugin/task.go +++ b/pkg/plugin/task.go @@ -3,7 +3,6 @@ package plugin import ( "net/http" - "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin/common" ) @@ -31,11 +30,10 @@ type taskBuilder interface { } type pluginTask struct { - plugin *Config - operation *OperationConfig - serverConnection common.StashServerConnection - args []*models.PluginArgInput - gqlHandler http.HandlerFunc + plugin *Config + operation *OperationConfig + input common.PluginInput + gqlHandler http.Handler progress chan float64 result *common.PluginOutput @@ -48,12 +46,3 @@ func (t *pluginTask) GetResult() *common.PluginOutput { func (t *pluginTask) createTask() Task { return t.plugin.Interface.getTaskBuilder().build(*t) } - -func (t *pluginTask) buildPluginInput() common.PluginInput { - args := applyDefaultArgs(t.args, t.operation.DefaultArgs) - t.serverConnection.PluginDir = t.plugin.getConfigPath() - return common.PluginInput{ - ServerConnection: t.serverConnection, - Args: toPluginArgs(args), - } -} diff --git a/pkg/session/session.go b/pkg/session/session.go new file mode 100644 index 000000000..55faa4282 --- /dev/null +++ b/pkg/session/session.go @@ -0,0 +1,240 @@ +package session + +import ( + "context" + "errors" + "net/http" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" + "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/manager/config" + "github.com/stashapp/stash/pkg/utils" +) + +type key int + +const ( + contextUser key = iota + contextVisitedPlugins +) + +const ( + userIDKey = "userID" + visitedPluginsKey = "visitedPlugins" +) + +const ( + ApiKeyHeader = "ApiKey" + ApiKeyParameter = "apikey" +) + +const ( + cookieName = "session" + usernameFormKey = "username" + passwordFormKey = "password" +) + +var ErrInvalidCredentials = errors.New("invalid username or password") +var ErrUnauthorized = errors.New("unauthorized") + +type Store struct { + sessionStore *sessions.CookieStore + config *config.Instance +} + +func NewStore(c *config.Instance) *Store { + ret := &Store{ + sessionStore: sessions.NewCookieStore(config.GetInstance().GetSessionStoreKey()), + config: c, + } + + ret.sessionStore.MaxAge(config.GetInstance().GetMaxSessionAge()) + + return ret +} + +func (s *Store) Login(w http.ResponseWriter, r *http.Request) error { + // ignore error - we want a new session regardless + newSession, _ := s.sessionStore.Get(r, cookieName) + + username := r.FormValue(usernameFormKey) + password := r.FormValue(passwordFormKey) + + // authenticate the user + if !config.GetInstance().ValidateCredentials(username, password) { + return ErrInvalidCredentials + } + + newSession.Values[userIDKey] = username + + err := newSession.Save(r, w) + if err != nil { + return err + } + + return nil +} + +func (s *Store) Logout(w http.ResponseWriter, r *http.Request) error { + session, err := s.sessionStore.Get(r, cookieName) + if err != nil { + return err + } + + delete(session.Values, userIDKey) + session.Options.MaxAge = -1 + + err = session.Save(r, w) + if err != nil { + return err + } + + return nil +} + +func (s *Store) GetSessionUserID(w http.ResponseWriter, r *http.Request) (string, error) { + session, err := s.sessionStore.Get(r, cookieName) + // ignore errors and treat as an empty user id, so that we handle expired + // cookie + if err != nil { + return "", nil + } + + if !session.IsNew { + val := session.Values[userIDKey] + + // refresh the cookie + err = session.Save(r, w) + if err != nil { + return "", err + } + + ret, _ := val.(string) + + return ret, nil + } + + return "", nil +} + +func SetCurrentUserID(ctx context.Context, userID string) context.Context { + return context.WithValue(ctx, contextUser, userID) +} + +// GetCurrentUserID gets the current user id from the provided context +func GetCurrentUserID(ctx context.Context) *string { + userCtxVal := ctx.Value(contextUser) + if userCtxVal != nil { + currentUser := userCtxVal.(string) + return ¤tUser + } + + return nil +} + +func (s *Store) VisitedPluginHandler() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // get the visited plugins from the cookie and set in the context + session, err := s.sessionStore.Get(r, cookieName) + + // ignore errors + if err == nil { + val := session.Values[visitedPluginsKey] + + visitedPlugins, _ := val.([]string) + + ctx := setVisitedPlugins(r.Context(), visitedPlugins) + r = r.WithContext(ctx) + } + + next.ServeHTTP(w, r) + }) + } +} + +func GetVisitedPlugins(ctx context.Context) []string { + ctxVal := ctx.Value(contextVisitedPlugins) + if ctxVal != nil { + return ctxVal.([]string) + } + + return nil +} + +func AddVisitedPlugin(ctx context.Context, pluginID string) context.Context { + curVal := GetVisitedPlugins(ctx) + curVal = utils.StrAppendUnique(curVal, pluginID) + return setVisitedPlugins(ctx, curVal) +} + +func setVisitedPlugins(ctx context.Context, visitedPlugins []string) context.Context { + return context.WithValue(ctx, contextVisitedPlugins, visitedPlugins) +} + +func (s *Store) createSessionCookie(username string) (*http.Cookie, error) { + session := sessions.NewSession(s.sessionStore, cookieName) + session.Values[userIDKey] = username + + encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, + s.sessionStore.Codecs...) + if err != nil { + return nil, err + } + + return sessions.NewCookie(session.Name(), encoded, session.Options), nil +} + +func (s *Store) MakePluginCookie(ctx context.Context) *http.Cookie { + currentUser := GetCurrentUserID(ctx) + visitedPlugins := GetVisitedPlugins(ctx) + + session := sessions.NewSession(s.sessionStore, cookieName) + if currentUser != nil { + session.Values[userIDKey] = *currentUser + } + + session.Values[visitedPluginsKey] = visitedPlugins + + encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, + s.sessionStore.Codecs...) + if err != nil { + logger.Errorf("error creating session cookie: %s", err.Error()) + return nil + } + + return sessions.NewCookie(session.Name(), encoded, session.Options) +} + +func (s *Store) Authenticate(w http.ResponseWriter, r *http.Request) (userID string, err error) { + c := s.config + + // translate api key into current user, if present + apiKey := r.Header.Get(ApiKeyHeader) + + // try getting the api key as a query parameter + if apiKey == "" { + apiKey = r.URL.Query().Get(ApiKeyParameter) + } + + if apiKey != "" { + // match against configured API and set userID to the + // configured username. In future, we'll want to + // get the username from the key. + if c.GetAPIKey() != apiKey { + return "", ErrUnauthorized + } + + userID = c.GetUsername() + } else { + // handle session + userID, err = s.GetSessionUserID(w, r) + } + + if err != nil { + return "", err + } + + return +} diff --git a/pkg/utils/context.go b/pkg/utils/context.go new file mode 100644 index 000000000..06427bb5b --- /dev/null +++ b/pkg/utils/context.go @@ -0,0 +1,28 @@ +package utils + +import ( + "context" + "time" +) + +type valueOnlyContext struct { + context.Context +} + +func (valueOnlyContext) Deadline() (deadline time.Time, ok bool) { + return +} + +func (valueOnlyContext) Done() <-chan struct{} { + return nil +} + +func (valueOnlyContext) Err() error { + return nil +} + +func ValueOnlyContext(ctx context.Context) context.Context { + return valueOnlyContext{ + ctx, + } +} diff --git a/pkg/utils/oshash_internal_test.go b/pkg/utils/oshash_internal_test.go index a1dc6e67c..0263bab8a 100644 --- a/pkg/utils/oshash_internal_test.go +++ b/pkg/utils/oshash_internal_test.go @@ -23,7 +23,7 @@ func TestOshashEmpty(t *testing.T) { func TestOshashCollisions(t *testing.T) { buf1 := []byte("this is dumb") buf2 := []byte("dumb is this") - var size = int64(len(buf1)) + size := int64(len(buf1)) head := make([]byte, chunkSize) tail1 := make([]byte, chunkSize) diff --git a/ui/v2.5/src/components/Changelog/versions/v080.md b/ui/v2.5/src/components/Changelog/versions/v080.md index 9b3404986..b79f4ca3d 100644 --- a/ui/v2.5/src/components/Changelog/versions/v080.md +++ b/ui/v2.5/src/components/Changelog/versions/v080.md @@ -1,4 +1,5 @@ ### ✨ New Features +* Added support for triggering plugin tasks during operations. ([#1452](https://github.com/stashapp/stash/pull/1452)) * Support Studio filter including child studios. ([#1397](https://github.com/stashapp/stash/pull/1397)) * Added support for tag aliases. ([#1412](https://github.com/stashapp/stash/pull/1412)) * Support embedded Javascript plugins. ([#1393](https://github.com/stashapp/stash/pull/1393)) diff --git a/ui/v2.5/src/components/Settings/SettingsPluginsPanel.tsx b/ui/v2.5/src/components/Settings/SettingsPluginsPanel.tsx index 8e92c18cd..ee3022183 100644 --- a/ui/v2.5/src/components/Settings/SettingsPluginsPanel.tsx +++ b/ui/v2.5/src/components/Settings/SettingsPluginsPanel.tsx @@ -1,9 +1,10 @@ import React from "react"; import { Button } from "react-bootstrap"; +import * as GQL from "src/core/generated-graphql"; import { mutateReloadPlugins, usePlugins } from "src/core/StashService"; import { useToast } from "src/hooks"; import { TextUtils } from "src/utils"; -import { Icon, LoadingIndicator } from "src/components/Shared"; +import { CollapseButton, Icon, LoadingIndicator } from "src/components/Shared"; export const SettingsPluginsPanel: React.FC = () => { const Toast = useToast(); @@ -33,13 +34,14 @@ export const SettingsPluginsPanel: React.FC = () => { function renderPlugins() { const elements = (data?.plugins ?? []).map((plugin) => (
{hh}
+