Auto tag rewrite (#1324)

This commit is contained in:
WithoutPants 2021-04-26 12:51:31 +10:00 committed by GitHub
parent f66010a367
commit 2eb2d865dc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 1469 additions and 370 deletions

View file

@ -1,6 +1,6 @@
// +build integration
package manager
package autotag
import (
"context"
@ -8,8 +8,6 @@ import (
"fmt"
"io/ioutil"
"os"
"strings"
"sync"
"testing"
"github.com/stashapp/stash/pkg/database"
@ -22,49 +20,12 @@ import (
)
const testName = "Foo's Bar"
const testExtension = ".mp4"
const existingStudioName = "ExistingStudio"
const existingStudioSceneName = testName + ".dontChangeStudio" + testExtension
const existingStudioSceneName = testName + ".dontChangeStudio.mp4"
var existingStudioID int
var testSeparators = []string{
".",
"-",
"_",
" ",
}
var testEndSeparators = []string{
"{",
"}",
"(",
")",
",",
}
func generateNamePatterns(name, separator string) []string {
var ret []string
ret = append(ret, fmt.Sprintf("%s%saaa"+testExtension, name, separator))
ret = append(ret, fmt.Sprintf("aaa%s%s"+testExtension, separator, name))
ret = append(ret, fmt.Sprintf("aaa%s%s%sbbb"+testExtension, separator, name, separator))
ret = append(ret, fmt.Sprintf("dir/%s%saaa"+testExtension, name, separator))
ret = append(ret, fmt.Sprintf("dir\\%s%saaa"+testExtension, name, separator))
ret = append(ret, fmt.Sprintf("%s%saaa/dir/bbb"+testExtension, name, separator))
ret = append(ret, fmt.Sprintf("%s%saaa\\dir\\bbb"+testExtension, name, separator))
ret = append(ret, fmt.Sprintf("dir/%s%s/aaa"+testExtension, name, separator))
ret = append(ret, fmt.Sprintf("dir\\%s%s\\aaa"+testExtension, name, separator))
return ret
}
func generateFalseNamePattern(name string, separator string) string {
splitted := strings.Split(name, " ")
return fmt.Sprintf("%s%saaa%s%s"+testExtension, splitted[0], separator, separator, splitted[1])
}
func testTeardown(databaseFile string) {
err := database.DB.Close()
@ -126,7 +87,7 @@ func createStudio(qb models.StudioWriter, name string) (*models.Studio, error) {
// create the studio
studio := models.Studio{
Checksum: name,
Name: sql.NullString{Valid: true, String: testName},
Name: sql.NullString{Valid: true, String: name},
}
return qb.Create(studio)
@ -148,23 +109,7 @@ func createTag(qb models.TagWriter) error {
func createScenes(sqb models.SceneReaderWriter) error {
// create the scenes
var scenePatterns []string
var falseScenePatterns []string
separators := append(testSeparators, testEndSeparators...)
for _, separator := range separators {
scenePatterns = append(scenePatterns, generateNamePatterns(testName, separator)...)
scenePatterns = append(scenePatterns, generateNamePatterns(strings.ToLower(testName), separator)...)
falseScenePatterns = append(falseScenePatterns, generateFalseNamePattern(testName, separator))
}
// add test cases for intra-name separators
for _, separator := range testSeparators {
if separator != " " {
scenePatterns = append(scenePatterns, generateNamePatterns(strings.Replace(testName, " ", separator, -1), separator)...)
}
}
scenePatterns, falseScenePatterns := generateScenePaths(testName)
for _, fn := range scenePatterns {
err := createScene(sqb, makeScene(fn, true))
@ -278,17 +223,14 @@ func TestParsePerformers(t *testing.T) {
return
}
task := AutoTagPerformerTask{
AutoTagTask: AutoTagTask{
txnManager: sqlite.NewTransactionManager(),
},
performer: performers[0],
for _, p := range performers {
if err := withTxn(func(r models.Repository) error {
return PerformerScenes(p, nil, r.Scene())
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
var wg sync.WaitGroup
wg.Add(1)
task.Start(&wg)
// verify that scenes were tagged correctly
withTxn(func(r models.Repository) error {
pqb := r.Performer()
@ -328,17 +270,14 @@ func TestParseStudios(t *testing.T) {
return
}
task := AutoTagStudioTask{
AutoTagTask: AutoTagTask{
txnManager: sqlite.NewTransactionManager(),
},
studio: studios[0],
for _, s := range studios {
if err := withTxn(func(r models.Repository) error {
return StudioScenes(s, nil, r.Scene())
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
var wg sync.WaitGroup
wg.Add(1)
task.Start(&wg)
// verify that scenes were tagged correctly
withTxn(func(r models.Repository) error {
scenes, err := r.Scene().All()
@ -354,9 +293,14 @@ func TestParseStudios(t *testing.T) {
}
} else {
// title is only set on scenes where we expect studio to be set
if scene.Title.String == scene.Path && scene.StudioID.Int64 != int64(studios[0].ID) {
t.Errorf("Did not set studio '%s' for path '%s'", testName, scene.Path)
} else if scene.Title.String != scene.Path && scene.StudioID.Int64 == int64(studios[0].ID) {
if scene.Title.String == scene.Path {
if !scene.StudioID.Valid {
t.Errorf("Did not set studio '%s' for path '%s'", testName, scene.Path)
} else if scene.StudioID.Int64 != int64(studios[1].ID) {
t.Errorf("Incorrect studio id %d set for path '%s'", scene.StudioID.Int64, scene.Path)
}
} else if scene.Title.String != scene.Path && scene.StudioID.Int64 == int64(studios[1].ID) {
t.Errorf("Incorrectly set studio '%s' for path '%s'", testName, scene.Path)
}
}
@ -377,17 +321,14 @@ func TestParseTags(t *testing.T) {
return
}
task := AutoTagTagTask{
AutoTagTask: AutoTagTask{
txnManager: sqlite.NewTransactionManager(),
},
tag: tags[0],
for _, s := range tags {
if err := withTxn(func(r models.Repository) error {
return TagScenes(s, nil, r.Scene())
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
var wg sync.WaitGroup
wg.Add(1)
task.Start(&wg)
// verify that scenes were tagged correctly
withTxn(func(r models.Repository) error {
scenes, err := r.Scene().All()

42
pkg/autotag/performer.go Normal file
View file

@ -0,0 +1,42 @@
package autotag
import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
)
func getMatchingPerformers(path string, performerReader models.PerformerReader) ([]*models.Performer, error) {
words := getPathWords(path)
performers, err := performerReader.QueryForAutoTag(words)
if err != nil {
return nil, err
}
var ret []*models.Performer
for _, p := range performers {
// TODO - commenting out alias handling until both sides work correctly
if nameMatchesPath(p.Name.String, path) { // || nameMatchesPath(p.Aliases.String, path) {
ret = append(ret, p)
}
}
return ret, nil
}
func getPerformerTagger(p *models.Performer) tagger {
return tagger{
ID: p.ID,
Type: "performer",
Name: p.Name.String,
}
}
// PerformerScenes searches for scenes whose path matches the provided performer name and tags the scene with the performer.
func PerformerScenes(p *models.Performer, paths []string, rw models.SceneReaderWriter) error {
t := getPerformerTagger(p)
return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) {
return scene.AddPerformer(rw, otherID, subjectID)
})
}

View file

@ -0,0 +1,81 @@
package autotag
import (
"testing"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
"github.com/stretchr/testify/assert"
)
func TestPerformerScenes(t *testing.T) {
type test struct {
performerName string
expectedRegex string
}
performerNames := []test{
{
"performer name",
`(?i)(?:^|_|[^\w\d])performer[.\-_ ]*name(?:$|_|[^\w\d])`,
},
{
"performer + name",
`(?i)(?:^|_|[^\w\d])performer[.\-_ ]*\+[.\-_ ]*name(?:$|_|[^\w\d])`,
},
}
for _, p := range performerNames {
testPerformerScenes(t, p.performerName, p.expectedRegex)
}
}
func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
mockSceneReader := &mocks.SceneReaderWriter{}
const performerID = 2
var scenes []*models.Scene
matchingPaths, falsePaths := generateScenePaths(performerName)
for i, p := range append(matchingPaths, falsePaths...) {
scenes = append(scenes, &models.Scene{
ID: i + 1,
Path: p,
})
}
performer := models.Performer{
ID: performerID,
Name: models.NullString(performerName),
}
organized := false
perPage := models.PerPageAll
expectedSceneFilter := &models.SceneFilterType{
Organized: &organized,
Path: &models.StringCriterionInput{
Value: expectedRegex,
Modifier: models.CriterionModifierMatchesRegex,
},
}
expectedFindFilter := &models.FindFilterType{
PerPage: &perPage,
}
mockSceneReader.On("Query", expectedSceneFilter, expectedFindFilter).Return(scenes, len(scenes), nil).Once()
for i := range matchingPaths {
sceneID := i + 1
mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once()
}
err := PerformerScenes(&performer, nil, mockSceneReader)
assert := assert.New(t)
assert.Nil(err)
mockSceneReader.AssertExpectations(t)
}

117
pkg/autotag/scene.go Normal file
View file

@ -0,0 +1,117 @@
package autotag
import (
"fmt"
"path/filepath"
"strings"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
)
func pathsFilter(paths []string) *models.SceneFilterType {
if paths == nil {
return nil
}
sep := string(filepath.Separator)
var ret *models.SceneFilterType
var or *models.SceneFilterType
for _, p := range paths {
newOr := &models.SceneFilterType{}
if or != nil {
or.Or = newOr
} else {
ret = newOr
}
or = newOr
if !strings.HasSuffix(p, sep) {
p = p + sep
}
or.Path = &models.StringCriterionInput{
Modifier: models.CriterionModifierEquals,
Value: p + "%",
}
}
return ret
}
func getMatchingScenes(name string, paths []string, sceneReader models.SceneReader) ([]*models.Scene, error) {
regex := getPathQueryRegex(name)
organized := false
filter := models.SceneFilterType{
Path: &models.StringCriterionInput{
Value: "(?i)" + regex,
Modifier: models.CriterionModifierMatchesRegex,
},
Organized: &organized,
}
filter.And = pathsFilter(paths)
pp := models.PerPageAll
scenes, _, err := sceneReader.Query(&filter, &models.FindFilterType{
PerPage: &pp,
})
if err != nil {
return nil, fmt.Errorf("error querying scenes with regex '%s': %s", regex, err.Error())
}
var ret []*models.Scene
for _, p := range scenes {
if nameMatchesPath(name, p.Path) {
ret = append(ret, p)
}
}
return ret, nil
}
func getSceneFileTagger(s *models.Scene) tagger {
return tagger{
ID: s.ID,
Type: "scene",
Name: s.GetTitle(),
Path: s.Path,
}
}
// ScenePerformers tags the provided scene with performers whose name matches the scene's path.
func ScenePerformers(s *models.Scene, rw models.SceneReaderWriter, performerReader models.PerformerReader) error {
t := getSceneFileTagger(s)
return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) {
return scene.AddPerformer(rw, subjectID, otherID)
})
}
// SceneStudios tags the provided scene with the first studio whose name matches the scene's path.
//
// Scenes will not be tagged if studio is already set.
func SceneStudios(s *models.Scene, rw models.SceneReaderWriter, studioReader models.StudioReader) error {
if s.StudioID.Valid {
// don't modify
return nil
}
t := getSceneFileTagger(s)
return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) {
return addSceneStudio(rw, subjectID, otherID)
})
}
// SceneTags tags the provided scene with tags whose name matches the scene's path.
func SceneTags(s *models.Scene, rw models.SceneReaderWriter, tagReader models.TagReader) error {
t := getSceneFileTagger(s)
return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) {
return scene.AddTag(rw, subjectID, otherID)
})
}

276
pkg/autotag/scene_test.go Normal file
View file

@ -0,0 +1,276 @@
package autotag
import (
"fmt"
"strings"
"testing"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
var testSeparators = []string{
".",
"-",
"_",
" ",
}
var testEndSeparators = []string{
"{",
"}",
"(",
")",
",",
}
func generateNamePatterns(name, separator string) []string {
var ret []string
ret = append(ret, fmt.Sprintf("%s%saaa.mp4", name, separator))
ret = append(ret, fmt.Sprintf("aaa%s%s.mp4", separator, name))
ret = append(ret, fmt.Sprintf("aaa%s%s%sbbb.mp4", separator, name, separator))
ret = append(ret, fmt.Sprintf("dir/%s%saaa.mp4", name, separator))
ret = append(ret, fmt.Sprintf("dir\\%s%saaa.mp4", name, separator))
ret = append(ret, fmt.Sprintf("%s%saaa/dir/bbb.mp4", name, separator))
ret = append(ret, fmt.Sprintf("%s%saaa\\dir\\bbb.mp4", name, separator))
ret = append(ret, fmt.Sprintf("dir/%s%s/aaa.mp4", name, separator))
ret = append(ret, fmt.Sprintf("dir\\%s%s\\aaa.mp4", name, separator))
return ret
}
func generateSplitNamePatterns(name, separator string) []string {
var ret []string
splitted := strings.Split(name, " ")
// only do this for names that are split into two
if len(splitted) == 2 {
ret = append(ret, fmt.Sprintf("%s%s%s.mp4", splitted[0], separator, splitted[1]))
}
return ret
}
func generateFalseNamePatterns(name string, separator string) []string {
splitted := strings.Split(name, " ")
var ret []string
// only do this for names that are split into two
if len(splitted) == 2 {
ret = append(ret, fmt.Sprintf("%s%saaa%s%s.mp4", splitted[0], separator, separator, splitted[1]))
}
return ret
}
func generateScenePaths(testName string) (scenePatterns []string, falseScenePatterns []string) {
separators := append(testSeparators, testEndSeparators...)
for _, separator := range separators {
scenePatterns = append(scenePatterns, generateNamePatterns(testName, separator)...)
scenePatterns = append(scenePatterns, generateNamePatterns(strings.ToLower(testName), separator)...)
scenePatterns = append(scenePatterns, generateNamePatterns(strings.ReplaceAll(testName, " ", ""), separator)...)
falseScenePatterns = append(falseScenePatterns, generateFalseNamePatterns(testName, separator)...)
}
// add test cases for intra-name separators
for _, separator := range testSeparators {
if separator != " " {
scenePatterns = append(scenePatterns, generateNamePatterns(strings.Replace(testName, " ", separator, -1), separator)...)
}
}
// add basic false scenarios
falseScenePatterns = append(falseScenePatterns, fmt.Sprintf("aaa%s.mp4", testName))
falseScenePatterns = append(falseScenePatterns, fmt.Sprintf("%saaa.mp4", testName))
// add path separator false scenarios
falseScenePatterns = append(falseScenePatterns, generateFalseNamePatterns(testName, "/")...)
falseScenePatterns = append(falseScenePatterns, generateFalseNamePatterns(testName, "\\")...)
// split patterns only valid for ._- and whitespace
for _, separator := range testSeparators {
scenePatterns = append(scenePatterns, generateSplitNamePatterns(testName, separator)...)
}
// false patterns for other separators
for _, separator := range testEndSeparators {
falseScenePatterns = append(falseScenePatterns, generateSplitNamePatterns(testName, separator)...)
}
return
}
type pathTestTable struct {
ScenePath string
Matches bool
}
func generateTestTable(testName string) []pathTestTable {
var ret []pathTestTable
var scenePatterns []string
var falseScenePatterns []string
separators := append(testSeparators, testEndSeparators...)
for _, separator := range separators {
scenePatterns = append(scenePatterns, generateNamePatterns(testName, separator)...)
scenePatterns = append(scenePatterns, generateNamePatterns(strings.ToLower(testName), separator)...)
falseScenePatterns = append(falseScenePatterns, generateFalseNamePatterns(testName, separator)...)
}
for _, p := range scenePatterns {
t := pathTestTable{
ScenePath: p,
Matches: true,
}
ret = append(ret, t)
}
for _, p := range falseScenePatterns {
t := pathTestTable{
ScenePath: p,
Matches: false,
}
ret = append(ret, t)
}
return ret
}
func TestScenePerformers(t *testing.T) {
const sceneID = 1
const performerName = "performer name"
const performerID = 2
performer := models.Performer{
ID: performerID,
Name: models.NullString(performerName),
}
const reversedPerformerName = "name performer"
const reversedPerformerID = 3
reversedPerformer := models.Performer{
ID: reversedPerformerID,
Name: models.NullString(reversedPerformerName),
}
testTables := generateTestTable(performerName)
assert := assert.New(t)
for _, test := range testTables {
mockPerformerReader := &mocks.PerformerReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches {
mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once()
}
scene := models.Scene{
ID: sceneID,
Path: test.ScenePath,
}
err := ScenePerformers(&scene, mockSceneReader, mockPerformerReader)
assert.Nil(err)
mockPerformerReader.AssertExpectations(t)
mockSceneReader.AssertExpectations(t)
}
}
func TestSceneStudios(t *testing.T) {
const sceneID = 1
const studioName = "studio name"
const studioID = 2
studio := models.Studio{
ID: studioID,
Name: models.NullString(studioName),
}
const reversedStudioName = "name studio"
const reversedStudioID = 3
reversedStudio := models.Studio{
ID: reversedStudioID,
Name: models.NullString(reversedStudioName),
}
testTables := generateTestTable(studioName)
assert := assert.New(t)
for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
if test.Matches {
mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
mockSceneReader.On("Update", models.ScenePartial{
ID: sceneID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
}
scene := models.Scene{
ID: sceneID,
Path: test.ScenePath,
}
err := SceneStudios(&scene, mockSceneReader, mockStudioReader)
assert.Nil(err)
mockStudioReader.AssertExpectations(t)
mockSceneReader.AssertExpectations(t)
}
}
func TestSceneTags(t *testing.T) {
const sceneID = 1
const tagName = "tag name"
const tagID = 2
tag := models.Tag{
ID: tagID,
Name: tagName,
}
const reversedTagName = "name tag"
const reversedTagID = 3
reversedTag := models.Tag{
ID: reversedTagID,
Name: reversedTagName,
}
testTables := generateTestTable(tagName)
assert := assert.New(t)
for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
if test.Matches {
mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once()
}
scene := models.Scene{
ID: sceneID,
Path: test.ScenePath,
}
err := SceneTags(&scene, mockSceneReader, mockTagReader)
assert.Nil(err)
mockTagReader.AssertExpectations(t)
mockSceneReader.AssertExpectations(t)
}
}

66
pkg/autotag/studio.go Normal file
View file

@ -0,0 +1,66 @@
package autotag
import (
"database/sql"
"github.com/stashapp/stash/pkg/models"
)
func getMatchingStudios(path string, reader models.StudioReader) ([]*models.Studio, error) {
words := getPathWords(path)
candidates, err := reader.QueryForAutoTag(words)
if err != nil {
return nil, err
}
var ret []*models.Studio
for _, c := range candidates {
if nameMatchesPath(c.Name.String, path) {
ret = append(ret, c)
}
}
return ret, nil
}
func addSceneStudio(sceneWriter models.SceneReaderWriter, sceneID, studioID int) (bool, error) {
// don't set if already set
scene, err := sceneWriter.Find(sceneID)
if err != nil {
return false, err
}
if scene.StudioID.Valid {
return false, nil
}
// set the studio id
s := sql.NullInt64{Int64: int64(studioID), Valid: true}
scenePartial := models.ScenePartial{
ID: sceneID,
StudioID: &s,
}
if _, err := sceneWriter.Update(scenePartial); err != nil {
return false, err
}
return true, nil
}
func getStudioTagger(p *models.Studio) tagger {
return tagger{
ID: p.ID,
Type: "studio",
Name: p.Name.String,
}
}
// StudioScenes searches for scenes whose path matches the provided studio name and tags the scene with the studio, if studio is not already set on the scene.
func StudioScenes(p *models.Studio, paths []string, rw models.SceneReaderWriter) error {
t := getStudioTagger(p)
return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) {
return addSceneStudio(rw, otherID, subjectID)
})
}

View file

@ -0,0 +1,85 @@
package autotag
import (
"testing"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
"github.com/stretchr/testify/assert"
)
func TestStudioScenes(t *testing.T) {
type test struct {
studioName string
expectedRegex string
}
studioNames := []test{
{
"studio name",
`(?i)(?:^|_|[^\w\d])studio[.\-_ ]*name(?:$|_|[^\w\d])`,
},
{
"studio + name",
`(?i)(?:^|_|[^\w\d])studio[.\-_ ]*\+[.\-_ ]*name(?:$|_|[^\w\d])`,
},
}
for _, p := range studioNames {
testStudioScenes(t, p.studioName, p.expectedRegex)
}
}
func testStudioScenes(t *testing.T, studioName, expectedRegex string) {
mockSceneReader := &mocks.SceneReaderWriter{}
const studioID = 2
var scenes []*models.Scene
matchingPaths, falsePaths := generateScenePaths(studioName)
for i, p := range append(matchingPaths, falsePaths...) {
scenes = append(scenes, &models.Scene{
ID: i + 1,
Path: p,
})
}
studio := models.Studio{
ID: studioID,
Name: models.NullString(studioName),
}
organized := false
perPage := models.PerPageAll
expectedSceneFilter := &models.SceneFilterType{
Organized: &organized,
Path: &models.StringCriterionInput{
Value: expectedRegex,
Modifier: models.CriterionModifierMatchesRegex,
},
}
expectedFindFilter := &models.FindFilterType{
PerPage: &perPage,
}
mockSceneReader.On("Query", expectedSceneFilter, expectedFindFilter).Return(scenes, len(scenes), nil).Once()
for i := range matchingPaths {
sceneID := i + 1
mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
mockSceneReader.On("Update", models.ScenePartial{
ID: sceneID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
}
err := StudioScenes(&studio, nil, mockSceneReader)
assert := assert.New(t)
assert.Nil(err)
mockSceneReader.AssertExpectations(t)
}

41
pkg/autotag/tag.go Normal file
View file

@ -0,0 +1,41 @@
package autotag
import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
)
func getMatchingTags(path string, tagReader models.TagReader) ([]*models.Tag, error) {
words := getPathWords(path)
tags, err := tagReader.QueryForAutoTag(words)
if err != nil {
return nil, err
}
var ret []*models.Tag
for _, p := range tags {
if nameMatchesPath(p.Name, path) {
ret = append(ret, p)
}
}
return ret, nil
}
func getTagTagger(p *models.Tag) tagger {
return tagger{
ID: p.ID,
Type: "tag",
Name: p.Name,
}
}
// TagScenes searches for scenes whose path matches the provided tag name and tags the scene with the tag.
func TagScenes(p *models.Tag, paths []string, rw models.SceneReaderWriter) error {
t := getTagTagger(p)
return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) {
return scene.AddTag(rw, otherID, subjectID)
})
}

81
pkg/autotag/tag_test.go Normal file
View file

@ -0,0 +1,81 @@
package autotag
import (
"testing"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
"github.com/stretchr/testify/assert"
)
func TestTagScenes(t *testing.T) {
type test struct {
tagName string
expectedRegex string
}
tagNames := []test{
{
"tag name",
`(?i)(?:^|_|[^\w\d])tag[.\-_ ]*name(?:$|_|[^\w\d])`,
},
{
"tag + name",
`(?i)(?:^|_|[^\w\d])tag[.\-_ ]*\+[.\-_ ]*name(?:$|_|[^\w\d])`,
},
}
for _, p := range tagNames {
testTagScenes(t, p.tagName, p.expectedRegex)
}
}
func testTagScenes(t *testing.T, tagName, expectedRegex string) {
mockSceneReader := &mocks.SceneReaderWriter{}
const tagID = 2
var scenes []*models.Scene
matchingPaths, falsePaths := generateScenePaths(tagName)
for i, p := range append(matchingPaths, falsePaths...) {
scenes = append(scenes, &models.Scene{
ID: i + 1,
Path: p,
})
}
tag := models.Tag{
ID: tagID,
Name: tagName,
}
organized := false
perPage := models.PerPageAll
expectedSceneFilter := &models.SceneFilterType{
Organized: &organized,
Path: &models.StringCriterionInput{
Value: expectedRegex,
Modifier: models.CriterionModifierMatchesRegex,
},
}
expectedFindFilter := &models.FindFilterType{
PerPage: &perPage,
}
mockSceneReader.On("Query", expectedSceneFilter, expectedFindFilter).Return(scenes, len(scenes), nil).Once()
for i := range matchingPaths {
sceneID := i + 1
mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once()
}
err := TagScenes(&tag, nil, mockSceneReader)
assert := assert.New(t)
assert.Nil(err)
mockSceneReader.AssertExpectations(t)
}

198
pkg/autotag/tagger.go Normal file
View file

@ -0,0 +1,198 @@
// Package autotag provides methods to auto-tag scenes with performers,
// studios and tags.
//
// The autotag engine tags scenes with performers/studios/tags if the scene's
// path matches the performer/studio/tag name. A scene's path is considered
// a match if it contains the performer/studio/tag's full name, ignoring any
// '.', '-', '_' characters in the path.
//
// For example, for a performer "foo bar", the following paths would be
// considered a match: "foo bar.mp4", "foobar.mp4", "foo.bar.mp4",
// "foo-bar.mp4", "aaa.foo bar.bbb.mp4".
// The following would not be considered a match:
// "aafoo bar.mp4", "foo barbb.mp4", "foo/bar.mp4"
package autotag
import (
"fmt"
"path/filepath"
"regexp"
"strings"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
)
const separatorChars = `.\-_ `
// fixes #1292
func escapePathRegex(name string) string {
ret := name
chars := `+*?()|[]{}^$`
for _, c := range chars {
cStr := string(c)
ret = strings.ReplaceAll(ret, cStr, `\`+cStr)
}
return ret
}
func getPathQueryRegex(name string) string {
// escape specific regex characters
name = escapePathRegex(name)
// handle path separators
const separator = `[` + separatorChars + `]`
ret := strings.Replace(name, " ", separator+"*", -1)
ret = `(?:^|_|[^\w\d])` + ret + `(?:$|_|[^\w\d])`
return ret
}
func nameMatchesPath(name, path string) bool {
// escape specific regex characters
name = escapePathRegex(name)
name = strings.ToLower(name)
path = strings.ToLower(path)
// handle path separators
const separator = `[` + separatorChars + `]`
reStr := strings.Replace(name, " ", separator+"*", -1)
reStr = `(?:^|_|[^\w\d])` + reStr + `(?:$|_|[^\w\d])`
re := regexp.MustCompile(reStr)
return re.MatchString(path)
}
func getPathWords(path string) []string {
retStr := path
// remove the extension
ext := filepath.Ext(retStr)
if ext != "" {
retStr = strings.TrimSuffix(retStr, ext)
}
// handle path separators
const separator = `(?:_|[^\w\d])+`
re := regexp.MustCompile(separator)
retStr = re.ReplaceAllString(retStr, " ")
words := strings.Split(retStr, " ")
// remove any single letter words
var ret []string
for _, w := range words {
if len(w) > 1 {
ret = append(ret, w)
}
}
return ret
}
type tagger struct {
ID int
Type string
Name string
Path string
}
type addLinkFunc func(subjectID, otherID int) (bool, error)
func (t *tagger) addError(otherType, otherName string, err error) error {
return fmt.Errorf("error adding %s '%s' to %s '%s': %s", otherType, otherName, t.Type, t.Name, err.Error())
}
func (t *tagger) addLog(otherType, otherName string) {
logger.Infof("Added %s '%s' to %s '%s'", otherType, otherName, t.Type, t.Name)
}
func (t *tagger) tagPerformers(performerReader models.PerformerReader, addFunc addLinkFunc) error {
others, err := getMatchingPerformers(t.Path, performerReader)
if err != nil {
return err
}
for _, p := range others {
added, err := addFunc(t.ID, p.ID)
if err != nil {
return t.addError("performer", p.Name.String, err)
}
if added {
t.addLog("performer", p.Name.String)
}
}
return nil
}
func (t *tagger) tagStudios(studioReader models.StudioReader, addFunc addLinkFunc) error {
others, err := getMatchingStudios(t.Path, studioReader)
if err != nil {
return err
}
// only add first studio
if len(others) > 0 {
studio := others[0]
added, err := addFunc(t.ID, studio.ID)
if err != nil {
return t.addError("studio", studio.Name.String, err)
}
if added {
t.addLog("studio", studio.Name.String)
}
}
return nil
}
func (t *tagger) tagTags(tagReader models.TagReader, addFunc addLinkFunc) error {
others, err := getMatchingTags(t.Path, tagReader)
if err != nil {
return err
}
for _, p := range others {
added, err := addFunc(t.ID, p.ID)
if err != nil {
return t.addError("tag", p.Name, err)
}
if added {
t.addLog("tag", p.Name)
}
}
return nil
}
func (t *tagger) tagScenes(paths []string, sceneReader models.SceneReader, addFunc addLinkFunc) error {
others, err := getMatchingScenes(t.Name, paths, sceneReader)
if err != nil {
return err
}
for _, p := range others {
added, err := addFunc(t.ID, p.ID)
if err != nil {
return t.addError("scene", p.GetTitle(), err)
}
if added {
t.addLog("scene", p.GetTitle())
}
}
return nil
}

View file

@ -5,12 +5,15 @@ import (
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/remeh/sizedwaitgroup"
"github.com/stashapp/stash/pkg/autotag"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/models"
@ -626,6 +629,15 @@ func (s *singleton) generateScreenshot(sceneId string, at *float64) {
}()
}
func (s *singleton) isFileBasedAutoTag(input models.AutoTagMetadataInput) bool {
const wildcard = "*"
performerIds := input.Performers
studioIds := input.Studios
tagIds := input.Tags
return (len(performerIds) == 0 || performerIds[0] == wildcard) && (len(studioIds) == 0 || studioIds[0] == wildcard) && (len(tagIds) == 0 || tagIds[0] == wildcard)
}
func (s *singleton) AutoTag(input models.AutoTagMetadataInput) {
if s.Status.Status != Idle {
return
@ -636,58 +648,160 @@ func (s *singleton) AutoTag(input models.AutoTagMetadataInput) {
go func() {
defer s.returnToIdleState()
performerIds := input.Performers
studioIds := input.Studios
tagIds := input.Tags
// calculate work load
performerCount := len(performerIds)
studioCount := len(studioIds)
tagCount := len(tagIds)
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
performerQuery := r.Performer()
studioQuery := r.Studio()
tagQuery := r.Tag()
const wildcard = "*"
var err error
if performerCount == 1 && performerIds[0] == wildcard {
performerCount, err = performerQuery.Count()
if err != nil {
return fmt.Errorf("Error getting performer count: %s", err.Error())
}
}
if studioCount == 1 && studioIds[0] == wildcard {
studioCount, err = studioQuery.Count()
if err != nil {
return fmt.Errorf("Error getting studio count: %s", err.Error())
}
}
if tagCount == 1 && tagIds[0] == wildcard {
tagCount, err = tagQuery.Count()
if err != nil {
return fmt.Errorf("Error getting tag count: %s", err.Error())
}
}
return nil
}); err != nil {
logger.Error(err.Error())
return
if s.isFileBasedAutoTag(input) {
// doing file-based auto-tag
s.autoTagScenes(input.Paths, len(input.Performers) > 0, len(input.Studios) > 0, len(input.Tags) > 0)
} else {
// doing specific performer/studio/tag auto-tag
s.autoTagSpecific(input)
}
total := performerCount + studioCount + tagCount
s.Status.setProgress(0, total)
s.autoTagPerformers(input.Paths, performerIds)
s.autoTagStudios(input.Paths, studioIds)
s.autoTagTags(input.Paths, tagIds)
}()
}
func (s *singleton) autoTagScenes(paths []string, performers, studios, tags bool) {
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
ret := &models.SceneFilterType{}
or := ret
sep := string(filepath.Separator)
for _, p := range paths {
if !strings.HasSuffix(p, sep) {
p = p + sep
}
if ret.Path == nil {
or = ret
} else {
newOr := &models.SceneFilterType{}
or.Or = newOr
or = newOr
}
or.Path = &models.StringCriterionInput{
Modifier: models.CriterionModifierEquals,
Value: p + "%",
}
}
organized := false
ret.Organized = &organized
// batch process scenes
batchSize := 1000
page := 1
findFilter := &models.FindFilterType{
PerPage: &batchSize,
Page: &page,
}
more := true
processed := 0
for more {
scenes, total, err := r.Scene().Query(ret, findFilter)
if err != nil {
return err
}
if processed == 0 {
logger.Infof("Starting autotag of %d scenes", total)
}
for _, ss := range scenes {
if s.Status.stopping {
logger.Info("Stopping due to user request")
return nil
}
t := autoTagSceneTask{
txnManager: s.TxnManager,
scene: ss,
performers: performers,
studios: studios,
tags: tags,
}
var wg sync.WaitGroup
wg.Add(1)
go t.Start(&wg)
wg.Wait()
processed++
s.Status.setProgress(processed, total)
}
if len(scenes) != batchSize {
more = false
} else {
page++
}
}
return nil
}); err != nil {
logger.Error(err.Error())
}
logger.Info("Finished autotag")
}
func (s *singleton) autoTagSpecific(input models.AutoTagMetadataInput) {
performerIds := input.Performers
studioIds := input.Studios
tagIds := input.Tags
performerCount := len(performerIds)
studioCount := len(studioIds)
tagCount := len(tagIds)
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
performerQuery := r.Performer()
studioQuery := r.Studio()
tagQuery := r.Tag()
const wildcard = "*"
var err error
if performerCount == 1 && performerIds[0] == wildcard {
performerCount, err = performerQuery.Count()
if err != nil {
return fmt.Errorf("error getting performer count: %s", err.Error())
}
}
if studioCount == 1 && studioIds[0] == wildcard {
studioCount, err = studioQuery.Count()
if err != nil {
return fmt.Errorf("error getting studio count: %s", err.Error())
}
}
if tagCount == 1 && tagIds[0] == wildcard {
tagCount, err = tagQuery.Count()
if err != nil {
return fmt.Errorf("error getting tag count: %s", err.Error())
}
}
return nil
}); err != nil {
logger.Error(err.Error())
return
}
total := performerCount + studioCount + tagCount
s.Status.setProgress(0, total)
logger.Infof("Starting autotag of %d performers, %d studios, %d tags", performerCount, studioCount, tagCount)
s.autoTagPerformers(input.Paths, performerIds)
s.autoTagStudios(input.Paths, studioIds)
s.autoTagTags(input.Paths, tagIds)
logger.Info("Finished autotag")
}
func (s *singleton) autoTagPerformers(paths []string, performerIds []string) {
var wg sync.WaitGroup
if s.Status.stopping {
return
}
for _, performerId := range performerIds {
var performers []*models.Performer
@ -698,46 +812,53 @@ func (s *singleton) autoTagPerformers(paths []string, performerIds []string) {
var err error
performers, err = performerQuery.All()
if err != nil {
return fmt.Errorf("Error querying performers: %s", err.Error())
return fmt.Errorf("error querying performers: %s", err.Error())
}
} else {
performerIdInt, err := strconv.Atoi(performerId)
if err != nil {
return fmt.Errorf("Error parsing performer id %s: %s", performerId, err.Error())
return fmt.Errorf("error parsing performer id %s: %s", performerId, err.Error())
}
performer, err := performerQuery.Find(performerIdInt)
if err != nil {
return fmt.Errorf("Error finding performer id %s: %s", performerId, err.Error())
return fmt.Errorf("error finding performer id %s: %s", performerId, err.Error())
}
if performer == nil {
return fmt.Errorf("performer with id %s not found", performerId)
}
performers = append(performers, performer)
}
for _, performer := range performers {
if s.Status.stopping {
logger.Info("Stopping due to user request")
return nil
}
if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
return autotag.PerformerScenes(performer, paths, r.Scene())
}); err != nil {
return fmt.Errorf("error auto-tagging performer '%s': %s", performer.Name.String, err.Error())
}
s.Status.incrementProgress()
}
return nil
}); err != nil {
logger.Error(err.Error())
continue
}
for _, performer := range performers {
wg.Add(1)
task := AutoTagPerformerTask{
AutoTagTask: AutoTagTask{
txnManager: s.TxnManager,
paths: paths,
},
performer: performer,
}
go task.Start(&wg)
wg.Wait()
s.Status.incrementProgress()
}
}
}
func (s *singleton) autoTagStudios(paths []string, studioIds []string) {
var wg sync.WaitGroup
if s.Status.stopping {
return
}
for _, studioId := range studioIds {
var studios []*models.Studio
@ -747,46 +868,54 @@ func (s *singleton) autoTagStudios(paths []string, studioIds []string) {
var err error
studios, err = studioQuery.All()
if err != nil {
return fmt.Errorf("Error querying studios: %s", err.Error())
return fmt.Errorf("error querying studios: %s", err.Error())
}
} else {
studioIdInt, err := strconv.Atoi(studioId)
if err != nil {
return fmt.Errorf("Error parsing studio id %s: %s", studioId, err.Error())
return fmt.Errorf("error parsing studio id %s: %s", studioId, err.Error())
}
studio, err := studioQuery.Find(studioIdInt)
if err != nil {
return fmt.Errorf("Error finding studio id %s: %s", studioId, err.Error())
return fmt.Errorf("error finding studio id %s: %s", studioId, err.Error())
}
if studio == nil {
return fmt.Errorf("studio with id %s not found", studioId)
}
studios = append(studios, studio)
}
for _, studio := range studios {
if s.Status.stopping {
logger.Info("Stopping due to user request")
return nil
}
if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
return autotag.StudioScenes(studio, paths, r.Scene())
}); err != nil {
return fmt.Errorf("error auto-tagging studio '%s': %s", studio.Name.String, err.Error())
}
s.Status.incrementProgress()
}
return nil
}); err != nil {
logger.Error(err.Error())
continue
}
for _, studio := range studios {
wg.Add(1)
task := AutoTagStudioTask{
AutoTagTask: AutoTagTask{
txnManager: s.TxnManager,
paths: paths,
},
studio: studio,
}
go task.Start(&wg)
wg.Wait()
s.Status.incrementProgress()
}
}
}
func (s *singleton) autoTagTags(paths []string, tagIds []string) {
var wg sync.WaitGroup
if s.Status.stopping {
return
}
for _, tagId := range tagIds {
var tags []*models.Tag
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
@ -795,41 +924,41 @@ func (s *singleton) autoTagTags(paths []string, tagIds []string) {
var err error
tags, err = tagQuery.All()
if err != nil {
return fmt.Errorf("Error querying tags: %s", err.Error())
return fmt.Errorf("error querying tags: %s", err.Error())
}
} else {
tagIdInt, err := strconv.Atoi(tagId)
if err != nil {
return fmt.Errorf("Error parsing tag id %s: %s", tagId, err.Error())
return fmt.Errorf("error parsing tag id %s: %s", tagId, err.Error())
}
tag, err := tagQuery.Find(tagIdInt)
if err != nil {
return fmt.Errorf("Error finding tag id %s: %s", tagId, err.Error())
return fmt.Errorf("error finding tag id %s: %s", tagId, err.Error())
}
tags = append(tags, tag)
}
for _, tag := range tags {
if s.Status.stopping {
logger.Info("Stopping due to user request")
return nil
}
if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
return autotag.TagScenes(tag, paths, r.Scene())
}); err != nil {
return fmt.Errorf("error auto-tagging tag '%s': %s", tag.Name, err.Error())
}
s.Status.incrementProgress()
}
return nil
}); err != nil {
logger.Error(err.Error())
continue
}
for _, tag := range tags {
wg.Add(1)
task := AutoTagTagTask{
AutoTagTask: AutoTagTask{
txnManager: s.TxnManager,
paths: paths,
},
tag: tag,
}
go task.Start(&wg)
wg.Wait()
s.Status.incrementProgress()
}
}
}

View file

@ -2,196 +2,38 @@ package manager
import (
"context"
"database/sql"
"fmt"
"path/filepath"
"strings"
"sync"
"github.com/stashapp/stash/pkg/autotag"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
)
type AutoTagTask struct {
paths []string
type autoTagSceneTask struct {
txnManager models.TransactionManager
scene *models.Scene
performers bool
studios bool
tags bool
}
type AutoTagPerformerTask struct {
AutoTagTask
performer *models.Performer
}
func (t *AutoTagPerformerTask) Start(wg *sync.WaitGroup) {
func (t *autoTagSceneTask) Start(wg *sync.WaitGroup) {
defer wg.Done()
t.autoTagPerformer()
}
func (t *AutoTagTask) getQueryRegex(name string) string {
const separatorChars = `.\-_ `
// handle path separators
const separator = `[` + separatorChars + `]`
ret := strings.Replace(name, " ", separator+"*", -1)
ret = `(?:^|_|[^\w\d])` + ret + `(?:$|_|[^\w\d])`
return ret
}
func (t *AutoTagTask) getQueryFilter(regex string) *models.SceneFilterType {
organized := false
ret := &models.SceneFilterType{
Path: &models.StringCriterionInput{
Modifier: models.CriterionModifierMatchesRegex,
Value: "(?i)" + regex,
},
Organized: &organized,
}
sep := string(filepath.Separator)
var or *models.SceneFilterType
for _, p := range t.paths {
newOr := &models.SceneFilterType{}
if or == nil {
ret.And = newOr
} else {
or.Or = newOr
}
or = newOr
if !strings.HasSuffix(p, sep) {
p = p + sep
}
or.Path = &models.StringCriterionInput{
Modifier: models.CriterionModifierEquals,
Value: p + "%",
}
}
return ret
}
func (t *AutoTagTask) getFindFilter() *models.FindFilterType {
perPage := -1
return &models.FindFilterType{
PerPage: &perPage,
}
}
func (t *AutoTagPerformerTask) autoTagPerformer() {
regex := t.getQueryRegex(t.performer.Name.String)
if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error {
qb := r.Scene()
scenes, _, err := qb.Query(t.getQueryFilter(regex), t.getFindFilter())
if err != nil {
return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error())
}
for _, s := range scenes {
added, err := scene.AddPerformer(qb, s.ID, t.performer.ID)
if err != nil {
return fmt.Errorf("Error adding performer '%s' to scene '%s': %s", t.performer.Name.String, s.GetTitle(), err.Error())
if t.performers {
if err := autotag.ScenePerformers(t.scene, r.Scene(), r.Performer()); err != nil {
return err
}
if added {
logger.Infof("Added performer '%s' to scene '%s'", t.performer.Name.String, s.GetTitle())
}
}
return nil
}); err != nil {
logger.Error(err.Error())
}
}
type AutoTagStudioTask struct {
AutoTagTask
studio *models.Studio
}
func (t *AutoTagStudioTask) Start(wg *sync.WaitGroup) {
defer wg.Done()
t.autoTagStudio()
}
func (t *AutoTagStudioTask) autoTagStudio() {
regex := t.getQueryRegex(t.studio.Name.String)
if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error {
qb := r.Scene()
scenes, _, err := qb.Query(t.getQueryFilter(regex), t.getFindFilter())
if err != nil {
return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error())
}
for _, s := range scenes {
// #306 - don't overwrite studio if already present
if s.StudioID.Valid {
// don't modify
continue
}
logger.Infof("Adding studio '%s' to scene '%s'", t.studio.Name.String, s.GetTitle())
// set the studio id
studioID := sql.NullInt64{Int64: int64(t.studio.ID), Valid: true}
scenePartial := models.ScenePartial{
ID: s.ID,
StudioID: &studioID,
}
if _, err := qb.Update(scenePartial); err != nil {
return fmt.Errorf("Error adding studio to scene: %s", err.Error())
}
}
return nil
}); err != nil {
logger.Error(err.Error())
}
}
type AutoTagTagTask struct {
AutoTagTask
tag *models.Tag
}
func (t *AutoTagTagTask) Start(wg *sync.WaitGroup) {
defer wg.Done()
t.autoTagTag()
}
func (t *AutoTagTagTask) autoTagTag() {
regex := t.getQueryRegex(t.tag.Name)
if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error {
qb := r.Scene()
scenes, _, err := qb.Query(t.getQueryFilter(regex), t.getFindFilter())
if err != nil {
return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error())
}
for _, s := range scenes {
added, err := scene.AddTag(qb, s.ID, t.tag.ID)
if err != nil {
return fmt.Errorf("Error adding tag '%s' to scene '%s': %s", t.tag.Name, s.GetTitle(), err.Error())
}
if added {
logger.Infof("Added tag '%s' to scene '%s'", t.tag.Name, s.GetTitle())
}
if t.studios {
if err := autotag.SceneStudios(t.scene, r.Scene(), r.Studio()); err != nil {
return err
}
}
if t.tags {
if err := autotag.SceneTags(t.scene, r.Scene(), r.Tag()); err != nil {
return err
}
}

View file

@ -1,5 +1,9 @@
package models
// PerPageAll is the value used for perPage to indicate all results should be
// returned.
const PerPageAll = -1
func (ff FindFilterType) GetSort(defaultSort string) string {
var sort string
if ff.Sort == nil {

View file

@ -388,6 +388,29 @@ func (_m *PerformerReaderWriter) Query(performerFilter *models.PerformerFilterTy
return r0, r1, r2
}
// QueryForAutoTag provides a mock function with given fields: words
func (_m *PerformerReaderWriter) QueryForAutoTag(words []string) ([]*models.Performer, error) {
ret := _m.Called(words)
var r0 []*models.Performer
if rf, ok := ret.Get(0).(func([]string) []*models.Performer); ok {
r0 = rf(words)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
}
}
var r1 error
if rf, ok := ret.Get(1).(func([]string) error); ok {
r1 = rf(words)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Update provides a mock function with given fields: updatedPerformer
func (_m *PerformerReaderWriter) Update(updatedPerformer models.PerformerPartial) (*models.Performer, error) {
ret := _m.Called(updatedPerformer)

View file

@ -296,6 +296,29 @@ func (_m *StudioReaderWriter) Query(studioFilter *models.StudioFilterType, findF
return r0, r1, r2
}
// QueryForAutoTag provides a mock function with given fields: words
func (_m *StudioReaderWriter) QueryForAutoTag(words []string) ([]*models.Studio, error) {
ret := _m.Called(words)
var r0 []*models.Studio
if rf, ok := ret.Get(0).(func([]string) []*models.Studio); ok {
r0 = rf(words)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
}
}
var r1 error
if rf, ok := ret.Get(1).(func([]string) error); ok {
r1 = rf(words)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Update provides a mock function with given fields: updatedStudio
func (_m *StudioReaderWriter) Update(updatedStudio models.StudioPartial) (*models.Studio, error) {
ret := _m.Called(updatedStudio)

View file

@ -367,6 +367,29 @@ func (_m *TagReaderWriter) Query(tagFilter *models.TagFilterType, findFilter *mo
return r0, r1, r2
}
// QueryForAutoTag provides a mock function with given fields: words
func (_m *TagReaderWriter) QueryForAutoTag(words []string) ([]*models.Tag, error) {
ret := _m.Called(words)
var r0 []*models.Tag
if rf, ok := ret.Get(0).(func([]string) []*models.Tag); ok {
r0 = rf(words)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
}
}
var r1 error
if rf, ok := ret.Get(1).(func([]string) error); ok {
r1 = rf(words)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Update provides a mock function with given fields: updatedTag
func (_m *TagReaderWriter) Update(updatedTag models.Tag) (*models.Tag, error) {
ret := _m.Called(updatedTag)

View file

@ -11,6 +11,9 @@ type PerformerReader interface {
CountByTagID(tagID int) (int, error)
Count() (int, error)
All() ([]*Performer, error)
// TODO - this interface is temporary until the filter schema can fully
// support the query needed
QueryForAutoTag(words []string) ([]*Performer, error)
Query(performerFilter *PerformerFilterType, findFilter *FindFilterType) ([]*Performer, int, error)
GetImage(performerID int) ([]byte, error)
GetStashIDs(performerID int) ([]*StashID, error)

View file

@ -7,6 +7,9 @@ type StudioReader interface {
FindByName(name string, nocase bool) (*Studio, error)
Count() (int, error)
All() ([]*Studio, error)
// TODO - this interface is temporary until the filter schema can fully
// support the query needed
QueryForAutoTag(words []string) ([]*Studio, error)
Query(studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error)
GetImage(studioID int) ([]byte, error)
HasImage(studioID int) (bool, error)

View file

@ -12,6 +12,9 @@ type TagReader interface {
FindByNames(names []string, nocase bool) ([]*Tag, error)
Count() (int, error)
All() ([]*Tag, error)
// TODO - this interface is temporary until the filter schema can fully
// support the query needed
QueryForAutoTag(words []string) ([]*Tag, error)
Query(tagFilter *TagFilterType, findFilter *FindFilterType) ([]*Tag, int, error)
GetImage(tagID int) ([]byte, error)
}

View file

@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"strconv"
"strings"
"github.com/stashapp/stash/pkg/models"
)
@ -172,6 +173,25 @@ func (qb *performerQueryBuilder) All() ([]*models.Performer, error) {
return qb.queryPerformers(selectAll("performers")+qb.getPerformerSort(nil), nil)
}
func (qb *performerQueryBuilder) QueryForAutoTag(words []string) ([]*models.Performer, error) {
// TODO - Query needs to be changed to support queries of this type, and
// this method should be removed
query := selectAll(performerTable)
var whereClauses []string
var args []interface{}
for _, w := range words {
whereClauses = append(whereClauses, "name like ?")
args = append(args, "%"+w+"%")
whereClauses = append(whereClauses, "aliases like ?")
args = append(args, "%"+w+"%")
}
where := strings.Join(whereClauses, " OR ")
return qb.queryPerformers(query+" WHERE "+where, args)
}
func (qb *performerQueryBuilder) Query(performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) {
if performerFilter == nil {
performerFilter = &models.PerformerFilterType{}

View file

@ -100,6 +100,26 @@ func TestPerformerFindByNames(t *testing.T) {
})
}
func TestPerformerQueryForAutoTag(t *testing.T) {
withTxn(func(r models.Repository) error {
tqb := r.Performer()
name := performerNames[performerIdxWithScene] // find a performer by name
performers, err := tqb.QueryForAutoTag([]string{name})
if err != nil {
t.Errorf("Error finding performers: %s", err.Error())
}
assert.Len(t, performers, 2)
assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[0].Name.String))
assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[1].Name.String))
return nil
})
}
func TestPerformerUpdatePerformerImage(t *testing.T) {
if err := withTxn(func(r models.Repository) error {
qb := r.Performer()

View file

@ -3,6 +3,7 @@ package sqlite
import (
"database/sql"
"fmt"
"strings"
"github.com/stashapp/stash/pkg/models"
)
@ -121,6 +122,23 @@ func (qb *studioQueryBuilder) All() ([]*models.Studio, error) {
return qb.queryStudios(selectAll("studios")+qb.getStudioSort(nil), nil)
}
func (qb *studioQueryBuilder) QueryForAutoTag(words []string) ([]*models.Studio, error) {
// TODO - Query needs to be changed to support queries of this type, and
// this method should be removed
query := selectAll(studioTable)
var whereClauses []string
var args []interface{}
for _, w := range words {
whereClauses = append(whereClauses, "name like ?")
args = append(args, "%"+w+"%")
}
where := strings.Join(whereClauses, " OR ")
return qb.queryStudios(query+" WHERE "+where, args)
}
func (qb *studioQueryBuilder) Query(studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
if studioFilter == nil {
studioFilter = &models.StudioFilterType{}

View file

@ -45,6 +45,26 @@ func TestStudioFindByName(t *testing.T) {
})
}
func TestStudioQueryForAutoTag(t *testing.T) {
withTxn(func(r models.Repository) error {
tqb := r.Studio()
name := studioNames[studioIdxWithScene] // find a studio by name
studios, err := tqb.QueryForAutoTag([]string{name})
if err != nil {
t.Errorf("Error finding studios: %s", err.Error())
}
assert.Len(t, studios, 2)
assert.Equal(t, strings.ToLower(studioNames[studioIdxWithScene]), strings.ToLower(studios[0].Name.String))
assert.Equal(t, strings.ToLower(studioNames[studioIdxWithScene]), strings.ToLower(studios[1].Name.String))
return nil
})
}
func TestStudioQueryParent(t *testing.T) {
withTxn(func(r models.Repository) error {
sqb := r.Studio()

View file

@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"fmt"
"strings"
"github.com/stashapp/stash/pkg/models"
)
@ -192,6 +193,23 @@ func (qb *tagQueryBuilder) All() ([]*models.Tag, error) {
return qb.queryTags(selectAll("tags")+qb.getDefaultTagSort(), nil)
}
func (qb *tagQueryBuilder) QueryForAutoTag(words []string) ([]*models.Tag, error) {
// TODO - Query needs to be changed to support queries of this type, and
// this method should be removed
query := selectAll(tagTable)
var whereClauses []string
var args []interface{}
for _, w := range words {
whereClauses = append(whereClauses, "name like ?")
args = append(args, "%"+w+"%")
}
where := strings.Join(whereClauses, " OR ")
return qb.queryTags(query+" WHERE "+where, args)
}
func (qb *tagQueryBuilder) validateFilter(tagFilter *models.TagFilterType) error {
const and = "AND"
const or = "OR"

View file

@ -70,6 +70,26 @@ func TestTagFindByName(t *testing.T) {
})
}
func TestTagQueryForAutoTag(t *testing.T) {
withTxn(func(r models.Repository) error {
tqb := r.Tag()
name := tagNames[tagIdxWithScene] // find a tag by name
tags, err := tqb.QueryForAutoTag([]string{name})
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
}
assert.Len(t, tags, 2)
assert.Equal(t, strings.ToLower(tagNames[tagIdxWithScene]), strings.ToLower(tags[0].Name))
assert.Equal(t, strings.ToLower(tagNames[tagIdxWithScene]), strings.ToLower(tags[1].Name))
return nil
})
}
func TestTagFindByNames(t *testing.T) {
var names []string

View file

@ -11,6 +11,7 @@
* Added scene queue.
### 🎨 Improvements
* Improved performance of the auto-tagger.
* Clean generation artifacts after generating each scene.
* Log message at startup when cleaning the `tmp` and `downloads` generated folders takes more than one second.
* Sort movie scenes by scene number by default.
@ -27,6 +28,7 @@
* Change performer text query to search by name and alias only.
### 🐛 Bug fixes
* Fixed error when auto-tagging for performers/studios/tags with regex characters in the name.
* Fix scraped performer image not updating after clearing the current image when creating a new performer.
* Fix error preventing adding a new library path when an existing library path is missing.
* Fix whitespace in query string returning all objects.