Speed up file-based auto-tag

Replaces the per-file SQL QueryForAutoTag prefilter with an in-memory
2-rune prefix index over performers/studios/tags, preloaded once at job
start. Also:

  - runs file processing through job.TaskQueue so scenes/images/
    galleries tag in parallel instead of one file at a time
  - keyset-paginates the query loop so batch N+1 doesn't pay the
    O(offset) scan past large tables
  - bulk-loads studio/tag aliases via a new optional AllAliasLoader
    interface, avoiding N+1 GetAliases calls during preload
  - caches compiled name regexps (same candidate names repeat across
    thousands of files)
  - hoists strings.ToLower(path) and allASCII(path) out of the per-
    candidate match loop
  - opens a fresh write txn per applied match instead of holding one
    for every tagger phase

Tagger gains *AtPath methods that own the cache + txn manager, letting
the task code stay slim.
This commit is contained in:
abdusalam.dihan 2026-04-19 19:46:46 +01:00
parent 443de78260
commit cd64433dc5
17 changed files with 1465 additions and 206 deletions

View file

@ -7,6 +7,7 @@ import (
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
type GalleryFinderUpdater interface {
@ -43,9 +44,11 @@ func getGalleryFileTagger(s *models.Gallery, cache *match.Cache) tagger {
}
}
// GalleryPerformers tags the provided gallery with performers whose name matches the gallery's path.
func GalleryPerformers(ctx context.Context, s *models.Gallery, rw GalleryPerformerUpdater, performerReader models.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getGalleryFileTagger(s, cache)
// GalleryPerformersAtPath tags the provided gallery with performers whose
// name matches the gallery's path. A fresh write txn is opened only when a
// match is applied.
func (tagger *Tagger) GalleryPerformersAtPath(ctx context.Context, s *models.Gallery, rw GalleryPerformerUpdater, performerReader models.PerformerAutoTagQueryer) error {
t := getGalleryFileTagger(s, tagger.Cache)
return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
if err := s.LoadPerformerIDs(ctx, rw); err != nil {
@ -57,7 +60,9 @@ func GalleryPerformers(ctx context.Context, s *models.Gallery, rw GalleryPerform
return false, nil
}
if err := gallery.AddPerformer(ctx, rw, s, otherID); err != nil {
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
return gallery.AddPerformer(ctx, rw, s, otherID)
}); err != nil {
return false, err
}
@ -65,25 +70,35 @@ func GalleryPerformers(ctx context.Context, s *models.Gallery, rw GalleryPerform
})
}
// GalleryStudios tags the provided gallery with the first studio whose name matches the gallery's path.
// GalleryStudiosAtPath tags the provided gallery with the first studio whose
// name matches the gallery's path.
//
// Gallerys will not be tagged if studio is already set.
func GalleryStudios(ctx context.Context, s *models.Gallery, rw GalleryFinderUpdater, studioReader models.StudioAutoTagQueryer, cache *match.Cache) error {
// Galleries will not be tagged if studio is already set.
func (tagger *Tagger) GalleryStudiosAtPath(ctx context.Context, s *models.Gallery, rw GalleryFinderUpdater, studioReader models.StudioAutoTagQueryer) error {
if s.StudioID != nil {
// don't modify
return nil
}
t := getGalleryFileTagger(s, cache)
t := getGalleryFileTagger(s, tagger.Cache)
return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
return addGalleryStudio(ctx, rw, s, otherID)
var added bool
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
var err error
added, err = addGalleryStudio(ctx, rw, s, otherID)
return err
}); err != nil {
return false, err
}
return added, nil
})
}
// GalleryTags tags the provided gallery with tags whose name matches the gallery's path.
func GalleryTags(ctx context.Context, s *models.Gallery, rw GalleryTagUpdater, tagReader models.TagAutoTagQueryer, cache *match.Cache) error {
t := getGalleryFileTagger(s, cache)
// GalleryTagsAtPath tags the provided gallery with tags whose name matches
// the gallery's path.
func (tagger *Tagger) GalleryTagsAtPath(ctx context.Context, s *models.Gallery, rw GalleryTagUpdater, tagReader models.TagAutoTagQueryer) error {
t := getGalleryFileTagger(s, tagger.Cache)
return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
if err := s.LoadTagIDs(ctx, rw); err != nil {
@ -95,7 +110,9 @@ func GalleryTags(ctx context.Context, s *models.Gallery, rw GalleryTagUpdater, t
return false, nil
}
if err := gallery.AddTag(ctx, rw, s, otherID); err != nil {
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
return gallery.AddTag(ctx, rw, s, otherID)
}); err != nil {
return false, err
}

View file

@ -68,7 +68,7 @@ func TestGalleryPerformers(t *testing.T) {
return galleryPartialsEqual(got, expected)
})
db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
}
gallery := models.Gallery{
@ -76,7 +76,8 @@ func TestGalleryPerformers(t *testing.T) {
Path: test.Path,
PerformerIDs: models.NewRelatedIDs([]int{}),
}
err := GalleryPerformers(testCtx, &gallery, db.Gallery, db.Performer, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.GalleryPerformersAtPath(testCtx, &gallery, db.Gallery, db.Performer)
assert.Nil(err)
db.AssertExpectations(t)
@ -114,14 +115,15 @@ func TestGalleryStudios(t *testing.T) {
return galleryPartialsEqual(got, expected)
})
db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
}
gallery := models.Gallery{
ID: galleryID,
Path: test.Path,
}
err := GalleryStudios(testCtx, &gallery, db.Gallery, db.Studio, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.GalleryStudiosAtPath(testCtx, &gallery, db.Gallery, db.Studio)
assert.Nil(err)
db.AssertExpectations(t)
@ -189,7 +191,7 @@ func TestGalleryTags(t *testing.T) {
return galleryPartialsEqual(got, expected)
})
db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
}
gallery := models.Gallery{
@ -197,7 +199,8 @@ func TestGalleryTags(t *testing.T) {
Path: test.Path,
TagIDs: models.NewRelatedIDs([]int{}),
}
err := GalleryTags(testCtx, &gallery, db.Gallery, db.Tag, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.GalleryTagsAtPath(testCtx, &gallery, db.Gallery, db.Tag)
assert.Nil(err)
db.AssertExpectations(t)

View file

@ -7,6 +7,7 @@ import (
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
type ImageFinderUpdater interface {
@ -34,9 +35,11 @@ func getImageFileTagger(s *models.Image, cache *match.Cache) tagger {
}
}
// ImagePerformers tags the provided image with performers whose name matches the image's path.
func ImagePerformers(ctx context.Context, s *models.Image, rw ImagePerformerUpdater, performerReader models.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getImageFileTagger(s, cache)
// ImagePerformersAtPath tags the provided image with performers whose name
// matches the image's path. A fresh write txn is opened only when a match is
// applied.
func (tagger *Tagger) ImagePerformersAtPath(ctx context.Context, s *models.Image, rw ImagePerformerUpdater, performerReader models.PerformerAutoTagQueryer) error {
t := getImageFileTagger(s, tagger.Cache)
return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
if err := s.LoadPerformerIDs(ctx, rw); err != nil {
@ -48,7 +51,9 @@ func ImagePerformers(ctx context.Context, s *models.Image, rw ImagePerformerUpda
return false, nil
}
if err := image.AddPerformer(ctx, rw, s, otherID); err != nil {
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
return image.AddPerformer(ctx, rw, s, otherID)
}); err != nil {
return false, err
}
@ -56,25 +61,35 @@ func ImagePerformers(ctx context.Context, s *models.Image, rw ImagePerformerUpda
})
}
// ImageStudios tags the provided image with the first studio whose name matches the image's path.
// ImageStudiosAtPath tags the provided image with the first studio whose
// name matches the image's path.
//
// Images will not be tagged if studio is already set.
func ImageStudios(ctx context.Context, s *models.Image, rw ImageFinderUpdater, studioReader models.StudioAutoTagQueryer, cache *match.Cache) error {
func (tagger *Tagger) ImageStudiosAtPath(ctx context.Context, s *models.Image, rw ImageFinderUpdater, studioReader models.StudioAutoTagQueryer) error {
if s.StudioID != nil {
// don't modify
return nil
}
t := getImageFileTagger(s, cache)
t := getImageFileTagger(s, tagger.Cache)
return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
return addImageStudio(ctx, rw, s, otherID)
var added bool
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
var err error
added, err = addImageStudio(ctx, rw, s, otherID)
return err
}); err != nil {
return false, err
}
return added, nil
})
}
// ImageTags tags the provided image with tags whose name matches the image's path.
func ImageTags(ctx context.Context, s *models.Image, rw ImageTagUpdater, tagReader models.TagAutoTagQueryer, cache *match.Cache) error {
t := getImageFileTagger(s, cache)
// ImageTagsAtPath tags the provided image with tags whose name matches the
// image's path.
func (tagger *Tagger) ImageTagsAtPath(ctx context.Context, s *models.Image, rw ImageTagUpdater, tagReader models.TagAutoTagQueryer) error {
t := getImageFileTagger(s, tagger.Cache)
return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
if err := s.LoadTagIDs(ctx, rw); err != nil {
@ -86,7 +101,9 @@ func ImageTags(ctx context.Context, s *models.Image, rw ImageTagUpdater, tagRead
return false, nil
}
if err := image.AddTag(ctx, rw, s, otherID); err != nil {
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
return image.AddTag(ctx, rw, s, otherID)
}); err != nil {
return false, err
}

View file

@ -65,7 +65,7 @@ func TestImagePerformers(t *testing.T) {
return imagePartialsEqual(got, expected)
})
db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
}
image := models.Image{
@ -73,7 +73,8 @@ func TestImagePerformers(t *testing.T) {
Path: test.Path,
PerformerIDs: models.NewRelatedIDs([]int{}),
}
err := ImagePerformers(testCtx, &image, db.Image, db.Performer, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.ImagePerformersAtPath(testCtx, &image, db.Image, db.Performer)
assert.Nil(err)
db.AssertExpectations(t)
@ -111,14 +112,15 @@ func TestImageStudios(t *testing.T) {
return imagePartialsEqual(got, expected)
})
db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
}
image := models.Image{
ID: imageID,
Path: test.Path,
}
err := ImageStudios(testCtx, &image, db.Image, db.Studio, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.ImageStudiosAtPath(testCtx, &image, db.Image, db.Studio)
assert.Nil(err)
db.AssertExpectations(t)
@ -186,7 +188,7 @@ func TestImageTags(t *testing.T) {
return imagePartialsEqual(got, expected)
})
db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
}
image := models.Image{
@ -194,7 +196,8 @@ func TestImageTags(t *testing.T) {
Path: test.Path,
TagIDs: models.NewRelatedIDs([]int{}),
}
err := ImageTags(testCtx, &image, db.Image, db.Tag, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.ImageTagsAtPath(testCtx, &image, db.Image, db.Tag)
assert.Nil(err)
db.AssertExpectations(t)

View file

@ -7,6 +7,7 @@ import (
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/txn"
)
type SceneFinderUpdater interface {
@ -34,9 +35,12 @@ func getSceneFileTagger(s *models.Scene, cache *match.Cache) tagger {
}
}
// ScenePerformers tags the provided scene with performers whose name matches the scene's path.
func ScenePerformers(ctx context.Context, s *models.Scene, rw ScenePerformerUpdater, performerReader models.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getSceneFileTagger(s, cache)
// ScenePerformersAtPath tags the provided scene with performers whose name
// matches the scene's path. The match phase runs using the current context
// (no outer write txn needed); a fresh write txn is opened only when a match
// is applied.
func (tagger *Tagger) ScenePerformersAtPath(ctx context.Context, s *models.Scene, rw ScenePerformerUpdater, performerReader models.PerformerAutoTagQueryer) error {
t := getSceneFileTagger(s, tagger.Cache)
return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
if err := s.LoadPerformerIDs(ctx, rw); err != nil {
@ -48,7 +52,9 @@ func ScenePerformers(ctx context.Context, s *models.Scene, rw ScenePerformerUpda
return false, nil
}
if err := scene.AddPerformer(ctx, rw, s, otherID); err != nil {
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
return scene.AddPerformer(ctx, rw, s, otherID)
}); err != nil {
return false, err
}
@ -56,25 +62,35 @@ func ScenePerformers(ctx context.Context, s *models.Scene, rw ScenePerformerUpda
})
}
// SceneStudios tags the provided scene with the first studio whose name matches the scene's path.
// SceneStudiosAtPath tags the provided scene with the first studio whose name
// matches the scene's path.
//
// Scenes will not be tagged if studio is already set.
func SceneStudios(ctx context.Context, s *models.Scene, rw SceneFinderUpdater, studioReader models.StudioAutoTagQueryer, cache *match.Cache) error {
func (tagger *Tagger) SceneStudiosAtPath(ctx context.Context, s *models.Scene, rw SceneFinderUpdater, studioReader models.StudioAutoTagQueryer) error {
if s.StudioID != nil {
// don't modify
return nil
}
t := getSceneFileTagger(s, cache)
t := getSceneFileTagger(s, tagger.Cache)
return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
return addSceneStudio(ctx, rw, s, otherID)
var added bool
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
var err error
added, err = addSceneStudio(ctx, rw, s, otherID)
return err
}); err != nil {
return false, err
}
return added, nil
})
}
// SceneTags tags the provided scene with tags whose name matches the scene's path.
func SceneTags(ctx context.Context, s *models.Scene, rw SceneTagUpdater, tagReader models.TagAutoTagQueryer, cache *match.Cache) error {
t := getSceneFileTagger(s, cache)
// SceneTagsAtPath tags the provided scene with tags whose name matches the
// scene's path.
func (tagger *Tagger) SceneTagsAtPath(ctx context.Context, s *models.Scene, rw SceneTagUpdater, tagReader models.TagAutoTagQueryer) error {
t := getSceneFileTagger(s, tagger.Cache)
return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
if err := s.LoadTagIDs(ctx, rw); err != nil {
@ -86,7 +102,9 @@ func SceneTags(ctx context.Context, s *models.Scene, rw SceneTagUpdater, tagRead
return false, nil
}
if err := scene.AddTag(ctx, rw, s, otherID); err != nil {
if err := txn.WithTxn(ctx, tagger.TxnManager, func(ctx context.Context) error {
return scene.AddTag(ctx, rw, s, otherID)
}); err != nil {
return false, err
}

View file

@ -204,10 +204,11 @@ func TestScenePerformers(t *testing.T) {
return scenePartialsEqual(got, expected)
})
db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
}
err := ScenePerformers(testCtx, &scene, db.Scene, db.Performer, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.ScenePerformersAtPath(testCtx, &scene, db.Scene, db.Performer)
assert.Nil(err)
db.AssertExpectations(t)
@ -247,14 +248,15 @@ func TestSceneStudios(t *testing.T) {
return scenePartialsEqual(got, expected)
})
db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
}
scene := models.Scene{
ID: sceneID,
Path: test.Path,
}
err := SceneStudios(testCtx, &scene, db.Scene, db.Studio, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.SceneStudiosAtPath(testCtx, &scene, db.Scene, db.Studio)
assert.Nil(err)
db.AssertExpectations(t)
@ -322,7 +324,7 @@ func TestSceneTags(t *testing.T) {
return scenePartialsEqual(got, expected)
})
db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
}
scene := models.Scene{
@ -330,7 +332,8 @@ func TestSceneTags(t *testing.T) {
Path: test.Path,
TagIDs: models.NewRelatedIDs([]int{}),
}
err := SceneTags(testCtx, &scene, db.Scene, db.Tag, nil)
tagger := &Tagger{TxnManager: db, Cache: nil}
err := tagger.SceneTagsAtPath(testCtx, &scene, db.Scene, db.Tag)
assert.Nil(err)
db.AssertExpectations(t)

View file

@ -6,10 +6,10 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/stashapp/stash/internal/autotag"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/job"
"github.com/stashapp/stash/pkg/logger"
@ -51,6 +51,35 @@ func (j *autoTagJob) isFileBasedAutoTag(input AutoTagMetadataInput) bool {
}
func (j *autoTagJob) autoTagFiles(ctx context.Context, progress *job.Progress, paths []string, performers, studios, tags bool) {
// Preload entity sets once. Each worker then matches against the
// in-memory set instead of paying a QueryForAutoTag roundtrip per file.
r := j.repository
preloadBegin := time.Now()
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
if performers {
if err := j.cache.PreloadPerformers(ctx, r.Performer); err != nil {
return fmt.Errorf("preloading performers: %w", err)
}
}
if studios {
if err := j.cache.PreloadStudios(ctx, r.Studio); err != nil {
return fmt.Errorf("preloading studios: %w", err)
}
}
if tags {
if err := j.cache.PreloadTags(ctx, r.Tag); err != nil {
return fmt.Errorf("preloading tags: %w", err)
}
}
return nil
}); err != nil {
if !job.IsCancelled(ctx) {
logger.Errorf("auto-tag preload error: %v", err)
}
return
}
logger.Infof("Preloaded auto-tag entities in %s", time.Since(preloadBegin))
t := autoTagFilesTask{
paths: paths,
performers: performers,
@ -545,25 +574,40 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context) {
return
}
logger.Info("Auto-tagging scenes...")
workers := config.GetInstance().GetParallelTasksWithAutoDetection()
logger.Infof("Auto-tagging scenes (workers=%d)...", workers)
batchSize := 1000
const batchSize = 1000
const queueSize = batchSize * 4
findFilter := models.BatchFindFilter(batchSize)
findFilter := models.KeysetFindFilter(batchSize)
sceneFilter := t.makeSceneFilter()
r := t.repository
more := true
for more {
taskQueue := job.NewTaskQueue(ctx, t.progress, queueSize, workers)
defer taskQueue.Close()
var lastID, processed int
for {
filter := sceneFilter
if lastID != 0 {
filter = &models.SceneFilterType{
ID: &models.IntCriterionInput{
Value: lastID,
Modifier: models.CriterionModifierGreaterThan,
},
}
filter.And = sceneFilter
}
var scenes []*models.Scene
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error
scenes, err = scene.Query(ctx, r.Scene, sceneFilter, findFilter)
scenes, err = scene.Query(ctx, r.Scene, filter, findFilter)
return err
}); err != nil {
if !job.IsCancelled(ctx) {
logger.Errorf("error querying scenes for auto-tag: %w", err)
logger.Errorf("error querying scenes for auto-tag: %v", err)
}
return
}
@ -573,32 +617,28 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context) {
logger.Info("Stopping auto-tag due to user request")
return
}
tt := autoTagSceneTask{
repository: r,
scene: ss,
performers: t.performers,
studios: t.studios,
tags: t.tags,
cache: t.cache,
}
var wg sync.WaitGroup
wg.Add(1)
go tt.Start(ctx, &wg)
wg.Wait()
t.progress.Increment()
taskQueue.Add(fmt.Sprintf("Auto-tagging %s", ss.DisplayName()), func(ctx context.Context) {
tt := autoTagSceneTask{
repository: r,
scene: ss,
performers: t.performers,
studios: t.studios,
tags: t.tags,
cache: t.cache,
}
tt.Start(ctx)
t.progress.Increment()
})
}
if len(scenes) != batchSize {
more = false
} else {
*findFilter.Page++
if len(scenes) < batchSize {
return
}
if *findFilter.Page%10 == 1 {
logger.Infof("Processed %d scenes...", (*findFilter.Page-1)*batchSize)
}
lastID = scenes[len(scenes)-1].ID
processed += len(scenes)
if processed%(batchSize*10) == 0 {
logger.Infof("Processed %d scenes...", processed)
}
}
}
@ -608,25 +648,40 @@ func (t *autoTagFilesTask) processImages(ctx context.Context) {
return
}
logger.Info("Auto-tagging images...")
workers := config.GetInstance().GetParallelTasksWithAutoDetection()
logger.Infof("Auto-tagging images (workers=%d)...", workers)
batchSize := 1000
const batchSize = 1000
const queueSize = batchSize * 4
findFilter := models.BatchFindFilter(batchSize)
findFilter := models.KeysetFindFilter(batchSize)
imageFilter := t.makeImageFilter()
r := t.repository
more := true
for more {
taskQueue := job.NewTaskQueue(ctx, t.progress, queueSize, workers)
defer taskQueue.Close()
var lastID, processed int
for {
filter := imageFilter
if lastID != 0 {
filter = &models.ImageFilterType{
ID: &models.IntCriterionInput{
Value: lastID,
Modifier: models.CriterionModifierGreaterThan,
},
}
filter.And = imageFilter
}
var images []*models.Image
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error
images, err = image.Query(ctx, r.Image, imageFilter, findFilter)
images, err = image.Query(ctx, r.Image, filter, findFilter)
return err
}); err != nil {
if !job.IsCancelled(ctx) {
logger.Errorf("error querying images for auto-tag: %w", err)
logger.Errorf("error querying images for auto-tag: %v", err)
}
return
}
@ -636,32 +691,28 @@ func (t *autoTagFilesTask) processImages(ctx context.Context) {
logger.Info("Stopping auto-tag due to user request")
return
}
tt := autoTagImageTask{
repository: t.repository,
image: ss,
performers: t.performers,
studios: t.studios,
tags: t.tags,
cache: t.cache,
}
var wg sync.WaitGroup
wg.Add(1)
go tt.Start(ctx, &wg)
wg.Wait()
t.progress.Increment()
taskQueue.Add(fmt.Sprintf("Auto-tagging %s", ss.DisplayName()), func(ctx context.Context) {
tt := autoTagImageTask{
repository: r,
image: ss,
performers: t.performers,
studios: t.studios,
tags: t.tags,
cache: t.cache,
}
tt.Start(ctx)
t.progress.Increment()
})
}
if len(images) != batchSize {
more = false
} else {
*findFilter.Page++
if len(images) < batchSize {
return
}
if *findFilter.Page%10 == 1 {
logger.Infof("Processed %d images...", (*findFilter.Page-1)*batchSize)
}
lastID = images[len(images)-1].ID
processed += len(images)
if processed%(batchSize*10) == 0 {
logger.Infof("Processed %d images...", processed)
}
}
}
@ -671,25 +722,40 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context) {
return
}
logger.Info("Auto-tagging galleries...")
workers := config.GetInstance().GetParallelTasksWithAutoDetection()
logger.Infof("Auto-tagging galleries (workers=%d)...", workers)
batchSize := 1000
const batchSize = 1000
const queueSize = batchSize * 4
findFilter := models.BatchFindFilter(batchSize)
findFilter := models.KeysetFindFilter(batchSize)
galleryFilter := t.makeGalleryFilter()
r := t.repository
more := true
for more {
taskQueue := job.NewTaskQueue(ctx, t.progress, queueSize, workers)
defer taskQueue.Close()
var lastID, processed int
for {
filter := galleryFilter
if lastID != 0 {
filter = &models.GalleryFilterType{
ID: &models.IntCriterionInput{
Value: lastID,
Modifier: models.CriterionModifierGreaterThan,
},
}
filter.And = galleryFilter
}
var galleries []*models.Gallery
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error
galleries, _, err = r.Gallery.Query(ctx, galleryFilter, findFilter)
galleries, _, err = r.Gallery.Query(ctx, filter, findFilter)
return err
}); err != nil {
if !job.IsCancelled(ctx) {
logger.Errorf("error querying galleries for auto-tag: %w", err)
logger.Errorf("error querying galleries for auto-tag: %v", err)
}
return
}
@ -699,32 +765,28 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context) {
logger.Info("Stopping auto-tag due to user request")
return
}
tt := autoTagGalleryTask{
repository: t.repository,
gallery: ss,
performers: t.performers,
studios: t.studios,
tags: t.tags,
cache: t.cache,
}
var wg sync.WaitGroup
wg.Add(1)
go tt.Start(ctx, &wg)
wg.Wait()
t.progress.Increment()
taskQueue.Add(fmt.Sprintf("Auto-tagging %s", ss.DisplayName()), func(ctx context.Context) {
tt := autoTagGalleryTask{
repository: r,
gallery: ss,
performers: t.performers,
studios: t.studios,
tags: t.tags,
cache: t.cache,
}
tt.Start(ctx)
t.progress.Increment()
})
}
if len(galleries) != batchSize {
more = false
} else {
*findFilter.Page++
if len(galleries) < batchSize {
return
}
if *findFilter.Page%10 == 1 {
logger.Infof("Processed %d galleries...", (*findFilter.Page-1)*batchSize)
}
lastID = galleries[len(galleries)-1].ID
processed += len(galleries)
if processed%(batchSize*10) == 0 {
logger.Infof("Processed %d galleries...", processed)
}
}
}
@ -763,27 +825,28 @@ type autoTagSceneTask struct {
cache *match.Cache
}
func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
func (t *autoTagSceneTask) Start(ctx context.Context) {
r := t.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
tagger := &autotag.Tagger{TxnManager: r.TxnManager, Cache: t.cache}
if err := r.WithDB(ctx, func(ctx context.Context) error {
if t.scene.Path == "" {
// nothing to do
return nil
}
if t.performers {
if err := autotag.ScenePerformers(ctx, t.scene, r.Scene, r.Performer, t.cache); err != nil {
if err := tagger.ScenePerformersAtPath(ctx, t.scene, r.Scene, r.Performer); err != nil {
return fmt.Errorf("tagging scene performers for %s: %v", t.scene.DisplayName(), err)
}
}
if t.studios {
if err := autotag.SceneStudios(ctx, t.scene, r.Scene, r.Studio, t.cache); err != nil {
if err := tagger.SceneStudiosAtPath(ctx, t.scene, r.Scene, r.Studio); err != nil {
return fmt.Errorf("tagging scene studio for %s: %v", t.scene.DisplayName(), err)
}
}
if t.tags {
if err := autotag.SceneTags(ctx, t.scene, r.Scene, r.Tag, t.cache); err != nil {
if err := tagger.SceneTagsAtPath(ctx, t.scene, r.Scene, r.Tag); err != nil {
return fmt.Errorf("tagging scene tags for %s: %v", t.scene.DisplayName(), err)
}
}
@ -807,22 +870,23 @@ type autoTagImageTask struct {
cache *match.Cache
}
func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
func (t *autoTagImageTask) Start(ctx context.Context) {
r := t.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
tagger := &autotag.Tagger{TxnManager: r.TxnManager, Cache: t.cache}
if err := r.WithDB(ctx, func(ctx context.Context) error {
if t.performers {
if err := autotag.ImagePerformers(ctx, t.image, r.Image, r.Performer, t.cache); err != nil {
if err := tagger.ImagePerformersAtPath(ctx, t.image, r.Image, r.Performer); err != nil {
return fmt.Errorf("tagging image performers for %s: %v", t.image.DisplayName(), err)
}
}
if t.studios {
if err := autotag.ImageStudios(ctx, t.image, r.Image, r.Studio, t.cache); err != nil {
if err := tagger.ImageStudiosAtPath(ctx, t.image, r.Image, r.Studio); err != nil {
return fmt.Errorf("tagging image studio for %s: %v", t.image.DisplayName(), err)
}
}
if t.tags {
if err := autotag.ImageTags(ctx, t.image, r.Image, r.Tag, t.cache); err != nil {
if err := tagger.ImageTagsAtPath(ctx, t.image, r.Image, r.Tag); err != nil {
return fmt.Errorf("tagging image tags for %s: %v", t.image.DisplayName(), err)
}
}
@ -846,22 +910,23 @@ type autoTagGalleryTask struct {
cache *match.Cache
}
func (t *autoTagGalleryTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
func (t *autoTagGalleryTask) Start(ctx context.Context) {
r := t.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
tagger := &autotag.Tagger{TxnManager: r.TxnManager, Cache: t.cache}
if err := r.WithDB(ctx, func(ctx context.Context) error {
if t.performers {
if err := autotag.GalleryPerformers(ctx, t.gallery, r.Gallery, r.Performer, t.cache); err != nil {
if err := tagger.GalleryPerformersAtPath(ctx, t.gallery, r.Gallery, r.Performer); err != nil {
return fmt.Errorf("tagging gallery performers for %s: %v", t.gallery.DisplayName(), err)
}
}
if t.studios {
if err := autotag.GalleryStudios(ctx, t.gallery, r.Gallery, r.Studio, t.cache); err != nil {
if err := tagger.GalleryStudiosAtPath(ctx, t.gallery, r.Gallery, r.Studio); err != nil {
return fmt.Errorf("tagging gallery studio for %s: %v", t.gallery.DisplayName(), err)
}
}
if t.tags {
if err := autotag.GalleryTags(ctx, t.gallery, r.Gallery, r.Tag, t.cache); err != nil {
if err := tagger.GalleryTagsAtPath(ctx, t.gallery, r.Gallery, r.Tag); err != nil {
return fmt.Errorf("tagging gallery tags for %s: %v", t.gallery.DisplayName(), err)
}
}

View file

@ -2,17 +2,362 @@ package match
import (
"context"
"regexp"
"strings"
"sync"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/stashapp/stash/pkg/models"
)
// regexpCacheSize bounds the compiled-regexp LRU. Sized generously so that
// for realistic libraries (up to ~100 k performers/studios/tags combined,
// each optionally with a unicode and ASCII variant) the cache never evicts
// during one auto-tag job. LRU is used for consistency with
// pkg/sqlite/regex.go; eviction only kicks in for libraries far past that.
const regexpCacheSize = 200_000
const singleFirstCharacterRegex = `^[\p{L}][.\-_ ]`
// Cache is used to cache queries that should not change across an autotag process.
var singleFirstCharacterRE = regexp.MustCompile(singleFirstCharacterRegex)
// firstTwoRunesLower returns the first two runes of s, lowercased. Returns
// "" if s has fewer than two runes. Mirrors what getPathWords produces for
// path words, so the two can be compared as index keys.
func firstTwoRunesLower(s string) string {
lower := strings.ToLower(s)
runes := []rune(lower)
if len(runes) < 2 {
return ""
}
return string(runes[0:2])
}
// performerCandidates returns the set of preloaded performers that should
// be regex-checked for the given path words. Mirrors the SQL
// `name LIKE 'xx%' OR name LIKE 'yy%'` prefilter, plus always-check
// performers whose name begins with a single-letter word (which the 2-rune
// prefix lookup can't reach).
func (c *Cache) performerCandidates(pathWords []string) []*models.Performer {
if len(c.performerByPrefix) == 0 && len(c.performerAlwaysCheck) == 0 {
return nil
}
seen := make(map[int]bool, len(pathWords)*2)
out := make([]*models.Performer, 0, len(pathWords)*2)
for _, w := range pathWords {
key := strings.ToLower(w)
for _, p := range c.performerByPrefix[key] {
if !seen[p.ID] {
seen[p.ID] = true
out = append(out, p)
}
}
}
for _, p := range c.performerAlwaysCheck {
if !seen[p.ID] {
seen[p.ID] = true
out = append(out, p)
}
}
return out
}
func (c *Cache) studioCandidates(pathWords []string) []cachedStudio {
if len(c.studioByPrefix) == 0 && len(c.studioAlwaysCheck) == 0 {
return nil
}
seen := make(map[int]bool, len(pathWords)*2)
out := make([]cachedStudio, 0, len(pathWords)*2)
for _, w := range pathWords {
key := strings.ToLower(w)
for _, s := range c.studioByPrefix[key] {
if !seen[s.Studio.ID] {
seen[s.Studio.ID] = true
out = append(out, s)
}
}
}
for _, s := range c.studioAlwaysCheck {
if !seen[s.Studio.ID] {
seen[s.Studio.ID] = true
out = append(out, s)
}
}
return out
}
func (c *Cache) tagCandidates(pathWords []string) []cachedTag {
if len(c.tagByPrefix) == 0 && len(c.tagAlwaysCheck) == 0 {
return nil
}
seen := make(map[int]bool, len(pathWords)*2)
out := make([]cachedTag, 0, len(pathWords)*2)
for _, w := range pathWords {
key := strings.ToLower(w)
for _, t := range c.tagByPrefix[key] {
if !seen[t.Tag.ID] {
seen[t.Tag.ID] = true
out = append(out, t)
}
}
}
for _, t := range c.tagAlwaysCheck {
if !seen[t.Tag.ID] {
seen[t.Tag.ID] = true
out = append(out, t)
}
}
return out
}
// Cache is used to cache queries that should not change across an autotag
// process. Safe for concurrent use by multiple goroutines.
type Cache struct {
performersOnce sync.Once
performersErr error
studiosOnce sync.Once
studiosErr error
tagsOnce sync.Once
tagsErr error
singleCharPerformers []*models.Performer
singleCharStudios []*models.Studio
singleCharTags []*models.Tag
// Preloaded candidate sets. When populated (via PreloadX), the
// PathTo* functions skip the per-path QueryForAutoTag DB roundtrip
// and consult the in-memory prefix index instead. Nil means
// "not preloaded, fall back to the old SQL-prefilter path".
allPerformers []*models.Performer
allStudios []cachedStudio
allTags []cachedTag
// Prefix indexes built at preload time. Map key is the first two
// lowercased runes of name (or alias, for studios/tags). The
// alwaysCandidate slice holds entries whose first "word" is a
// single letter — they wouldn't be reached by 2-rune path word
// lookup, so they must always be checked (mirroring the existing
// single-letter regex query).
performerByPrefix map[string][]*models.Performer
performerAlwaysCheck []*models.Performer
studioByPrefix map[string][]cachedStudio
studioAlwaysCheck []cachedStudio
tagByPrefix map[string][]cachedTag
tagAlwaysCheck []cachedTag
regexpCacheOnce sync.Once
regexpCache *lru.Cache[regexpCacheKey, *regexp.Regexp]
}
// cachedStudio bundles a studio with its aliases so PathToStudio can match
// against both without an N+1 GetAliases query.
type cachedStudio struct {
Studio *models.Studio
Aliases []string
}
// cachedTag bundles a tag with its aliases so PathToTags can match against
// both without an N+1 GetAliases query.
type cachedTag struct {
Tag *models.Tag
Aliases []string
}
// PreloadPerformers loads all non-ignored performers into the cache and
// builds a 2-rune prefix index so subsequent PathToPerformers calls can
// skip both the per-path QueryForAutoTag and the per-candidate regex
// when no prefix matches.
func (c *Cache) PreloadPerformers(ctx context.Context, reader models.PerformerAutoTagQueryer) error {
if c.allPerformers != nil {
return nil
}
ignoreAutoTag := false
perPage := -1
perfs, _, err := reader.Query(ctx, &models.PerformerFilterType{
IgnoreAutoTag: &ignoreAutoTag,
}, &models.FindFilterType{PerPage: &perPage})
if err != nil {
return err
}
if perfs == nil {
perfs = []*models.Performer{}
}
c.allPerformers = perfs
c.performerByPrefix = make(map[string][]*models.Performer, len(perfs))
for _, p := range perfs {
if prefix := firstTwoRunesLower(p.Name); prefix != "" {
c.performerByPrefix[prefix] = append(c.performerByPrefix[prefix], p)
}
if singleFirstCharacterRE.MatchString(p.Name) {
c.performerAlwaysCheck = append(c.performerAlwaysCheck, p)
}
}
return nil
}
// loadAllAliases loads aliases for the given ids. Uses the reader's bulk
// GetAllAliases method when available (avoiding the N+1 per-id roundtrip);
// otherwise falls back to per-id GetAliases.
func loadAllAliases(ctx context.Context, reader models.AliasLoader, ids []int) (map[int][]string, error) {
if bulk, ok := reader.(models.AllAliasLoader); ok {
return bulk.GetAllAliases(ctx)
}
ret := make(map[int][]string, len(ids))
for _, id := range ids {
a, err := reader.GetAliases(ctx, id)
if err != nil {
return nil, err
}
if len(a) > 0 {
ret[id] = a
}
}
return ret, nil
}
// PreloadStudios loads all non-ignored studios plus their aliases into the
// cache and builds a 2-rune prefix index (over names AND aliases, mirroring
// the SQL LEFT JOIN on studio_aliases).
func (c *Cache) PreloadStudios(ctx context.Context, reader models.StudioAutoTagQueryer) error {
if c.allStudios != nil {
return nil
}
ignoreAutoTag := false
perPage := -1
studios, _, err := reader.Query(ctx, &models.StudioFilterType{
IgnoreAutoTag: &ignoreAutoTag,
}, &models.FindFilterType{PerPage: &perPage})
if err != nil {
return err
}
ids := make([]int, len(studios))
for i, s := range studios {
ids[i] = s.ID
}
aliasesByID, err := loadAllAliases(ctx, reader, ids)
if err != nil {
return err
}
out := make([]cachedStudio, len(studios))
c.studioByPrefix = make(map[string][]cachedStudio, len(studios))
seenPerPrefix := make(map[string]map[int]bool)
for i, s := range studios {
aliases := aliasesByID[s.ID]
cs := cachedStudio{Studio: s, Aliases: aliases}
out[i] = cs
c.indexByPrefix(s.ID, s.Name, aliases, seenPerPrefix, func(prefix string) {
c.studioByPrefix[prefix] = append(c.studioByPrefix[prefix], cs)
})
if hasSingleFirstChar(s.Name, aliases) {
c.studioAlwaysCheck = append(c.studioAlwaysCheck, cs)
}
}
c.allStudios = out
return nil
}
// PreloadTags loads all non-ignored tags plus their aliases into the cache
// and builds a 2-rune prefix index (over names AND aliases).
func (c *Cache) PreloadTags(ctx context.Context, reader models.TagAutoTagQueryer) error {
if c.allTags != nil {
return nil
}
ignoreAutoTag := false
perPage := -1
tags, _, err := reader.Query(ctx, &models.TagFilterType{
IgnoreAutoTag: &ignoreAutoTag,
}, &models.FindFilterType{PerPage: &perPage})
if err != nil {
return err
}
ids := make([]int, len(tags))
for i, t := range tags {
ids[i] = t.ID
}
aliasesByID, err := loadAllAliases(ctx, reader, ids)
if err != nil {
return err
}
out := make([]cachedTag, len(tags))
c.tagByPrefix = make(map[string][]cachedTag, len(tags))
seenPerPrefix := make(map[string]map[int]bool)
for i, t := range tags {
aliases := aliasesByID[t.ID]
ct := cachedTag{Tag: t, Aliases: aliases}
out[i] = ct
c.indexByPrefix(t.ID, t.Name, aliases, seenPerPrefix, func(prefix string) {
c.tagByPrefix[prefix] = append(c.tagByPrefix[prefix], ct)
})
if hasSingleFirstChar(t.Name, aliases) {
c.tagAlwaysCheck = append(c.tagAlwaysCheck, ct)
}
}
c.allTags = out
return nil
}
// indexByPrefix records the entity under every distinct 2-rune prefix of
// its name/aliases (deduping so a name+alias that share a prefix bucket
// only add the entity once).
func (c *Cache) indexByPrefix(id int, name string, aliases []string, seen map[string]map[int]bool, add func(prefix string)) {
emit := func(s string) {
prefix := firstTwoRunesLower(s)
if prefix == "" {
return
}
if seen[prefix] == nil {
seen[prefix] = make(map[int]bool)
}
if !seen[prefix][id] {
seen[prefix][id] = true
add(prefix)
}
}
emit(name)
for _, a := range aliases {
emit(a)
}
}
func hasSingleFirstChar(name string, aliases []string) bool {
if singleFirstCharacterRE.MatchString(name) {
return true
}
for _, a := range aliases {
if singleFirstCharacterRE.MatchString(a) {
return true
}
}
return false
}
type regexpCacheKey struct {
name string
useUnicode bool
}
// nameRegexp returns a compiled regexp for the given name, caching the
// result so repeated autotag calls across many files don't pay the
// compile cost each time.
func (c *Cache) nameRegexp(name string, useUnicode bool) *regexp.Regexp {
if c == nil {
return nameToRegexp(name, useUnicode)
}
c.regexpCacheOnce.Do(func() {
c.regexpCache, _ = lru.New[regexpCacheKey, *regexp.Regexp](regexpCacheSize)
})
key := regexpCacheKey{name: name, useUnicode: useUnicode}
if r, ok := c.regexpCache.Get(key); ok {
return r
}
r := nameToRegexp(name, useUnicode)
c.regexpCache.Add(key, r)
return r
}
// getSingleLetterPerformers returns all performers with names that start with single character words.
@ -25,7 +370,7 @@ func getSingleLetterPerformers(ctx context.Context, c *Cache, reader models.Perf
c = &Cache{}
}
if c.singleCharPerformers == nil {
c.performersOnce.Do(func() {
pp := -1
performers, _, err := reader.Query(ctx, &models.PerformerFilterType{
Name: &models.StringCriterionInput{
@ -37,18 +382,18 @@ func getSingleLetterPerformers(ctx context.Context, c *Cache, reader models.Perf
})
if err != nil {
return nil, err
c.performersErr = err
return
}
if len(performers) == 0 {
// make singleWordPerformers not nil
c.singleCharPerformers = make([]*models.Performer, 0)
} else {
c.singleCharPerformers = performers
}
}
})
return c.singleCharPerformers, nil
return c.singleCharPerformers, c.performersErr
}
// getSingleLetterStudios returns all studios with names that start with single character words.
@ -58,7 +403,7 @@ func getSingleLetterStudios(ctx context.Context, c *Cache, reader models.StudioA
c = &Cache{}
}
if c.singleCharStudios == nil {
c.studiosOnce.Do(func() {
pp := -1
studios, _, err := reader.Query(ctx, &models.StudioFilterType{
Name: &models.StringCriterionInput{
@ -70,18 +415,18 @@ func getSingleLetterStudios(ctx context.Context, c *Cache, reader models.StudioA
})
if err != nil {
return nil, err
c.studiosErr = err
return
}
if len(studios) == 0 {
// make singleWordStudios not nil
c.singleCharStudios = make([]*models.Studio, 0)
} else {
c.singleCharStudios = studios
}
}
})
return c.singleCharStudios, nil
return c.singleCharStudios, c.studiosErr
}
// getSingleLetterTags returns all tags with names that start with single character words.
@ -91,7 +436,7 @@ func getSingleLetterTags(ctx context.Context, c *Cache, reader models.TagAutoTag
c = &Cache{}
}
if c.singleCharTags == nil {
c.tagsOnce.Do(func() {
pp := -1
tags, _, err := reader.Query(ctx, &models.TagFilterType{
Name: &models.StringCriterionInput{
@ -111,16 +456,16 @@ func getSingleLetterTags(ctx context.Context, c *Cache, reader models.TagAutoTag
})
if err != nil {
return nil, err
c.tagsErr = err
return
}
if len(tags) == 0 {
// make singleWordTags not nil
c.singleCharTags = make([]*models.Tag, 0)
} else {
c.singleCharTags = tags
}
}
})
return c.singleCharTags, nil
return c.singleCharTags, c.tagsErr
}

204
pkg/match/cache_test.go Normal file
View file

@ -0,0 +1,204 @@
package match
import (
"context"
"slices"
"testing"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
)
func TestFirstTwoRunesLower(t *testing.T) {
t.Parallel()
tests := []struct {
in string
want string
}{
{"alice smith", "al"},
{"ALICE", "al"},
{"Àbc", "àb"},
{"伏字 name", "伏字"},
{"ab", "ab"},
{"a", ""}, // single rune -> no prefix
{"", ""}, // empty -> no prefix
{"X Man", "x "}, // space is preserved in 2-rune prefix
}
for _, tt := range tests {
t.Run(tt.in, func(t *testing.T) {
t.Parallel()
if got := firstTwoRunesLower(tt.in); got != tt.want {
t.Errorf("firstTwoRunesLower(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestCacheNameRegexpCaches(t *testing.T) {
t.Parallel()
c := &Cache{}
r1 := c.nameRegexp("alice smith", true)
r2 := c.nameRegexp("alice smith", true)
if r1 != r2 {
t.Error("expected cached regexp to be reused across calls")
}
// Different useUnicode flag -> different cached regexp.
r3 := c.nameRegexp("alice smith", false)
if r3 == r1 {
t.Error("expected ASCII and unicode variants to be distinct cached entries")
}
// Nil cache must still return a valid regexp, just uncached.
var nilCache *Cache
if got := nilCache.nameRegexp("alice smith", true); got == nil {
t.Error("nil cache should still return a regexp")
}
}
func TestPreloadPerformersBuildsIndex(t *testing.T) {
t.Parallel()
alice := &models.Performer{ID: 1, Name: "Alice Smith"}
bob := &models.Performer{ID: 2, Name: "bob jones"}
xman := &models.Performer{ID: 3, Name: "X Man"}
ignored := &models.Performer{ID: 4, Name: "ignored", IgnoreAutoTag: true}
performers := []*models.Performer{alice, bob, xman, ignored}
db := mocks.NewDatabase()
primePerformerMock(db.Performer, performers)
c := &Cache{}
if err := c.PreloadPerformers(context.Background(), db.Performer); err != nil {
t.Fatalf("PreloadPerformers: %v", err)
}
// allPerformers excludes IgnoreAutoTag=true.
if got := len(c.allPerformers); got != 3 {
t.Errorf("allPerformers len = %d, want 3 (ignored must be excluded)", got)
}
// Prefix "al" -> alice, "bo" -> bob, "x " -> xman.
assertBucket := func(prefix string, wantIDs []int) {
t.Helper()
var gotIDs []int
for _, p := range c.performerByPrefix[prefix] {
gotIDs = append(gotIDs, p.ID)
}
slices.Sort(gotIDs)
if !slices.Equal(gotIDs, wantIDs) {
t.Errorf("bucket %q = %v, want %v", prefix, gotIDs, wantIDs)
}
}
assertBucket("al", []int{1})
assertBucket("bo", []int{2})
assertBucket("x ", []int{3})
// Single-letter-first-word performer must also be in alwaysCheck.
var alwaysIDs []int
for _, p := range c.performerAlwaysCheck {
alwaysIDs = append(alwaysIDs, p.ID)
}
if !slices.Equal(alwaysIDs, []int{3}) {
t.Errorf("alwaysCheck IDs = %v, want [3]", alwaysIDs)
}
// Idempotent: second call is a no-op.
if err := c.PreloadPerformers(context.Background(), db.Performer); err != nil {
t.Fatalf("second PreloadPerformers: %v", err)
}
if got := len(c.allPerformers); got != 3 {
t.Errorf("after idempotent call allPerformers len = %d, want 3", got)
}
}
func TestPreloadStudiosIndexesAliasPrefixes(t *testing.T) {
t.Parallel()
// Name "Acme" shares no prefix with alias "Widgets" — both must be
// reachable by their own 2-rune prefix.
s := &models.Studio{ID: 1, Name: "Acme Corp"}
ignored := &models.Studio{ID: 2, Name: "ignored", IgnoreAutoTag: true}
db := mocks.NewDatabase()
primeStudioMock(db.Studio, []*models.Studio{s, ignored}, map[int][]string{1: {"Widgets Inc"}})
c := &Cache{}
if err := c.PreloadStudios(context.Background(), db.Studio); err != nil {
t.Fatalf("PreloadStudios: %v", err)
}
if got := len(c.allStudios); got != 1 {
t.Errorf("allStudios len = %d, want 1 (ignored must be excluded)", got)
}
// "ac" bucket has the studio (via name), "wi" bucket has it (via alias).
if len(c.studioByPrefix["ac"]) != 1 || c.studioByPrefix["ac"][0].Studio.ID != 1 {
t.Errorf("bucket 'ac' should hold studio 1, got %+v", c.studioByPrefix["ac"])
}
if len(c.studioByPrefix["wi"]) != 1 || c.studioByPrefix["wi"][0].Studio.ID != 1 {
t.Errorf("bucket 'wi' should hold studio 1, got %+v", c.studioByPrefix["wi"])
}
}
func TestPreloadStudiosDedupsSharedPrefix(t *testing.T) {
t.Parallel()
// Name and two aliases all share prefix "pr"; the bucket must contain
// the studio exactly once.
s := &models.Studio{ID: 1, Name: "Primary"}
db := mocks.NewDatabase()
primeStudioMock(db.Studio, []*models.Studio{s}, map[int][]string{1: {"Primary Nick", "Primary Alt"}})
c := &Cache{}
if err := c.PreloadStudios(context.Background(), db.Studio); err != nil {
t.Fatal(err)
}
if got := len(c.studioByPrefix["pr"]); got != 1 {
t.Errorf("bucket 'pr' should have 1 entry, got %d", got)
}
}
func TestPreloadTagsIndexesAliasPrefixes(t *testing.T) {
t.Parallel()
db := mocks.NewDatabase()
primeTagMock(db.Tag, []*models.Tag{{ID: 1, Name: "documentary"}}, map[int][]string{1: {"film"}})
c := &Cache{}
if err := c.PreloadTags(context.Background(), db.Tag); err != nil {
t.Fatal(err)
}
if len(c.tagByPrefix["do"]) != 1 || c.tagByPrefix["do"][0].Tag.ID != 1 {
t.Errorf("bucket 'do' should hold tag 1")
}
if len(c.tagByPrefix["fi"]) != 1 || c.tagByPrefix["fi"][0].Tag.ID != 1 {
t.Errorf("bucket 'fi' should hold tag 1 (via alias)")
}
}
func TestCandidateLookupDedupesAcrossPathWords(t *testing.T) {
t.Parallel()
// A performer with name "alabama" falls in bucket "al". If a path has
// two words that both map to bucket "al" (e.g., from separate tokens),
// the candidate must appear exactly once.
p := &models.Performer{ID: 1, Name: "alabama"}
db := mocks.NewDatabase()
primePerformerMock(db.Performer, []*models.Performer{p})
c := &Cache{}
if err := c.PreloadPerformers(context.Background(), db.Performer); err != nil {
t.Fatal(err)
}
got := c.performerCandidates([]string{"al", "AL", "al"}) // same bucket three times
if len(got) != 1 {
t.Errorf("expected 1 candidate after dedup, got %d: %v", len(got), got)
}
}

View file

@ -94,6 +94,36 @@ func nameMatchesPath(name, path string) int {
return regexpMatchesPath(re, path)
}
// pathMatcher holds per-path precomputed values so they aren't recomputed
// for every candidate name. `allASCII` and `strings.ToLower(path)` were
// running once per (candidate, file) pair before; under a worker pool with
// thousands of candidates per file that was the dominant allocation.
type pathMatcher struct {
loweredPath string
useUnicode bool
cache *Cache
}
func newPathMatcher(path string, cache *Cache) pathMatcher {
return pathMatcher{
loweredPath: strings.ToLower(path),
useUnicode: !allASCII(path),
cache: cache,
}
}
// match returns the right-most index where name matches the path, or -1.
// Uses the cache's compiled-regexp table so each name is compiled once per
// autotag run instead of once per file.
func (m *pathMatcher) match(name string) int {
re := m.cache.nameRegexp(name, m.useUnicode)
found := re.FindAllStringIndex(m.loweredPath, -1)
if found == nil {
return -1
}
return found[len(found)-1][0]
}
// nameToRegexp compiles a regexp pattern to match paths from the given name.
// Set useUnicode to true if this regexp is to be used on any strings with unicode characters.
func nameToRegexp(name string, useUnicode bool) *regexp.Regexp {
@ -141,30 +171,47 @@ func getPerformers(ctx context.Context, words []string, performerReader models.P
return append(performers, swPerformers...), nil
}
// PathToPerformers returns performers whose name matches the given path.
//
// When the cache has been preloaded via Cache.PreloadPerformers, the full
// non-ignored performer set is already in memory and a 2-rune prefix index
// narrows candidates before regex-matching — this is the path the bulk
// file-based auto-tag job takes. Otherwise (e.g., the built-in scraper,
// which runs on a single scene per request) falls back to a per-call SQL
// prefilter via reader.QueryForAutoTag.
func PathToPerformers(ctx context.Context, path string, reader models.PerformerAutoTagQueryer, cache *Cache, trimExt bool) ([]*models.Performer, error) {
words := getPathWords(path, trimExt)
performers, err := getPerformers(ctx, words, reader, cache)
if err != nil {
return nil, err
var performers []*models.Performer
if cache != nil && cache.allPerformers != nil {
performers = cache.performerCandidates(getPathWords(path, trimExt))
} else {
words := getPathWords(path, trimExt)
var err error
performers, err = getPerformers(ctx, words, reader, cache)
if err != nil {
return nil, err
}
}
pm := newPathMatcher(path, cache)
var ret []*models.Performer
for _, p := range performers {
matches := false
if nameMatchesPath(p.Name, path) != -1 {
if pm.match(p.Name) != -1 {
matches = true
}
// TODO - disabled alias matching until we can get finer
// control over the matching
// control over the matching. To re-enable:
// - uncomment this block (fallback path)
// - have Cache.PreloadPerformers load aliases (e.g. via
// loadAllAliases, as PreloadStudios/PreloadTags do) and
// iterate them here in the preloaded path too
// if !matches {
// if err := p.LoadAliases(ctx, reader); err != nil {
// return nil, err
// }
// for _, alias := range p.Aliases.List() {
// if nameMatchesPath(alias, path) != -1 {
// if pm.match(alias) != -1 {
// matches = true
// break
// }
@ -193,13 +240,34 @@ func getStudios(ctx context.Context, words []string, reader models.StudioAutoTag
return append(studios, swStudios...), nil
}
// PathToStudio returns the Studio that matches the given path.
// Where multiple matching studios are found, the one that matches the latest
// position in the path is returned.
// PathToStudio returns the studio whose name or alias matches the given
// path. Where multiple match, the one matching the latest position wins.
//
// See PathToPerformers for the preloaded-vs-fallback behavior.
func PathToStudio(ctx context.Context, path string, reader models.StudioAutoTagQueryer, cache *Cache, trimExt bool) (*models.Studio, error) {
pm := newPathMatcher(path, cache)
if cache != nil && cache.allStudios != nil {
candidates := cache.studioCandidates(getPathWords(path, trimExt))
var ret *models.Studio
index := -1
for _, c := range candidates {
if matchIndex := pm.match(c.Studio.Name); matchIndex != -1 && matchIndex > index {
ret = c.Studio
index = matchIndex
}
for _, alias := range c.Aliases {
if matchIndex := pm.match(alias); matchIndex != -1 && matchIndex > index {
ret = c.Studio
index = matchIndex
}
}
}
return ret, nil
}
words := getPathWords(path, trimExt)
candidates, err := getStudios(ctx, words, reader, cache)
if err != nil {
return nil, err
}
@ -207,8 +275,7 @@ func PathToStudio(ctx context.Context, path string, reader models.StudioAutoTagQ
var ret *models.Studio
index := -1
for _, c := range candidates {
matchIndex := nameMatchesPath(c.Name, path)
if matchIndex != -1 && matchIndex > index {
if matchIndex := pm.match(c.Name); matchIndex != -1 && matchIndex > index {
ret = c
index = matchIndex
}
@ -217,10 +284,8 @@ func PathToStudio(ctx context.Context, path string, reader models.StudioAutoTagQ
if err != nil {
return nil, err
}
for _, alias := range aliases {
matchIndex = nameMatchesPath(alias, path)
if matchIndex != -1 && matchIndex > index {
if matchIndex := pm.match(alias); matchIndex != -1 && matchIndex > index {
ret = c
index = matchIndex
}
@ -244,10 +309,32 @@ func getTags(ctx context.Context, words []string, reader models.TagAutoTagQuerye
return append(tags, swTags...), nil
}
// PathToTags returns tags whose name or alias matches the given path.
//
// See PathToPerformers for the preloaded-vs-fallback behavior.
func PathToTags(ctx context.Context, path string, reader models.TagAutoTagQueryer, cache *Cache, trimExt bool) ([]*models.Tag, error) {
pm := newPathMatcher(path, cache)
if cache != nil && cache.allTags != nil {
candidates := cache.tagCandidates(getPathWords(path, trimExt))
var ret []*models.Tag
for _, c := range candidates {
if pm.match(c.Tag.Name) != -1 {
ret = append(ret, c.Tag)
continue
}
for _, alias := range c.Aliases {
if pm.match(alias) != -1 {
ret = append(ret, c.Tag)
break
}
}
}
return ret, nil
}
words := getPathWords(path, trimExt)
tags, err := getTags(ctx, words, reader, cache)
if err != nil {
return nil, err
}
@ -255,23 +342,21 @@ func PathToTags(ctx context.Context, path string, reader models.TagAutoTagQuerye
var ret []*models.Tag
for _, t := range tags {
matches := false
if nameMatchesPath(t.Name, path) != -1 {
if pm.match(t.Name) != -1 {
matches = true
}
if !matches {
aliases, err := reader.GetAliases(ctx, t.ID)
if err != nil {
return nil, err
}
for _, alias := range aliases {
if nameMatchesPath(alias, path) != -1 {
if pm.match(alias) != -1 {
matches = true
break
}
}
}
if matches {
ret = append(ret, t)
}

View file

@ -0,0 +1,426 @@
package match
import (
"context"
"slices"
"testing"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
"github.com/stretchr/testify/mock"
)
// Path-matching semantic tests that lock in the behavior of
// PathTo{Performers,Studio,Tags} via the generated testify mocks in
// pkg/models/mocks. These are the regression guard when the candidate-
// lookup strategy changes (e.g., replacing the SQL prefilter with an
// in-memory matcher): each case runs against both cache=nil and a
// preloaded cache, asserting identical output.
// --- mock setup helpers ---
// preloadFilter matches the filter PreloadX passes: IgnoreAutoTag=false.
// singleLetterFilter matches the filter the single-letter-cache path
// passes: a regex in Name. Keeping them disjoint means testify will
// route each Query call to the right stub regardless of declaration
// order.
func performerPreloadFilter() interface{} {
return mock.MatchedBy(func(f *models.PerformerFilterType) bool {
return f != nil && f.IgnoreAutoTag != nil && !*f.IgnoreAutoTag
})
}
func performerSingleLetterFilter() interface{} {
return mock.MatchedBy(func(f *models.PerformerFilterType) bool {
return f != nil && f.Name != nil
})
}
func studioPreloadFilter() interface{} {
return mock.MatchedBy(func(f *models.StudioFilterType) bool {
return f != nil && f.IgnoreAutoTag != nil && !*f.IgnoreAutoTag
})
}
func studioSingleLetterFilter() interface{} {
return mock.MatchedBy(func(f *models.StudioFilterType) bool {
return f != nil && f.Name != nil
})
}
func tagPreloadFilter() interface{} {
return mock.MatchedBy(func(f *models.TagFilterType) bool {
return f != nil && f.IgnoreAutoTag != nil && !*f.IgnoreAutoTag
})
}
func tagSingleLetterFilter() interface{} {
return mock.MatchedBy(func(f *models.TagFilterType) bool {
return f != nil && f.Name != nil
})
}
// primePerformerMock sets up a PerformerReaderWriter to serve both the
// no-preload path (QueryForAutoTag returns all non-ignored; single-letter
// Query returns nothing) and the preload path (Query with IgnoreAutoTag
// filter returns all non-ignored). All expectations are .Maybe() because
// which ones fire depends on whether the test passes a cache.
func primePerformerMock(m *mocks.PerformerReaderWriter, performers []*models.Performer) {
var nonIgnored []*models.Performer
for _, p := range performers {
if !p.IgnoreAutoTag {
nonIgnored = append(nonIgnored, p)
}
}
m.On("QueryForAutoTag", mock.Anything, mock.Anything).Return(nonIgnored, nil).Maybe()
m.On("Query", mock.Anything, performerPreloadFilter(), mock.Anything).Return(nonIgnored, len(nonIgnored), nil).Maybe()
m.On("Query", mock.Anything, performerSingleLetterFilter(), mock.Anything).Return(nil, 0, nil).Maybe()
}
func primeStudioMock(m *mocks.StudioReaderWriter, studios []*models.Studio, aliases map[int][]string) {
var nonIgnored []*models.Studio
for _, s := range studios {
if !s.IgnoreAutoTag {
nonIgnored = append(nonIgnored, s)
}
}
m.On("QueryForAutoTag", mock.Anything, mock.Anything).Return(nonIgnored, nil).Maybe()
m.On("Query", mock.Anything, studioPreloadFilter(), mock.Anything).Return(nonIgnored, len(nonIgnored), nil).Maybe()
m.On("Query", mock.Anything, studioSingleLetterFilter(), mock.Anything).Return(nil, 0, nil).Maybe()
for _, s := range studios {
m.On("GetAliases", mock.Anything, s.ID).Return(aliases[s.ID], nil).Maybe()
}
}
func primeTagMock(m *mocks.TagReaderWriter, tags []*models.Tag, aliases map[int][]string) {
var nonIgnored []*models.Tag
for _, t := range tags {
if !t.IgnoreAutoTag {
nonIgnored = append(nonIgnored, t)
}
}
m.On("QueryForAutoTag", mock.Anything, mock.Anything).Return(nonIgnored, nil).Maybe()
m.On("Query", mock.Anything, tagPreloadFilter(), mock.Anything).Return(nonIgnored, len(nonIgnored), nil).Maybe()
m.On("Query", mock.Anything, tagSingleLetterFilter(), mock.Anything).Return(nil, 0, nil).Maybe()
for _, t := range tags {
m.On("GetAliases", mock.Anything, t.ID).Return(aliases[t.ID], nil).Maybe()
}
}
// --- helpers ---
func perfIDs(ps []*models.Performer) []int {
ids := make([]int, 0, len(ps))
for _, p := range ps {
ids = append(ids, p.ID)
}
slices.Sort(ids)
return ids
}
func tagIDs(ts []*models.Tag) []int {
ids := make([]int, 0, len(ts))
for _, t := range ts {
ids = append(ids, t.ID)
}
slices.Sort(ids)
return ids
}
// --- tests ---
func TestPathToPerformers_Semantics(t *testing.T) {
t.Parallel()
ctx := context.Background()
alice := &models.Performer{ID: 1, Name: "alice smith"}
bob := &models.Performer{ID: 2, Name: "bob jones"}
unicodeP := &models.Performer{ID: 3, Name: "伏字"}
ignored := &models.Performer{ID: 4, Name: "ignored person", IgnoreAutoTag: true}
substr := &models.Performer{ID: 5, Name: "ali"} // substring of "alice" - should NOT match "alice smith.jpg"
performers := []*models.Performer{alice, bob, unicodeP, ignored, substr}
db := mocks.NewDatabase()
primePerformerMock(db.Performer, performers)
tests := []struct {
name string
path string
wantIDs []int
}{
{"plain name match", "/media/alice smith.jpg", []int{1}},
{"separator variants", "/media/alice.smith.jpg", []int{1}},
{"separator variants 2", "/media/alice_smith.jpg", []int{1}},
{"multiple matches", "/media/alice smith and bob jones.jpg", []int{1, 2}},
{"case insensitive", "/media/ALICE SMITH.jpg", []int{1}},
{"unicode", "/media/伏字.jpg", []int{3}},
{"ignore_auto_tag skipped", "/media/ignored person.jpg", nil},
{"no substring match", "/media/alicent.jpg", nil},
{"short name does NOT match inside longer", "/media/alice smith.jpg", []int{1}}, // 'ali' should not match
{"short name matches exact", "/media/ali.jpg", []int{5}},
{"no match", "/media/nobody here.jpg", nil},
}
for _, tt := range tests {
t.Run(tt.name+"/no-preload", func(t *testing.T) {
got, err := PathToPerformers(ctx, tt.path, db.Performer, nil, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotIDs := perfIDs(got); !slices.Equal(gotIDs, tt.wantIDs) {
t.Errorf("got %v, want %v", gotIDs, tt.wantIDs)
}
})
t.Run(tt.name+"/preloaded", func(t *testing.T) {
cache := &Cache{}
if err := cache.PreloadPerformers(ctx, db.Performer); err != nil {
t.Fatalf("preload: %v", err)
}
got, err := PathToPerformers(ctx, tt.path, db.Performer, cache, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotIDs := perfIDs(got); !slices.Equal(gotIDs, tt.wantIDs) {
t.Errorf("got %v, want %v", gotIDs, tt.wantIDs)
}
})
}
}
func TestPathToStudio_Semantics(t *testing.T) {
t.Parallel()
ctx := context.Background()
s1 := &models.Studio{ID: 1, Name: "first studio"}
s2 := &models.Studio{ID: 2, Name: "second"}
s3 := &models.Studio{ID: 3, Name: "third", IgnoreAutoTag: true}
studios := []*models.Studio{s1, s2, s3}
aliases := map[int][]string{2: {"second alias"}}
db := mocks.NewDatabase()
primeStudioMock(db.Studio, studios, aliases)
tests := []struct {
name string
path string
wantID int // 0 == no match
}{
{"primary name", "/first studio/scene.mp4", 1},
{"alias matches", "/second alias/scene.mp4", 2},
{"ignore_auto_tag studio skipped", "/third/scene.mp4", 0},
{"multiple matches - rightmost wins", "/first studio/second/scene.mp4", 2},
{"no match", "/unknown/scene.mp4", 0},
}
runCase := func(t *testing.T, path string, wantID int, cache *Cache) {
got, err := PathToStudio(ctx, path, db.Studio, cache, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var gotID int
if got != nil {
gotID = got.ID
}
if gotID != wantID {
t.Errorf("got %d, want %d", gotID, wantID)
}
}
for _, tt := range tests {
t.Run(tt.name+"/no-preload", func(t *testing.T) {
runCase(t, tt.path, tt.wantID, nil)
})
t.Run(tt.name+"/preloaded", func(t *testing.T) {
cache := &Cache{}
if err := cache.PreloadStudios(ctx, db.Studio); err != nil {
t.Fatalf("preload: %v", err)
}
runCase(t, tt.path, tt.wantID, cache)
})
}
}
func TestPathToTags_Semantics(t *testing.T) {
t.Parallel()
ctx := context.Background()
t1 := &models.Tag{ID: 1, Name: "anime"}
t2 := &models.Tag{ID: 2, Name: "docs"}
t3 := &models.Tag{ID: 3, Name: "skip me", IgnoreAutoTag: true}
tags := []*models.Tag{t1, t2, t3}
aliases := map[int][]string{2: {"documentary"}}
db := mocks.NewDatabase()
primeTagMock(db.Tag, tags, aliases)
tests := []struct {
name string
path string
wantIDs []int
}{
{"name match", "/media/anime/x.mp4", []int{1}},
{"alias match", "/media/documentary/x.mp4", []int{2}},
{"multiple matches", "/media/anime-documentary/x.mp4", []int{1, 2}},
{"ignore_auto_tag skipped", "/media/skip me/x.mp4", nil},
{"no match", "/media/comedy/x.mp4", nil},
}
runCase := func(t *testing.T, path string, wantIDs []int, cache *Cache) {
got, err := PathToTags(ctx, path, db.Tag, cache, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotIDs := tagIDs(got); !slices.Equal(gotIDs, wantIDs) {
t.Errorf("got %v, want %v", gotIDs, wantIDs)
}
}
for _, tt := range tests {
t.Run(tt.name+"/no-preload", func(t *testing.T) {
runCase(t, tt.path, tt.wantIDs, nil)
})
t.Run(tt.name+"/preloaded", func(t *testing.T) {
cache := &Cache{}
if err := cache.PreloadTags(ctx, db.Tag); err != nil {
t.Fatalf("preload: %v", err)
}
runCase(t, tt.path, tt.wantIDs, cache)
})
}
}
// Performer whose name starts with a single-letter word (e.g., "X Man")
// can't be reached via 2-rune prefix lookup (getPathWords drops 1-char
// words). The preload must put them in the alwaysCheck list so they're
// still regex-tested.
func TestPathToPerformers_SingleLetterFirstWord(t *testing.T) {
t.Parallel()
ctx := context.Background()
xman := &models.Performer{ID: 1, Name: "X Man"}
other := &models.Performer{ID: 2, Name: "alice smith"}
db := mocks.NewDatabase()
primePerformerMock(db.Performer, []*models.Performer{xman, other})
cache := &Cache{}
if err := cache.PreloadPerformers(ctx, db.Performer); err != nil {
t.Fatal(err)
}
got, err := PathToPerformers(ctx, "/media/X Man.mp4", db.Performer, cache, false)
if err != nil {
t.Fatal(err)
}
if ids := perfIDs(got); !slices.Equal(ids, []int{1}) {
t.Errorf("expected [1], got %v", ids)
}
}
// A studio whose name shares no prefix with its aliases must be reachable
// by alias prefix. "Acme Corp" with alias "Widgets Inc" must match a path
// containing "widgets inc".
func TestPathToStudio_AliasPrefixDistinctFromName(t *testing.T) {
t.Parallel()
ctx := context.Background()
s := &models.Studio{ID: 1, Name: "Acme Corp"}
db := mocks.NewDatabase()
primeStudioMock(db.Studio, []*models.Studio{s}, map[int][]string{1: {"Widgets Inc"}})
cache := &Cache{}
if err := cache.PreloadStudios(ctx, db.Studio); err != nil {
t.Fatal(err)
}
got, err := PathToStudio(ctx, "/media/Widgets Inc/scene.mp4", db.Studio, cache, false)
if err != nil {
t.Fatal(err)
}
if got == nil || got.ID != 1 {
t.Errorf("expected studio 1, got %v", got)
}
}
// Same for tags.
func TestPathToTags_AliasPrefixDistinctFromName(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := mocks.NewDatabase()
primeTagMock(db.Tag, []*models.Tag{{ID: 1, Name: "documentary"}}, map[int][]string{1: {"film"}})
cache := &Cache{}
if err := cache.PreloadTags(ctx, db.Tag); err != nil {
t.Fatal(err)
}
got, err := PathToTags(ctx, "/media/film/x.mp4", db.Tag, cache, false)
if err != nil {
t.Fatal(err)
}
if ids := tagIDs(got); !slices.Equal(ids, []int{1}) {
t.Errorf("expected [1], got %v", ids)
}
}
// Two aliases on the same studio with different prefixes should each
// reach the studio. Index bucket must dedupe inside the bucket.
func TestPathToStudio_MultipleAliasesDedup(t *testing.T) {
t.Parallel()
ctx := context.Background()
s := &models.Studio{ID: 1, Name: "Primary Name"}
db := mocks.NewDatabase()
primeStudioMock(db.Studio, []*models.Studio{s}, map[int][]string{1: {"Primary Nickname", "Primary Alt"}})
cache := &Cache{}
if err := cache.PreloadStudios(ctx, db.Studio); err != nil {
t.Fatal(err)
}
// Studio "Primary Name" and both aliases all share prefix "pr".
// The bucket should contain it exactly once.
if got := len(cache.studioByPrefix["pr"]); got != 1 {
t.Errorf("bucket 'pr' should have 1 entry, got %d", got)
}
}
// Equivalence test: the function must return the same result regardless of
// whether a match.Cache is passed in. This is the invariant that any
// caching-based optimization must preserve.
func TestPathToPerformers_CachedVsUncached(t *testing.T) {
t.Parallel()
ctx := context.Background()
perfs := []*models.Performer{
{ID: 1, Name: "alice smith"},
{ID: 2, Name: "bob jones"},
{ID: 3, Name: "charlie"},
{ID: 4, Name: "david wong"},
}
db := mocks.NewDatabase()
primePerformerMock(db.Performer, perfs)
paths := []string{
"/media/alice smith.jpg",
"/media/bob_jones.jpg",
"/media/alice smith and charlie.jpg",
"/media/nobody.jpg",
"/media/alice smith.jpg", // repeat: cached regex should not change outcome
}
var noCache, withCache [][]int
cache := &Cache{}
for _, p := range paths {
uc, err := PathToPerformers(ctx, p, db.Performer, nil, false)
if err != nil {
t.Fatal(err)
}
wc, err := PathToPerformers(ctx, p, db.Performer, cache, false)
if err != nil {
t.Fatal(err)
}
noCache = append(noCache, perfIDs(uc))
withCache = append(withCache, perfIDs(wc))
}
for i := range paths {
if !slices.Equal(noCache[i], withCache[i]) {
t.Errorf("path %q: no-cache %v vs cached %v", paths[i], noCache[i], withCache[i])
}
}
}

View file

@ -127,3 +127,17 @@ func BatchFindFilter(batchSize int) *FindFilterType {
Page: &page,
}
}
// KeysetFindFilter returns a FindFilterType suitable for id-ordered keyset
// pagination. Callers pair it with a WHERE id > lastID clause to iterate
// large tables without paying the O(offset) scan that LIMIT/OFFSET pays
// on later pages.
func KeysetFindFilter(batchSize int) *FindFilterType {
sort := "id"
sortDir := SortDirectionEnumAsc
return &FindFilterType{
PerPage: &batchSize,
Sort: &sort,
Direction: &sortDir,
}
}

View file

@ -63,6 +63,14 @@ type AliasLoader interface {
GetAliases(ctx context.Context, relatedID int) ([]string, error)
}
// AllAliasLoader is an optional bulk variant of AliasLoader: it returns
// aliases for every id in one query, letting callers that need aliases for
// many entities skip the N+1 per-id lookups. Implementations are free to
// add this alongside AliasLoader; callers use it via a type assertion.
type AllAliasLoader interface {
GetAllAliases(ctx context.Context) (map[int][]string, error)
}
type URLLoader interface {
GetURLs(ctx context.Context, relatedID int) ([]string, error)
}

View file

@ -378,6 +378,23 @@ func (r *stringRepository) get(ctx context.Context, id int) ([]string, error) {
return ret, err
}
// getAll returns every (id, value) pair in the join table, grouped by id.
// Used to avoid N+1 lookups when callers need values for many ids at once.
func (r *stringRepository) getAll(ctx context.Context) (map[int][]string, error) {
query := fmt.Sprintf("SELECT %s, %s from %s", r.idColumn, r.stringColumn, r.tableName)
ret := make(map[int][]string)
err := r.queryFunc(ctx, query, nil, false, func(rows *sqlx.Rows) error {
var id int
var out string
if err := rows.Scan(&id, &out); err != nil {
return err
}
ret[id] = append(ret[id], out)
return nil
})
return ret, err
}
func (r *stringRepository) insert(ctx context.Context, id int, s string) (sql.Result, error) {
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.stringColumn)
return dbWrapper.Exec(ctx, stmt, id, s)

View file

@ -742,6 +742,12 @@ func (qb *StudioStore) GetAliases(ctx context.Context, studioID int) ([]string,
return studiosAliasesTableMgr.get(ctx, studioID)
}
// GetAllAliases returns a map of studio id to its aliases. Lets callers that
// need aliases for many studios avoid N+1 per-id lookups.
func (qb *StudioStore) GetAllAliases(ctx context.Context) (map[int][]string, error) {
return studiosAliasesTableMgr.getAll(ctx)
}
func (qb *StudioStore) GetURLs(ctx context.Context, studioID int) ([]string, error) {
return studiosURLsTableMgr.get(ctx, studioID)
}

View file

@ -423,6 +423,28 @@ func (t *stringTable) get(ctx context.Context, id int) ([]string, error) {
return ret, nil
}
// getAll returns every (id, value) pair in the join table, grouped by id.
// Used to avoid N+1 lookups when callers need values for many ids at once.
func (t *stringTable) getAll(ctx context.Context) (map[int][]string, error) {
q := dialect.Select(t.idColumn, t.stringColumn).From(t.table.table)
const single = false
ret := make(map[int][]string)
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
var id int
var v string
if err := rows.Scan(&id, &v); err != nil {
return err
}
ret[id] = append(ret[id], v)
return nil
}); err != nil {
return nil, fmt.Errorf("getting all values from %s: %w", t.table.table.GetTable(), err)
}
return ret, nil
}
func (t *stringTable) insertJoin(ctx context.Context, id int, v string) (sql.Result, error) {
q := dialect.Insert(t.table.table).Cols(t.idColumn.GetCol(), t.stringColumn.GetCol()).Vals(
goqu.Vals{id, v},

View file

@ -940,6 +940,12 @@ func (qb *TagStore) GetAliases(ctx context.Context, tagID int) ([]string, error)
return tagRepository.aliases.get(ctx, tagID)
}
// GetAllAliases returns a map of tag id to its aliases. Lets callers that
// need aliases for many tags avoid N+1 per-id lookups.
func (qb *TagStore) GetAllAliases(ctx context.Context) (map[int][]string, error) {
return tagRepository.aliases.getAll(ctx)
}
func (qb *TagStore) UpdateAliases(ctx context.Context, tagID int, aliases []string) error {
return tagRepository.aliases.replace(ctx, tagID, aliases)
}