mirror of
https://github.com/stashapp/stash.git
synced 2025-12-15 21:03:22 +01:00
Auto tag rewrite (#1324)
This commit is contained in:
parent
f66010a367
commit
2eb2d865dc
26 changed files with 1469 additions and 370 deletions
|
|
@ -1,6 +1,6 @@
|
|||
// +build integration
|
||||
|
||||
package manager
|
||||
package autotag
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
|
@ -8,8 +8,6 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stashapp/stash/pkg/database"
|
||||
|
|
@ -22,49 +20,12 @@ import (
|
|||
)
|
||||
|
||||
const testName = "Foo's Bar"
|
||||
const testExtension = ".mp4"
|
||||
const existingStudioName = "ExistingStudio"
|
||||
|
||||
const existingStudioSceneName = testName + ".dontChangeStudio" + testExtension
|
||||
const existingStudioSceneName = testName + ".dontChangeStudio.mp4"
|
||||
|
||||
var existingStudioID int
|
||||
|
||||
var testSeparators = []string{
|
||||
".",
|
||||
"-",
|
||||
"_",
|
||||
" ",
|
||||
}
|
||||
|
||||
var testEndSeparators = []string{
|
||||
"{",
|
||||
"}",
|
||||
"(",
|
||||
")",
|
||||
",",
|
||||
}
|
||||
|
||||
func generateNamePatterns(name, separator string) []string {
|
||||
var ret []string
|
||||
ret = append(ret, fmt.Sprintf("%s%saaa"+testExtension, name, separator))
|
||||
ret = append(ret, fmt.Sprintf("aaa%s%s"+testExtension, separator, name))
|
||||
ret = append(ret, fmt.Sprintf("aaa%s%s%sbbb"+testExtension, separator, name, separator))
|
||||
ret = append(ret, fmt.Sprintf("dir/%s%saaa"+testExtension, name, separator))
|
||||
ret = append(ret, fmt.Sprintf("dir\\%s%saaa"+testExtension, name, separator))
|
||||
ret = append(ret, fmt.Sprintf("%s%saaa/dir/bbb"+testExtension, name, separator))
|
||||
ret = append(ret, fmt.Sprintf("%s%saaa\\dir\\bbb"+testExtension, name, separator))
|
||||
ret = append(ret, fmt.Sprintf("dir/%s%s/aaa"+testExtension, name, separator))
|
||||
ret = append(ret, fmt.Sprintf("dir\\%s%s\\aaa"+testExtension, name, separator))
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateFalseNamePattern(name string, separator string) string {
|
||||
splitted := strings.Split(name, " ")
|
||||
|
||||
return fmt.Sprintf("%s%saaa%s%s"+testExtension, splitted[0], separator, separator, splitted[1])
|
||||
}
|
||||
|
||||
func testTeardown(databaseFile string) {
|
||||
err := database.DB.Close()
|
||||
|
||||
|
|
@ -126,7 +87,7 @@ func createStudio(qb models.StudioWriter, name string) (*models.Studio, error) {
|
|||
// create the studio
|
||||
studio := models.Studio{
|
||||
Checksum: name,
|
||||
Name: sql.NullString{Valid: true, String: testName},
|
||||
Name: sql.NullString{Valid: true, String: name},
|
||||
}
|
||||
|
||||
return qb.Create(studio)
|
||||
|
|
@ -148,23 +109,7 @@ func createTag(qb models.TagWriter) error {
|
|||
|
||||
func createScenes(sqb models.SceneReaderWriter) error {
|
||||
// create the scenes
|
||||
var scenePatterns []string
|
||||
var falseScenePatterns []string
|
||||
|
||||
separators := append(testSeparators, testEndSeparators...)
|
||||
|
||||
for _, separator := range separators {
|
||||
scenePatterns = append(scenePatterns, generateNamePatterns(testName, separator)...)
|
||||
scenePatterns = append(scenePatterns, generateNamePatterns(strings.ToLower(testName), separator)...)
|
||||
falseScenePatterns = append(falseScenePatterns, generateFalseNamePattern(testName, separator))
|
||||
}
|
||||
|
||||
// add test cases for intra-name separators
|
||||
for _, separator := range testSeparators {
|
||||
if separator != " " {
|
||||
scenePatterns = append(scenePatterns, generateNamePatterns(strings.Replace(testName, " ", separator, -1), separator)...)
|
||||
}
|
||||
}
|
||||
scenePatterns, falseScenePatterns := generateScenePaths(testName)
|
||||
|
||||
for _, fn := range scenePatterns {
|
||||
err := createScene(sqb, makeScene(fn, true))
|
||||
|
|
@ -278,17 +223,14 @@ func TestParsePerformers(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
task := AutoTagPerformerTask{
|
||||
AutoTagTask: AutoTagTask{
|
||||
txnManager: sqlite.NewTransactionManager(),
|
||||
},
|
||||
performer: performers[0],
|
||||
for _, p := range performers {
|
||||
if err := withTxn(func(r models.Repository) error {
|
||||
return PerformerScenes(p, nil, r.Scene())
|
||||
}); err != nil {
|
||||
t.Errorf("Error auto-tagging performers: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
task.Start(&wg)
|
||||
|
||||
// verify that scenes were tagged correctly
|
||||
withTxn(func(r models.Repository) error {
|
||||
pqb := r.Performer()
|
||||
|
|
@ -328,17 +270,14 @@ func TestParseStudios(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
task := AutoTagStudioTask{
|
||||
AutoTagTask: AutoTagTask{
|
||||
txnManager: sqlite.NewTransactionManager(),
|
||||
},
|
||||
studio: studios[0],
|
||||
for _, s := range studios {
|
||||
if err := withTxn(func(r models.Repository) error {
|
||||
return StudioScenes(s, nil, r.Scene())
|
||||
}); err != nil {
|
||||
t.Errorf("Error auto-tagging performers: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
task.Start(&wg)
|
||||
|
||||
// verify that scenes were tagged correctly
|
||||
withTxn(func(r models.Repository) error {
|
||||
scenes, err := r.Scene().All()
|
||||
|
|
@ -354,9 +293,14 @@ func TestParseStudios(t *testing.T) {
|
|||
}
|
||||
} else {
|
||||
// title is only set on scenes where we expect studio to be set
|
||||
if scene.Title.String == scene.Path && scene.StudioID.Int64 != int64(studios[0].ID) {
|
||||
t.Errorf("Did not set studio '%s' for path '%s'", testName, scene.Path)
|
||||
} else if scene.Title.String != scene.Path && scene.StudioID.Int64 == int64(studios[0].ID) {
|
||||
if scene.Title.String == scene.Path {
|
||||
if !scene.StudioID.Valid {
|
||||
t.Errorf("Did not set studio '%s' for path '%s'", testName, scene.Path)
|
||||
} else if scene.StudioID.Int64 != int64(studios[1].ID) {
|
||||
t.Errorf("Incorrect studio id %d set for path '%s'", scene.StudioID.Int64, scene.Path)
|
||||
}
|
||||
|
||||
} else if scene.Title.String != scene.Path && scene.StudioID.Int64 == int64(studios[1].ID) {
|
||||
t.Errorf("Incorrectly set studio '%s' for path '%s'", testName, scene.Path)
|
||||
}
|
||||
}
|
||||
|
|
@ -377,17 +321,14 @@ func TestParseTags(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
task := AutoTagTagTask{
|
||||
AutoTagTask: AutoTagTask{
|
||||
txnManager: sqlite.NewTransactionManager(),
|
||||
},
|
||||
tag: tags[0],
|
||||
for _, s := range tags {
|
||||
if err := withTxn(func(r models.Repository) error {
|
||||
return TagScenes(s, nil, r.Scene())
|
||||
}); err != nil {
|
||||
t.Errorf("Error auto-tagging performers: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
task.Start(&wg)
|
||||
|
||||
// verify that scenes were tagged correctly
|
||||
withTxn(func(r models.Repository) error {
|
||||
scenes, err := r.Scene().All()
|
||||
42
pkg/autotag/performer.go
Normal file
42
pkg/autotag/performer.go
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
package autotag
|
||||
|
||||
import (
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/scene"
|
||||
)
|
||||
|
||||
func getMatchingPerformers(path string, performerReader models.PerformerReader) ([]*models.Performer, error) {
|
||||
words := getPathWords(path)
|
||||
performers, err := performerReader.QueryForAutoTag(words)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret []*models.Performer
|
||||
for _, p := range performers {
|
||||
// TODO - commenting out alias handling until both sides work correctly
|
||||
if nameMatchesPath(p.Name.String, path) { // || nameMatchesPath(p.Aliases.String, path) {
|
||||
ret = append(ret, p)
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func getPerformerTagger(p *models.Performer) tagger {
|
||||
return tagger{
|
||||
ID: p.ID,
|
||||
Type: "performer",
|
||||
Name: p.Name.String,
|
||||
}
|
||||
}
|
||||
|
||||
// PerformerScenes searches for scenes whose path matches the provided performer name and tags the scene with the performer.
|
||||
func PerformerScenes(p *models.Performer, paths []string, rw models.SceneReaderWriter) error {
|
||||
t := getPerformerTagger(p)
|
||||
|
||||
return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) {
|
||||
return scene.AddPerformer(rw, otherID, subjectID)
|
||||
})
|
||||
}
|
||||
81
pkg/autotag/performer_test.go
Normal file
81
pkg/autotag/performer_test.go
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
package autotag
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/models/mocks"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPerformerScenes(t *testing.T) {
|
||||
type test struct {
|
||||
performerName string
|
||||
expectedRegex string
|
||||
}
|
||||
|
||||
performerNames := []test{
|
||||
{
|
||||
"performer name",
|
||||
`(?i)(?:^|_|[^\w\d])performer[.\-_ ]*name(?:$|_|[^\w\d])`,
|
||||
},
|
||||
{
|
||||
"performer + name",
|
||||
`(?i)(?:^|_|[^\w\d])performer[.\-_ ]*\+[.\-_ ]*name(?:$|_|[^\w\d])`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range performerNames {
|
||||
testPerformerScenes(t, p.performerName, p.expectedRegex)
|
||||
}
|
||||
}
|
||||
|
||||
func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
|
||||
const performerID = 2
|
||||
|
||||
var scenes []*models.Scene
|
||||
matchingPaths, falsePaths := generateScenePaths(performerName)
|
||||
for i, p := range append(matchingPaths, falsePaths...) {
|
||||
scenes = append(scenes, &models.Scene{
|
||||
ID: i + 1,
|
||||
Path: p,
|
||||
})
|
||||
}
|
||||
|
||||
performer := models.Performer{
|
||||
ID: performerID,
|
||||
Name: models.NullString(performerName),
|
||||
}
|
||||
|
||||
organized := false
|
||||
perPage := models.PerPageAll
|
||||
|
||||
expectedSceneFilter := &models.SceneFilterType{
|
||||
Organized: &organized,
|
||||
Path: &models.StringCriterionInput{
|
||||
Value: expectedRegex,
|
||||
Modifier: models.CriterionModifierMatchesRegex,
|
||||
},
|
||||
}
|
||||
|
||||
expectedFindFilter := &models.FindFilterType{
|
||||
PerPage: &perPage,
|
||||
}
|
||||
|
||||
mockSceneReader.On("Query", expectedSceneFilter, expectedFindFilter).Return(scenes, len(scenes), nil).Once()
|
||||
|
||||
for i := range matchingPaths {
|
||||
sceneID := i + 1
|
||||
mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once()
|
||||
mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once()
|
||||
}
|
||||
|
||||
err := PerformerScenes(&performer, nil, mockSceneReader)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Nil(err)
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
}
|
||||
117
pkg/autotag/scene.go
Normal file
117
pkg/autotag/scene.go
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
package autotag
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/scene"
|
||||
)
|
||||
|
||||
func pathsFilter(paths []string) *models.SceneFilterType {
|
||||
if paths == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sep := string(filepath.Separator)
|
||||
|
||||
var ret *models.SceneFilterType
|
||||
var or *models.SceneFilterType
|
||||
for _, p := range paths {
|
||||
newOr := &models.SceneFilterType{}
|
||||
if or != nil {
|
||||
or.Or = newOr
|
||||
} else {
|
||||
ret = newOr
|
||||
}
|
||||
|
||||
or = newOr
|
||||
|
||||
if !strings.HasSuffix(p, sep) {
|
||||
p = p + sep
|
||||
}
|
||||
|
||||
or.Path = &models.StringCriterionInput{
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
Value: p + "%",
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func getMatchingScenes(name string, paths []string, sceneReader models.SceneReader) ([]*models.Scene, error) {
|
||||
regex := getPathQueryRegex(name)
|
||||
organized := false
|
||||
filter := models.SceneFilterType{
|
||||
Path: &models.StringCriterionInput{
|
||||
Value: "(?i)" + regex,
|
||||
Modifier: models.CriterionModifierMatchesRegex,
|
||||
},
|
||||
Organized: &organized,
|
||||
}
|
||||
|
||||
filter.And = pathsFilter(paths)
|
||||
|
||||
pp := models.PerPageAll
|
||||
scenes, _, err := sceneReader.Query(&filter, &models.FindFilterType{
|
||||
PerPage: &pp,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error querying scenes with regex '%s': %s", regex, err.Error())
|
||||
}
|
||||
|
||||
var ret []*models.Scene
|
||||
for _, p := range scenes {
|
||||
if nameMatchesPath(name, p.Path) {
|
||||
ret = append(ret, p)
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func getSceneFileTagger(s *models.Scene) tagger {
|
||||
return tagger{
|
||||
ID: s.ID,
|
||||
Type: "scene",
|
||||
Name: s.GetTitle(),
|
||||
Path: s.Path,
|
||||
}
|
||||
}
|
||||
|
||||
// ScenePerformers tags the provided scene with performers whose name matches the scene's path.
|
||||
func ScenePerformers(s *models.Scene, rw models.SceneReaderWriter, performerReader models.PerformerReader) error {
|
||||
t := getSceneFileTagger(s)
|
||||
|
||||
return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) {
|
||||
return scene.AddPerformer(rw, subjectID, otherID)
|
||||
})
|
||||
}
|
||||
|
||||
// SceneStudios tags the provided scene with the first studio whose name matches the scene's path.
|
||||
//
|
||||
// Scenes will not be tagged if studio is already set.
|
||||
func SceneStudios(s *models.Scene, rw models.SceneReaderWriter, studioReader models.StudioReader) error {
|
||||
if s.StudioID.Valid {
|
||||
// don't modify
|
||||
return nil
|
||||
}
|
||||
|
||||
t := getSceneFileTagger(s)
|
||||
|
||||
return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) {
|
||||
return addSceneStudio(rw, subjectID, otherID)
|
||||
})
|
||||
}
|
||||
|
||||
// SceneTags tags the provided scene with tags whose name matches the scene's path.
|
||||
func SceneTags(s *models.Scene, rw models.SceneReaderWriter, tagReader models.TagReader) error {
|
||||
t := getSceneFileTagger(s)
|
||||
|
||||
return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) {
|
||||
return scene.AddTag(rw, subjectID, otherID)
|
||||
})
|
||||
}
|
||||
276
pkg/autotag/scene_test.go
Normal file
276
pkg/autotag/scene_test.go
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
package autotag
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/models/mocks"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
var testSeparators = []string{
|
||||
".",
|
||||
"-",
|
||||
"_",
|
||||
" ",
|
||||
}
|
||||
|
||||
var testEndSeparators = []string{
|
||||
"{",
|
||||
"}",
|
||||
"(",
|
||||
")",
|
||||
",",
|
||||
}
|
||||
|
||||
func generateNamePatterns(name, separator string) []string {
|
||||
var ret []string
|
||||
ret = append(ret, fmt.Sprintf("%s%saaa.mp4", name, separator))
|
||||
ret = append(ret, fmt.Sprintf("aaa%s%s.mp4", separator, name))
|
||||
ret = append(ret, fmt.Sprintf("aaa%s%s%sbbb.mp4", separator, name, separator))
|
||||
ret = append(ret, fmt.Sprintf("dir/%s%saaa.mp4", name, separator))
|
||||
ret = append(ret, fmt.Sprintf("dir\\%s%saaa.mp4", name, separator))
|
||||
ret = append(ret, fmt.Sprintf("%s%saaa/dir/bbb.mp4", name, separator))
|
||||
ret = append(ret, fmt.Sprintf("%s%saaa\\dir\\bbb.mp4", name, separator))
|
||||
ret = append(ret, fmt.Sprintf("dir/%s%s/aaa.mp4", name, separator))
|
||||
ret = append(ret, fmt.Sprintf("dir\\%s%s\\aaa.mp4", name, separator))
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateSplitNamePatterns(name, separator string) []string {
|
||||
var ret []string
|
||||
splitted := strings.Split(name, " ")
|
||||
// only do this for names that are split into two
|
||||
if len(splitted) == 2 {
|
||||
ret = append(ret, fmt.Sprintf("%s%s%s.mp4", splitted[0], separator, splitted[1]))
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateFalseNamePatterns(name string, separator string) []string {
|
||||
splitted := strings.Split(name, " ")
|
||||
|
||||
var ret []string
|
||||
// only do this for names that are split into two
|
||||
if len(splitted) == 2 {
|
||||
ret = append(ret, fmt.Sprintf("%s%saaa%s%s.mp4", splitted[0], separator, separator, splitted[1]))
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateScenePaths(testName string) (scenePatterns []string, falseScenePatterns []string) {
|
||||
separators := append(testSeparators, testEndSeparators...)
|
||||
|
||||
for _, separator := range separators {
|
||||
scenePatterns = append(scenePatterns, generateNamePatterns(testName, separator)...)
|
||||
scenePatterns = append(scenePatterns, generateNamePatterns(strings.ToLower(testName), separator)...)
|
||||
scenePatterns = append(scenePatterns, generateNamePatterns(strings.ReplaceAll(testName, " ", ""), separator)...)
|
||||
falseScenePatterns = append(falseScenePatterns, generateFalseNamePatterns(testName, separator)...)
|
||||
}
|
||||
|
||||
// add test cases for intra-name separators
|
||||
for _, separator := range testSeparators {
|
||||
if separator != " " {
|
||||
scenePatterns = append(scenePatterns, generateNamePatterns(strings.Replace(testName, " ", separator, -1), separator)...)
|
||||
}
|
||||
}
|
||||
|
||||
// add basic false scenarios
|
||||
falseScenePatterns = append(falseScenePatterns, fmt.Sprintf("aaa%s.mp4", testName))
|
||||
falseScenePatterns = append(falseScenePatterns, fmt.Sprintf("%saaa.mp4", testName))
|
||||
|
||||
// add path separator false scenarios
|
||||
falseScenePatterns = append(falseScenePatterns, generateFalseNamePatterns(testName, "/")...)
|
||||
falseScenePatterns = append(falseScenePatterns, generateFalseNamePatterns(testName, "\\")...)
|
||||
|
||||
// split patterns only valid for ._- and whitespace
|
||||
for _, separator := range testSeparators {
|
||||
scenePatterns = append(scenePatterns, generateSplitNamePatterns(testName, separator)...)
|
||||
}
|
||||
|
||||
// false patterns for other separators
|
||||
for _, separator := range testEndSeparators {
|
||||
falseScenePatterns = append(falseScenePatterns, generateSplitNamePatterns(testName, separator)...)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type pathTestTable struct {
|
||||
ScenePath string
|
||||
Matches bool
|
||||
}
|
||||
|
||||
func generateTestTable(testName string) []pathTestTable {
|
||||
var ret []pathTestTable
|
||||
|
||||
var scenePatterns []string
|
||||
var falseScenePatterns []string
|
||||
|
||||
separators := append(testSeparators, testEndSeparators...)
|
||||
|
||||
for _, separator := range separators {
|
||||
scenePatterns = append(scenePatterns, generateNamePatterns(testName, separator)...)
|
||||
scenePatterns = append(scenePatterns, generateNamePatterns(strings.ToLower(testName), separator)...)
|
||||
falseScenePatterns = append(falseScenePatterns, generateFalseNamePatterns(testName, separator)...)
|
||||
}
|
||||
|
||||
for _, p := range scenePatterns {
|
||||
t := pathTestTable{
|
||||
ScenePath: p,
|
||||
Matches: true,
|
||||
}
|
||||
|
||||
ret = append(ret, t)
|
||||
}
|
||||
|
||||
for _, p := range falseScenePatterns {
|
||||
t := pathTestTable{
|
||||
ScenePath: p,
|
||||
Matches: false,
|
||||
}
|
||||
|
||||
ret = append(ret, t)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func TestScenePerformers(t *testing.T) {
|
||||
const sceneID = 1
|
||||
const performerName = "performer name"
|
||||
const performerID = 2
|
||||
performer := models.Performer{
|
||||
ID: performerID,
|
||||
Name: models.NullString(performerName),
|
||||
}
|
||||
|
||||
const reversedPerformerName = "name performer"
|
||||
const reversedPerformerID = 3
|
||||
reversedPerformer := models.Performer{
|
||||
ID: reversedPerformerID,
|
||||
Name: models.NullString(reversedPerformerName),
|
||||
}
|
||||
|
||||
testTables := generateTestTable(performerName)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
for _, test := range testTables {
|
||||
mockPerformerReader := &mocks.PerformerReaderWriter{}
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
|
||||
mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
|
||||
|
||||
if test.Matches {
|
||||
mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once()
|
||||
mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once()
|
||||
}
|
||||
|
||||
scene := models.Scene{
|
||||
ID: sceneID,
|
||||
Path: test.ScenePath,
|
||||
}
|
||||
err := ScenePerformers(&scene, mockSceneReader, mockPerformerReader)
|
||||
|
||||
assert.Nil(err)
|
||||
mockPerformerReader.AssertExpectations(t)
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSceneStudios(t *testing.T) {
|
||||
const sceneID = 1
|
||||
const studioName = "studio name"
|
||||
const studioID = 2
|
||||
studio := models.Studio{
|
||||
ID: studioID,
|
||||
Name: models.NullString(studioName),
|
||||
}
|
||||
|
||||
const reversedStudioName = "name studio"
|
||||
const reversedStudioID = 3
|
||||
reversedStudio := models.Studio{
|
||||
ID: reversedStudioID,
|
||||
Name: models.NullString(reversedStudioName),
|
||||
}
|
||||
|
||||
testTables := generateTestTable(studioName)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
for _, test := range testTables {
|
||||
mockStudioReader := &mocks.StudioReaderWriter{}
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
|
||||
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
|
||||
if test.Matches {
|
||||
mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once()
|
||||
expectedStudioID := models.NullInt64(studioID)
|
||||
mockSceneReader.On("Update", models.ScenePartial{
|
||||
ID: sceneID,
|
||||
StudioID: &expectedStudioID,
|
||||
}).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
scene := models.Scene{
|
||||
ID: sceneID,
|
||||
Path: test.ScenePath,
|
||||
}
|
||||
err := SceneStudios(&scene, mockSceneReader, mockStudioReader)
|
||||
|
||||
assert.Nil(err)
|
||||
mockStudioReader.AssertExpectations(t)
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSceneTags(t *testing.T) {
|
||||
const sceneID = 1
|
||||
const tagName = "tag name"
|
||||
const tagID = 2
|
||||
tag := models.Tag{
|
||||
ID: tagID,
|
||||
Name: tagName,
|
||||
}
|
||||
|
||||
const reversedTagName = "name tag"
|
||||
const reversedTagID = 3
|
||||
reversedTag := models.Tag{
|
||||
ID: reversedTagID,
|
||||
Name: reversedTagName,
|
||||
}
|
||||
|
||||
testTables := generateTestTable(tagName)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
for _, test := range testTables {
|
||||
mockTagReader := &mocks.TagReaderWriter{}
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
|
||||
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
|
||||
if test.Matches {
|
||||
mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once()
|
||||
mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once()
|
||||
}
|
||||
|
||||
scene := models.Scene{
|
||||
ID: sceneID,
|
||||
Path: test.ScenePath,
|
||||
}
|
||||
err := SceneTags(&scene, mockSceneReader, mockTagReader)
|
||||
|
||||
assert.Nil(err)
|
||||
mockTagReader.AssertExpectations(t)
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
66
pkg/autotag/studio.go
Normal file
66
pkg/autotag/studio.go
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
package autotag
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
||||
func getMatchingStudios(path string, reader models.StudioReader) ([]*models.Studio, error) {
|
||||
words := getPathWords(path)
|
||||
candidates, err := reader.QueryForAutoTag(words)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret []*models.Studio
|
||||
for _, c := range candidates {
|
||||
if nameMatchesPath(c.Name.String, path) {
|
||||
ret = append(ret, c)
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func addSceneStudio(sceneWriter models.SceneReaderWriter, sceneID, studioID int) (bool, error) {
|
||||
// don't set if already set
|
||||
scene, err := sceneWriter.Find(sceneID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if scene.StudioID.Valid {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// set the studio id
|
||||
s := sql.NullInt64{Int64: int64(studioID), Valid: true}
|
||||
scenePartial := models.ScenePartial{
|
||||
ID: sceneID,
|
||||
StudioID: &s,
|
||||
}
|
||||
|
||||
if _, err := sceneWriter.Update(scenePartial); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func getStudioTagger(p *models.Studio) tagger {
|
||||
return tagger{
|
||||
ID: p.ID,
|
||||
Type: "studio",
|
||||
Name: p.Name.String,
|
||||
}
|
||||
}
|
||||
|
||||
// StudioScenes searches for scenes whose path matches the provided studio name and tags the scene with the studio, if studio is not already set on the scene.
|
||||
func StudioScenes(p *models.Studio, paths []string, rw models.SceneReaderWriter) error {
|
||||
t := getStudioTagger(p)
|
||||
|
||||
return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) {
|
||||
return addSceneStudio(rw, otherID, subjectID)
|
||||
})
|
||||
}
|
||||
85
pkg/autotag/studio_test.go
Normal file
85
pkg/autotag/studio_test.go
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
package autotag
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/models/mocks"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStudioScenes(t *testing.T) {
|
||||
type test struct {
|
||||
studioName string
|
||||
expectedRegex string
|
||||
}
|
||||
|
||||
studioNames := []test{
|
||||
{
|
||||
"studio name",
|
||||
`(?i)(?:^|_|[^\w\d])studio[.\-_ ]*name(?:$|_|[^\w\d])`,
|
||||
},
|
||||
{
|
||||
"studio + name",
|
||||
`(?i)(?:^|_|[^\w\d])studio[.\-_ ]*\+[.\-_ ]*name(?:$|_|[^\w\d])`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range studioNames {
|
||||
testStudioScenes(t, p.studioName, p.expectedRegex)
|
||||
}
|
||||
}
|
||||
|
||||
func testStudioScenes(t *testing.T, studioName, expectedRegex string) {
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
|
||||
const studioID = 2
|
||||
|
||||
var scenes []*models.Scene
|
||||
matchingPaths, falsePaths := generateScenePaths(studioName)
|
||||
for i, p := range append(matchingPaths, falsePaths...) {
|
||||
scenes = append(scenes, &models.Scene{
|
||||
ID: i + 1,
|
||||
Path: p,
|
||||
})
|
||||
}
|
||||
|
||||
studio := models.Studio{
|
||||
ID: studioID,
|
||||
Name: models.NullString(studioName),
|
||||
}
|
||||
|
||||
organized := false
|
||||
perPage := models.PerPageAll
|
||||
|
||||
expectedSceneFilter := &models.SceneFilterType{
|
||||
Organized: &organized,
|
||||
Path: &models.StringCriterionInput{
|
||||
Value: expectedRegex,
|
||||
Modifier: models.CriterionModifierMatchesRegex,
|
||||
},
|
||||
}
|
||||
|
||||
expectedFindFilter := &models.FindFilterType{
|
||||
PerPage: &perPage,
|
||||
}
|
||||
|
||||
mockSceneReader.On("Query", expectedSceneFilter, expectedFindFilter).Return(scenes, len(scenes), nil).Once()
|
||||
|
||||
for i := range matchingPaths {
|
||||
sceneID := i + 1
|
||||
mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once()
|
||||
expectedStudioID := models.NullInt64(studioID)
|
||||
mockSceneReader.On("Update", models.ScenePartial{
|
||||
ID: sceneID,
|
||||
StudioID: &expectedStudioID,
|
||||
}).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
err := StudioScenes(&studio, nil, mockSceneReader)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Nil(err)
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
}
|
||||
41
pkg/autotag/tag.go
Normal file
41
pkg/autotag/tag.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
package autotag
|
||||
|
||||
import (
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/scene"
|
||||
)
|
||||
|
||||
func getMatchingTags(path string, tagReader models.TagReader) ([]*models.Tag, error) {
|
||||
words := getPathWords(path)
|
||||
tags, err := tagReader.QueryForAutoTag(words)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret []*models.Tag
|
||||
for _, p := range tags {
|
||||
if nameMatchesPath(p.Name, path) {
|
||||
ret = append(ret, p)
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func getTagTagger(p *models.Tag) tagger {
|
||||
return tagger{
|
||||
ID: p.ID,
|
||||
Type: "tag",
|
||||
Name: p.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// TagScenes searches for scenes whose path matches the provided tag name and tags the scene with the tag.
|
||||
func TagScenes(p *models.Tag, paths []string, rw models.SceneReaderWriter) error {
|
||||
t := getTagTagger(p)
|
||||
|
||||
return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) {
|
||||
return scene.AddTag(rw, otherID, subjectID)
|
||||
})
|
||||
}
|
||||
81
pkg/autotag/tag_test.go
Normal file
81
pkg/autotag/tag_test.go
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
package autotag
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/models/mocks"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTagScenes(t *testing.T) {
|
||||
type test struct {
|
||||
tagName string
|
||||
expectedRegex string
|
||||
}
|
||||
|
||||
tagNames := []test{
|
||||
{
|
||||
"tag name",
|
||||
`(?i)(?:^|_|[^\w\d])tag[.\-_ ]*name(?:$|_|[^\w\d])`,
|
||||
},
|
||||
{
|
||||
"tag + name",
|
||||
`(?i)(?:^|_|[^\w\d])tag[.\-_ ]*\+[.\-_ ]*name(?:$|_|[^\w\d])`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range tagNames {
|
||||
testTagScenes(t, p.tagName, p.expectedRegex)
|
||||
}
|
||||
}
|
||||
|
||||
func testTagScenes(t *testing.T, tagName, expectedRegex string) {
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
|
||||
const tagID = 2
|
||||
|
||||
var scenes []*models.Scene
|
||||
matchingPaths, falsePaths := generateScenePaths(tagName)
|
||||
for i, p := range append(matchingPaths, falsePaths...) {
|
||||
scenes = append(scenes, &models.Scene{
|
||||
ID: i + 1,
|
||||
Path: p,
|
||||
})
|
||||
}
|
||||
|
||||
tag := models.Tag{
|
||||
ID: tagID,
|
||||
Name: tagName,
|
||||
}
|
||||
|
||||
organized := false
|
||||
perPage := models.PerPageAll
|
||||
|
||||
expectedSceneFilter := &models.SceneFilterType{
|
||||
Organized: &organized,
|
||||
Path: &models.StringCriterionInput{
|
||||
Value: expectedRegex,
|
||||
Modifier: models.CriterionModifierMatchesRegex,
|
||||
},
|
||||
}
|
||||
|
||||
expectedFindFilter := &models.FindFilterType{
|
||||
PerPage: &perPage,
|
||||
}
|
||||
|
||||
mockSceneReader.On("Query", expectedSceneFilter, expectedFindFilter).Return(scenes, len(scenes), nil).Once()
|
||||
|
||||
for i := range matchingPaths {
|
||||
sceneID := i + 1
|
||||
mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once()
|
||||
mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once()
|
||||
}
|
||||
|
||||
err := TagScenes(&tag, nil, mockSceneReader)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Nil(err)
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
}
|
||||
198
pkg/autotag/tagger.go
Normal file
198
pkg/autotag/tagger.go
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
// Package autotag provides methods to auto-tag scenes with performers,
|
||||
// studios and tags.
|
||||
//
|
||||
// The autotag engine tags scenes with performers/studios/tags if the scene's
|
||||
// path matches the performer/studio/tag name. A scene's path is considered
|
||||
// a match if it contains the performer/studio/tag's full name, ignoring any
|
||||
// '.', '-', '_' characters in the path.
|
||||
//
|
||||
// For example, for a performer "foo bar", the following paths would be
|
||||
// considered a match: "foo bar.mp4", "foobar.mp4", "foo.bar.mp4",
|
||||
// "foo-bar.mp4", "aaa.foo bar.bbb.mp4".
|
||||
// The following would not be considered a match:
|
||||
// "aafoo bar.mp4", "foo barbb.mp4", "foo/bar.mp4"
|
||||
package autotag
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
||||
const separatorChars = `.\-_ `
|
||||
|
||||
// fixes #1292
|
||||
func escapePathRegex(name string) string {
|
||||
ret := name
|
||||
|
||||
chars := `+*?()|[]{}^$`
|
||||
for _, c := range chars {
|
||||
cStr := string(c)
|
||||
ret = strings.ReplaceAll(ret, cStr, `\`+cStr)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func getPathQueryRegex(name string) string {
|
||||
// escape specific regex characters
|
||||
name = escapePathRegex(name)
|
||||
|
||||
// handle path separators
|
||||
const separator = `[` + separatorChars + `]`
|
||||
|
||||
ret := strings.Replace(name, " ", separator+"*", -1)
|
||||
ret = `(?:^|_|[^\w\d])` + ret + `(?:$|_|[^\w\d])`
|
||||
return ret
|
||||
}
|
||||
|
||||
func nameMatchesPath(name, path string) bool {
|
||||
// escape specific regex characters
|
||||
name = escapePathRegex(name)
|
||||
|
||||
name = strings.ToLower(name)
|
||||
path = strings.ToLower(path)
|
||||
|
||||
// handle path separators
|
||||
const separator = `[` + separatorChars + `]`
|
||||
|
||||
reStr := strings.Replace(name, " ", separator+"*", -1)
|
||||
reStr = `(?:^|_|[^\w\d])` + reStr + `(?:$|_|[^\w\d])`
|
||||
|
||||
re := regexp.MustCompile(reStr)
|
||||
return re.MatchString(path)
|
||||
}
|
||||
|
||||
func getPathWords(path string) []string {
|
||||
retStr := path
|
||||
|
||||
// remove the extension
|
||||
ext := filepath.Ext(retStr)
|
||||
if ext != "" {
|
||||
retStr = strings.TrimSuffix(retStr, ext)
|
||||
}
|
||||
|
||||
// handle path separators
|
||||
const separator = `(?:_|[^\w\d])+`
|
||||
re := regexp.MustCompile(separator)
|
||||
retStr = re.ReplaceAllString(retStr, " ")
|
||||
|
||||
words := strings.Split(retStr, " ")
|
||||
|
||||
// remove any single letter words
|
||||
var ret []string
|
||||
for _, w := range words {
|
||||
if len(w) > 1 {
|
||||
ret = append(ret, w)
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
type tagger struct {
|
||||
ID int
|
||||
Type string
|
||||
Name string
|
||||
Path string
|
||||
}
|
||||
|
||||
type addLinkFunc func(subjectID, otherID int) (bool, error)
|
||||
|
||||
func (t *tagger) addError(otherType, otherName string, err error) error {
|
||||
return fmt.Errorf("error adding %s '%s' to %s '%s': %s", otherType, otherName, t.Type, t.Name, err.Error())
|
||||
}
|
||||
|
||||
func (t *tagger) addLog(otherType, otherName string) {
|
||||
logger.Infof("Added %s '%s' to %s '%s'", otherType, otherName, t.Type, t.Name)
|
||||
}
|
||||
|
||||
func (t *tagger) tagPerformers(performerReader models.PerformerReader, addFunc addLinkFunc) error {
|
||||
others, err := getMatchingPerformers(t.Path, performerReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, p := range others {
|
||||
added, err := addFunc(t.ID, p.ID)
|
||||
|
||||
if err != nil {
|
||||
return t.addError("performer", p.Name.String, err)
|
||||
}
|
||||
|
||||
if added {
|
||||
t.addLog("performer", p.Name.String)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tagger) tagStudios(studioReader models.StudioReader, addFunc addLinkFunc) error {
|
||||
others, err := getMatchingStudios(t.Path, studioReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// only add first studio
|
||||
if len(others) > 0 {
|
||||
studio := others[0]
|
||||
added, err := addFunc(t.ID, studio.ID)
|
||||
|
||||
if err != nil {
|
||||
return t.addError("studio", studio.Name.String, err)
|
||||
}
|
||||
|
||||
if added {
|
||||
t.addLog("studio", studio.Name.String)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tagger) tagTags(tagReader models.TagReader, addFunc addLinkFunc) error {
|
||||
others, err := getMatchingTags(t.Path, tagReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, p := range others {
|
||||
added, err := addFunc(t.ID, p.ID)
|
||||
|
||||
if err != nil {
|
||||
return t.addError("tag", p.Name, err)
|
||||
}
|
||||
|
||||
if added {
|
||||
t.addLog("tag", p.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tagger) tagScenes(paths []string, sceneReader models.SceneReader, addFunc addLinkFunc) error {
|
||||
others, err := getMatchingScenes(t.Name, paths, sceneReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, p := range others {
|
||||
added, err := addFunc(t.ID, p.ID)
|
||||
|
||||
if err != nil {
|
||||
return t.addError("scene", p.GetTitle(), err)
|
||||
}
|
||||
|
||||
if added {
|
||||
t.addLog("scene", p.GetTitle())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -5,12 +5,15 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/remeh/sizedwaitgroup"
|
||||
|
||||
"github.com/stashapp/stash/pkg/autotag"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/manager/config"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
|
|
@ -626,6 +629,15 @@ func (s *singleton) generateScreenshot(sceneId string, at *float64) {
|
|||
}()
|
||||
}
|
||||
|
||||
func (s *singleton) isFileBasedAutoTag(input models.AutoTagMetadataInput) bool {
|
||||
const wildcard = "*"
|
||||
performerIds := input.Performers
|
||||
studioIds := input.Studios
|
||||
tagIds := input.Tags
|
||||
|
||||
return (len(performerIds) == 0 || performerIds[0] == wildcard) && (len(studioIds) == 0 || studioIds[0] == wildcard) && (len(tagIds) == 0 || tagIds[0] == wildcard)
|
||||
}
|
||||
|
||||
func (s *singleton) AutoTag(input models.AutoTagMetadataInput) {
|
||||
if s.Status.Status != Idle {
|
||||
return
|
||||
|
|
@ -636,58 +648,160 @@ func (s *singleton) AutoTag(input models.AutoTagMetadataInput) {
|
|||
go func() {
|
||||
defer s.returnToIdleState()
|
||||
|
||||
performerIds := input.Performers
|
||||
studioIds := input.Studios
|
||||
tagIds := input.Tags
|
||||
|
||||
// calculate work load
|
||||
performerCount := len(performerIds)
|
||||
studioCount := len(studioIds)
|
||||
tagCount := len(tagIds)
|
||||
|
||||
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
|
||||
performerQuery := r.Performer()
|
||||
studioQuery := r.Studio()
|
||||
tagQuery := r.Tag()
|
||||
|
||||
const wildcard = "*"
|
||||
var err error
|
||||
if performerCount == 1 && performerIds[0] == wildcard {
|
||||
performerCount, err = performerQuery.Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error getting performer count: %s", err.Error())
|
||||
}
|
||||
}
|
||||
if studioCount == 1 && studioIds[0] == wildcard {
|
||||
studioCount, err = studioQuery.Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error getting studio count: %s", err.Error())
|
||||
}
|
||||
}
|
||||
if tagCount == 1 && tagIds[0] == wildcard {
|
||||
tagCount, err = tagQuery.Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error getting tag count: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
return
|
||||
if s.isFileBasedAutoTag(input) {
|
||||
// doing file-based auto-tag
|
||||
s.autoTagScenes(input.Paths, len(input.Performers) > 0, len(input.Studios) > 0, len(input.Tags) > 0)
|
||||
} else {
|
||||
// doing specific performer/studio/tag auto-tag
|
||||
s.autoTagSpecific(input)
|
||||
}
|
||||
|
||||
total := performerCount + studioCount + tagCount
|
||||
s.Status.setProgress(0, total)
|
||||
|
||||
s.autoTagPerformers(input.Paths, performerIds)
|
||||
s.autoTagStudios(input.Paths, studioIds)
|
||||
s.autoTagTags(input.Paths, tagIds)
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *singleton) autoTagScenes(paths []string, performers, studios, tags bool) {
|
||||
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
|
||||
ret := &models.SceneFilterType{}
|
||||
or := ret
|
||||
sep := string(filepath.Separator)
|
||||
|
||||
for _, p := range paths {
|
||||
if !strings.HasSuffix(p, sep) {
|
||||
p = p + sep
|
||||
}
|
||||
|
||||
if ret.Path == nil {
|
||||
or = ret
|
||||
} else {
|
||||
newOr := &models.SceneFilterType{}
|
||||
or.Or = newOr
|
||||
or = newOr
|
||||
}
|
||||
|
||||
or.Path = &models.StringCriterionInput{
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
Value: p + "%",
|
||||
}
|
||||
}
|
||||
|
||||
organized := false
|
||||
ret.Organized = &organized
|
||||
|
||||
// batch process scenes
|
||||
batchSize := 1000
|
||||
page := 1
|
||||
findFilter := &models.FindFilterType{
|
||||
PerPage: &batchSize,
|
||||
Page: &page,
|
||||
}
|
||||
|
||||
more := true
|
||||
processed := 0
|
||||
for more {
|
||||
scenes, total, err := r.Scene().Query(ret, findFilter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if processed == 0 {
|
||||
logger.Infof("Starting autotag of %d scenes", total)
|
||||
}
|
||||
|
||||
for _, ss := range scenes {
|
||||
if s.Status.stopping {
|
||||
logger.Info("Stopping due to user request")
|
||||
return nil
|
||||
}
|
||||
|
||||
t := autoTagSceneTask{
|
||||
txnManager: s.TxnManager,
|
||||
scene: ss,
|
||||
performers: performers,
|
||||
studios: studios,
|
||||
tags: tags,
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go t.Start(&wg)
|
||||
wg.Wait()
|
||||
|
||||
processed++
|
||||
s.Status.setProgress(processed, total)
|
||||
}
|
||||
|
||||
if len(scenes) != batchSize {
|
||||
more = false
|
||||
} else {
|
||||
page++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
|
||||
logger.Info("Finished autotag")
|
||||
}
|
||||
|
||||
func (s *singleton) autoTagSpecific(input models.AutoTagMetadataInput) {
|
||||
performerIds := input.Performers
|
||||
studioIds := input.Studios
|
||||
tagIds := input.Tags
|
||||
|
||||
performerCount := len(performerIds)
|
||||
studioCount := len(studioIds)
|
||||
tagCount := len(tagIds)
|
||||
|
||||
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
|
||||
performerQuery := r.Performer()
|
||||
studioQuery := r.Studio()
|
||||
tagQuery := r.Tag()
|
||||
|
||||
const wildcard = "*"
|
||||
var err error
|
||||
if performerCount == 1 && performerIds[0] == wildcard {
|
||||
performerCount, err = performerQuery.Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting performer count: %s", err.Error())
|
||||
}
|
||||
}
|
||||
if studioCount == 1 && studioIds[0] == wildcard {
|
||||
studioCount, err = studioQuery.Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting studio count: %s", err.Error())
|
||||
}
|
||||
}
|
||||
if tagCount == 1 && tagIds[0] == wildcard {
|
||||
tagCount, err = tagQuery.Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting tag count: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
total := performerCount + studioCount + tagCount
|
||||
s.Status.setProgress(0, total)
|
||||
|
||||
logger.Infof("Starting autotag of %d performers, %d studios, %d tags", performerCount, studioCount, tagCount)
|
||||
|
||||
s.autoTagPerformers(input.Paths, performerIds)
|
||||
s.autoTagStudios(input.Paths, studioIds)
|
||||
s.autoTagTags(input.Paths, tagIds)
|
||||
|
||||
logger.Info("Finished autotag")
|
||||
}
|
||||
|
||||
func (s *singleton) autoTagPerformers(paths []string, performerIds []string) {
|
||||
var wg sync.WaitGroup
|
||||
if s.Status.stopping {
|
||||
return
|
||||
}
|
||||
|
||||
for _, performerId := range performerIds {
|
||||
var performers []*models.Performer
|
||||
|
||||
|
|
@ -698,46 +812,53 @@ func (s *singleton) autoTagPerformers(paths []string, performerIds []string) {
|
|||
var err error
|
||||
performers, err = performerQuery.All()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error querying performers: %s", err.Error())
|
||||
return fmt.Errorf("error querying performers: %s", err.Error())
|
||||
}
|
||||
} else {
|
||||
performerIdInt, err := strconv.Atoi(performerId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error parsing performer id %s: %s", performerId, err.Error())
|
||||
return fmt.Errorf("error parsing performer id %s: %s", performerId, err.Error())
|
||||
}
|
||||
|
||||
performer, err := performerQuery.Find(performerIdInt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error finding performer id %s: %s", performerId, err.Error())
|
||||
return fmt.Errorf("error finding performer id %s: %s", performerId, err.Error())
|
||||
}
|
||||
|
||||
if performer == nil {
|
||||
return fmt.Errorf("performer with id %s not found", performerId)
|
||||
}
|
||||
performers = append(performers, performer)
|
||||
}
|
||||
|
||||
for _, performer := range performers {
|
||||
if s.Status.stopping {
|
||||
logger.Info("Stopping due to user request")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
|
||||
return autotag.PerformerScenes(performer, paths, r.Scene())
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error auto-tagging performer '%s': %s", performer.Name.String, err.Error())
|
||||
}
|
||||
|
||||
s.Status.incrementProgress()
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
for _, performer := range performers {
|
||||
wg.Add(1)
|
||||
task := AutoTagPerformerTask{
|
||||
AutoTagTask: AutoTagTask{
|
||||
txnManager: s.TxnManager,
|
||||
paths: paths,
|
||||
},
|
||||
performer: performer,
|
||||
}
|
||||
go task.Start(&wg)
|
||||
wg.Wait()
|
||||
|
||||
s.Status.incrementProgress()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *singleton) autoTagStudios(paths []string, studioIds []string) {
|
||||
var wg sync.WaitGroup
|
||||
if s.Status.stopping {
|
||||
return
|
||||
}
|
||||
|
||||
for _, studioId := range studioIds {
|
||||
var studios []*models.Studio
|
||||
|
||||
|
|
@ -747,46 +868,54 @@ func (s *singleton) autoTagStudios(paths []string, studioIds []string) {
|
|||
var err error
|
||||
studios, err = studioQuery.All()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error querying studios: %s", err.Error())
|
||||
return fmt.Errorf("error querying studios: %s", err.Error())
|
||||
}
|
||||
} else {
|
||||
studioIdInt, err := strconv.Atoi(studioId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error parsing studio id %s: %s", studioId, err.Error())
|
||||
return fmt.Errorf("error parsing studio id %s: %s", studioId, err.Error())
|
||||
}
|
||||
|
||||
studio, err := studioQuery.Find(studioIdInt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error finding studio id %s: %s", studioId, err.Error())
|
||||
return fmt.Errorf("error finding studio id %s: %s", studioId, err.Error())
|
||||
}
|
||||
|
||||
if studio == nil {
|
||||
return fmt.Errorf("studio with id %s not found", studioId)
|
||||
}
|
||||
|
||||
studios = append(studios, studio)
|
||||
}
|
||||
|
||||
for _, studio := range studios {
|
||||
if s.Status.stopping {
|
||||
logger.Info("Stopping due to user request")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
|
||||
return autotag.StudioScenes(studio, paths, r.Scene())
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error auto-tagging studio '%s': %s", studio.Name.String, err.Error())
|
||||
}
|
||||
|
||||
s.Status.incrementProgress()
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
for _, studio := range studios {
|
||||
wg.Add(1)
|
||||
task := AutoTagStudioTask{
|
||||
AutoTagTask: AutoTagTask{
|
||||
txnManager: s.TxnManager,
|
||||
paths: paths,
|
||||
},
|
||||
studio: studio,
|
||||
}
|
||||
go task.Start(&wg)
|
||||
wg.Wait()
|
||||
|
||||
s.Status.incrementProgress()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *singleton) autoTagTags(paths []string, tagIds []string) {
|
||||
var wg sync.WaitGroup
|
||||
if s.Status.stopping {
|
||||
return
|
||||
}
|
||||
|
||||
for _, tagId := range tagIds {
|
||||
var tags []*models.Tag
|
||||
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
|
||||
|
|
@ -795,41 +924,41 @@ func (s *singleton) autoTagTags(paths []string, tagIds []string) {
|
|||
var err error
|
||||
tags, err = tagQuery.All()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error querying tags: %s", err.Error())
|
||||
return fmt.Errorf("error querying tags: %s", err.Error())
|
||||
}
|
||||
} else {
|
||||
tagIdInt, err := strconv.Atoi(tagId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error parsing tag id %s: %s", tagId, err.Error())
|
||||
return fmt.Errorf("error parsing tag id %s: %s", tagId, err.Error())
|
||||
}
|
||||
|
||||
tag, err := tagQuery.Find(tagIdInt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error finding tag id %s: %s", tagId, err.Error())
|
||||
return fmt.Errorf("error finding tag id %s: %s", tagId, err.Error())
|
||||
}
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
if s.Status.stopping {
|
||||
logger.Info("Stopping due to user request")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
|
||||
return autotag.TagScenes(tag, paths, r.Scene())
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error auto-tagging tag '%s': %s", tag.Name, err.Error())
|
||||
}
|
||||
|
||||
s.Status.incrementProgress()
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
wg.Add(1)
|
||||
task := AutoTagTagTask{
|
||||
AutoTagTask: AutoTagTask{
|
||||
txnManager: s.TxnManager,
|
||||
paths: paths,
|
||||
},
|
||||
tag: tag,
|
||||
}
|
||||
go task.Start(&wg)
|
||||
wg.Wait()
|
||||
|
||||
s.Status.incrementProgress()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,196 +2,38 @@ package manager
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/stashapp/stash/pkg/autotag"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/scene"
|
||||
)
|
||||
|
||||
type AutoTagTask struct {
|
||||
paths []string
|
||||
type autoTagSceneTask struct {
|
||||
txnManager models.TransactionManager
|
||||
scene *models.Scene
|
||||
|
||||
performers bool
|
||||
studios bool
|
||||
tags bool
|
||||
}
|
||||
|
||||
type AutoTagPerformerTask struct {
|
||||
AutoTagTask
|
||||
performer *models.Performer
|
||||
}
|
||||
|
||||
func (t *AutoTagPerformerTask) Start(wg *sync.WaitGroup) {
|
||||
func (t *autoTagSceneTask) Start(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
t.autoTagPerformer()
|
||||
}
|
||||
|
||||
func (t *AutoTagTask) getQueryRegex(name string) string {
|
||||
const separatorChars = `.\-_ `
|
||||
// handle path separators
|
||||
const separator = `[` + separatorChars + `]`
|
||||
|
||||
ret := strings.Replace(name, " ", separator+"*", -1)
|
||||
ret = `(?:^|_|[^\w\d])` + ret + `(?:$|_|[^\w\d])`
|
||||
return ret
|
||||
}
|
||||
|
||||
func (t *AutoTagTask) getQueryFilter(regex string) *models.SceneFilterType {
|
||||
organized := false
|
||||
ret := &models.SceneFilterType{
|
||||
Path: &models.StringCriterionInput{
|
||||
Modifier: models.CriterionModifierMatchesRegex,
|
||||
Value: "(?i)" + regex,
|
||||
},
|
||||
Organized: &organized,
|
||||
}
|
||||
|
||||
sep := string(filepath.Separator)
|
||||
|
||||
var or *models.SceneFilterType
|
||||
for _, p := range t.paths {
|
||||
newOr := &models.SceneFilterType{}
|
||||
if or == nil {
|
||||
ret.And = newOr
|
||||
} else {
|
||||
or.Or = newOr
|
||||
}
|
||||
|
||||
or = newOr
|
||||
|
||||
if !strings.HasSuffix(p, sep) {
|
||||
p = p + sep
|
||||
}
|
||||
|
||||
or.Path = &models.StringCriterionInput{
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
Value: p + "%",
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (t *AutoTagTask) getFindFilter() *models.FindFilterType {
|
||||
perPage := -1
|
||||
return &models.FindFilterType{
|
||||
PerPage: &perPage,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *AutoTagPerformerTask) autoTagPerformer() {
|
||||
regex := t.getQueryRegex(t.performer.Name.String)
|
||||
|
||||
if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error {
|
||||
qb := r.Scene()
|
||||
|
||||
scenes, _, err := qb.Query(t.getQueryFilter(regex), t.getFindFilter())
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error())
|
||||
}
|
||||
|
||||
for _, s := range scenes {
|
||||
added, err := scene.AddPerformer(qb, s.ID, t.performer.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error adding performer '%s' to scene '%s': %s", t.performer.Name.String, s.GetTitle(), err.Error())
|
||||
if t.performers {
|
||||
if err := autotag.ScenePerformers(t.scene, r.Scene(), r.Performer()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if added {
|
||||
logger.Infof("Added performer '%s' to scene '%s'", t.performer.Name.String, s.GetTitle())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
type AutoTagStudioTask struct {
|
||||
AutoTagTask
|
||||
studio *models.Studio
|
||||
}
|
||||
|
||||
func (t *AutoTagStudioTask) Start(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
t.autoTagStudio()
|
||||
}
|
||||
|
||||
func (t *AutoTagStudioTask) autoTagStudio() {
|
||||
regex := t.getQueryRegex(t.studio.Name.String)
|
||||
|
||||
if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error {
|
||||
qb := r.Scene()
|
||||
scenes, _, err := qb.Query(t.getQueryFilter(regex), t.getFindFilter())
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error())
|
||||
}
|
||||
|
||||
for _, s := range scenes {
|
||||
// #306 - don't overwrite studio if already present
|
||||
if s.StudioID.Valid {
|
||||
// don't modify
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("Adding studio '%s' to scene '%s'", t.studio.Name.String, s.GetTitle())
|
||||
|
||||
// set the studio id
|
||||
studioID := sql.NullInt64{Int64: int64(t.studio.ID), Valid: true}
|
||||
scenePartial := models.ScenePartial{
|
||||
ID: s.ID,
|
||||
StudioID: &studioID,
|
||||
}
|
||||
|
||||
if _, err := qb.Update(scenePartial); err != nil {
|
||||
return fmt.Errorf("Error adding studio to scene: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
type AutoTagTagTask struct {
|
||||
AutoTagTask
|
||||
tag *models.Tag
|
||||
}
|
||||
|
||||
func (t *AutoTagTagTask) Start(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
t.autoTagTag()
|
||||
}
|
||||
|
||||
func (t *AutoTagTagTask) autoTagTag() {
|
||||
regex := t.getQueryRegex(t.tag.Name)
|
||||
|
||||
if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error {
|
||||
qb := r.Scene()
|
||||
scenes, _, err := qb.Query(t.getQueryFilter(regex), t.getFindFilter())
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error())
|
||||
}
|
||||
|
||||
for _, s := range scenes {
|
||||
added, err := scene.AddTag(qb, s.ID, t.tag.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error adding tag '%s' to scene '%s': %s", t.tag.Name, s.GetTitle(), err.Error())
|
||||
}
|
||||
|
||||
if added {
|
||||
logger.Infof("Added tag '%s' to scene '%s'", t.tag.Name, s.GetTitle())
|
||||
}
|
||||
if t.studios {
|
||||
if err := autotag.SceneStudios(t.scene, r.Scene(), r.Studio()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if t.tags {
|
||||
if err := autotag.SceneTags(t.scene, r.Scene(), r.Tag()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
package models
|
||||
|
||||
// PerPageAll is the value used for perPage to indicate all results should be
|
||||
// returned.
|
||||
const PerPageAll = -1
|
||||
|
||||
func (ff FindFilterType) GetSort(defaultSort string) string {
|
||||
var sort string
|
||||
if ff.Sort == nil {
|
||||
|
|
|
|||
|
|
@ -388,6 +388,29 @@ func (_m *PerformerReaderWriter) Query(performerFilter *models.PerformerFilterTy
|
|||
return r0, r1, r2
|
||||
}
|
||||
|
||||
// QueryForAutoTag provides a mock function with given fields: words
|
||||
func (_m *PerformerReaderWriter) QueryForAutoTag(words []string) ([]*models.Performer, error) {
|
||||
ret := _m.Called(words)
|
||||
|
||||
var r0 []*models.Performer
|
||||
if rf, ok := ret.Get(0).(func([]string) []*models.Performer); ok {
|
||||
r0 = rf(words)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]*models.Performer)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func([]string) error); ok {
|
||||
r1 = rf(words)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Update provides a mock function with given fields: updatedPerformer
|
||||
func (_m *PerformerReaderWriter) Update(updatedPerformer models.PerformerPartial) (*models.Performer, error) {
|
||||
ret := _m.Called(updatedPerformer)
|
||||
|
|
|
|||
|
|
@ -296,6 +296,29 @@ func (_m *StudioReaderWriter) Query(studioFilter *models.StudioFilterType, findF
|
|||
return r0, r1, r2
|
||||
}
|
||||
|
||||
// QueryForAutoTag provides a mock function with given fields: words
|
||||
func (_m *StudioReaderWriter) QueryForAutoTag(words []string) ([]*models.Studio, error) {
|
||||
ret := _m.Called(words)
|
||||
|
||||
var r0 []*models.Studio
|
||||
if rf, ok := ret.Get(0).(func([]string) []*models.Studio); ok {
|
||||
r0 = rf(words)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]*models.Studio)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func([]string) error); ok {
|
||||
r1 = rf(words)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Update provides a mock function with given fields: updatedStudio
|
||||
func (_m *StudioReaderWriter) Update(updatedStudio models.StudioPartial) (*models.Studio, error) {
|
||||
ret := _m.Called(updatedStudio)
|
||||
|
|
|
|||
|
|
@ -367,6 +367,29 @@ func (_m *TagReaderWriter) Query(tagFilter *models.TagFilterType, findFilter *mo
|
|||
return r0, r1, r2
|
||||
}
|
||||
|
||||
// QueryForAutoTag provides a mock function with given fields: words
|
||||
func (_m *TagReaderWriter) QueryForAutoTag(words []string) ([]*models.Tag, error) {
|
||||
ret := _m.Called(words)
|
||||
|
||||
var r0 []*models.Tag
|
||||
if rf, ok := ret.Get(0).(func([]string) []*models.Tag); ok {
|
||||
r0 = rf(words)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]*models.Tag)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func([]string) error); ok {
|
||||
r1 = rf(words)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Update provides a mock function with given fields: updatedTag
|
||||
func (_m *TagReaderWriter) Update(updatedTag models.Tag) (*models.Tag, error) {
|
||||
ret := _m.Called(updatedTag)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ type PerformerReader interface {
|
|||
CountByTagID(tagID int) (int, error)
|
||||
Count() (int, error)
|
||||
All() ([]*Performer, error)
|
||||
// TODO - this interface is temporary until the filter schema can fully
|
||||
// support the query needed
|
||||
QueryForAutoTag(words []string) ([]*Performer, error)
|
||||
Query(performerFilter *PerformerFilterType, findFilter *FindFilterType) ([]*Performer, int, error)
|
||||
GetImage(performerID int) ([]byte, error)
|
||||
GetStashIDs(performerID int) ([]*StashID, error)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ type StudioReader interface {
|
|||
FindByName(name string, nocase bool) (*Studio, error)
|
||||
Count() (int, error)
|
||||
All() ([]*Studio, error)
|
||||
// TODO - this interface is temporary until the filter schema can fully
|
||||
// support the query needed
|
||||
QueryForAutoTag(words []string) ([]*Studio, error)
|
||||
Query(studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error)
|
||||
GetImage(studioID int) ([]byte, error)
|
||||
HasImage(studioID int) (bool, error)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@ type TagReader interface {
|
|||
FindByNames(names []string, nocase bool) ([]*Tag, error)
|
||||
Count() (int, error)
|
||||
All() ([]*Tag, error)
|
||||
// TODO - this interface is temporary until the filter schema can fully
|
||||
// support the query needed
|
||||
QueryForAutoTag(words []string) ([]*Tag, error)
|
||||
Query(tagFilter *TagFilterType, findFilter *FindFilterType) ([]*Tag, int, error)
|
||||
GetImage(tagID int) ([]byte, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
|
@ -172,6 +173,25 @@ func (qb *performerQueryBuilder) All() ([]*models.Performer, error) {
|
|||
return qb.queryPerformers(selectAll("performers")+qb.getPerformerSort(nil), nil)
|
||||
}
|
||||
|
||||
func (qb *performerQueryBuilder) QueryForAutoTag(words []string) ([]*models.Performer, error) {
|
||||
// TODO - Query needs to be changed to support queries of this type, and
|
||||
// this method should be removed
|
||||
query := selectAll(performerTable)
|
||||
|
||||
var whereClauses []string
|
||||
var args []interface{}
|
||||
|
||||
for _, w := range words {
|
||||
whereClauses = append(whereClauses, "name like ?")
|
||||
args = append(args, "%"+w+"%")
|
||||
whereClauses = append(whereClauses, "aliases like ?")
|
||||
args = append(args, "%"+w+"%")
|
||||
}
|
||||
|
||||
where := strings.Join(whereClauses, " OR ")
|
||||
return qb.queryPerformers(query+" WHERE "+where, args)
|
||||
}
|
||||
|
||||
func (qb *performerQueryBuilder) Query(performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) {
|
||||
if performerFilter == nil {
|
||||
performerFilter = &models.PerformerFilterType{}
|
||||
|
|
|
|||
|
|
@ -100,6 +100,26 @@ func TestPerformerFindByNames(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestPerformerQueryForAutoTag(t *testing.T) {
|
||||
withTxn(func(r models.Repository) error {
|
||||
tqb := r.Performer()
|
||||
|
||||
name := performerNames[performerIdxWithScene] // find a performer by name
|
||||
|
||||
performers, err := tqb.QueryForAutoTag([]string{name})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Error finding performers: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Len(t, performers, 2)
|
||||
assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[0].Name.String))
|
||||
assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[1].Name.String))
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestPerformerUpdatePerformerImage(t *testing.T) {
|
||||
if err := withTxn(func(r models.Repository) error {
|
||||
qb := r.Performer()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package sqlite
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
|
@ -121,6 +122,23 @@ func (qb *studioQueryBuilder) All() ([]*models.Studio, error) {
|
|||
return qb.queryStudios(selectAll("studios")+qb.getStudioSort(nil), nil)
|
||||
}
|
||||
|
||||
func (qb *studioQueryBuilder) QueryForAutoTag(words []string) ([]*models.Studio, error) {
|
||||
// TODO - Query needs to be changed to support queries of this type, and
|
||||
// this method should be removed
|
||||
query := selectAll(studioTable)
|
||||
|
||||
var whereClauses []string
|
||||
var args []interface{}
|
||||
|
||||
for _, w := range words {
|
||||
whereClauses = append(whereClauses, "name like ?")
|
||||
args = append(args, "%"+w+"%")
|
||||
}
|
||||
|
||||
where := strings.Join(whereClauses, " OR ")
|
||||
return qb.queryStudios(query+" WHERE "+where, args)
|
||||
}
|
||||
|
||||
func (qb *studioQueryBuilder) Query(studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
|
||||
if studioFilter == nil {
|
||||
studioFilter = &models.StudioFilterType{}
|
||||
|
|
|
|||
|
|
@ -45,6 +45,26 @@ func TestStudioFindByName(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestStudioQueryForAutoTag(t *testing.T) {
|
||||
withTxn(func(r models.Repository) error {
|
||||
tqb := r.Studio()
|
||||
|
||||
name := studioNames[studioIdxWithScene] // find a studio by name
|
||||
|
||||
studios, err := tqb.QueryForAutoTag([]string{name})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Error finding studios: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Len(t, studios, 2)
|
||||
assert.Equal(t, strings.ToLower(studioNames[studioIdxWithScene]), strings.ToLower(studios[0].Name.String))
|
||||
assert.Equal(t, strings.ToLower(studioNames[studioIdxWithScene]), strings.ToLower(studios[1].Name.String))
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestStudioQueryParent(t *testing.T) {
|
||||
withTxn(func(r models.Repository) error {
|
||||
sqb := r.Studio()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
|
@ -192,6 +193,23 @@ func (qb *tagQueryBuilder) All() ([]*models.Tag, error) {
|
|||
return qb.queryTags(selectAll("tags")+qb.getDefaultTagSort(), nil)
|
||||
}
|
||||
|
||||
func (qb *tagQueryBuilder) QueryForAutoTag(words []string) ([]*models.Tag, error) {
|
||||
// TODO - Query needs to be changed to support queries of this type, and
|
||||
// this method should be removed
|
||||
query := selectAll(tagTable)
|
||||
|
||||
var whereClauses []string
|
||||
var args []interface{}
|
||||
|
||||
for _, w := range words {
|
||||
whereClauses = append(whereClauses, "name like ?")
|
||||
args = append(args, "%"+w+"%")
|
||||
}
|
||||
|
||||
where := strings.Join(whereClauses, " OR ")
|
||||
return qb.queryTags(query+" WHERE "+where, args)
|
||||
}
|
||||
|
||||
func (qb *tagQueryBuilder) validateFilter(tagFilter *models.TagFilterType) error {
|
||||
const and = "AND"
|
||||
const or = "OR"
|
||||
|
|
|
|||
|
|
@ -70,6 +70,26 @@ func TestTagFindByName(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestTagQueryForAutoTag(t *testing.T) {
|
||||
withTxn(func(r models.Repository) error {
|
||||
tqb := r.Tag()
|
||||
|
||||
name := tagNames[tagIdxWithScene] // find a tag by name
|
||||
|
||||
tags, err := tqb.QueryForAutoTag([]string{name})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Error finding tags: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Len(t, tags, 2)
|
||||
assert.Equal(t, strings.ToLower(tagNames[tagIdxWithScene]), strings.ToLower(tags[0].Name))
|
||||
assert.Equal(t, strings.ToLower(tagNames[tagIdxWithScene]), strings.ToLower(tags[1].Name))
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestTagFindByNames(t *testing.T) {
|
||||
var names []string
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
* Added scene queue.
|
||||
|
||||
### 🎨 Improvements
|
||||
* Improved performance of the auto-tagger.
|
||||
* Clean generation artifacts after generating each scene.
|
||||
* Log message at startup when cleaning the `tmp` and `downloads` generated folders takes more than one second.
|
||||
* Sort movie scenes by scene number by default.
|
||||
|
|
@ -27,6 +28,7 @@
|
|||
* Change performer text query to search by name and alias only.
|
||||
|
||||
### 🐛 Bug fixes
|
||||
* Fixed error when auto-tagging for performers/studios/tags with regex characters in the name.
|
||||
* Fix scraped performer image not updating after clearing the current image when creating a new performer.
|
||||
* Fix error preventing adding a new library path when an existing library path is missing.
|
||||
* Fix whitespace in query string returning all objects.
|
||||
|
|
|
|||
Loading…
Reference in a new issue