mirror of
https://github.com/stashapp/stash.git
synced 2025-12-06 08:26:00 +01:00
Model refactor, part 3 (#4152)
* Remove manager.Repository * Refactor other repositories * Fix tests and add database mock * Add AssertExpectations method * Refactor routes * Move default movie image to internal/static and add convenience methods * Refactor default performer image boxes
This commit is contained in:
parent
40bcb4baa5
commit
33f2ebf2a3
87 changed files with 1843 additions and 1651 deletions
|
|
@ -1,12 +1,13 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/stashapp/stash/internal/manager/config"
|
||||
"github.com/stashapp/stash/internal/static"
|
||||
"github.com/stashapp/stash/pkg/hash"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
|
|
@ -18,7 +19,7 @@ type imageBox struct {
|
|||
files []string
|
||||
}
|
||||
|
||||
var imageExtensions = []string{
|
||||
var imageBoxExts = []string{
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
|
|
@ -42,7 +43,7 @@ func newImageBox(box fs.FS) (*imageBox, error) {
|
|||
}
|
||||
|
||||
baseName := strings.ToLower(d.Name())
|
||||
for _, ext := range imageExtensions {
|
||||
for _, ext := range imageBoxExts {
|
||||
if strings.HasSuffix(baseName, ext) {
|
||||
ret.files = append(ret.files, path)
|
||||
break
|
||||
|
|
@ -55,65 +56,14 @@ func newImageBox(box fs.FS) (*imageBox, error) {
|
|||
return ret, err
|
||||
}
|
||||
|
||||
var performerBox *imageBox
|
||||
var performerBoxMale *imageBox
|
||||
var performerBoxCustom *imageBox
|
||||
|
||||
func initialiseImages() {
|
||||
var err error
|
||||
performerBox, err = newImageBox(&static.Performer)
|
||||
if err != nil {
|
||||
logger.Warnf("error loading performer images: %v", err)
|
||||
}
|
||||
performerBoxMale, err = newImageBox(&static.PerformerMale)
|
||||
if err != nil {
|
||||
logger.Warnf("error loading male performer images: %v", err)
|
||||
}
|
||||
initialiseCustomImages()
|
||||
}
|
||||
|
||||
func initialiseCustomImages() {
|
||||
customPath := config.GetInstance().GetCustomPerformerImageLocation()
|
||||
if customPath != "" {
|
||||
logger.Debugf("Loading custom performer images from %s", customPath)
|
||||
// We need to set performerBoxCustom at runtime, as this is a custom path, and store it in a pointer.
|
||||
var err error
|
||||
performerBoxCustom, err = newImageBox(os.DirFS(customPath))
|
||||
if err != nil {
|
||||
logger.Warnf("error loading custom performer from %s: %v", customPath, err)
|
||||
}
|
||||
} else {
|
||||
performerBoxCustom = nil
|
||||
}
|
||||
}
|
||||
|
||||
func getRandomPerformerImageUsingName(name string, gender *models.GenderEnum, customPath string) ([]byte, error) {
|
||||
var box *imageBox
|
||||
|
||||
// If we have a custom path, we should return a new box in the given path.
|
||||
if performerBoxCustom != nil && len(performerBoxCustom.files) > 0 {
|
||||
box = performerBoxCustom
|
||||
func (box *imageBox) GetRandomImageByName(name string) ([]byte, error) {
|
||||
files := box.files
|
||||
if len(files) == 0 {
|
||||
return nil, errors.New("box is empty")
|
||||
}
|
||||
|
||||
var g models.GenderEnum
|
||||
if gender != nil {
|
||||
g = *gender
|
||||
}
|
||||
|
||||
if box == nil {
|
||||
switch g {
|
||||
case models.GenderEnumFemale, models.GenderEnumTransgenderFemale:
|
||||
box = performerBox
|
||||
case models.GenderEnumMale, models.GenderEnumTransgenderMale:
|
||||
box = performerBoxMale
|
||||
default:
|
||||
box = performerBox
|
||||
}
|
||||
}
|
||||
|
||||
imageFiles := box.files
|
||||
index := hash.IntFromString(name) % uint64(len(imageFiles))
|
||||
img, err := box.box.Open(imageFiles[index])
|
||||
index := hash.IntFromString(name) % uint64(len(files))
|
||||
img, err := box.box.Open(files[index])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -121,3 +71,64 @@ func getRandomPerformerImageUsingName(name string, gender *models.GenderEnum, cu
|
|||
|
||||
return io.ReadAll(img)
|
||||
}
|
||||
|
||||
var performerBox *imageBox
|
||||
var performerBoxMale *imageBox
|
||||
var performerBoxCustom *imageBox
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
performerBox, err = newImageBox(static.Sub(static.Performer))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("loading performer images: %v", err))
|
||||
}
|
||||
performerBoxMale, err = newImageBox(static.Sub(static.PerformerMale))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("loading male performer images: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func initCustomPerformerImages(customPath string) {
|
||||
if customPath != "" {
|
||||
logger.Debugf("Loading custom performer images from %s", customPath)
|
||||
var err error
|
||||
performerBoxCustom, err = newImageBox(os.DirFS(customPath))
|
||||
if err != nil {
|
||||
logger.Warnf("error loading custom performer images from %s: %v", customPath, err)
|
||||
}
|
||||
} else {
|
||||
performerBoxCustom = nil
|
||||
}
|
||||
}
|
||||
|
||||
func getDefaultPerformerImage(name string, gender *models.GenderEnum) []byte {
|
||||
// try the custom box first if we have one
|
||||
if performerBoxCustom != nil {
|
||||
ret, err := performerBoxCustom.GetRandomImageByName(name)
|
||||
if err == nil {
|
||||
return ret
|
||||
}
|
||||
logger.Warnf("error loading custom default performer image: %v", err)
|
||||
}
|
||||
|
||||
var g models.GenderEnum
|
||||
if gender != nil {
|
||||
g = *gender
|
||||
}
|
||||
|
||||
var box *imageBox
|
||||
switch g {
|
||||
case models.GenderEnumFemale, models.GenderEnumTransgenderFemale:
|
||||
box = performerBox
|
||||
case models.GenderEnumMale, models.GenderEnumTransgenderMale:
|
||||
box = performerBoxMale
|
||||
default:
|
||||
box = performerBox
|
||||
}
|
||||
|
||||
ret, err := box.GetRandomImageByName(name)
|
||||
if err != nil {
|
||||
logger.Warnf("error loading default performer image: %v", err)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/stashapp/stash/internal/manager"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
type contextKey struct{ name string }
|
||||
|
|
@ -49,8 +47,7 @@ type Loaders struct {
|
|||
}
|
||||
|
||||
type Middleware struct {
|
||||
DatabaseProvider txn.DatabaseProvider
|
||||
Repository manager.Repository
|
||||
Repository models.Repository
|
||||
}
|
||||
|
||||
func (m Middleware) Middleware(next http.Handler) http.Handler {
|
||||
|
|
@ -131,13 +128,9 @@ func toErrorSlice(err error) []error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m Middleware) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
|
||||
return txn.WithDatabase(ctx, m.DatabaseProvider, fn)
|
||||
}
|
||||
|
||||
func (m Middleware) fetchScenes(ctx context.Context) func(keys []int) ([]*models.Scene, []error) {
|
||||
return func(keys []int) (ret []*models.Scene, errs []error) {
|
||||
err := m.withTxn(ctx, func(ctx context.Context) error {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
ret, err = m.Repository.Scene.FindMany(ctx, keys)
|
||||
return err
|
||||
|
|
@ -148,7 +141,7 @@ func (m Middleware) fetchScenes(ctx context.Context) func(keys []int) ([]*models
|
|||
|
||||
func (m Middleware) fetchImages(ctx context.Context) func(keys []int) ([]*models.Image, []error) {
|
||||
return func(keys []int) (ret []*models.Image, errs []error) {
|
||||
err := m.withTxn(ctx, func(ctx context.Context) error {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
ret, err = m.Repository.Image.FindMany(ctx, keys)
|
||||
return err
|
||||
|
|
@ -160,7 +153,7 @@ func (m Middleware) fetchImages(ctx context.Context) func(keys []int) ([]*models
|
|||
|
||||
func (m Middleware) fetchGalleries(ctx context.Context) func(keys []int) ([]*models.Gallery, []error) {
|
||||
return func(keys []int) (ret []*models.Gallery, errs []error) {
|
||||
err := m.withTxn(ctx, func(ctx context.Context) error {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
ret, err = m.Repository.Gallery.FindMany(ctx, keys)
|
||||
return err
|
||||
|
|
@ -172,7 +165,7 @@ func (m Middleware) fetchGalleries(ctx context.Context) func(keys []int) ([]*mod
|
|||
|
||||
func (m Middleware) fetchPerformers(ctx context.Context) func(keys []int) ([]*models.Performer, []error) {
|
||||
return func(keys []int) (ret []*models.Performer, errs []error) {
|
||||
err := m.withTxn(ctx, func(ctx context.Context) error {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
ret, err = m.Repository.Performer.FindMany(ctx, keys)
|
||||
return err
|
||||
|
|
@ -184,7 +177,7 @@ func (m Middleware) fetchPerformers(ctx context.Context) func(keys []int) ([]*mo
|
|||
|
||||
func (m Middleware) fetchStudios(ctx context.Context) func(keys []int) ([]*models.Studio, []error) {
|
||||
return func(keys []int) (ret []*models.Studio, errs []error) {
|
||||
err := m.withTxn(ctx, func(ctx context.Context) error {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
ret, err = m.Repository.Studio.FindMany(ctx, keys)
|
||||
return err
|
||||
|
|
@ -195,7 +188,7 @@ func (m Middleware) fetchStudios(ctx context.Context) func(keys []int) ([]*model
|
|||
|
||||
func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.Tag, []error) {
|
||||
return func(keys []int) (ret []*models.Tag, errs []error) {
|
||||
err := m.withTxn(ctx, func(ctx context.Context) error {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
ret, err = m.Repository.Tag.FindMany(ctx, keys)
|
||||
return err
|
||||
|
|
@ -206,7 +199,7 @@ func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.T
|
|||
|
||||
func (m Middleware) fetchMovies(ctx context.Context) func(keys []int) ([]*models.Movie, []error) {
|
||||
return func(keys []int) (ret []*models.Movie, errs []error) {
|
||||
err := m.withTxn(ctx, func(ctx context.Context) error {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
ret, err = m.Repository.Movie.FindMany(ctx, keys)
|
||||
return err
|
||||
|
|
@ -217,7 +210,7 @@ func (m Middleware) fetchMovies(ctx context.Context) func(keys []int) ([]*models
|
|||
|
||||
func (m Middleware) fetchFiles(ctx context.Context) func(keys []models.FileID) ([]models.File, []error) {
|
||||
return func(keys []models.FileID) (ret []models.File, errs []error) {
|
||||
err := m.withTxn(ctx, func(ctx context.Context) error {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
ret, err = m.Repository.File.Find(ctx, keys...)
|
||||
return err
|
||||
|
|
@ -228,7 +221,7 @@ func (m Middleware) fetchFiles(ctx context.Context) func(keys []models.FileID) (
|
|||
|
||||
func (m Middleware) fetchScenesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) {
|
||||
return func(keys []int) (ret [][]models.FileID, errs []error) {
|
||||
err := m.withTxn(ctx, func(ctx context.Context) error {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
ret, err = m.Repository.Scene.GetManyFileIDs(ctx, keys)
|
||||
return err
|
||||
|
|
@ -239,7 +232,7 @@ func (m Middleware) fetchScenesFileIDs(ctx context.Context) func(keys []int) ([]
|
|||
|
||||
func (m Middleware) fetchImagesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) {
|
||||
return func(keys []int) (ret [][]models.FileID, errs []error) {
|
||||
err := m.withTxn(ctx, func(ctx context.Context) error {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
ret, err = m.Repository.Image.GetManyFileIDs(ctx, keys)
|
||||
return err
|
||||
|
|
@ -250,7 +243,7 @@ func (m Middleware) fetchImagesFileIDs(ctx context.Context) func(keys []int) ([]
|
|||
|
||||
func (m Middleware) fetchGalleriesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) {
|
||||
return func(keys []int) (ret [][]models.FileID, errs []error) {
|
||||
err := m.withTxn(ctx, func(ctx context.Context) error {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
ret, err = m.Repository.Gallery.GetManyFileIDs(ctx, keys)
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
"github.com/stashapp/stash/pkg/scraper"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
"github.com/stashapp/stash/pkg/scraper/stashbox"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -33,8 +33,7 @@ type hookExecutor interface {
|
|||
}
|
||||
|
||||
type Resolver struct {
|
||||
txnManager txn.Manager
|
||||
repository manager.Repository
|
||||
repository models.Repository
|
||||
sceneService manager.SceneService
|
||||
imageService manager.ImageService
|
||||
galleryService manager.GalleryService
|
||||
|
|
@ -102,11 +101,15 @@ type tagResolver struct{ *Resolver }
|
|||
type savedFilterResolver struct{ *Resolver }
|
||||
|
||||
func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
|
||||
return txn.WithTxn(ctx, r.txnManager, fn)
|
||||
return r.repository.WithTxn(ctx, fn)
|
||||
}
|
||||
|
||||
func (r *Resolver) withReadTxn(ctx context.Context, fn func(ctx context.Context) error) error {
|
||||
return txn.WithReadTxn(ctx, r.txnManager, fn)
|
||||
return r.repository.WithReadTxn(ctx, fn)
|
||||
}
|
||||
|
||||
func (r *Resolver) stashboxRepository() stashbox.Repository {
|
||||
return stashbox.NewRepository(r.repository)
|
||||
}
|
||||
|
||||
func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) {
|
||||
|
|
|
|||
|
|
@ -316,7 +316,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen
|
|||
|
||||
if input.CustomPerformerImageLocation != nil {
|
||||
c.Set(config.CustomPerformerImageLocation, *input.CustomPerformerImageLocation)
|
||||
initialiseCustomImages()
|
||||
initCustomPerformerImages(*input.CustomPerformerImageLocation)
|
||||
}
|
||||
|
||||
if input.ScraperUserAgent != nil {
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ func (r *mutationResolver) MoveFiles(ctx context.Context, input MoveFilesInput)
|
|||
fileStore := r.repository.File
|
||||
folderStore := r.repository.Folder
|
||||
mover := file.NewMover(fileStore, folderStore)
|
||||
mover.RegisterHooks(ctx, r.txnManager)
|
||||
mover.RegisterHooks(ctx)
|
||||
|
||||
var (
|
||||
folder *models.Folder
|
||||
|
|
|
|||
|
|
@ -11,30 +11,30 @@ import (
|
|||
)
|
||||
|
||||
func (r *mutationResolver) MigrateSceneScreenshots(ctx context.Context, input MigrateSceneScreenshotsInput) (string, error) {
|
||||
db := manager.GetInstance().Database
|
||||
mgr := manager.GetInstance()
|
||||
t := &task.MigrateSceneScreenshotsJob{
|
||||
ScreenshotsPath: manager.GetInstance().Paths.Generated.Screenshots,
|
||||
Input: scene.MigrateSceneScreenshotsInput{
|
||||
DeleteFiles: utils.IsTrue(input.DeleteFiles),
|
||||
OverwriteExisting: utils.IsTrue(input.OverwriteExisting),
|
||||
},
|
||||
SceneRepo: db.Scene,
|
||||
TxnManager: db,
|
||||
SceneRepo: mgr.Repository.Scene,
|
||||
TxnManager: mgr.Repository.TxnManager,
|
||||
}
|
||||
jobID := manager.GetInstance().JobManager.Add(ctx, "Migrating scene screenshots to blobs...", t)
|
||||
jobID := mgr.JobManager.Add(ctx, "Migrating scene screenshots to blobs...", t)
|
||||
|
||||
return strconv.Itoa(jobID), nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) MigrateBlobs(ctx context.Context, input MigrateBlobsInput) (string, error) {
|
||||
db := manager.GetInstance().Database
|
||||
mgr := manager.GetInstance()
|
||||
t := &task.MigrateBlobsJob{
|
||||
TxnManager: db,
|
||||
BlobStore: db.Blobs,
|
||||
Vacuumer: db,
|
||||
TxnManager: mgr.Database,
|
||||
BlobStore: mgr.Database.Blobs,
|
||||
Vacuumer: mgr.Database,
|
||||
DeleteOld: utils.IsTrue(input.DeleteOld),
|
||||
}
|
||||
jobID := manager.GetInstance().JobManager.Add(ctx, "Migrating blobs...", t)
|
||||
jobID := mgr.JobManager.Add(ctx, "Migrating blobs...", t)
|
||||
|
||||
return strconv.Itoa(jobID), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/stashapp/stash/internal/static"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
|
||||
|
|
@ -50,12 +51,6 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp
|
|||
return nil, fmt.Errorf("converting studio id: %w", err)
|
||||
}
|
||||
|
||||
// HACK: if back image is being set, set the front image to the default.
|
||||
// This is because we can't have a null front image with a non-null back image.
|
||||
if input.FrontImage == nil && input.BackImage != nil {
|
||||
input.FrontImage = &models.DefaultMovieImage
|
||||
}
|
||||
|
||||
// Process the base 64 encoded image string
|
||||
var frontimageData []byte
|
||||
if input.FrontImage != nil {
|
||||
|
|
@ -74,6 +69,12 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp
|
|||
}
|
||||
}
|
||||
|
||||
// HACK: if back image is being set, set the front image to the default.
|
||||
// This is because we can't have a null front image with a non-null back image.
|
||||
if len(frontimageData) == 0 && len(backimageData) != 0 {
|
||||
frontimageData = static.ReadAll(static.DefaultMovieImage)
|
||||
}
|
||||
|
||||
// Start the transaction and save the movie
|
||||
if err := r.withTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.repository.Movie
|
||||
|
|
|
|||
|
|
@ -11,15 +11,6 @@ import (
|
|||
"github.com/stashapp/stash/pkg/scraper/stashbox"
|
||||
)
|
||||
|
||||
func (r *Resolver) stashboxRepository() stashbox.Repository {
|
||||
return stashbox.Repository{
|
||||
Scene: r.repository.Scene,
|
||||
Performer: r.repository.Performer,
|
||||
Tag: r.repository.Tag,
|
||||
Studio: r.repository.Studio,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input StashBoxFingerprintSubmissionInput) (bool, error) {
|
||||
boxes := config.GetInstance().GetStashBoxes()
|
||||
|
||||
|
|
@ -27,7 +18,7 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
|
|||
return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
|
||||
}
|
||||
|
||||
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
|
||||
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.stashboxRepository())
|
||||
|
||||
return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint)
|
||||
}
|
||||
|
|
@ -49,7 +40,7 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S
|
|||
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
|
||||
}
|
||||
|
||||
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
|
||||
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.stashboxRepository())
|
||||
|
||||
id, err := strconv.Atoi(input.ID)
|
||||
if err != nil {
|
||||
|
|
@ -91,7 +82,7 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp
|
|||
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
|
||||
}
|
||||
|
||||
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
|
||||
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.stashboxRepository())
|
||||
|
||||
id, err := strconv.Atoi(input.ID)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stashapp/stash/internal/manager"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/models/mocks"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
|
|
@ -15,14 +14,9 @@ import (
|
|||
)
|
||||
|
||||
// TODO - move this into a common area
|
||||
func newResolver() *Resolver {
|
||||
txnMgr := &mocks.TxnManager{}
|
||||
func newResolver(db *mocks.Database) *Resolver {
|
||||
return &Resolver{
|
||||
txnManager: txnMgr,
|
||||
repository: manager.Repository{
|
||||
TxnManager: txnMgr,
|
||||
Tag: &mocks.TagReaderWriter{},
|
||||
},
|
||||
repository: db.Repository(),
|
||||
hookExecutor: &mockHookExecutor{},
|
||||
}
|
||||
}
|
||||
|
|
@ -45,9 +39,8 @@ func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType
|
|||
}
|
||||
|
||||
func TestTagCreate(t *testing.T) {
|
||||
r := newResolver()
|
||||
|
||||
tagRW := r.repository.Tag.(*mocks.TagReaderWriter)
|
||||
db := mocks.NewDatabase()
|
||||
r := newResolver(db)
|
||||
|
||||
pp := 1
|
||||
findFilter := &models.FindFilterType{
|
||||
|
|
@ -72,17 +65,17 @@ func TestTagCreate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
tagRW.On("Query", mock.Anything, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
|
||||
db.Tag.On("Query", mock.Anything, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
|
||||
{
|
||||
ID: existingTagID,
|
||||
Name: existingTagName,
|
||||
},
|
||||
}, 1, nil).Once()
|
||||
tagRW.On("Query", mock.Anything, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
|
||||
tagRW.On("Query", mock.Anything, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
|
||||
db.Tag.On("Query", mock.Anything, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
|
||||
db.Tag.On("Query", mock.Anything, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
|
||||
|
||||
expectedErr := errors.New("TagCreate error")
|
||||
tagRW.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Return(expectedErr)
|
||||
db.Tag.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Return(expectedErr)
|
||||
|
||||
// fails here because testCtx is empty
|
||||
// TODO: Fix this
|
||||
|
|
@ -101,22 +94,22 @@ func TestTagCreate(t *testing.T) {
|
|||
})
|
||||
|
||||
assert.Equal(t, expectedErr, err)
|
||||
tagRW.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
|
||||
r = newResolver()
|
||||
tagRW = r.repository.Tag.(*mocks.TagReaderWriter)
|
||||
db = mocks.NewDatabase()
|
||||
r = newResolver(db)
|
||||
|
||||
tagRW.On("Query", mock.Anything, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
|
||||
tagRW.On("Query", mock.Anything, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
|
||||
db.Tag.On("Query", mock.Anything, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
|
||||
db.Tag.On("Query", mock.Anything, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
|
||||
newTag := &models.Tag{
|
||||
ID: newTagID,
|
||||
Name: tagName,
|
||||
}
|
||||
tagRW.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
db.Tag.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
arg := args.Get(1).(*models.Tag)
|
||||
arg.ID = newTagID
|
||||
}).Return(nil)
|
||||
tagRW.On("Find", mock.Anything, newTagID).Return(newTag, nil)
|
||||
db.Tag.On("Find", mock.Anything, newTagID).Return(newTag, nil)
|
||||
|
||||
tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
|
||||
Name: tagName,
|
||||
|
|
@ -124,4 +117,5 @@ func TestTagCreate(t *testing.T) {
|
|||
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, tag)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -243,7 +243,9 @@ func makeConfigUIResult() map[string]interface{} {
|
|||
}
|
||||
|
||||
func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input config.StashBoxInput) (*StashBoxValidationResult, error) {
|
||||
client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager, r.stashboxRepository())
|
||||
box := models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}
|
||||
client := stashbox.NewClient(box, r.stashboxRepository())
|
||||
|
||||
user, err := client.GetUser(ctx)
|
||||
|
||||
valid := user != nil && user.Me != nil
|
||||
|
|
|
|||
|
|
@ -191,16 +191,11 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
|
|||
}
|
||||
|
||||
func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config models.SceneParserInput) (ret *SceneParserResultType, err error) {
|
||||
parser := scene.NewFilenameParser(filter, config)
|
||||
repo := scene.NewFilenameParserRepository(r.repository)
|
||||
parser := scene.NewFilenameParser(filter, config, repo)
|
||||
|
||||
if err := r.withReadTxn(ctx, func(ctx context.Context) error {
|
||||
result, count, err := parser.Parse(ctx, scene.FilenameParserRepository{
|
||||
Scene: r.repository.Scene,
|
||||
Performer: r.repository.Performer,
|
||||
Studio: r.repository.Studio,
|
||||
Movie: r.repository.Movie,
|
||||
Tag: r.repository.Tag,
|
||||
})
|
||||
result, count, err := parser.Parse(ctx)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ func (r *queryResolver) getStashBoxClient(index int) (*stashbox.Client, error) {
|
|||
return nil, fmt.Errorf("%w: invalid stash_box_index %d", ErrInput, index)
|
||||
}
|
||||
|
||||
return stashbox.NewClient(*boxes[index], r.txnManager, r.stashboxRepository()), nil
|
||||
return stashbox.NewClient(*boxes[index], r.stashboxRepository()), nil
|
||||
}
|
||||
|
||||
func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) {
|
||||
|
|
|
|||
15
internal/api/routes.go
Normal file
15
internal/api/routes.go
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
type routes struct {
|
||||
txnManager txn.Manager
|
||||
}
|
||||
|
||||
func (rs routes) withReadTxn(r *http.Request, fn txn.TxnFunc) error {
|
||||
return txn.WithReadTxn(r.Context(), rs.txnManager, fn)
|
||||
}
|
||||
|
|
@ -12,6 +12,10 @@ type customRoutes struct {
|
|||
servedFolders config.URLMap
|
||||
}
|
||||
|
||||
func getCustomRoutes(servedFolders config.URLMap) chi.Router {
|
||||
return customRoutes{servedFolders: servedFolders}.Routes()
|
||||
}
|
||||
|
||||
func (rs customRoutes) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import (
|
|||
|
||||
type downloadsRoutes struct{}
|
||||
|
||||
func getDownloadsRoutes() chi.Router {
|
||||
return downloadsRoutes{}.Routes()
|
||||
}
|
||||
|
||||
func (rs downloadsRoutes) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package api
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
|
|
@ -17,7 +16,6 @@ import (
|
|||
"github.com/stashapp/stash/pkg/image"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
|
|
@ -27,11 +25,19 @@ type ImageFinder interface {
|
|||
}
|
||||
|
||||
type imageRoutes struct {
|
||||
txnManager txn.Manager
|
||||
routes
|
||||
imageFinder ImageFinder
|
||||
fileGetter models.FileGetter
|
||||
}
|
||||
|
||||
func getImageRoutes(repo models.Repository) chi.Router {
|
||||
return imageRoutes{
|
||||
routes: routes{txnManager: repo.TxnManager},
|
||||
imageFinder: repo.Image,
|
||||
fileGetter: repo.File,
|
||||
}.Routes()
|
||||
}
|
||||
|
||||
func (rs imageRoutes) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
|
||||
|
|
@ -46,8 +52,6 @@ func (rs imageRoutes) Routes() chi.Router {
|
|||
return r
|
||||
}
|
||||
|
||||
// region Handlers
|
||||
|
||||
func (rs imageRoutes) Thumbnail(w http.ResponseWriter, r *http.Request) {
|
||||
img := r.Context().Value(imageKey).(*models.Image)
|
||||
filepath := manager.GetInstance().Paths.Generated.GetThumbnailPath(img.Checksum, models.DefaultGthumbWidth)
|
||||
|
|
@ -119,8 +123,6 @@ func (rs imageRoutes) Image(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (rs imageRoutes) serveImage(w http.ResponseWriter, r *http.Request, i *models.Image, useDefault bool) {
|
||||
const defaultImageImage = "image/image.svg"
|
||||
|
||||
if i.Files.Primary() != nil {
|
||||
err := i.Files.Primary().Base().Serve(&file.OsFS{}, w, r)
|
||||
if err == nil {
|
||||
|
|
@ -141,22 +143,18 @@ func (rs imageRoutes) serveImage(w http.ResponseWriter, r *http.Request, i *mode
|
|||
return
|
||||
}
|
||||
|
||||
// fall back to static image
|
||||
f, _ := static.Image.Open(defaultImageImage)
|
||||
defer f.Close()
|
||||
image, _ := io.ReadAll(f)
|
||||
// fallback to default image
|
||||
image := static.ReadAll(static.DefaultImageImage)
|
||||
utils.ServeImage(w, r, image)
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
func (rs imageRoutes) ImageCtx(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
imageIdentifierQueryParam := chi.URLParam(r, "imageId")
|
||||
imageID, _ := strconv.Atoi(imageIdentifierQueryParam)
|
||||
|
||||
var image *models.Image
|
||||
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
_ = rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
qb := rs.imageFinder
|
||||
if imageID == 0 {
|
||||
images, _ := qb.FindByChecksum(ctx, imageIdentifierQueryParam)
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@ import (
|
|||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/stashapp/stash/internal/static"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
|
|
@ -20,10 +21,17 @@ type MovieFinder interface {
|
|||
}
|
||||
|
||||
type movieRoutes struct {
|
||||
txnManager txn.Manager
|
||||
routes
|
||||
movieFinder MovieFinder
|
||||
}
|
||||
|
||||
func getMovieRoutes(repo models.Repository) chi.Router {
|
||||
return movieRoutes{
|
||||
routes: routes{txnManager: repo.TxnManager},
|
||||
movieFinder: repo.Movie,
|
||||
}.Routes()
|
||||
}
|
||||
|
||||
func (rs movieRoutes) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
|
||||
|
|
@ -41,7 +49,7 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) {
|
|||
defaultParam := r.URL.Query().Get("default")
|
||||
var image []byte
|
||||
if defaultParam != "true" {
|
||||
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
image, err = rs.movieFinder.GetFrontImage(ctx, movie.ID)
|
||||
return err
|
||||
|
|
@ -54,8 +62,9 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// fallback to default image
|
||||
if len(image) == 0 {
|
||||
image, _ = utils.ProcessBase64Image(models.DefaultMovieImage)
|
||||
image = static.ReadAll(static.DefaultMovieImage)
|
||||
}
|
||||
|
||||
utils.ServeImage(w, r, image)
|
||||
|
|
@ -66,7 +75,7 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
|
|||
defaultParam := r.URL.Query().Get("default")
|
||||
var image []byte
|
||||
if defaultParam != "true" {
|
||||
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
image, err = rs.movieFinder.GetBackImage(ctx, movie.ID)
|
||||
return err
|
||||
|
|
@ -79,8 +88,9 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// fallback to default image
|
||||
if len(image) == 0 {
|
||||
image, _ = utils.ProcessBase64Image(models.DefaultMovieImage)
|
||||
image = static.ReadAll(static.DefaultMovieImage)
|
||||
}
|
||||
|
||||
utils.ServeImage(w, r, image)
|
||||
|
|
@ -95,7 +105,7 @@ func (rs movieRoutes) MovieCtx(next http.Handler) http.Handler {
|
|||
}
|
||||
|
||||
var movie *models.Movie
|
||||
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
_ = rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
movie, _ = rs.movieFinder.Find(ctx, movieID)
|
||||
return nil
|
||||
})
|
||||
|
|
|
|||
|
|
@ -7,10 +7,8 @@ import (
|
|||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/stashapp/stash/internal/manager/config"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
|
|
@ -20,10 +18,17 @@ type PerformerFinder interface {
|
|||
}
|
||||
|
||||
type performerRoutes struct {
|
||||
txnManager txn.Manager
|
||||
routes
|
||||
performerFinder PerformerFinder
|
||||
}
|
||||
|
||||
func getPerformerRoutes(repo models.Repository) chi.Router {
|
||||
return performerRoutes{
|
||||
routes: routes{txnManager: repo.TxnManager},
|
||||
performerFinder: repo.Performer,
|
||||
}.Routes()
|
||||
}
|
||||
|
||||
func (rs performerRoutes) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
|
||||
|
|
@ -41,7 +46,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var image []byte
|
||||
if defaultParam != "true" {
|
||||
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
image, err = rs.performerFinder.GetImage(ctx, performer.ID)
|
||||
return err
|
||||
|
|
@ -55,7 +60,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
if len(image) == 0 {
|
||||
image, _ = getRandomPerformerImageUsingName(performer.Name, performer.Gender, config.GetInstance().GetCustomPerformerImageLocation())
|
||||
image = getDefaultPerformerImage(performer.Name, performer.Gender)
|
||||
}
|
||||
|
||||
utils.ServeImage(w, r, image)
|
||||
|
|
@ -70,7 +75,7 @@ func (rs performerRoutes) PerformerCtx(next http.Handler) http.Handler {
|
|||
}
|
||||
|
||||
var performer *models.Performer
|
||||
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
_ = rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
performer, err = rs.performerFinder.Find(ctx, performerID)
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import (
|
|||
"github.com/stashapp/stash/pkg/fsutil"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
|
|
@ -43,7 +42,7 @@ type CaptionFinder interface {
|
|||
}
|
||||
|
||||
type sceneRoutes struct {
|
||||
txnManager txn.Manager
|
||||
routes
|
||||
sceneFinder SceneFinder
|
||||
fileGetter models.FileGetter
|
||||
captionFinder CaptionFinder
|
||||
|
|
@ -51,6 +50,17 @@ type sceneRoutes struct {
|
|||
tagFinder SceneMarkerTagFinder
|
||||
}
|
||||
|
||||
func getSceneRoutes(repo models.Repository) chi.Router {
|
||||
return sceneRoutes{
|
||||
routes: routes{txnManager: repo.TxnManager},
|
||||
sceneFinder: repo.Scene,
|
||||
fileGetter: repo.File,
|
||||
captionFinder: repo.File,
|
||||
sceneMarkerFinder: repo.SceneMarker,
|
||||
tagFinder: repo.Tag,
|
||||
}.Routes()
|
||||
}
|
||||
|
||||
func (rs sceneRoutes) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
|
||||
|
|
@ -89,8 +99,6 @@ func (rs sceneRoutes) Routes() chi.Router {
|
|||
return r
|
||||
}
|
||||
|
||||
// region Handlers
|
||||
|
||||
func (rs sceneRoutes) StreamDirect(w http.ResponseWriter, r *http.Request) {
|
||||
scene := r.Context().Value(sceneKey).(*models.Scene)
|
||||
ss := manager.SceneServer{
|
||||
|
|
@ -270,13 +278,13 @@ func (rs sceneRoutes) Webp(w http.ResponseWriter, r *http.Request) {
|
|||
utils.ServeStaticFile(w, r, filepath)
|
||||
}
|
||||
|
||||
func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.SceneMarker) (*string, error) {
|
||||
func (rs sceneRoutes) getChapterVttTitle(r *http.Request, marker *models.SceneMarker) (*string, error) {
|
||||
if marker.Title != "" {
|
||||
return &marker.Title, nil
|
||||
}
|
||||
|
||||
var title string
|
||||
if err := txn.WithReadTxn(ctx, rs.txnManager, func(ctx context.Context) error {
|
||||
if err := rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
qb := rs.tagFinder
|
||||
primaryTag, err := qb.Find(ctx, marker.PrimaryTagID)
|
||||
if err != nil {
|
||||
|
|
@ -305,7 +313,7 @@ func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.Sce
|
|||
func (rs sceneRoutes) VttChapter(w http.ResponseWriter, r *http.Request) {
|
||||
scene := r.Context().Value(sceneKey).(*models.Scene)
|
||||
var sceneMarkers []*models.SceneMarker
|
||||
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
sceneMarkers, err = rs.sceneMarkerFinder.FindBySceneID(ctx, scene.ID)
|
||||
return err
|
||||
|
|
@ -325,7 +333,7 @@ func (rs sceneRoutes) VttChapter(w http.ResponseWriter, r *http.Request) {
|
|||
time := utils.GetVTTTime(marker.Seconds)
|
||||
vttLines = append(vttLines, time+" --> "+time)
|
||||
|
||||
vttTitle, err := rs.getChapterVttTitle(r.Context(), marker)
|
||||
vttTitle, err := rs.getChapterVttTitle(r, marker)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
|
|
@ -404,7 +412,7 @@ func (rs sceneRoutes) Caption(w http.ResponseWriter, r *http.Request, lang strin
|
|||
s := r.Context().Value(sceneKey).(*models.Scene)
|
||||
|
||||
var captions []*models.VideoCaption
|
||||
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
primaryFile := s.Files.Primary()
|
||||
if primaryFile == nil {
|
||||
|
|
@ -466,7 +474,7 @@ func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request)
|
|||
sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm())
|
||||
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
|
||||
var sceneMarker *models.SceneMarker
|
||||
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
|
||||
return err
|
||||
|
|
@ -494,7 +502,7 @@ func (rs sceneRoutes) SceneMarkerPreview(w http.ResponseWriter, r *http.Request)
|
|||
sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm())
|
||||
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
|
||||
var sceneMarker *models.SceneMarker
|
||||
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
|
||||
return err
|
||||
|
|
@ -530,7 +538,7 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque
|
|||
sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm())
|
||||
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
|
||||
var sceneMarker *models.SceneMarker
|
||||
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
|
||||
return err
|
||||
|
|
@ -561,8 +569,6 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque
|
|||
}
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sceneID, err := strconv.Atoi(chi.URLParam(r, "sceneId"))
|
||||
|
|
@ -572,7 +578,7 @@ func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler {
|
|||
}
|
||||
|
||||
var scene *models.Scene
|
||||
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
_ = rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
qb := rs.sceneFinder
|
||||
scene, _ = qb.Find(ctx, sceneID)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package api
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
|
|
@ -11,7 +10,6 @@ import (
|
|||
"github.com/stashapp/stash/internal/static"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
|
|
@ -21,10 +19,17 @@ type StudioFinder interface {
|
|||
}
|
||||
|
||||
type studioRoutes struct {
|
||||
txnManager txn.Manager
|
||||
routes
|
||||
studioFinder StudioFinder
|
||||
}
|
||||
|
||||
func getStudioRoutes(repo models.Repository) chi.Router {
|
||||
return studioRoutes{
|
||||
routes: routes{txnManager: repo.TxnManager},
|
||||
studioFinder: repo.Studio,
|
||||
}.Routes()
|
||||
}
|
||||
|
||||
func (rs studioRoutes) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
|
||||
|
|
@ -42,7 +47,7 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var image []byte
|
||||
if defaultParam != "true" {
|
||||
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
image, err = rs.studioFinder.GetImage(ctx, studio.ID)
|
||||
return err
|
||||
|
|
@ -55,15 +60,9 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// fallback to default image
|
||||
if len(image) == 0 {
|
||||
const defaultStudioImage = "studio/studio.svg"
|
||||
|
||||
// fall back to static image
|
||||
f, _ := static.Studio.Open(defaultStudioImage)
|
||||
defer f.Close()
|
||||
stat, _ := f.Stat()
|
||||
http.ServeContent(w, r, "studio.svg", stat.ModTime(), f.(io.ReadSeeker))
|
||||
return
|
||||
image = static.ReadAll(static.DefaultStudioImage)
|
||||
}
|
||||
|
||||
utils.ServeImage(w, r, image)
|
||||
|
|
@ -78,7 +77,7 @@ func (rs studioRoutes) StudioCtx(next http.Handler) http.Handler {
|
|||
}
|
||||
|
||||
var studio *models.Studio
|
||||
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
_ = rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
studio, err = rs.studioFinder.Find(ctx, studioID)
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package api
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
|
|
@ -11,7 +10,6 @@ import (
|
|||
"github.com/stashapp/stash/internal/static"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
|
|
@ -21,8 +19,15 @@ type TagFinder interface {
|
|||
}
|
||||
|
||||
type tagRoutes struct {
|
||||
txnManager txn.Manager
|
||||
tagFinder TagFinder
|
||||
routes
|
||||
tagFinder TagFinder
|
||||
}
|
||||
|
||||
func getTagRoutes(repo models.Repository) chi.Router {
|
||||
return tagRoutes{
|
||||
routes: routes{txnManager: repo.TxnManager},
|
||||
tagFinder: repo.Tag,
|
||||
}.Routes()
|
||||
}
|
||||
|
||||
func (rs tagRoutes) Routes() chi.Router {
|
||||
|
|
@ -42,7 +47,7 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var image []byte
|
||||
if defaultParam != "true" {
|
||||
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
image, err = rs.tagFinder.GetImage(ctx, tag.ID)
|
||||
return err
|
||||
|
|
@ -55,15 +60,9 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// fallback to default image
|
||||
if len(image) == 0 {
|
||||
const defaultTagImage = "tag/tag.svg"
|
||||
|
||||
// fall back to static image
|
||||
f, _ := static.Tag.Open(defaultTagImage)
|
||||
defer f.Close()
|
||||
stat, _ := f.Stat()
|
||||
http.ServeContent(w, r, "tag.svg", stat.ModTime(), f.(io.ReadSeeker))
|
||||
return
|
||||
image = static.ReadAll(static.DefaultTagImage)
|
||||
}
|
||||
|
||||
utils.ServeImage(w, r, image)
|
||||
|
|
@ -78,7 +77,7 @@ func (rs tagRoutes) TagCtx(next http.Handler) http.Handler {
|
|||
}
|
||||
|
||||
var tag *models.Tag
|
||||
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
|
||||
_ = rs.withReadTxn(r, func(ctx context.Context) error {
|
||||
var err error
|
||||
tag, err = rs.tagFinder.Find(ctx, tagID)
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -50,7 +50,9 @@ var uiBox = ui.UIBox
|
|||
var loginUIBox = ui.LoginUIBox
|
||||
|
||||
func Start() error {
|
||||
initialiseImages()
|
||||
c := config.GetInstance()
|
||||
|
||||
initCustomPerformerImages(c.GetCustomPerformerImageLocation())
|
||||
|
||||
r := chi.NewRouter()
|
||||
|
||||
|
|
@ -62,7 +64,6 @@ func Start() error {
|
|||
|
||||
r.Use(middleware.Recoverer)
|
||||
|
||||
c := config.GetInstance()
|
||||
if c.GetLogAccess() {
|
||||
httpLogger := httplog.NewLogger("Stash", httplog.Options{
|
||||
Concise: true,
|
||||
|
|
@ -82,11 +83,10 @@ func Start() error {
|
|||
return errors.New(message)
|
||||
}
|
||||
|
||||
txnManager := manager.GetInstance().Repository
|
||||
repo := manager.GetInstance().Repository
|
||||
|
||||
dataloaders := loaders.Middleware{
|
||||
DatabaseProvider: txnManager,
|
||||
Repository: txnManager,
|
||||
Repository: repo,
|
||||
}
|
||||
|
||||
r.Use(dataloaders.Middleware)
|
||||
|
|
@ -96,8 +96,7 @@ func Start() error {
|
|||
imageService := manager.GetInstance().ImageService
|
||||
galleryService := manager.GetInstance().GalleryService
|
||||
resolver := &Resolver{
|
||||
txnManager: txnManager,
|
||||
repository: txnManager,
|
||||
repository: repo,
|
||||
sceneService: sceneService,
|
||||
imageService: imageService,
|
||||
galleryService: galleryService,
|
||||
|
|
@ -144,36 +143,13 @@ func Start() error {
|
|||
gqlPlayground.Handler("GraphQL playground", endpoint)(w, r)
|
||||
})
|
||||
|
||||
r.Mount("/performer", performerRoutes{
|
||||
txnManager: txnManager,
|
||||
performerFinder: txnManager.Performer,
|
||||
}.Routes())
|
||||
r.Mount("/scene", sceneRoutes{
|
||||
txnManager: txnManager,
|
||||
sceneFinder: txnManager.Scene,
|
||||
fileGetter: txnManager.File,
|
||||
captionFinder: txnManager.File,
|
||||
sceneMarkerFinder: txnManager.SceneMarker,
|
||||
tagFinder: txnManager.Tag,
|
||||
}.Routes())
|
||||
r.Mount("/image", imageRoutes{
|
||||
txnManager: txnManager,
|
||||
imageFinder: txnManager.Image,
|
||||
fileGetter: txnManager.File,
|
||||
}.Routes())
|
||||
r.Mount("/studio", studioRoutes{
|
||||
txnManager: txnManager,
|
||||
studioFinder: txnManager.Studio,
|
||||
}.Routes())
|
||||
r.Mount("/movie", movieRoutes{
|
||||
txnManager: txnManager,
|
||||
movieFinder: txnManager.Movie,
|
||||
}.Routes())
|
||||
r.Mount("/tag", tagRoutes{
|
||||
txnManager: txnManager,
|
||||
tagFinder: txnManager.Tag,
|
||||
}.Routes())
|
||||
r.Mount("/downloads", downloadsRoutes{}.Routes())
|
||||
r.Mount("/performer", getPerformerRoutes(repo))
|
||||
r.Mount("/scene", getSceneRoutes(repo))
|
||||
r.Mount("/image", getImageRoutes(repo))
|
||||
r.Mount("/studio", getStudioRoutes(repo))
|
||||
r.Mount("/movie", getMovieRoutes(repo))
|
||||
r.Mount("/tag", getTagRoutes(repo))
|
||||
r.Mount("/downloads", getDownloadsRoutes())
|
||||
|
||||
r.HandleFunc("/css", cssHandler(c, pluginCache))
|
||||
r.HandleFunc("/javascript", javascriptHandler(c, pluginCache))
|
||||
|
|
@ -193,9 +169,7 @@ func Start() error {
|
|||
// Serve static folders
|
||||
customServedFolders := c.GetCustomServedFolders()
|
||||
if customServedFolders != nil {
|
||||
r.Mount("/custom", customRoutes{
|
||||
servedFolders: customServedFolders,
|
||||
}.Routes())
|
||||
r.Mount("/custom", getCustomRoutes(customServedFolders))
|
||||
}
|
||||
|
||||
customUILocation := c.GetCustomUILocation()
|
||||
|
|
|
|||
|
|
@ -52,11 +52,10 @@ func TestGalleryPerformers(t *testing.T) {
|
|||
assert := assert.New(t)
|
||||
|
||||
for _, test := range testTables {
|
||||
mockPerformerReader := &mocks.PerformerReaderWriter{}
|
||||
mockGalleryReader := &mocks.GalleryReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
|
||||
db.Performer.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Performer.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
|
||||
|
||||
if test.Matches {
|
||||
matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool {
|
||||
|
|
@ -69,7 +68,7 @@ func TestGalleryPerformers(t *testing.T) {
|
|||
|
||||
return galleryPartialsEqual(got, expected)
|
||||
})
|
||||
mockGalleryReader.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
|
||||
db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
gallery := models.Gallery{
|
||||
|
|
@ -77,11 +76,10 @@ func TestGalleryPerformers(t *testing.T) {
|
|||
Path: test.Path,
|
||||
PerformerIDs: models.NewRelatedIDs([]int{}),
|
||||
}
|
||||
err := GalleryPerformers(testCtx, &gallery, mockGalleryReader, mockPerformerReader, nil)
|
||||
err := GalleryPerformers(testCtx, &gallery, db.Gallery, db.Performer, nil)
|
||||
|
||||
assert.Nil(err)
|
||||
mockPerformerReader.AssertExpectations(t)
|
||||
mockGalleryReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -107,7 +105,7 @@ func TestGalleryStudios(t *testing.T) {
|
|||
|
||||
assert := assert.New(t)
|
||||
|
||||
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) {
|
||||
doTest := func(db *mocks.Database, test pathTestTable) {
|
||||
if test.Matches {
|
||||
matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool {
|
||||
expected := models.GalleryPartial{
|
||||
|
|
@ -116,29 +114,27 @@ func TestGalleryStudios(t *testing.T) {
|
|||
|
||||
return galleryPartialsEqual(got, expected)
|
||||
})
|
||||
mockGalleryReader.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
|
||||
db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
gallery := models.Gallery{
|
||||
ID: galleryID,
|
||||
Path: test.Path,
|
||||
}
|
||||
err := GalleryStudios(testCtx, &gallery, mockGalleryReader, mockStudioReader, nil)
|
||||
err := GalleryStudios(testCtx, &gallery, db.Gallery, db.Studio, nil)
|
||||
|
||||
assert.Nil(err)
|
||||
mockStudioReader.AssertExpectations(t)
|
||||
mockGalleryReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
for _, test := range testTables {
|
||||
mockStudioReader := &mocks.StudioReaderWriter{}
|
||||
mockGalleryReader := &mocks.GalleryReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
|
||||
db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
db.Studio.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
|
||||
|
||||
doTest(mockStudioReader, mockGalleryReader, test)
|
||||
doTest(db, test)
|
||||
}
|
||||
|
||||
// test against aliases
|
||||
|
|
@ -146,17 +142,16 @@ func TestGalleryStudios(t *testing.T) {
|
|||
studio.Name = unmatchedName
|
||||
|
||||
for _, test := range testTables {
|
||||
mockStudioReader := &mocks.StudioReaderWriter{}
|
||||
mockGalleryReader := &mocks.GalleryReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
|
||||
db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
db.Studio.On("GetAliases", testCtx, studioID).Return([]string{
|
||||
studioName,
|
||||
}, nil).Once()
|
||||
mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
|
||||
db.Studio.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
|
||||
|
||||
doTest(mockStudioReader, mockGalleryReader, test)
|
||||
doTest(db, test)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -182,7 +177,7 @@ func TestGalleryTags(t *testing.T) {
|
|||
|
||||
assert := assert.New(t)
|
||||
|
||||
doTest := func(mockTagReader *mocks.TagReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) {
|
||||
doTest := func(db *mocks.Database, test pathTestTable) {
|
||||
if test.Matches {
|
||||
matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool {
|
||||
expected := models.GalleryPartial{
|
||||
|
|
@ -194,7 +189,7 @@ func TestGalleryTags(t *testing.T) {
|
|||
|
||||
return galleryPartialsEqual(got, expected)
|
||||
})
|
||||
mockGalleryReader.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
|
||||
db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
gallery := models.Gallery{
|
||||
|
|
@ -202,38 +197,35 @@ func TestGalleryTags(t *testing.T) {
|
|||
Path: test.Path,
|
||||
TagIDs: models.NewRelatedIDs([]int{}),
|
||||
}
|
||||
err := GalleryTags(testCtx, &gallery, mockGalleryReader, mockTagReader, nil)
|
||||
err := GalleryTags(testCtx, &gallery, db.Gallery, db.Tag, nil)
|
||||
|
||||
assert.Nil(err)
|
||||
mockTagReader.AssertExpectations(t)
|
||||
mockGalleryReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
for _, test := range testTables {
|
||||
mockTagReader := &mocks.TagReaderWriter{}
|
||||
mockGalleryReader := &mocks.GalleryReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
|
||||
db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
|
||||
|
||||
doTest(mockTagReader, mockGalleryReader, test)
|
||||
doTest(db, test)
|
||||
}
|
||||
|
||||
const unmatchedName = "unmatched"
|
||||
tag.Name = unmatchedName
|
||||
|
||||
for _, test := range testTables {
|
||||
mockTagReader := &mocks.TagReaderWriter{}
|
||||
mockGalleryReader := &mocks.GalleryReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
|
||||
db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, tagID).Return([]string{
|
||||
tagName,
|
||||
}, nil).Once()
|
||||
mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
|
||||
|
||||
doTest(mockTagReader, mockGalleryReader, test)
|
||||
doTest(db, test)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,11 +49,10 @@ func TestImagePerformers(t *testing.T) {
|
|||
assert := assert.New(t)
|
||||
|
||||
for _, test := range testTables {
|
||||
mockPerformerReader := &mocks.PerformerReaderWriter{}
|
||||
mockImageReader := &mocks.ImageReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
|
||||
db.Performer.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Performer.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
|
||||
|
||||
if test.Matches {
|
||||
matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool {
|
||||
|
|
@ -66,7 +65,7 @@ func TestImagePerformers(t *testing.T) {
|
|||
|
||||
return imagePartialsEqual(got, expected)
|
||||
})
|
||||
mockImageReader.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
|
||||
db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
image := models.Image{
|
||||
|
|
@ -74,11 +73,10 @@ func TestImagePerformers(t *testing.T) {
|
|||
Path: test.Path,
|
||||
PerformerIDs: models.NewRelatedIDs([]int{}),
|
||||
}
|
||||
err := ImagePerformers(testCtx, &image, mockImageReader, mockPerformerReader, nil)
|
||||
err := ImagePerformers(testCtx, &image, db.Image, db.Performer, nil)
|
||||
|
||||
assert.Nil(err)
|
||||
mockPerformerReader.AssertExpectations(t)
|
||||
mockImageReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -104,7 +102,7 @@ func TestImageStudios(t *testing.T) {
|
|||
|
||||
assert := assert.New(t)
|
||||
|
||||
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) {
|
||||
doTest := func(db *mocks.Database, test pathTestTable) {
|
||||
if test.Matches {
|
||||
matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool {
|
||||
expected := models.ImagePartial{
|
||||
|
|
@ -113,29 +111,27 @@ func TestImageStudios(t *testing.T) {
|
|||
|
||||
return imagePartialsEqual(got, expected)
|
||||
})
|
||||
mockImageReader.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
|
||||
db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
image := models.Image{
|
||||
ID: imageID,
|
||||
Path: test.Path,
|
||||
}
|
||||
err := ImageStudios(testCtx, &image, mockImageReader, mockStudioReader, nil)
|
||||
err := ImageStudios(testCtx, &image, db.Image, db.Studio, nil)
|
||||
|
||||
assert.Nil(err)
|
||||
mockStudioReader.AssertExpectations(t)
|
||||
mockImageReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
for _, test := range testTables {
|
||||
mockStudioReader := &mocks.StudioReaderWriter{}
|
||||
mockImageReader := &mocks.ImageReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
|
||||
db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
db.Studio.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
|
||||
|
||||
doTest(mockStudioReader, mockImageReader, test)
|
||||
doTest(db, test)
|
||||
}
|
||||
|
||||
// test against aliases
|
||||
|
|
@ -143,17 +139,16 @@ func TestImageStudios(t *testing.T) {
|
|||
studio.Name = unmatchedName
|
||||
|
||||
for _, test := range testTables {
|
||||
mockStudioReader := &mocks.StudioReaderWriter{}
|
||||
mockImageReader := &mocks.ImageReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
|
||||
db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
db.Studio.On("GetAliases", testCtx, studioID).Return([]string{
|
||||
studioName,
|
||||
}, nil).Once()
|
||||
mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
|
||||
db.Studio.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
|
||||
|
||||
doTest(mockStudioReader, mockImageReader, test)
|
||||
doTest(db, test)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -179,7 +174,7 @@ func TestImageTags(t *testing.T) {
|
|||
|
||||
assert := assert.New(t)
|
||||
|
||||
doTest := func(mockTagReader *mocks.TagReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) {
|
||||
doTest := func(db *mocks.Database, test pathTestTable) {
|
||||
if test.Matches {
|
||||
matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool {
|
||||
expected := models.ImagePartial{
|
||||
|
|
@ -191,7 +186,7 @@ func TestImageTags(t *testing.T) {
|
|||
|
||||
return imagePartialsEqual(got, expected)
|
||||
})
|
||||
mockImageReader.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
|
||||
db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
image := models.Image{
|
||||
|
|
@ -199,22 +194,20 @@ func TestImageTags(t *testing.T) {
|
|||
Path: test.Path,
|
||||
TagIDs: models.NewRelatedIDs([]int{}),
|
||||
}
|
||||
err := ImageTags(testCtx, &image, mockImageReader, mockTagReader, nil)
|
||||
err := ImageTags(testCtx, &image, db.Image, db.Tag, nil)
|
||||
|
||||
assert.Nil(err)
|
||||
mockTagReader.AssertExpectations(t)
|
||||
mockImageReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
for _, test := range testTables {
|
||||
mockTagReader := &mocks.TagReaderWriter{}
|
||||
mockImageReader := &mocks.ImageReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
|
||||
db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
|
||||
|
||||
doTest(mockTagReader, mockImageReader, test)
|
||||
doTest(db, test)
|
||||
}
|
||||
|
||||
// test against aliases
|
||||
|
|
@ -222,16 +215,15 @@ func TestImageTags(t *testing.T) {
|
|||
tag.Name = unmatchedName
|
||||
|
||||
for _, test := range testTables {
|
||||
mockTagReader := &mocks.TagReaderWriter{}
|
||||
mockImageReader := &mocks.ImageReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
|
||||
db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, tagID).Return([]string{
|
||||
tagName,
|
||||
}, nil).Once()
|
||||
mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
|
||||
|
||||
doTest(mockTagReader, mockImageReader, test)
|
||||
doTest(db, test)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ func runTests(m *testing.M) int {
|
|||
panic(fmt.Sprintf("Could not initialize database: %s", err.Error()))
|
||||
}
|
||||
|
||||
r = db.TxnRepository()
|
||||
r = db.Repository()
|
||||
|
||||
// defer close and delete the database
|
||||
defer testTeardown(databaseFile)
|
||||
|
|
@ -474,11 +474,11 @@ func createGallery(ctx context.Context, w models.GalleryWriter, o *models.Galler
|
|||
}
|
||||
|
||||
func withTxn(f func(ctx context.Context) error) error {
|
||||
return txn.WithTxn(context.TODO(), db, f)
|
||||
return txn.WithTxn(testCtx, db, f)
|
||||
}
|
||||
|
||||
func withDB(f func(ctx context.Context) error) error {
|
||||
return txn.WithDatabase(context.TODO(), db, f)
|
||||
return txn.WithDatabase(testCtx, db, f)
|
||||
}
|
||||
|
||||
func populateDB() error {
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ func TestPerformerScenes(t *testing.T) {
|
|||
}
|
||||
|
||||
func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
const performerID = 2
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
|
|||
Direction: &direction,
|
||||
}
|
||||
|
||||
mockSceneReader.On("Query", mock.Anything, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)).
|
||||
db.Scene.On("Query", mock.Anything, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)).
|
||||
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
|
||||
|
||||
for i := range matchingPaths {
|
||||
|
|
@ -100,19 +100,19 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
|
|||
|
||||
return scenePartialsEqual(got, expected)
|
||||
})
|
||||
mockSceneReader.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
|
||||
db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
tagger := Tagger{
|
||||
TxnManager: &mocks.TxnManager{},
|
||||
TxnManager: db,
|
||||
}
|
||||
|
||||
err := tagger.PerformerScenes(testCtx, &performer, nil, mockSceneReader)
|
||||
err := tagger.PerformerScenes(testCtx, &performer, nil, db.Scene)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Nil(err)
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestPerformerImages(t *testing.T) {
|
||||
|
|
@ -140,7 +140,7 @@ func TestPerformerImages(t *testing.T) {
|
|||
}
|
||||
|
||||
func testPerformerImages(t *testing.T, performerName, expectedRegex string) {
|
||||
mockImageReader := &mocks.ImageReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
const performerID = 2
|
||||
|
||||
|
|
@ -179,7 +179,7 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) {
|
|||
Direction: &direction,
|
||||
}
|
||||
|
||||
mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)).
|
||||
db.Image.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)).
|
||||
Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
|
||||
|
||||
for i := range matchingPaths {
|
||||
|
|
@ -195,19 +195,19 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) {
|
|||
|
||||
return imagePartialsEqual(got, expected)
|
||||
})
|
||||
mockImageReader.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
|
||||
db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
tagger := Tagger{
|
||||
TxnManager: &mocks.TxnManager{},
|
||||
TxnManager: db,
|
||||
}
|
||||
|
||||
err := tagger.PerformerImages(testCtx, &performer, nil, mockImageReader)
|
||||
err := tagger.PerformerImages(testCtx, &performer, nil, db.Image)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Nil(err)
|
||||
mockImageReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestPerformerGalleries(t *testing.T) {
|
||||
|
|
@ -235,7 +235,7 @@ func TestPerformerGalleries(t *testing.T) {
|
|||
}
|
||||
|
||||
func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) {
|
||||
mockGalleryReader := &mocks.GalleryReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
const performerID = 2
|
||||
|
||||
|
|
@ -275,7 +275,7 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) {
|
|||
Direction: &direction,
|
||||
}
|
||||
|
||||
mockGalleryReader.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
|
||||
db.Gallery.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
|
||||
|
||||
for i := range matchingPaths {
|
||||
galleryID := i + 1
|
||||
|
|
@ -290,17 +290,17 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) {
|
|||
|
||||
return galleryPartialsEqual(got, expected)
|
||||
})
|
||||
mockGalleryReader.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
|
||||
db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
tagger := Tagger{
|
||||
TxnManager: &mocks.TxnManager{},
|
||||
TxnManager: db,
|
||||
}
|
||||
|
||||
err := tagger.PerformerGalleries(testCtx, &performer, nil, mockGalleryReader)
|
||||
err := tagger.PerformerGalleries(testCtx, &performer, nil, db.Gallery)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Nil(err)
|
||||
mockGalleryReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -182,11 +182,10 @@ func TestScenePerformers(t *testing.T) {
|
|||
assert := assert.New(t)
|
||||
|
||||
for _, test := range testTables {
|
||||
mockPerformerReader := &mocks.PerformerReaderWriter{}
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
|
||||
db.Performer.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Performer.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
|
||||
|
||||
scene := models.Scene{
|
||||
ID: sceneID,
|
||||
|
|
@ -205,14 +204,13 @@ func TestScenePerformers(t *testing.T) {
|
|||
|
||||
return scenePartialsEqual(got, expected)
|
||||
})
|
||||
mockSceneReader.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
|
||||
db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
err := ScenePerformers(testCtx, &scene, mockSceneReader, mockPerformerReader, nil)
|
||||
err := ScenePerformers(testCtx, &scene, db.Scene, db.Performer, nil)
|
||||
|
||||
assert.Nil(err)
|
||||
mockPerformerReader.AssertExpectations(t)
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -240,7 +238,7 @@ func TestSceneStudios(t *testing.T) {
|
|||
|
||||
assert := assert.New(t)
|
||||
|
||||
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) {
|
||||
doTest := func(db *mocks.Database, test pathTestTable) {
|
||||
if test.Matches {
|
||||
matchPartial := mock.MatchedBy(func(got models.ScenePartial) bool {
|
||||
expected := models.ScenePartial{
|
||||
|
|
@ -249,29 +247,27 @@ func TestSceneStudios(t *testing.T) {
|
|||
|
||||
return scenePartialsEqual(got, expected)
|
||||
})
|
||||
mockSceneReader.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
|
||||
db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
scene := models.Scene{
|
||||
ID: sceneID,
|
||||
Path: test.Path,
|
||||
}
|
||||
err := SceneStudios(testCtx, &scene, mockSceneReader, mockStudioReader, nil)
|
||||
err := SceneStudios(testCtx, &scene, db.Scene, db.Studio, nil)
|
||||
|
||||
assert.Nil(err)
|
||||
mockStudioReader.AssertExpectations(t)
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
for _, test := range testTables {
|
||||
mockStudioReader := &mocks.StudioReaderWriter{}
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
|
||||
db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
db.Studio.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
|
||||
|
||||
doTest(mockStudioReader, mockSceneReader, test)
|
||||
doTest(db, test)
|
||||
}
|
||||
|
||||
const unmatchedName = "unmatched"
|
||||
|
|
@ -279,17 +275,16 @@ func TestSceneStudios(t *testing.T) {
|
|||
|
||||
// test against aliases
|
||||
for _, test := range testTables {
|
||||
mockStudioReader := &mocks.StudioReaderWriter{}
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
|
||||
db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
|
||||
db.Studio.On("GetAliases", testCtx, studioID).Return([]string{
|
||||
studioName,
|
||||
}, nil).Once()
|
||||
mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
|
||||
db.Studio.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
|
||||
|
||||
doTest(mockStudioReader, mockSceneReader, test)
|
||||
doTest(db, test)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -315,7 +310,7 @@ func TestSceneTags(t *testing.T) {
|
|||
|
||||
assert := assert.New(t)
|
||||
|
||||
doTest := func(mockTagReader *mocks.TagReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) {
|
||||
doTest := func(db *mocks.Database, test pathTestTable) {
|
||||
if test.Matches {
|
||||
matchPartial := mock.MatchedBy(func(got models.ScenePartial) bool {
|
||||
expected := models.ScenePartial{
|
||||
|
|
@ -327,7 +322,7 @@ func TestSceneTags(t *testing.T) {
|
|||
|
||||
return scenePartialsEqual(got, expected)
|
||||
})
|
||||
mockSceneReader.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
|
||||
db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
scene := models.Scene{
|
||||
|
|
@ -335,22 +330,20 @@ func TestSceneTags(t *testing.T) {
|
|||
Path: test.Path,
|
||||
TagIDs: models.NewRelatedIDs([]int{}),
|
||||
}
|
||||
err := SceneTags(testCtx, &scene, mockSceneReader, mockTagReader, nil)
|
||||
err := SceneTags(testCtx, &scene, db.Scene, db.Tag, nil)
|
||||
|
||||
assert.Nil(err)
|
||||
mockTagReader.AssertExpectations(t)
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
for _, test := range testTables {
|
||||
mockTagReader := &mocks.TagReaderWriter{}
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
|
||||
db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
|
||||
|
||||
doTest(mockTagReader, mockSceneReader, test)
|
||||
doTest(db, test)
|
||||
}
|
||||
|
||||
const unmatchedName = "unmatched"
|
||||
|
|
@ -358,16 +351,15 @@ func TestSceneTags(t *testing.T) {
|
|||
|
||||
// test against aliases
|
||||
for _, test := range testTables {
|
||||
mockTagReader := &mocks.TagReaderWriter{}
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
|
||||
db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
|
||||
db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, tagID).Return([]string{
|
||||
tagName,
|
||||
}, nil).Once()
|
||||
mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
|
||||
|
||||
doTest(mockTagReader, mockSceneReader, test)
|
||||
doTest(db, test)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
|
|||
aliasName := tc.aliasName
|
||||
aliasRegex := tc.aliasRegex
|
||||
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
var studioID = 2
|
||||
|
||||
|
|
@ -130,7 +130,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
|
|||
}
|
||||
|
||||
// if alias provided, then don't find by name
|
||||
onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
|
||||
onNameQuery := db.Scene.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
|
||||
|
||||
if aliasName == "" {
|
||||
onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
|
||||
|
|
@ -145,7 +145,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
|
|||
},
|
||||
}
|
||||
|
||||
mockSceneReader.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
|
||||
db.Scene.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
|
||||
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
|
||||
}
|
||||
|
||||
|
|
@ -159,19 +159,19 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
|
|||
|
||||
return scenePartialsEqual(got, expected)
|
||||
})
|
||||
mockSceneReader.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
|
||||
db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
tagger := Tagger{
|
||||
TxnManager: &mocks.TxnManager{},
|
||||
TxnManager: db,
|
||||
}
|
||||
|
||||
err := tagger.StudioScenes(testCtx, &studio, nil, aliases, mockSceneReader)
|
||||
err := tagger.StudioScenes(testCtx, &studio, nil, aliases, db.Scene)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Nil(err)
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestStudioImages(t *testing.T) {
|
||||
|
|
@ -188,7 +188,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
|
|||
aliasName := tc.aliasName
|
||||
aliasRegex := tc.aliasRegex
|
||||
|
||||
mockImageReader := &mocks.ImageReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
var studioID = 2
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
|
|||
}
|
||||
|
||||
// if alias provided, then don't find by name
|
||||
onNameQuery := mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
|
||||
onNameQuery := db.Image.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
|
||||
if aliasName == "" {
|
||||
onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
|
||||
} else {
|
||||
|
|
@ -248,7 +248,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
|
|||
},
|
||||
}
|
||||
|
||||
mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
|
||||
db.Image.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
|
||||
Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
|
||||
}
|
||||
|
||||
|
|
@ -262,19 +262,19 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
|
|||
|
||||
return imagePartialsEqual(got, expected)
|
||||
})
|
||||
mockImageReader.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
|
||||
db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
tagger := Tagger{
|
||||
TxnManager: &mocks.TxnManager{},
|
||||
TxnManager: db,
|
||||
}
|
||||
|
||||
err := tagger.StudioImages(testCtx, &studio, nil, aliases, mockImageReader)
|
||||
err := tagger.StudioImages(testCtx, &studio, nil, aliases, db.Image)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Nil(err)
|
||||
mockImageReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestStudioGalleries(t *testing.T) {
|
||||
|
|
@ -290,7 +290,8 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
|
|||
expectedRegex := tc.expectedRegex
|
||||
aliasName := tc.aliasName
|
||||
aliasRegex := tc.aliasRegex
|
||||
mockGalleryReader := &mocks.GalleryReaderWriter{}
|
||||
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
var studioID = 2
|
||||
|
||||
|
|
@ -337,7 +338,7 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
|
|||
}
|
||||
|
||||
// if alias provided, then don't find by name
|
||||
onNameQuery := mockGalleryReader.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter)
|
||||
onNameQuery := db.Gallery.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter)
|
||||
if aliasName == "" {
|
||||
onNameQuery.Return(galleries, len(galleries), nil).Once()
|
||||
} else {
|
||||
|
|
@ -351,7 +352,7 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
|
|||
},
|
||||
}
|
||||
|
||||
mockGalleryReader.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
|
||||
db.Gallery.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
|
||||
}
|
||||
|
||||
for i := range matchingPaths {
|
||||
|
|
@ -364,17 +365,17 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
|
|||
|
||||
return galleryPartialsEqual(got, expected)
|
||||
})
|
||||
mockGalleryReader.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
|
||||
db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
tagger := Tagger{
|
||||
TxnManager: &mocks.TxnManager{},
|
||||
TxnManager: db,
|
||||
}
|
||||
|
||||
err := tagger.StudioGalleries(testCtx, &studio, nil, aliases, mockGalleryReader)
|
||||
err := tagger.StudioGalleries(testCtx, &studio, nil, aliases, db.Gallery)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Nil(err)
|
||||
mockGalleryReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ func testTagScenes(t *testing.T, tc testTagCase) {
|
|||
aliasName := tc.aliasName
|
||||
aliasRegex := tc.aliasRegex
|
||||
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
const tagID = 2
|
||||
|
||||
|
|
@ -131,7 +131,7 @@ func testTagScenes(t *testing.T, tc testTagCase) {
|
|||
}
|
||||
|
||||
// if alias provided, then don't find by name
|
||||
onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
|
||||
onNameQuery := db.Scene.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
|
||||
if aliasName == "" {
|
||||
onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
|
||||
} else {
|
||||
|
|
@ -145,7 +145,7 @@ func testTagScenes(t *testing.T, tc testTagCase) {
|
|||
},
|
||||
}
|
||||
|
||||
mockSceneReader.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
|
||||
db.Scene.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
|
||||
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
|
||||
}
|
||||
|
||||
|
|
@ -162,19 +162,19 @@ func testTagScenes(t *testing.T, tc testTagCase) {
|
|||
|
||||
return scenePartialsEqual(got, expected)
|
||||
})
|
||||
mockSceneReader.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
|
||||
db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
tagger := Tagger{
|
||||
TxnManager: &mocks.TxnManager{},
|
||||
TxnManager: db,
|
||||
}
|
||||
|
||||
err := tagger.TagScenes(testCtx, &tag, nil, aliases, mockSceneReader)
|
||||
err := tagger.TagScenes(testCtx, &tag, nil, aliases, db.Scene)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Nil(err)
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTagImages(t *testing.T) {
|
||||
|
|
@ -191,7 +191,7 @@ func testTagImages(t *testing.T, tc testTagCase) {
|
|||
aliasName := tc.aliasName
|
||||
aliasRegex := tc.aliasRegex
|
||||
|
||||
mockImageReader := &mocks.ImageReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
const tagID = 2
|
||||
|
||||
|
|
@ -238,7 +238,7 @@ func testTagImages(t *testing.T, tc testTagCase) {
|
|||
}
|
||||
|
||||
// if alias provided, then don't find by name
|
||||
onNameQuery := mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
|
||||
onNameQuery := db.Image.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
|
||||
if aliasName == "" {
|
||||
onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
|
||||
} else {
|
||||
|
|
@ -252,7 +252,7 @@ func testTagImages(t *testing.T, tc testTagCase) {
|
|||
},
|
||||
}
|
||||
|
||||
mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
|
||||
db.Image.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
|
||||
Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
|
||||
}
|
||||
|
||||
|
|
@ -269,19 +269,19 @@ func testTagImages(t *testing.T, tc testTagCase) {
|
|||
|
||||
return imagePartialsEqual(got, expected)
|
||||
})
|
||||
mockImageReader.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
|
||||
db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
|
||||
}
|
||||
|
||||
tagger := Tagger{
|
||||
TxnManager: &mocks.TxnManager{},
|
||||
TxnManager: db,
|
||||
}
|
||||
|
||||
err := tagger.TagImages(testCtx, &tag, nil, aliases, mockImageReader)
|
||||
err := tagger.TagImages(testCtx, &tag, nil, aliases, db.Image)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Nil(err)
|
||||
mockImageReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTagGalleries(t *testing.T) {
|
||||
|
|
@ -298,7 +298,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
|
|||
aliasName := tc.aliasName
|
||||
aliasRegex := tc.aliasRegex
|
||||
|
||||
mockGalleryReader := &mocks.GalleryReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
const tagID = 2
|
||||
|
||||
|
|
@ -346,7 +346,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
|
|||
}
|
||||
|
||||
// if alias provided, then don't find by name
|
||||
onNameQuery := mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter)
|
||||
onNameQuery := db.Gallery.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter)
|
||||
if aliasName == "" {
|
||||
onNameQuery.Return(galleries, len(galleries), nil).Once()
|
||||
} else {
|
||||
|
|
@ -360,7 +360,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
|
|||
},
|
||||
}
|
||||
|
||||
mockGalleryReader.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
|
||||
db.Gallery.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
|
||||
}
|
||||
|
||||
for i := range matchingPaths {
|
||||
|
|
@ -376,18 +376,18 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
|
|||
|
||||
return galleryPartialsEqual(got, expected)
|
||||
})
|
||||
mockGalleryReader.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
|
||||
db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
|
||||
|
||||
}
|
||||
|
||||
tagger := Tagger{
|
||||
TxnManager: &mocks.TxnManager{},
|
||||
TxnManager: db,
|
||||
}
|
||||
|
||||
err := tagger.TagGalleries(testCtx, &tag, nil, aliases, mockGalleryReader)
|
||||
err := tagger.TagGalleries(testCtx, &tag, nil, aliases, db.Gallery)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Nil(err)
|
||||
mockGalleryReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ import (
|
|||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/scene"
|
||||
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
var pageSize = 100
|
||||
|
|
@ -360,10 +359,11 @@ func (me *contentDirectoryService) handleBrowseMetadata(obj object, host string)
|
|||
} else {
|
||||
var scene *models.Scene
|
||||
|
||||
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
|
||||
scene, err = me.repository.SceneFinder.Find(ctx, sceneID)
|
||||
r := me.repository
|
||||
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
|
||||
scene, err = r.SceneFinder.Find(ctx, sceneID)
|
||||
if scene != nil {
|
||||
err = scene.LoadPrimaryFile(ctx, me.repository.FileGetter)
|
||||
err = scene.LoadPrimaryFile(ctx, r.FileGetter)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -452,7 +452,8 @@ func getSortDirection(sceneFilter *models.SceneFilterType, sort string) models.S
|
|||
func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType, parentID string, host string) []interface{} {
|
||||
var objs []interface{}
|
||||
|
||||
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
|
||||
r := me.repository
|
||||
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
|
||||
sort := me.VideoSortOrder
|
||||
direction := getSortDirection(sceneFilter, sort)
|
||||
findFilter := &models.FindFilterType{
|
||||
|
|
@ -461,7 +462,7 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
|
|||
Direction: &direction,
|
||||
}
|
||||
|
||||
scenes, total, err := scene.QueryWithCount(ctx, me.repository.SceneFinder, sceneFilter, findFilter)
|
||||
scenes, total, err := scene.QueryWithCount(ctx, r.SceneFinder, sceneFilter, findFilter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -472,13 +473,13 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
|
|||
parentID: parentID,
|
||||
}
|
||||
|
||||
objs, err = pager.getPages(ctx, me.repository.SceneFinder, total)
|
||||
objs, err = pager.getPages(ctx, r.SceneFinder, total)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
for _, s := range scenes {
|
||||
if err := s.LoadPrimaryFile(ctx, me.repository.FileGetter); err != nil {
|
||||
if err := s.LoadPrimaryFile(ctx, r.FileGetter); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -497,7 +498,8 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
|
|||
func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilterType, parentID string, page int, host string) []interface{} {
|
||||
var objs []interface{}
|
||||
|
||||
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
|
||||
r := me.repository
|
||||
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
|
||||
pager := scenePager{
|
||||
sceneFilter: sceneFilter,
|
||||
parentID: parentID,
|
||||
|
|
@ -506,7 +508,7 @@ func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilter
|
|||
sort := me.VideoSortOrder
|
||||
direction := getSortDirection(sceneFilter, sort)
|
||||
var err error
|
||||
objs, err = pager.getPageVideos(ctx, me.repository.SceneFinder, me.repository.FileGetter, page, host, sort, direction)
|
||||
objs, err = pager.getPageVideos(ctx, r.SceneFinder, r.FileGetter, page, host, sort, direction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -540,8 +542,9 @@ func (me *contentDirectoryService) getAllScenes(host string) []interface{} {
|
|||
func (me *contentDirectoryService) getStudios() []interface{} {
|
||||
var objs []interface{}
|
||||
|
||||
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
|
||||
studios, err := me.repository.StudioFinder.All(ctx)
|
||||
r := me.repository
|
||||
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
|
||||
studios, err := r.StudioFinder.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -579,8 +582,9 @@ func (me *contentDirectoryService) getStudioScenes(paths []string, host string)
|
|||
func (me *contentDirectoryService) getTags() []interface{} {
|
||||
var objs []interface{}
|
||||
|
||||
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
|
||||
tags, err := me.repository.TagFinder.All(ctx)
|
||||
r := me.repository
|
||||
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
|
||||
tags, err := r.TagFinder.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -618,8 +622,9 @@ func (me *contentDirectoryService) getTagScenes(paths []string, host string) []i
|
|||
func (me *contentDirectoryService) getPerformers() []interface{} {
|
||||
var objs []interface{}
|
||||
|
||||
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
|
||||
performers, err := me.repository.PerformerFinder.All(ctx)
|
||||
r := me.repository
|
||||
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
|
||||
performers, err := r.PerformerFinder.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -657,8 +662,9 @@ func (me *contentDirectoryService) getPerformerScenes(paths []string, host strin
|
|||
func (me *contentDirectoryService) getMovies() []interface{} {
|
||||
var objs []interface{}
|
||||
|
||||
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
|
||||
movies, err := me.repository.MovieFinder.All(ctx)
|
||||
r := me.repository
|
||||
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
|
||||
movies, err := r.MovieFinder.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,7 +48,6 @@ import (
|
|||
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
type SceneFinder interface {
|
||||
|
|
@ -271,7 +270,6 @@ type Server struct {
|
|||
// Time interval between SSPD announces
|
||||
NotifyInterval time.Duration
|
||||
|
||||
txnManager txn.Manager
|
||||
repository Repository
|
||||
sceneServer sceneServer
|
||||
ipWhitelistManager *ipWhitelistManager
|
||||
|
|
@ -439,12 +437,13 @@ func (me *Server) serveIcon(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var scene *models.Scene
|
||||
err := txn.WithReadTxn(r.Context(), me.txnManager, func(ctx context.Context) error {
|
||||
repo := me.repository
|
||||
err := repo.WithReadTxn(r.Context(), func(ctx context.Context) error {
|
||||
idInt, err := strconv.Atoi(sceneId)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
scene, _ = me.repository.SceneFinder.Find(ctx, idInt)
|
||||
scene, _ = repo.SceneFinder.Find(ctx, idInt)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -579,12 +578,13 @@ func (me *Server) initMux(mux *http.ServeMux) {
|
|||
mux.HandleFunc(resPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
sceneId := r.URL.Query().Get("scene")
|
||||
var scene *models.Scene
|
||||
err := txn.WithReadTxn(r.Context(), me.txnManager, func(ctx context.Context) error {
|
||||
repo := me.repository
|
||||
err := repo.WithReadTxn(r.Context(), func(ctx context.Context) error {
|
||||
sceneIdInt, err := strconv.Atoi(sceneId)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
scene, _ = me.repository.SceneFinder.Find(ctx, sceneIdInt)
|
||||
scene, _ = repo.SceneFinder.Find(ctx, sceneIdInt)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package dlna
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
|
@ -14,6 +15,8 @@ import (
|
|||
)
|
||||
|
||||
type Repository struct {
|
||||
TxnManager models.TxnManager
|
||||
|
||||
SceneFinder SceneFinder
|
||||
FileGetter models.FileGetter
|
||||
StudioFinder StudioFinder
|
||||
|
|
@ -22,6 +25,22 @@ type Repository struct {
|
|||
MovieFinder MovieFinder
|
||||
}
|
||||
|
||||
func NewRepository(repo models.Repository) Repository {
|
||||
return Repository{
|
||||
TxnManager: repo.TxnManager,
|
||||
FileGetter: repo.File,
|
||||
SceneFinder: repo.Scene,
|
||||
StudioFinder: repo.Studio,
|
||||
TagFinder: repo.Tag,
|
||||
PerformerFinder: repo.Performer,
|
||||
MovieFinder: repo.Movie,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
|
||||
return txn.WithReadTxn(ctx, r.TxnManager, fn)
|
||||
}
|
||||
|
||||
type Status struct {
|
||||
Running bool `json:"running"`
|
||||
// If not currently running, time until it will be started. If running, time until it will be stopped
|
||||
|
|
@ -60,7 +79,6 @@ type Config interface {
|
|||
}
|
||||
|
||||
type Service struct {
|
||||
txnManager txn.Manager
|
||||
repository Repository
|
||||
config Config
|
||||
sceneServer sceneServer
|
||||
|
|
@ -133,9 +151,8 @@ func (s *Service) init() error {
|
|||
}
|
||||
|
||||
s.server = &Server{
|
||||
txnManager: s.txnManager,
|
||||
sceneServer: s.sceneServer,
|
||||
repository: s.repository,
|
||||
sceneServer: s.sceneServer,
|
||||
ipWhitelistManager: s.ipWhitelistMgr,
|
||||
Interfaces: interfaces,
|
||||
HTTPConn: func() net.Listener {
|
||||
|
|
@ -197,9 +214,8 @@ func (s *Service) init() error {
|
|||
// }
|
||||
|
||||
// NewService initialises and returns a new DLNA service.
|
||||
func NewService(txnManager txn.Manager, repo Repository, cfg Config, sceneServer sceneServer) *Service {
|
||||
func NewService(repo Repository, cfg Config, sceneServer sceneServer) *Service {
|
||||
ret := &Service{
|
||||
txnManager: txnManager,
|
||||
repository: repo,
|
||||
sceneServer: sceneServer,
|
||||
config: cfg,
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ type ScraperSource struct {
|
|||
}
|
||||
|
||||
type SceneIdentifier struct {
|
||||
TxnManager txn.Manager
|
||||
SceneReaderUpdater SceneReaderUpdater
|
||||
StudioReaderWriter models.StudioReaderWriter
|
||||
PerformerCreator PerformerCreator
|
||||
|
|
@ -53,8 +54,8 @@ type SceneIdentifier struct {
|
|||
SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor
|
||||
}
|
||||
|
||||
func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager, scene *models.Scene) error {
|
||||
result, err := t.scrapeScene(ctx, txnManager, scene)
|
||||
func (t *SceneIdentifier) Identify(ctx context.Context, scene *models.Scene) error {
|
||||
result, err := t.scrapeScene(ctx, scene)
|
||||
var multipleMatchErr *MultipleMatchesFoundError
|
||||
if err != nil {
|
||||
if !errors.As(err, &multipleMatchErr) {
|
||||
|
|
@ -70,7 +71,7 @@ func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager,
|
|||
options := t.getOptions(multipleMatchErr.Source)
|
||||
if options.SkipMultipleMatchTag != nil && len(*options.SkipMultipleMatchTag) > 0 {
|
||||
// Tag it with the multiple results tag
|
||||
err := t.addTagToScene(ctx, txnManager, scene, *options.SkipMultipleMatchTag)
|
||||
err := t.addTagToScene(ctx, scene, *options.SkipMultipleMatchTag)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -83,7 +84,7 @@ func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager,
|
|||
}
|
||||
|
||||
// results were found, modify the scene
|
||||
if err := t.modifyScene(ctx, txnManager, scene, result); err != nil {
|
||||
if err := t.modifyScene(ctx, scene, result); err != nil {
|
||||
return fmt.Errorf("error modifying scene: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -95,7 +96,7 @@ type scrapeResult struct {
|
|||
source ScraperSource
|
||||
}
|
||||
|
||||
func (t *SceneIdentifier) scrapeScene(ctx context.Context, txnManager txn.Manager, scene *models.Scene) (*scrapeResult, error) {
|
||||
func (t *SceneIdentifier) scrapeScene(ctx context.Context, scene *models.Scene) (*scrapeResult, error) {
|
||||
// iterate through the input sources
|
||||
for _, source := range t.Sources {
|
||||
// scrape using the source
|
||||
|
|
@ -261,9 +262,9 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager txn.Manager, s *models.Scene, result *scrapeResult) error {
|
||||
func (t *SceneIdentifier) modifyScene(ctx context.Context, s *models.Scene, result *scrapeResult) error {
|
||||
var updater *scene.UpdateSet
|
||||
if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error {
|
||||
if err := txn.WithTxn(ctx, t.TxnManager, func(ctx context.Context) error {
|
||||
// load scene relationships
|
||||
if err := s.LoadURLs(ctx, t.SceneReaderUpdater); err != nil {
|
||||
return err
|
||||
|
|
@ -316,8 +317,8 @@ func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager txn.Manage
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *SceneIdentifier) addTagToScene(ctx context.Context, txnManager txn.Manager, s *models.Scene, tagToAdd string) error {
|
||||
if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error {
|
||||
func (t *SceneIdentifier) addTagToScene(ctx context.Context, s *models.Scene, tagToAdd string) error {
|
||||
if err := txn.WithTxn(ctx, t.TxnManager, func(ctx context.Context) error {
|
||||
tagID, err := strconv.Atoi(tagToAdd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error converting tag ID %s: %w", tagToAdd, err)
|
||||
|
|
|
|||
|
|
@ -108,17 +108,17 @@ func TestSceneIdentifier_Identify(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
|
||||
mockSceneReaderWriter.On("GetURLs", mock.Anything, mock.Anything).Return(nil, nil)
|
||||
mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
db.Scene.On("GetURLs", mock.Anything, mock.Anything).Return(nil, nil)
|
||||
db.Scene.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
|
||||
return id == errUpdateID
|
||||
}), mock.Anything).Return(nil, errors.New("update error"))
|
||||
mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
|
||||
db.Scene.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
|
||||
return id != errUpdateID
|
||||
}), mock.Anything).Return(nil, nil)
|
||||
|
||||
mockTagFinderCreator := &mocks.TagReaderWriter{}
|
||||
mockTagFinderCreator.On("Find", mock.Anything, skipMultipleTagID).Return(&models.Tag{
|
||||
db.Tag.On("Find", mock.Anything, skipMultipleTagID).Return(&models.Tag{
|
||||
ID: skipMultipleTagID,
|
||||
Name: skipMultipleTagIDStr,
|
||||
}, nil)
|
||||
|
|
@ -185,8 +185,11 @@ func TestSceneIdentifier_Identify(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
identifier := SceneIdentifier{
|
||||
SceneReaderUpdater: mockSceneReaderWriter,
|
||||
TagFinderCreator: mockTagFinderCreator,
|
||||
TxnManager: db,
|
||||
SceneReaderUpdater: db.Scene,
|
||||
StudioReaderWriter: db.Studio,
|
||||
PerformerCreator: db.Performer,
|
||||
TagFinderCreator: db.Tag,
|
||||
DefaultOptions: defaultOptions,
|
||||
Sources: sources,
|
||||
SceneUpdatePostHookExecutor: mockHookExecutor{},
|
||||
|
|
@ -202,7 +205,7 @@ func TestSceneIdentifier_Identify(t *testing.T) {
|
|||
TagIDs: models.NewRelatedIDs([]int{}),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
|
||||
}
|
||||
if err := identifier.Identify(testCtx, &mocks.TxnManager{}, scene); (err != nil) != tt.wantErr {
|
||||
if err := identifier.Identify(testCtx, scene); (err != nil) != tt.wantErr {
|
||||
t.Errorf("SceneIdentifier.Identify() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
|
|
@ -210,9 +213,8 @@ func TestSceneIdentifier_Identify(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSceneIdentifier_modifyScene(t *testing.T) {
|
||||
repo := models.Repository{
|
||||
TxnManager: &mocks.TxnManager{},
|
||||
}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
boolFalse := false
|
||||
defaultOptions := &MetadataOptions{
|
||||
SetOrganized: &boolFalse,
|
||||
|
|
@ -221,7 +223,12 @@ func TestSceneIdentifier_modifyScene(t *testing.T) {
|
|||
SkipSingleNamePerformers: &boolFalse,
|
||||
}
|
||||
tr := &SceneIdentifier{
|
||||
DefaultOptions: defaultOptions,
|
||||
TxnManager: db,
|
||||
SceneReaderUpdater: db.Scene,
|
||||
StudioReaderWriter: db.Studio,
|
||||
PerformerCreator: db.Performer,
|
||||
TagFinderCreator: db.Tag,
|
||||
DefaultOptions: defaultOptions,
|
||||
}
|
||||
|
||||
type args struct {
|
||||
|
|
@ -254,7 +261,7 @@ func TestSceneIdentifier_modifyScene(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tr.modifyScene(testCtx, repo, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr {
|
||||
if err := tr.modifyScene(testCtx, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr {
|
||||
t.Errorf("SceneIdentifier.modifyScene() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -22,8 +22,9 @@ func Test_getPerformerID(t *testing.T) {
|
|||
remoteSiteID := "2"
|
||||
name := "name"
|
||||
|
||||
mockPerformerReaderWriter := mocks.PerformerReaderWriter{}
|
||||
mockPerformerReaderWriter.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
db.Performer.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
|
||||
p := args.Get(1).(*models.Performer)
|
||||
p.ID = validStoredID
|
||||
}).Return(nil)
|
||||
|
|
@ -131,7 +132,7 @@ func Test_getPerformerID(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := getPerformerID(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p, tt.args.createMissing, tt.args.skipSingleName)
|
||||
got, err := getPerformerID(testCtx, tt.args.endpoint, db.Performer, tt.args.p, tt.args.createMissing, tt.args.skipSingleName)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("getPerformerID() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
|
@ -151,15 +152,16 @@ func Test_createMissingPerformer(t *testing.T) {
|
|||
invalidName := "invalidName"
|
||||
performerID := 1
|
||||
|
||||
mockPerformerReaderWriter := mocks.PerformerReaderWriter{}
|
||||
mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool {
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
db.Performer.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool {
|
||||
return p.Name == validName
|
||||
})).Run(func(args mock.Arguments) {
|
||||
p := args.Get(1).(*models.Performer)
|
||||
p.ID = performerID
|
||||
}).Return(nil)
|
||||
|
||||
mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool {
|
||||
db.Performer.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool {
|
||||
return p.Name == invalidName
|
||||
})).Return(errors.New("error creating performer"))
|
||||
|
||||
|
|
@ -212,7 +214,7 @@ func Test_createMissingPerformer(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := createMissingPerformer(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p)
|
||||
got, err := createMissingPerformer(testCtx, tt.args.endpoint, db.Performer, tt.args.p)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("createMissingPerformer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -24,14 +24,15 @@ func Test_sceneRelationships_studio(t *testing.T) {
|
|||
Strategy: FieldStrategyMerge,
|
||||
}
|
||||
|
||||
mockStudioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
mockStudioReaderWriter.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
db.Studio.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = validStoredIDInt
|
||||
}).Return(nil)
|
||||
|
||||
tr := sceneRelationships{
|
||||
studioReaderWriter: mockStudioReaderWriter,
|
||||
studioReaderWriter: db.Studio,
|
||||
fieldOptions: make(map[string]*FieldOptions),
|
||||
}
|
||||
|
||||
|
|
@ -174,8 +175,10 @@ func Test_sceneRelationships_performers(t *testing.T) {
|
|||
}),
|
||||
}
|
||||
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
tr := sceneRelationships{
|
||||
sceneReader: &mocks.SceneReaderWriter{},
|
||||
sceneReader: db.Scene,
|
||||
fieldOptions: make(map[string]*FieldOptions),
|
||||
}
|
||||
|
||||
|
|
@ -363,22 +366,21 @@ func Test_sceneRelationships_tags(t *testing.T) {
|
|||
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
|
||||
}
|
||||
|
||||
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
|
||||
mockTagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool {
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool {
|
||||
return p.Name == validName
|
||||
})).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = validStoredIDInt
|
||||
}).Return(nil)
|
||||
mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool {
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool {
|
||||
return p.Name == invalidName
|
||||
})).Return(errors.New("error creating tag"))
|
||||
|
||||
tr := sceneRelationships{
|
||||
sceneReader: mockSceneReaderWriter,
|
||||
tagCreator: mockTagReaderWriter,
|
||||
sceneReader: db.Scene,
|
||||
tagCreator: db.Tag,
|
||||
fieldOptions: make(map[string]*FieldOptions),
|
||||
}
|
||||
|
||||
|
|
@ -552,10 +554,10 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
|
|||
}),
|
||||
}
|
||||
|
||||
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
tr := sceneRelationships{
|
||||
sceneReader: mockSceneReaderWriter,
|
||||
sceneReader: db.Scene,
|
||||
fieldOptions: make(map[string]*FieldOptions),
|
||||
}
|
||||
|
||||
|
|
@ -706,12 +708,13 @@ func Test_sceneRelationships_cover(t *testing.T) {
|
|||
newDataEncoded := base64Prefix + utils.GetBase64StringFromData(newData)
|
||||
invalidData := newDataEncoded + "!!!"
|
||||
|
||||
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
|
||||
mockSceneReaderWriter.On("GetCover", testCtx, sceneID).Return(existingData, nil)
|
||||
mockSceneReaderWriter.On("GetCover", testCtx, errSceneID).Return(nil, errors.New("error getting cover"))
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
db.Scene.On("GetCover", testCtx, sceneID).Return(existingData, nil)
|
||||
db.Scene.On("GetCover", testCtx, errSceneID).Return(nil, errors.New("error getting cover"))
|
||||
|
||||
tr := sceneRelationships{
|
||||
sceneReader: mockSceneReaderWriter,
|
||||
sceneReader: db.Scene,
|
||||
fieldOptions: make(map[string]*FieldOptions),
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,18 +19,19 @@ func Test_createMissingStudio(t *testing.T) {
|
|||
invalidName := "invalidName"
|
||||
createdID := 1
|
||||
|
||||
mockStudioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
|
||||
return p.Name == validName
|
||||
})).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = createdID
|
||||
}).Return(nil)
|
||||
mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
|
||||
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
|
||||
return p.Name == invalidName
|
||||
})).Return(errors.New("error creating studio"))
|
||||
|
||||
mockStudioReaderWriter.On("UpdatePartial", testCtx, models.StudioPartial{
|
||||
db.Studio.On("UpdatePartial", testCtx, models.StudioPartial{
|
||||
ID: createdID,
|
||||
StashIDs: &models.UpdateStashIDs{
|
||||
StashIDs: []models.StashID{
|
||||
|
|
@ -42,7 +43,7 @@ func Test_createMissingStudio(t *testing.T) {
|
|||
Mode: models.RelationshipUpdateModeSet,
|
||||
},
|
||||
}).Return(nil, errors.New("error updating stash ids"))
|
||||
mockStudioReaderWriter.On("UpdatePartial", testCtx, models.StudioPartial{
|
||||
db.Studio.On("UpdatePartial", testCtx, models.StudioPartial{
|
||||
ID: createdID,
|
||||
StashIDs: &models.UpdateStashIDs{
|
||||
StashIDs: []models.StashID{
|
||||
|
|
@ -106,7 +107,7 @@ func Test_createMissingStudio(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := createMissingStudio(testCtx, tt.args.endpoint, mockStudioReaderWriter, tt.args.studio)
|
||||
got, err := createMissingStudio(testCtx, tt.args.endpoint, db.Studio, tt.args.studio)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("createMissingStudio() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ type Manager struct {
|
|||
DLNAService *dlna.Service
|
||||
|
||||
Database *sqlite.Database
|
||||
Repository Repository
|
||||
Repository models.Repository
|
||||
|
||||
SceneService SceneService
|
||||
ImageService ImageService
|
||||
|
|
@ -174,6 +174,7 @@ func initialize() error {
|
|||
initProfiling(cfg.GetCPUProfilePath())
|
||||
|
||||
db := sqlite.NewDatabase()
|
||||
repo := db.Repository()
|
||||
|
||||
// start with empty paths
|
||||
emptyPaths := paths.Paths{}
|
||||
|
|
@ -186,49 +187,43 @@ func initialize() error {
|
|||
PluginCache: plugin.NewCache(cfg),
|
||||
|
||||
Database: db,
|
||||
Repository: sqliteRepository(db),
|
||||
Repository: repo,
|
||||
Paths: &emptyPaths,
|
||||
|
||||
scanSubs: &subscriptionManager{},
|
||||
}
|
||||
|
||||
instance.SceneService = &scene.Service{
|
||||
File: db.File,
|
||||
Repository: db.Scene,
|
||||
MarkerRepository: db.SceneMarker,
|
||||
File: repo.File,
|
||||
Repository: repo.Scene,
|
||||
MarkerRepository: repo.SceneMarker,
|
||||
PluginCache: instance.PluginCache,
|
||||
Paths: instance.Paths,
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
instance.ImageService = &image.Service{
|
||||
File: db.File,
|
||||
Repository: db.Image,
|
||||
File: repo.File,
|
||||
Repository: repo.Image,
|
||||
}
|
||||
|
||||
instance.GalleryService = &gallery.Service{
|
||||
Repository: db.Gallery,
|
||||
ImageFinder: db.Image,
|
||||
Repository: repo.Gallery,
|
||||
ImageFinder: repo.Image,
|
||||
ImageService: instance.ImageService,
|
||||
File: db.File,
|
||||
Folder: db.Folder,
|
||||
File: repo.File,
|
||||
Folder: repo.Folder,
|
||||
}
|
||||
|
||||
instance.JobManager = initJobManager()
|
||||
|
||||
sceneServer := SceneServer{
|
||||
TxnManager: instance.Repository,
|
||||
SceneCoverGetter: instance.Repository.Scene,
|
||||
TxnManager: repo.TxnManager,
|
||||
SceneCoverGetter: repo.Scene,
|
||||
}
|
||||
|
||||
instance.DLNAService = dlna.NewService(instance.Repository, dlna.Repository{
|
||||
SceneFinder: instance.Repository.Scene,
|
||||
FileGetter: instance.Repository.File,
|
||||
StudioFinder: instance.Repository.Studio,
|
||||
TagFinder: instance.Repository.Tag,
|
||||
PerformerFinder: instance.Repository.Performer,
|
||||
MovieFinder: instance.Repository.Movie,
|
||||
}, instance.Config, &sceneServer)
|
||||
dlnaRepository := dlna.NewRepository(repo)
|
||||
instance.DLNAService = dlna.NewService(dlnaRepository, cfg, &sceneServer)
|
||||
|
||||
if !cfg.IsNewSystem() {
|
||||
logger.Infof("using config file: %s", cfg.GetConfigFile())
|
||||
|
|
@ -268,8 +263,8 @@ func initialize() error {
|
|||
logger.Warnf("could not initialize FFMPEG subsystem: %v", err)
|
||||
}
|
||||
|
||||
instance.Scanner = makeScanner(db, instance.PluginCache)
|
||||
instance.Cleaner = makeCleaner(db, instance.PluginCache)
|
||||
instance.Scanner = makeScanner(repo, instance.PluginCache)
|
||||
instance.Cleaner = makeCleaner(repo, instance.PluginCache)
|
||||
|
||||
// if DLNA is enabled, start it now
|
||||
if instance.Config.GetDLNADefaultEnabled() {
|
||||
|
|
@ -293,14 +288,9 @@ func galleryFileFilter(ctx context.Context, f models.File) bool {
|
|||
return isZip(f.Base().Basename)
|
||||
}
|
||||
|
||||
func makeScanner(db *sqlite.Database, pluginCache *plugin.Cache) *file.Scanner {
|
||||
func makeScanner(repo models.Repository, pluginCache *plugin.Cache) *file.Scanner {
|
||||
return &file.Scanner{
|
||||
Repository: file.Repository{
|
||||
Manager: db,
|
||||
DatabaseProvider: db,
|
||||
FileStore: db.File,
|
||||
FolderStore: db.Folder,
|
||||
},
|
||||
Repository: file.NewRepository(repo),
|
||||
FileDecorators: []file.Decorator{
|
||||
&file.FilteredDecorator{
|
||||
Decorator: &video.Decorator{
|
||||
|
|
@ -320,15 +310,10 @@ func makeScanner(db *sqlite.Database, pluginCache *plugin.Cache) *file.Scanner {
|
|||
}
|
||||
}
|
||||
|
||||
func makeCleaner(db *sqlite.Database, pluginCache *plugin.Cache) *file.Cleaner {
|
||||
func makeCleaner(repo models.Repository, pluginCache *plugin.Cache) *file.Cleaner {
|
||||
return &file.Cleaner{
|
||||
FS: &file.OsFS{},
|
||||
Repository: file.Repository{
|
||||
Manager: db,
|
||||
DatabaseProvider: db,
|
||||
FileStore: db.File,
|
||||
FolderStore: db.Folder,
|
||||
},
|
||||
FS: &file.OsFS{},
|
||||
Repository: file.NewRepository(repo),
|
||||
Handlers: []file.CleanHandler{
|
||||
&cleanHandler{},
|
||||
},
|
||||
|
|
@ -523,14 +508,8 @@ func writeStashIcon() {
|
|||
|
||||
// initScraperCache initializes a new scraper cache and returns it.
|
||||
func (s *Manager) initScraperCache() *scraper.Cache {
|
||||
ret, err := scraper.NewCache(config.GetInstance(), s.Repository, scraper.Repository{
|
||||
SceneFinder: s.Repository.Scene,
|
||||
GalleryFinder: s.Repository.Gallery,
|
||||
TagFinder: s.Repository.Tag,
|
||||
PerformerFinder: s.Repository.Performer,
|
||||
MovieFinder: s.Repository.Movie,
|
||||
StudioFinder: s.Repository.Studio,
|
||||
})
|
||||
scraperRepository := scraper.NewRepository(s.Repository)
|
||||
ret, err := scraper.NewCache(s.Config, scraperRepository)
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("Error reading scraper configs: %s", err.Error())
|
||||
|
|
@ -697,7 +676,7 @@ func (s *Manager) Setup(ctx context.Context, input SetupInput) error {
|
|||
return fmt.Errorf("error initializing FFMPEG subsystem: %v", err)
|
||||
}
|
||||
|
||||
instance.Scanner = makeScanner(instance.Database, instance.PluginCache)
|
||||
instance.Scanner = makeScanner(instance.Repository, instance.PluginCache)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -112,7 +112,8 @@ func (s *Manager) Import(ctx context.Context) (int, error) {
|
|||
|
||||
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
|
||||
task := ImportTask{
|
||||
txnManager: s.Repository,
|
||||
repository: s.Repository,
|
||||
resetter: s.Database,
|
||||
BaseDir: metadataPath,
|
||||
Reset: true,
|
||||
DuplicateBehaviour: ImportDuplicateEnumFail,
|
||||
|
|
@ -136,7 +137,7 @@ func (s *Manager) Export(ctx context.Context) (int, error) {
|
|||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
task := ExportTask{
|
||||
txnManager: s.Repository,
|
||||
repository: s.Repository,
|
||||
full: true,
|
||||
fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(),
|
||||
}
|
||||
|
|
@ -167,7 +168,7 @@ func (s *Manager) Generate(ctx context.Context, input GenerateMetadataInput) (in
|
|||
}
|
||||
|
||||
j := &GenerateJob{
|
||||
txnManager: s.Repository,
|
||||
repository: s.Repository,
|
||||
input: input,
|
||||
}
|
||||
|
||||
|
|
@ -212,7 +213,7 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl
|
|||
}
|
||||
|
||||
task := GenerateCoverTask{
|
||||
txnManager: s.Repository,
|
||||
repository: s.Repository,
|
||||
Scene: *scene,
|
||||
ScreenshotAt: at,
|
||||
Overwrite: true,
|
||||
|
|
@ -239,7 +240,7 @@ type AutoTagMetadataInput struct {
|
|||
|
||||
func (s *Manager) AutoTag(ctx context.Context, input AutoTagMetadataInput) int {
|
||||
j := autoTagJob{
|
||||
txnManager: s.Repository,
|
||||
repository: s.Repository,
|
||||
input: input,
|
||||
}
|
||||
|
||||
|
|
@ -255,7 +256,7 @@ type CleanMetadataInput struct {
|
|||
func (s *Manager) Clean(ctx context.Context, input CleanMetadataInput) int {
|
||||
j := cleanJob{
|
||||
cleaner: s.Cleaner,
|
||||
txnManager: s.Repository,
|
||||
repository: s.Repository,
|
||||
sceneService: s.SceneService,
|
||||
imageService: s.ImageService,
|
||||
input: input,
|
||||
|
|
|
|||
|
|
@ -6,59 +6,8 @@ import (
|
|||
"github.com/stashapp/stash/pkg/image"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/scene"
|
||||
"github.com/stashapp/stash/pkg/sqlite"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
type Repository struct {
|
||||
models.TxnManager
|
||||
|
||||
File models.FileReaderWriter
|
||||
Folder models.FolderReaderWriter
|
||||
Gallery models.GalleryReaderWriter
|
||||
GalleryChapter models.GalleryChapterReaderWriter
|
||||
Image models.ImageReaderWriter
|
||||
Movie models.MovieReaderWriter
|
||||
Performer models.PerformerReaderWriter
|
||||
Scene models.SceneReaderWriter
|
||||
SceneMarker models.SceneMarkerReaderWriter
|
||||
Studio models.StudioReaderWriter
|
||||
Tag models.TagReaderWriter
|
||||
SavedFilter models.SavedFilterReaderWriter
|
||||
}
|
||||
|
||||
func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error {
|
||||
return txn.WithTxn(ctx, r, fn)
|
||||
}
|
||||
|
||||
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
|
||||
return txn.WithReadTxn(ctx, r, fn)
|
||||
}
|
||||
|
||||
func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error {
|
||||
return txn.WithDatabase(ctx, r, fn)
|
||||
}
|
||||
|
||||
func sqliteRepository(d *sqlite.Database) Repository {
|
||||
txnRepo := d.TxnRepository()
|
||||
|
||||
return Repository{
|
||||
TxnManager: txnRepo,
|
||||
File: d.File,
|
||||
Folder: d.Folder,
|
||||
Gallery: d.Gallery,
|
||||
GalleryChapter: txnRepo.GalleryChapter,
|
||||
Image: d.Image,
|
||||
Movie: txnRepo.Movie,
|
||||
Performer: txnRepo.Performer,
|
||||
Scene: d.Scene,
|
||||
SceneMarker: txnRepo.SceneMarker,
|
||||
Studio: txnRepo.Studio,
|
||||
Tag: txnRepo.Tag,
|
||||
SavedFilter: txnRepo.SavedFilter,
|
||||
}
|
||||
}
|
||||
|
||||
type SceneService interface {
|
||||
Create(ctx context.Context, input *models.Scene, fileIDs []models.FileID, coverImage []byte) (*models.Scene, error)
|
||||
AssignFile(ctx context.Context, sceneID int, fileID models.FileID) error
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package manager
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/stashapp/stash/internal/manager/config"
|
||||
|
|
@ -58,8 +57,6 @@ func (s *SceneServer) StreamSceneDirect(scene *models.Scene, w http.ResponseWrit
|
|||
}
|
||||
|
||||
func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter, r *http.Request) {
|
||||
const defaultSceneImage = "scene/scene.svg"
|
||||
|
||||
var cover []byte
|
||||
readTxnErr := txn.WithReadTxn(r.Context(), s.TxnManager, func(ctx context.Context) error {
|
||||
cover, _ = s.SceneCoverGetter.GetCover(ctx, scene.ID)
|
||||
|
|
@ -92,10 +89,7 @@ func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter
|
|||
}
|
||||
|
||||
// fallback to default cover if none found
|
||||
// should always be there
|
||||
f, _ := static.Scene.Open(defaultSceneImage)
|
||||
defer f.Close()
|
||||
cover, _ = io.ReadAll(f)
|
||||
cover = static.ReadAll(static.DefaultSceneImage)
|
||||
}
|
||||
|
||||
utils.ServeImage(w, r, cover)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import (
|
|||
)
|
||||
|
||||
type autoTagJob struct {
|
||||
txnManager Repository
|
||||
repository models.Repository
|
||||
input AutoTagMetadataInput
|
||||
|
||||
cache match.Cache
|
||||
|
|
@ -56,7 +56,7 @@ func (j *autoTagJob) autoTagFiles(ctx context.Context, progress *job.Progress, p
|
|||
studios: studios,
|
||||
tags: tags,
|
||||
progress: progress,
|
||||
txnManager: j.txnManager,
|
||||
repository: j.repository,
|
||||
cache: &j.cache,
|
||||
}
|
||||
|
||||
|
|
@ -73,8 +73,8 @@ func (j *autoTagJob) autoTagSpecific(ctx context.Context, progress *job.Progress
|
|||
studioCount := len(studioIds)
|
||||
tagCount := len(tagIds)
|
||||
|
||||
if err := j.txnManager.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
r := j.txnManager
|
||||
r := j.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
performerQuery := r.Performer
|
||||
studioQuery := r.Studio
|
||||
tagQuery := r.Tag
|
||||
|
|
@ -123,16 +123,17 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
|
|||
return
|
||||
}
|
||||
|
||||
r := j.repository
|
||||
tagger := autotag.Tagger{
|
||||
TxnManager: j.txnManager,
|
||||
TxnManager: r.TxnManager,
|
||||
Cache: &j.cache,
|
||||
}
|
||||
|
||||
for _, performerId := range performerIds {
|
||||
var performers []*models.Performer
|
||||
|
||||
if err := j.txnManager.WithDB(ctx, func(ctx context.Context) error {
|
||||
performerQuery := j.txnManager.Performer
|
||||
if err := r.WithDB(ctx, func(ctx context.Context) error {
|
||||
performerQuery := r.Performer
|
||||
ignoreAutoTag := false
|
||||
perPage := -1
|
||||
|
||||
|
|
@ -161,7 +162,7 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
|
|||
return fmt.Errorf("performer with id %s not found", performerId)
|
||||
}
|
||||
|
||||
if err := performer.LoadAliases(ctx, j.txnManager.Performer); err != nil {
|
||||
if err := performer.LoadAliases(ctx, r.Performer); err != nil {
|
||||
return fmt.Errorf("loading aliases for performer %d: %w", performer.ID, err)
|
||||
}
|
||||
performers = append(performers, performer)
|
||||
|
|
@ -173,7 +174,6 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
|
|||
}
|
||||
|
||||
err := func() error {
|
||||
r := j.txnManager
|
||||
if err := tagger.PerformerScenes(ctx, performer, paths, r.Scene); err != nil {
|
||||
return fmt.Errorf("processing scenes: %w", err)
|
||||
}
|
||||
|
|
@ -215,9 +215,9 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress,
|
|||
return
|
||||
}
|
||||
|
||||
r := j.txnManager
|
||||
r := j.repository
|
||||
tagger := autotag.Tagger{
|
||||
TxnManager: j.txnManager,
|
||||
TxnManager: r.TxnManager,
|
||||
Cache: &j.cache,
|
||||
}
|
||||
|
||||
|
|
@ -308,15 +308,15 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa
|
|||
return
|
||||
}
|
||||
|
||||
r := j.txnManager
|
||||
r := j.repository
|
||||
tagger := autotag.Tagger{
|
||||
TxnManager: j.txnManager,
|
||||
TxnManager: r.TxnManager,
|
||||
Cache: &j.cache,
|
||||
}
|
||||
|
||||
for _, tagId := range tagIds {
|
||||
var tags []*models.Tag
|
||||
if err := j.txnManager.WithDB(ctx, func(ctx context.Context) error {
|
||||
if err := r.WithDB(ctx, func(ctx context.Context) error {
|
||||
tagQuery := r.Tag
|
||||
ignoreAutoTag := false
|
||||
perPage := -1
|
||||
|
|
@ -402,7 +402,7 @@ type autoTagFilesTask struct {
|
|||
tags bool
|
||||
|
||||
progress *job.Progress
|
||||
txnManager Repository
|
||||
repository models.Repository
|
||||
cache *match.Cache
|
||||
}
|
||||
|
||||
|
|
@ -482,7 +482,9 @@ func (t *autoTagFilesTask) makeGalleryFilter() *models.GalleryFilterType {
|
|||
return ret
|
||||
}
|
||||
|
||||
func (t *autoTagFilesTask) getCount(ctx context.Context, r Repository) (int, error) {
|
||||
func (t *autoTagFilesTask) getCount(ctx context.Context) (int, error) {
|
||||
r := t.repository
|
||||
|
||||
pp := 0
|
||||
findFilter := &models.FindFilterType{
|
||||
PerPage: &pp,
|
||||
|
|
@ -522,7 +524,7 @@ func (t *autoTagFilesTask) getCount(ctx context.Context, r Repository) (int, err
|
|||
return sceneCount + imageCount + galleryCount, nil
|
||||
}
|
||||
|
||||
func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) {
|
||||
func (t *autoTagFilesTask) processScenes(ctx context.Context) {
|
||||
if job.IsCancelled(ctx) {
|
||||
return
|
||||
}
|
||||
|
|
@ -534,10 +536,12 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) {
|
|||
findFilter := models.BatchFindFilter(batchSize)
|
||||
sceneFilter := t.makeSceneFilter()
|
||||
|
||||
r := t.repository
|
||||
|
||||
more := true
|
||||
for more {
|
||||
var scenes []*models.Scene
|
||||
if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
scenes, err = scene.Query(ctx, r.Scene, sceneFilter, findFilter)
|
||||
return err
|
||||
|
|
@ -555,7 +559,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) {
|
|||
}
|
||||
|
||||
tt := autoTagSceneTask{
|
||||
txnManager: t.txnManager,
|
||||
repository: r,
|
||||
scene: ss,
|
||||
performers: t.performers,
|
||||
studios: t.studios,
|
||||
|
|
@ -583,7 +587,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) {
|
||||
func (t *autoTagFilesTask) processImages(ctx context.Context) {
|
||||
if job.IsCancelled(ctx) {
|
||||
return
|
||||
}
|
||||
|
|
@ -595,10 +599,12 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) {
|
|||
findFilter := models.BatchFindFilter(batchSize)
|
||||
imageFilter := t.makeImageFilter()
|
||||
|
||||
r := t.repository
|
||||
|
||||
more := true
|
||||
for more {
|
||||
var images []*models.Image
|
||||
if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
images, err = image.Query(ctx, r.Image, imageFilter, findFilter)
|
||||
return err
|
||||
|
|
@ -616,7 +622,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) {
|
|||
}
|
||||
|
||||
tt := autoTagImageTask{
|
||||
txnManager: t.txnManager,
|
||||
repository: t.repository,
|
||||
image: ss,
|
||||
performers: t.performers,
|
||||
studios: t.studios,
|
||||
|
|
@ -644,7 +650,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) {
|
||||
func (t *autoTagFilesTask) processGalleries(ctx context.Context) {
|
||||
if job.IsCancelled(ctx) {
|
||||
return
|
||||
}
|
||||
|
|
@ -656,10 +662,12 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) {
|
|||
findFilter := models.BatchFindFilter(batchSize)
|
||||
galleryFilter := t.makeGalleryFilter()
|
||||
|
||||
r := t.repository
|
||||
|
||||
more := true
|
||||
for more {
|
||||
var galleries []*models.Gallery
|
||||
if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
galleries, _, err = r.Gallery.Query(ctx, galleryFilter, findFilter)
|
||||
return err
|
||||
|
|
@ -677,7 +685,7 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) {
|
|||
}
|
||||
|
||||
tt := autoTagGalleryTask{
|
||||
txnManager: t.txnManager,
|
||||
repository: t.repository,
|
||||
gallery: ss,
|
||||
performers: t.performers,
|
||||
studios: t.studios,
|
||||
|
|
@ -706,9 +714,8 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) {
|
|||
}
|
||||
|
||||
func (t *autoTagFilesTask) process(ctx context.Context) {
|
||||
r := t.txnManager
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
total, err := t.getCount(ctx, t.txnManager)
|
||||
if err := t.repository.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
total, err := t.getCount(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -724,13 +731,13 @@ func (t *autoTagFilesTask) process(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
t.processScenes(ctx, r)
|
||||
t.processImages(ctx, r)
|
||||
t.processGalleries(ctx, r)
|
||||
t.processScenes(ctx)
|
||||
t.processImages(ctx)
|
||||
t.processGalleries(ctx)
|
||||
}
|
||||
|
||||
type autoTagSceneTask struct {
|
||||
txnManager Repository
|
||||
repository models.Repository
|
||||
scene *models.Scene
|
||||
|
||||
performers bool
|
||||
|
|
@ -742,8 +749,8 @@ type autoTagSceneTask struct {
|
|||
|
||||
func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
r := t.txnManager
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
r := t.repository
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
if t.scene.Path == "" {
|
||||
// nothing to do
|
||||
return nil
|
||||
|
|
@ -774,7 +781,7 @@ func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
|
|||
}
|
||||
|
||||
type autoTagImageTask struct {
|
||||
txnManager Repository
|
||||
repository models.Repository
|
||||
image *models.Image
|
||||
|
||||
performers bool
|
||||
|
|
@ -786,8 +793,8 @@ type autoTagImageTask struct {
|
|||
|
||||
func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
r := t.txnManager
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
r := t.repository
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
if t.performers {
|
||||
if err := autotag.ImagePerformers(ctx, t.image, r.Image, r.Performer, t.cache); err != nil {
|
||||
return fmt.Errorf("tagging image performers for %s: %v", t.image.DisplayName(), err)
|
||||
|
|
@ -813,7 +820,7 @@ func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
|
|||
}
|
||||
|
||||
type autoTagGalleryTask struct {
|
||||
txnManager Repository
|
||||
repository models.Repository
|
||||
gallery *models.Gallery
|
||||
|
||||
performers bool
|
||||
|
|
@ -825,8 +832,8 @@ type autoTagGalleryTask struct {
|
|||
|
||||
func (t *autoTagGalleryTask) Start(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
r := t.txnManager
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
r := t.repository
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
if t.performers {
|
||||
if err := autotag.GalleryPerformers(ctx, t.gallery, r.Gallery, r.Performer, t.cache); err != nil {
|
||||
return fmt.Errorf("tagging gallery performers for %s: %v", t.gallery.DisplayName(), err)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import (
|
|||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
"github.com/stashapp/stash/pkg/scene"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
type cleaner interface {
|
||||
|
|
@ -25,7 +24,7 @@ type cleaner interface {
|
|||
|
||||
type cleanJob struct {
|
||||
cleaner cleaner
|
||||
txnManager Repository
|
||||
repository models.Repository
|
||||
input CleanMetadataInput
|
||||
sceneService SceneService
|
||||
imageService ImageService
|
||||
|
|
@ -61,10 +60,11 @@ func (j *cleanJob) cleanEmptyGalleries(ctx context.Context) {
|
|||
const batchSize = 1000
|
||||
var toClean []int
|
||||
findFilter := models.BatchFindFilter(batchSize)
|
||||
if err := txn.WithTxn(ctx, j.txnManager, func(ctx context.Context) error {
|
||||
r := j.repository
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
found := true
|
||||
for found {
|
||||
emptyGalleries, _, err := j.txnManager.Gallery.Query(ctx, &models.GalleryFilterType{
|
||||
emptyGalleries, _, err := r.Gallery.Query(ctx, &models.GalleryFilterType{
|
||||
ImageCount: &models.IntCriterionInput{
|
||||
Value: 0,
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
|
|
@ -108,9 +108,10 @@ func (j *cleanJob) cleanEmptyGalleries(ctx context.Context) {
|
|||
|
||||
func (j *cleanJob) deleteGallery(ctx context.Context, id int) {
|
||||
pluginCache := GetInstance().PluginCache
|
||||
qb := j.txnManager.Gallery
|
||||
|
||||
if err := txn.WithTxn(ctx, j.txnManager, func(ctx context.Context) error {
|
||||
r := j.repository
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Gallery
|
||||
g, err := qb.Find(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -120,7 +121,7 @@ func (j *cleanJob) deleteGallery(ctx context.Context, id int) {
|
|||
return fmt.Errorf("gallery with id %d not found", id)
|
||||
}
|
||||
|
||||
if err := g.LoadPrimaryFile(ctx, j.txnManager.File); err != nil {
|
||||
if err := g.LoadPrimaryFile(ctx, r.File); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -253,9 +254,7 @@ func (f *cleanFilter) shouldCleanImage(path string, stash *config.StashConfig) b
|
|||
return false
|
||||
}
|
||||
|
||||
type cleanHandler struct {
|
||||
PluginCache *plugin.Cache
|
||||
}
|
||||
type cleanHandler struct{}
|
||||
|
||||
func (h *cleanHandler) HandleFile(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error {
|
||||
if err := h.handleRelatedScenes(ctx, fileDeleter, fileID); err != nil {
|
||||
|
|
@ -277,7 +276,7 @@ func (h *cleanHandler) HandleFolder(ctx context.Context, fileDeleter *file.Delet
|
|||
|
||||
func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error {
|
||||
mgr := GetInstance()
|
||||
sceneQB := mgr.Database.Scene
|
||||
sceneQB := mgr.Repository.Scene
|
||||
scenes, err := sceneQB.FindByFileID(ctx, fileID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -303,12 +302,9 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil
|
|||
return err
|
||||
}
|
||||
|
||||
checksum := scene.Checksum
|
||||
oshash := scene.OSHash
|
||||
|
||||
mgr.PluginCache.RegisterPostHooks(ctx, scene.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{
|
||||
Checksum: checksum,
|
||||
OSHash: oshash,
|
||||
Checksum: scene.Checksum,
|
||||
OSHash: scene.OSHash,
|
||||
Path: scene.Path,
|
||||
}, nil)
|
||||
} else {
|
||||
|
|
@ -335,7 +331,7 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil
|
|||
|
||||
func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID models.FileID) error {
|
||||
mgr := GetInstance()
|
||||
qb := mgr.Database.Gallery
|
||||
qb := mgr.Repository.Gallery
|
||||
galleries, err := qb.FindByFileID(ctx, fileID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -381,7 +377,7 @@ func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID models
|
|||
|
||||
func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderID models.FolderID) error {
|
||||
mgr := GetInstance()
|
||||
qb := mgr.Database.Gallery
|
||||
qb := mgr.Repository.Gallery
|
||||
galleries, err := qb.FindByFolderID(ctx, folderID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -405,7 +401,7 @@ func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderI
|
|||
|
||||
func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error {
|
||||
mgr := GetInstance()
|
||||
imageQB := mgr.Database.Image
|
||||
imageQB := mgr.Repository.Image
|
||||
images, err := imageQB.FindByFileID(ctx, fileID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -413,7 +409,7 @@ func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *fil
|
|||
|
||||
imageFileDeleter := &image.FileDeleter{
|
||||
Deleter: fileDeleter,
|
||||
Paths: GetInstance().Paths,
|
||||
Paths: mgr.Paths,
|
||||
}
|
||||
|
||||
for _, i := range images {
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ import (
|
|||
)
|
||||
|
||||
type ExportTask struct {
|
||||
txnManager Repository
|
||||
repository models.Repository
|
||||
full bool
|
||||
|
||||
baseDir string
|
||||
|
|
@ -98,7 +98,7 @@ func CreateExportTask(a models.HashAlgorithm, input ExportObjectsInput) *ExportT
|
|||
}
|
||||
|
||||
return &ExportTask{
|
||||
txnManager: GetInstance().Repository,
|
||||
repository: GetInstance().Repository,
|
||||
fileNamingAlgorithm: a,
|
||||
scenes: newExportSpec(input.Scenes),
|
||||
images: newExportSpec(input.Images),
|
||||
|
|
@ -148,29 +148,27 @@ func (t *ExportTask) Start(ctx context.Context, wg *sync.WaitGroup) {
|
|||
paths.EmptyJSONDirs(t.baseDir)
|
||||
paths.EnsureJSONDirs(t.baseDir)
|
||||
|
||||
txnErr := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
r := t.txnManager
|
||||
|
||||
txnErr := t.repository.WithTxn(ctx, func(ctx context.Context) error {
|
||||
// include movie scenes and gallery images
|
||||
if !t.full {
|
||||
// only include movie scenes if includeDependencies is also set
|
||||
if !t.scenes.all && t.includeDependencies {
|
||||
t.populateMovieScenes(ctx, r)
|
||||
t.populateMovieScenes(ctx)
|
||||
}
|
||||
|
||||
// always export gallery images
|
||||
if !t.images.all {
|
||||
t.populateGalleryImages(ctx, r)
|
||||
t.populateGalleryImages(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
t.ExportScenes(ctx, workerCount, r)
|
||||
t.ExportImages(ctx, workerCount, r)
|
||||
t.ExportGalleries(ctx, workerCount, r)
|
||||
t.ExportMovies(ctx, workerCount, r)
|
||||
t.ExportPerformers(ctx, workerCount, r)
|
||||
t.ExportStudios(ctx, workerCount, r)
|
||||
t.ExportTags(ctx, workerCount, r)
|
||||
t.ExportScenes(ctx, workerCount)
|
||||
t.ExportImages(ctx, workerCount)
|
||||
t.ExportGalleries(ctx, workerCount)
|
||||
t.ExportMovies(ctx, workerCount)
|
||||
t.ExportPerformers(ctx, workerCount)
|
||||
t.ExportStudios(ctx, workerCount)
|
||||
t.ExportTags(ctx, workerCount)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
|
@ -277,9 +275,10 @@ func (t *ExportTask) zipFile(fn, outDir string, z *zip.Writer) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *ExportTask) populateMovieScenes(ctx context.Context, repo Repository) {
|
||||
reader := repo.Movie
|
||||
sceneReader := repo.Scene
|
||||
func (t *ExportTask) populateMovieScenes(ctx context.Context) {
|
||||
r := t.repository
|
||||
reader := r.Movie
|
||||
sceneReader := r.Scene
|
||||
|
||||
var movies []*models.Movie
|
||||
var err error
|
||||
|
|
@ -307,9 +306,10 @@ func (t *ExportTask) populateMovieScenes(ctx context.Context, repo Repository) {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *ExportTask) populateGalleryImages(ctx context.Context, repo Repository) {
|
||||
reader := repo.Gallery
|
||||
imageReader := repo.Image
|
||||
func (t *ExportTask) populateGalleryImages(ctx context.Context) {
|
||||
r := t.repository
|
||||
reader := r.Gallery
|
||||
imageReader := r.Image
|
||||
|
||||
var galleries []*models.Gallery
|
||||
var err error
|
||||
|
|
@ -342,10 +342,10 @@ func (t *ExportTask) populateGalleryImages(ctx context.Context, repo Repository)
|
|||
}
|
||||
}
|
||||
|
||||
func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo Repository) {
|
||||
func (t *ExportTask) ExportScenes(ctx context.Context, workers int) {
|
||||
var scenesWg sync.WaitGroup
|
||||
|
||||
sceneReader := repo.Scene
|
||||
sceneReader := t.repository.Scene
|
||||
|
||||
var scenes []*models.Scene
|
||||
var err error
|
||||
|
|
@ -367,7 +367,7 @@ func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo Reposit
|
|||
|
||||
for w := 0; w < workers; w++ { // create export Scene workers
|
||||
scenesWg.Add(1)
|
||||
go exportScene(ctx, &scenesWg, jobCh, repo, t)
|
||||
go t.exportScene(ctx, &scenesWg, jobCh)
|
||||
}
|
||||
|
||||
for i, scene := range scenes {
|
||||
|
|
@ -385,7 +385,7 @@ func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo Reposit
|
|||
logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers)
|
||||
}
|
||||
|
||||
func exportFile(f models.File, t *ExportTask) {
|
||||
func (t *ExportTask) exportFile(f models.File) {
|
||||
newFileJSON := fileToJSON(f)
|
||||
|
||||
fn := newFileJSON.Filename()
|
||||
|
|
@ -449,7 +449,7 @@ func fileToJSON(f models.File) jsonschema.DirEntry {
|
|||
return &base
|
||||
}
|
||||
|
||||
func exportFolder(f models.Folder, t *ExportTask) {
|
||||
func (t *ExportTask) exportFolder(f models.Folder) {
|
||||
newFileJSON := folderToJSON(f)
|
||||
|
||||
fn := newFileJSON.Filename()
|
||||
|
|
@ -475,15 +475,17 @@ func folderToJSON(f models.Folder) jsonschema.DirEntry {
|
|||
return &base
|
||||
}
|
||||
|
||||
func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo Repository, t *ExportTask) {
|
||||
func (t *ExportTask) exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Scene) {
|
||||
defer wg.Done()
|
||||
sceneReader := repo.Scene
|
||||
studioReader := repo.Studio
|
||||
movieReader := repo.Movie
|
||||
galleryReader := repo.Gallery
|
||||
performerReader := repo.Performer
|
||||
tagReader := repo.Tag
|
||||
sceneMarkerReader := repo.SceneMarker
|
||||
|
||||
r := t.repository
|
||||
sceneReader := r.Scene
|
||||
studioReader := r.Studio
|
||||
movieReader := r.Movie
|
||||
galleryReader := r.Gallery
|
||||
performerReader := r.Performer
|
||||
tagReader := r.Tag
|
||||
sceneMarkerReader := r.SceneMarker
|
||||
|
||||
for s := range jobChan {
|
||||
sceneHash := s.GetHash(t.fileNamingAlgorithm)
|
||||
|
|
@ -500,7 +502,7 @@ func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models
|
|||
|
||||
// export files
|
||||
for _, f := range s.Files.List() {
|
||||
exportFile(f, t)
|
||||
t.exportFile(f)
|
||||
}
|
||||
|
||||
newSceneJSON.Studio, err = scene.GetStudioName(ctx, studioReader, s)
|
||||
|
|
@ -589,10 +591,11 @@ func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models
|
|||
}
|
||||
}
|
||||
|
||||
func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo Repository) {
|
||||
func (t *ExportTask) ExportImages(ctx context.Context, workers int) {
|
||||
var imagesWg sync.WaitGroup
|
||||
|
||||
imageReader := repo.Image
|
||||
r := t.repository
|
||||
imageReader := r.Image
|
||||
|
||||
var images []*models.Image
|
||||
var err error
|
||||
|
|
@ -614,7 +617,7 @@ func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo Reposit
|
|||
|
||||
for w := 0; w < workers; w++ { // create export Image workers
|
||||
imagesWg.Add(1)
|
||||
go exportImage(ctx, &imagesWg, jobCh, repo, t)
|
||||
go t.exportImage(ctx, &imagesWg, jobCh)
|
||||
}
|
||||
|
||||
for i, image := range images {
|
||||
|
|
@ -632,22 +635,24 @@ func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo Reposit
|
|||
logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers)
|
||||
}
|
||||
|
||||
func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Image, repo Repository, t *ExportTask) {
|
||||
func (t *ExportTask) exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Image) {
|
||||
defer wg.Done()
|
||||
studioReader := repo.Studio
|
||||
galleryReader := repo.Gallery
|
||||
performerReader := repo.Performer
|
||||
tagReader := repo.Tag
|
||||
|
||||
r := t.repository
|
||||
studioReader := r.Studio
|
||||
galleryReader := r.Gallery
|
||||
performerReader := r.Performer
|
||||
tagReader := r.Tag
|
||||
|
||||
for s := range jobChan {
|
||||
imageHash := s.Checksum
|
||||
|
||||
if err := s.LoadFiles(ctx, repo.Image); err != nil {
|
||||
if err := s.LoadFiles(ctx, r.Image); err != nil {
|
||||
logger.Errorf("[images] <%s> error getting image files: %s", imageHash, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.LoadURLs(ctx, repo.Image); err != nil {
|
||||
if err := s.LoadURLs(ctx, r.Image); err != nil {
|
||||
logger.Errorf("[images] <%s> error getting image urls: %s", imageHash, err.Error())
|
||||
continue
|
||||
}
|
||||
|
|
@ -656,7 +661,7 @@ func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models
|
|||
|
||||
// export files
|
||||
for _, f := range s.Files.List() {
|
||||
exportFile(f, t)
|
||||
t.exportFile(f)
|
||||
}
|
||||
|
||||
var err error
|
||||
|
|
@ -715,10 +720,10 @@ func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models
|
|||
}
|
||||
}
|
||||
|
||||
func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo Repository) {
|
||||
func (t *ExportTask) ExportGalleries(ctx context.Context, workers int) {
|
||||
var galleriesWg sync.WaitGroup
|
||||
|
||||
reader := repo.Gallery
|
||||
reader := t.repository.Gallery
|
||||
|
||||
var galleries []*models.Gallery
|
||||
var err error
|
||||
|
|
@ -740,7 +745,7 @@ func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo Repo
|
|||
|
||||
for w := 0; w < workers; w++ { // create export Scene workers
|
||||
galleriesWg.Add(1)
|
||||
go exportGallery(ctx, &galleriesWg, jobCh, repo, t)
|
||||
go t.exportGallery(ctx, &galleriesWg, jobCh)
|
||||
}
|
||||
|
||||
for i, gallery := range galleries {
|
||||
|
|
@ -759,15 +764,17 @@ func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo Repo
|
|||
logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers)
|
||||
}
|
||||
|
||||
func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo Repository, t *ExportTask) {
|
||||
func (t *ExportTask) exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Gallery) {
|
||||
defer wg.Done()
|
||||
studioReader := repo.Studio
|
||||
performerReader := repo.Performer
|
||||
tagReader := repo.Tag
|
||||
galleryChapterReader := repo.GalleryChapter
|
||||
|
||||
r := t.repository
|
||||
studioReader := r.Studio
|
||||
performerReader := r.Performer
|
||||
tagReader := r.Tag
|
||||
galleryChapterReader := r.GalleryChapter
|
||||
|
||||
for g := range jobChan {
|
||||
if err := g.LoadFiles(ctx, repo.Gallery); err != nil {
|
||||
if err := g.LoadFiles(ctx, r.Gallery); err != nil {
|
||||
logger.Errorf("[galleries] <%s> failed to fetch files for gallery: %s", g.DisplayName(), err.Error())
|
||||
continue
|
||||
}
|
||||
|
|
@ -782,12 +789,12 @@ func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *mode
|
|||
|
||||
// export files
|
||||
for _, f := range g.Files.List() {
|
||||
exportFile(f, t)
|
||||
t.exportFile(f)
|
||||
}
|
||||
|
||||
// export folder if necessary
|
||||
if g.FolderID != nil {
|
||||
folder, err := repo.Folder.Find(ctx, *g.FolderID)
|
||||
folder, err := r.Folder.Find(ctx, *g.FolderID)
|
||||
if err != nil {
|
||||
logger.Errorf("[galleries] <%s> error getting gallery folder: %v", galleryHash, err)
|
||||
continue
|
||||
|
|
@ -798,7 +805,7 @@ func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *mode
|
|||
continue
|
||||
}
|
||||
|
||||
exportFolder(*folder, t)
|
||||
t.exportFolder(*folder)
|
||||
}
|
||||
|
||||
newGalleryJSON.Studio, err = gallery.GetStudioName(ctx, studioReader, g)
|
||||
|
|
@ -857,10 +864,10 @@ func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *mode
|
|||
}
|
||||
}
|
||||
|
||||
func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo Repository) {
|
||||
func (t *ExportTask) ExportPerformers(ctx context.Context, workers int) {
|
||||
var performersWg sync.WaitGroup
|
||||
|
||||
reader := repo.Performer
|
||||
reader := t.repository.Performer
|
||||
var performers []*models.Performer
|
||||
var err error
|
||||
all := t.full || (t.performers != nil && t.performers.all)
|
||||
|
|
@ -880,7 +887,7 @@ func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo Rep
|
|||
|
||||
for w := 0; w < workers; w++ { // create export Performer workers
|
||||
performersWg.Add(1)
|
||||
go t.exportPerformer(ctx, &performersWg, jobCh, repo)
|
||||
go t.exportPerformer(ctx, &performersWg, jobCh)
|
||||
}
|
||||
|
||||
for i, performer := range performers {
|
||||
|
|
@ -896,10 +903,11 @@ func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo Rep
|
|||
logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers)
|
||||
}
|
||||
|
||||
func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo Repository) {
|
||||
func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Performer) {
|
||||
defer wg.Done()
|
||||
|
||||
performerReader := repo.Performer
|
||||
r := t.repository
|
||||
performerReader := r.Performer
|
||||
|
||||
for p := range jobChan {
|
||||
newPerformerJSON, err := performer.ToJSON(ctx, performerReader, p)
|
||||
|
|
@ -909,7 +917,7 @@ func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jo
|
|||
continue
|
||||
}
|
||||
|
||||
tags, err := repo.Tag.FindByPerformerID(ctx, p.ID)
|
||||
tags, err := r.Tag.FindByPerformerID(ctx, p.ID)
|
||||
if err != nil {
|
||||
logger.Errorf("[performers] <%s> error getting performer tags: %s", p.Name, err.Error())
|
||||
continue
|
||||
|
|
@ -929,10 +937,10 @@ func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jo
|
|||
}
|
||||
}
|
||||
|
||||
func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo Repository) {
|
||||
func (t *ExportTask) ExportStudios(ctx context.Context, workers int) {
|
||||
var studiosWg sync.WaitGroup
|
||||
|
||||
reader := repo.Studio
|
||||
reader := t.repository.Studio
|
||||
var studios []*models.Studio
|
||||
var err error
|
||||
all := t.full || (t.studios != nil && t.studios.all)
|
||||
|
|
@ -953,7 +961,7 @@ func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo Reposi
|
|||
|
||||
for w := 0; w < workers; w++ { // create export Studio workers
|
||||
studiosWg.Add(1)
|
||||
go t.exportStudio(ctx, &studiosWg, jobCh, repo)
|
||||
go t.exportStudio(ctx, &studiosWg, jobCh)
|
||||
}
|
||||
|
||||
for i, studio := range studios {
|
||||
|
|
@ -969,10 +977,10 @@ func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo Reposi
|
|||
logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers)
|
||||
}
|
||||
|
||||
func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo Repository) {
|
||||
func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Studio) {
|
||||
defer wg.Done()
|
||||
|
||||
studioReader := repo.Studio
|
||||
studioReader := t.repository.Studio
|
||||
|
||||
for s := range jobChan {
|
||||
newStudioJSON, err := studio.ToJSON(ctx, studioReader, s)
|
||||
|
|
@ -990,10 +998,10 @@ func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobCh
|
|||
}
|
||||
}
|
||||
|
||||
func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo Repository) {
|
||||
func (t *ExportTask) ExportTags(ctx context.Context, workers int) {
|
||||
var tagsWg sync.WaitGroup
|
||||
|
||||
reader := repo.Tag
|
||||
reader := t.repository.Tag
|
||||
var tags []*models.Tag
|
||||
var err error
|
||||
all := t.full || (t.tags != nil && t.tags.all)
|
||||
|
|
@ -1014,7 +1022,7 @@ func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo Repositor
|
|||
|
||||
for w := 0; w < workers; w++ { // create export Tag workers
|
||||
tagsWg.Add(1)
|
||||
go t.exportTag(ctx, &tagsWg, jobCh, repo)
|
||||
go t.exportTag(ctx, &tagsWg, jobCh)
|
||||
}
|
||||
|
||||
for i, tag := range tags {
|
||||
|
|
@ -1030,10 +1038,10 @@ func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo Repositor
|
|||
logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers)
|
||||
}
|
||||
|
||||
func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo Repository) {
|
||||
func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Tag) {
|
||||
defer wg.Done()
|
||||
|
||||
tagReader := repo.Tag
|
||||
tagReader := t.repository.Tag
|
||||
|
||||
for thisTag := range jobChan {
|
||||
newTagJSON, err := tag.ToJSON(ctx, tagReader, thisTag)
|
||||
|
|
@ -1051,10 +1059,10 @@ func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan
|
|||
}
|
||||
}
|
||||
|
||||
func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo Repository) {
|
||||
func (t *ExportTask) ExportMovies(ctx context.Context, workers int) {
|
||||
var moviesWg sync.WaitGroup
|
||||
|
||||
reader := repo.Movie
|
||||
reader := t.repository.Movie
|
||||
var movies []*models.Movie
|
||||
var err error
|
||||
all := t.full || (t.movies != nil && t.movies.all)
|
||||
|
|
@ -1075,7 +1083,7 @@ func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo Reposit
|
|||
|
||||
for w := 0; w < workers; w++ { // create export Studio workers
|
||||
moviesWg.Add(1)
|
||||
go t.exportMovie(ctx, &moviesWg, jobCh, repo)
|
||||
go t.exportMovie(ctx, &moviesWg, jobCh)
|
||||
}
|
||||
|
||||
for i, movie := range movies {
|
||||
|
|
@ -1091,11 +1099,12 @@ func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo Reposit
|
|||
logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers)
|
||||
|
||||
}
|
||||
func (t *ExportTask) exportMovie(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo Repository) {
|
||||
func (t *ExportTask) exportMovie(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Movie) {
|
||||
defer wg.Done()
|
||||
|
||||
movieReader := repo.Movie
|
||||
studioReader := repo.Studio
|
||||
r := t.repository
|
||||
movieReader := r.Movie
|
||||
studioReader := r.Studio
|
||||
|
||||
for m := range jobChan {
|
||||
newMovieJSON, err := movie.ToJSON(ctx, movieReader, studioReader, m)
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ type GeneratePreviewOptionsInput struct {
|
|||
const generateQueueSize = 200000
|
||||
|
||||
type GenerateJob struct {
|
||||
txnManager Repository
|
||||
repository models.Repository
|
||||
input GenerateMetadataInput
|
||||
|
||||
overwrite bool
|
||||
|
|
@ -112,8 +112,9 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) {
|
|||
Overwrite: j.overwrite,
|
||||
}
|
||||
|
||||
if err := j.txnManager.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
qb := j.txnManager.Scene
|
||||
r := j.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Scene
|
||||
if len(j.input.SceneIDs) == 0 && len(j.input.MarkerIDs) == 0 {
|
||||
totals = j.queueTasks(ctx, g, queue)
|
||||
} else {
|
||||
|
|
@ -129,7 +130,7 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) {
|
|||
}
|
||||
|
||||
if len(j.input.MarkerIDs) > 0 {
|
||||
markers, err = j.txnManager.SceneMarker.FindMany(ctx, markerIDs)
|
||||
markers, err = r.SceneMarker.FindMany(ctx, markerIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -229,12 +230,14 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
|
|||
|
||||
findFilter := models.BatchFindFilter(batchSize)
|
||||
|
||||
r := j.repository
|
||||
|
||||
for more := true; more; {
|
||||
if job.IsCancelled(ctx) {
|
||||
return totals
|
||||
}
|
||||
|
||||
scenes, err := scene.Query(ctx, j.txnManager.Scene, nil, findFilter)
|
||||
scenes, err := scene.Query(ctx, r.Scene, nil, findFilter)
|
||||
if err != nil {
|
||||
logger.Errorf("Error encountered queuing files to scan: %s", err.Error())
|
||||
return totals
|
||||
|
|
@ -245,7 +248,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
|
|||
return totals
|
||||
}
|
||||
|
||||
if err := ss.LoadFiles(ctx, j.txnManager.Scene); err != nil {
|
||||
if err := ss.LoadFiles(ctx, r.Scene); err != nil {
|
||||
logger.Errorf("Error encountered queuing files to scan: %s", err.Error())
|
||||
return totals
|
||||
}
|
||||
|
|
@ -266,7 +269,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
|
|||
return totals
|
||||
}
|
||||
|
||||
images, err := image.Query(ctx, j.txnManager.Image, nil, findFilter)
|
||||
images, err := image.Query(ctx, r.Image, nil, findFilter)
|
||||
if err != nil {
|
||||
logger.Errorf("Error encountered queuing files to scan: %s", err.Error())
|
||||
return totals
|
||||
|
|
@ -277,7 +280,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
|
|||
return totals
|
||||
}
|
||||
|
||||
if err := ss.LoadFiles(ctx, j.txnManager.Image); err != nil {
|
||||
if err := ss.LoadFiles(ctx, r.Image); err != nil {
|
||||
logger.Errorf("Error encountered queuing files to scan: %s", err.Error())
|
||||
return totals
|
||||
}
|
||||
|
|
@ -331,9 +334,11 @@ func getGeneratePreviewOptions(optionsInput GeneratePreviewOptionsInput) generat
|
|||
}
|
||||
|
||||
func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator, scene *models.Scene, queue chan<- Task, totals *totalsGenerate) {
|
||||
r := j.repository
|
||||
|
||||
if j.input.Covers {
|
||||
task := &GenerateCoverTask{
|
||||
txnManager: j.txnManager,
|
||||
repository: r,
|
||||
Scene: *scene,
|
||||
Overwrite: j.overwrite,
|
||||
}
|
||||
|
|
@ -390,7 +395,7 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
|
|||
|
||||
if j.input.Markers {
|
||||
task := &GenerateMarkersTask{
|
||||
TxnManager: j.txnManager,
|
||||
repository: r,
|
||||
Scene: scene,
|
||||
Overwrite: j.overwrite,
|
||||
fileNamingAlgorithm: j.fileNamingAlgo,
|
||||
|
|
@ -429,10 +434,9 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
|
|||
// generate for all files in scene
|
||||
for _, f := range scene.Files.List() {
|
||||
task := &GeneratePhashTask{
|
||||
repository: r,
|
||||
File: f,
|
||||
fileNamingAlgorithm: j.fileNamingAlgo,
|
||||
txnManager: j.txnManager,
|
||||
fileUpdater: j.txnManager.File,
|
||||
Overwrite: j.overwrite,
|
||||
}
|
||||
|
||||
|
|
@ -446,10 +450,10 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
|
|||
|
||||
if j.input.InteractiveHeatmapsSpeeds {
|
||||
task := &GenerateInteractiveHeatmapSpeedTask{
|
||||
repository: r,
|
||||
Scene: *scene,
|
||||
Overwrite: j.overwrite,
|
||||
fileNamingAlgorithm: j.fileNamingAlgo,
|
||||
TxnManager: j.txnManager,
|
||||
}
|
||||
|
||||
if task.required() {
|
||||
|
|
@ -462,7 +466,7 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
|
|||
|
||||
func (j *GenerateJob) queueMarkerJob(g *generate.Generator, marker *models.SceneMarker, queue chan<- Task, totals *totalsGenerate) {
|
||||
task := &GenerateMarkersTask{
|
||||
TxnManager: j.txnManager,
|
||||
repository: j.repository,
|
||||
Marker: marker,
|
||||
Overwrite: j.overwrite,
|
||||
fileNamingAlgorithm: j.fileNamingAlgo,
|
||||
|
|
|
|||
|
|
@ -11,10 +11,10 @@ import (
|
|||
)
|
||||
|
||||
type GenerateInteractiveHeatmapSpeedTask struct {
|
||||
repository models.Repository
|
||||
Scene models.Scene
|
||||
Overwrite bool
|
||||
fileNamingAlgorithm models.HashAlgorithm
|
||||
TxnManager Repository
|
||||
}
|
||||
|
||||
func (t *GenerateInteractiveHeatmapSpeedTask) GetDescription() string {
|
||||
|
|
@ -42,10 +42,11 @@ func (t *GenerateInteractiveHeatmapSpeedTask) Start(ctx context.Context) {
|
|||
|
||||
median := generator.InteractiveSpeed
|
||||
|
||||
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
r := t.repository
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
primaryFile := t.Scene.Files.Primary()
|
||||
primaryFile.InteractiveSpeed = &median
|
||||
qb := t.TxnManager.File
|
||||
qb := r.File
|
||||
return qb.Update(ctx, primaryFile)
|
||||
}); err != nil && ctx.Err() == nil {
|
||||
logger.Error(err.Error())
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import (
|
|||
)
|
||||
|
||||
type GenerateMarkersTask struct {
|
||||
TxnManager Repository
|
||||
repository models.Repository
|
||||
Scene *models.Scene
|
||||
Marker *models.SceneMarker
|
||||
Overwrite bool
|
||||
|
|
@ -41,9 +41,10 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
|
|||
|
||||
if t.Marker != nil {
|
||||
var scene *models.Scene
|
||||
if err := t.TxnManager.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
r := t.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
scene, err = t.TxnManager.Scene.Find(ctx, t.Marker.SceneID)
|
||||
scene, err = r.Scene.Find(ctx, t.Marker.SceneID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -51,7 +52,7 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
|
|||
return fmt.Errorf("scene with id %d not found", t.Marker.SceneID)
|
||||
}
|
||||
|
||||
return scene.LoadPrimaryFile(ctx, t.TxnManager.File)
|
||||
return scene.LoadPrimaryFile(ctx, r.File)
|
||||
}); err != nil {
|
||||
logger.Errorf("error finding scene for marker generation: %v", err)
|
||||
return
|
||||
|
|
@ -70,9 +71,10 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
|
|||
|
||||
func (t *GenerateMarkersTask) generateSceneMarkers(ctx context.Context) {
|
||||
var sceneMarkers []*models.SceneMarker
|
||||
if err := t.TxnManager.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
r := t.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
sceneMarkers, err = t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
|
||||
sceneMarkers, err = r.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
|
||||
return err
|
||||
}); err != nil {
|
||||
logger.Errorf("error getting scene markers: %s", err.Error())
|
||||
|
|
@ -129,7 +131,7 @@ func (t *GenerateMarkersTask) generateMarker(videoFile *models.VideoFile, scene
|
|||
|
||||
func (t *GenerateMarkersTask) markersNeeded(ctx context.Context) int {
|
||||
markers := 0
|
||||
sceneMarkers, err := t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
|
||||
sceneMarkers, err := t.repository.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
|
||||
if err != nil {
|
||||
logger.Errorf("error finding scene markers: %s", err.Error())
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -7,15 +7,13 @@ import (
|
|||
"github.com/stashapp/stash/pkg/hash/videophash"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
type GeneratePhashTask struct {
|
||||
repository models.Repository
|
||||
File *models.VideoFile
|
||||
Overwrite bool
|
||||
fileNamingAlgorithm models.HashAlgorithm
|
||||
txnManager txn.Manager
|
||||
fileUpdater models.FileUpdater
|
||||
}
|
||||
|
||||
func (t *GeneratePhashTask) GetDescription() string {
|
||||
|
|
@ -34,15 +32,15 @@ func (t *GeneratePhashTask) Start(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := txn.WithTxn(ctx, t.txnManager, func(ctx context.Context) error {
|
||||
qb := t.fileUpdater
|
||||
r := t.repository
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
hashValue := int64(*hash)
|
||||
t.File.Fingerprints = t.File.Fingerprints.AppendUnique(models.Fingerprint{
|
||||
Type: models.FingerprintTypePhash,
|
||||
Fingerprint: hashValue,
|
||||
})
|
||||
|
||||
return qb.Update(ctx, t.File)
|
||||
return r.File.Update(ctx, t.File)
|
||||
}); err != nil && ctx.Err() == nil {
|
||||
logger.Errorf("Error setting phash: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ import (
|
|||
)
|
||||
|
||||
type GenerateCoverTask struct {
|
||||
repository models.Repository
|
||||
Scene models.Scene
|
||||
ScreenshotAt *float64
|
||||
txnManager Repository
|
||||
Overwrite bool
|
||||
}
|
||||
|
||||
|
|
@ -23,11 +23,13 @@ func (t *GenerateCoverTask) GetDescription() string {
|
|||
func (t *GenerateCoverTask) Start(ctx context.Context) {
|
||||
scenePath := t.Scene.Path
|
||||
|
||||
r := t.repository
|
||||
|
||||
var required bool
|
||||
if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
required = t.required(ctx)
|
||||
|
||||
return t.Scene.LoadPrimaryFile(ctx, t.txnManager.File)
|
||||
return t.Scene.LoadPrimaryFile(ctx, r.File)
|
||||
}); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
|
@ -70,8 +72,8 @@ func (t *GenerateCoverTask) Start(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
qb := t.txnManager.Scene
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Scene
|
||||
scenePartial := models.NewScenePartial()
|
||||
|
||||
// update the scene cover table
|
||||
|
|
@ -103,7 +105,7 @@ func (t *GenerateCoverTask) required(ctx context.Context) bool {
|
|||
}
|
||||
|
||||
// if the scene has a cover, then we don't need to generate it
|
||||
hasCover, err := t.txnManager.Scene.HasCover(ctx, t.Scene.ID)
|
||||
hasCover, err := t.repository.Scene.HasCover(ctx, t.Scene.ID)
|
||||
if err != nil {
|
||||
logger.Errorf("Error getting cover: %v", err)
|
||||
return false
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import (
|
|||
"github.com/stashapp/stash/pkg/scraper"
|
||||
"github.com/stashapp/stash/pkg/scraper/stashbox"
|
||||
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
var ErrInput = errors.New("invalid request input")
|
||||
|
|
@ -52,7 +51,8 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
|
|||
// if scene ids provided, use those
|
||||
// otherwise, batch query for all scenes - ordering by path
|
||||
// don't use a transaction to query scenes
|
||||
if err := txn.WithDatabase(ctx, instance.Repository, func(ctx context.Context) error {
|
||||
r := instance.Repository
|
||||
if err := r.WithDB(ctx, func(ctx context.Context) error {
|
||||
if len(j.input.SceneIDs) == 0 {
|
||||
return j.identifyAllScenes(ctx, sources)
|
||||
}
|
||||
|
|
@ -70,7 +70,7 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
|
|||
|
||||
// find the scene
|
||||
var err error
|
||||
scene, err := instance.Repository.Scene.Find(ctx, id)
|
||||
scene, err := r.Scene.Find(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding scene id %d: %w", id, err)
|
||||
}
|
||||
|
|
@ -89,6 +89,8 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
|
|||
}
|
||||
|
||||
func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.ScraperSource) error {
|
||||
r := instance.Repository
|
||||
|
||||
// exclude organised
|
||||
organised := false
|
||||
sceneFilter := scene.FilterFromPaths(j.input.Paths)
|
||||
|
|
@ -102,7 +104,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.
|
|||
// get the count
|
||||
pp := 0
|
||||
findFilter.PerPage = &pp
|
||||
countResult, err := instance.Repository.Scene.Query(ctx, models.SceneQueryOptions{
|
||||
countResult, err := r.Scene.Query(ctx, models.SceneQueryOptions{
|
||||
QueryOptions: models.QueryOptions{
|
||||
FindFilter: findFilter,
|
||||
Count: true,
|
||||
|
|
@ -115,7 +117,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.
|
|||
|
||||
j.progress.SetTotal(countResult.Count)
|
||||
|
||||
return scene.BatchProcess(ctx, instance.Repository.Scene, sceneFilter, findFilter, func(scene *models.Scene) error {
|
||||
return scene.BatchProcess(ctx, r.Scene, sceneFilter, findFilter, func(scene *models.Scene) error {
|
||||
if job.IsCancelled(ctx) {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -132,18 +134,20 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source
|
|||
|
||||
var taskError error
|
||||
j.progress.ExecuteTask("Identifying "+s.Path, func() {
|
||||
r := instance.Repository
|
||||
task := identify.SceneIdentifier{
|
||||
SceneReaderUpdater: instance.Repository.Scene,
|
||||
StudioReaderWriter: instance.Repository.Studio,
|
||||
PerformerCreator: instance.Repository.Performer,
|
||||
TagFinderCreator: instance.Repository.Tag,
|
||||
TxnManager: r.TxnManager,
|
||||
SceneReaderUpdater: r.Scene,
|
||||
StudioReaderWriter: r.Studio,
|
||||
PerformerCreator: r.Performer,
|
||||
TagFinderCreator: r.Tag,
|
||||
|
||||
DefaultOptions: j.input.Options,
|
||||
Sources: sources,
|
||||
SceneUpdatePostHookExecutor: j.postHookExecutor,
|
||||
}
|
||||
|
||||
taskError = task.Identify(ctx, instance.Repository, s)
|
||||
taskError = task.Identify(ctx, s)
|
||||
})
|
||||
|
||||
if taskError != nil {
|
||||
|
|
@ -164,15 +168,11 @@ func (j *IdentifyJob) getSources() ([]identify.ScraperSource, error) {
|
|||
|
||||
var src identify.ScraperSource
|
||||
if stashBox != nil {
|
||||
stashboxRepository := stashbox.NewRepository(instance.Repository)
|
||||
src = identify.ScraperSource{
|
||||
Name: "stash-box: " + stashBox.Endpoint,
|
||||
Scraper: stashboxSource{
|
||||
stashbox.NewClient(*stashBox, instance.Repository, stashbox.Repository{
|
||||
Scene: instance.Repository.Scene,
|
||||
Performer: instance.Repository.Performer,
|
||||
Tag: instance.Repository.Tag,
|
||||
Studio: instance.Repository.Studio,
|
||||
}),
|
||||
stashbox.NewClient(*stashBox, stashboxRepository),
|
||||
stashBox.Endpoint,
|
||||
},
|
||||
RemoteSite: stashBox.Endpoint,
|
||||
|
|
|
|||
|
|
@ -25,8 +25,13 @@ import (
|
|||
"github.com/stashapp/stash/pkg/tag"
|
||||
)
|
||||
|
||||
type Resetter interface {
|
||||
Reset() error
|
||||
}
|
||||
|
||||
type ImportTask struct {
|
||||
txnManager Repository
|
||||
repository models.Repository
|
||||
resetter Resetter
|
||||
json jsonUtils
|
||||
|
||||
BaseDir string
|
||||
|
|
@ -66,8 +71,10 @@ func CreateImportTask(a models.HashAlgorithm, input ImportObjectsInput) (*Import
|
|||
}
|
||||
}
|
||||
|
||||
mgr := GetInstance()
|
||||
return &ImportTask{
|
||||
txnManager: GetInstance().Repository,
|
||||
repository: mgr.Repository,
|
||||
resetter: mgr.Database,
|
||||
BaseDir: baseDir,
|
||||
TmpZip: tmpZip,
|
||||
Reset: false,
|
||||
|
|
@ -109,7 +116,7 @@ func (t *ImportTask) Start(ctx context.Context) {
|
|||
}
|
||||
|
||||
if t.Reset {
|
||||
err := t.txnManager.Reset()
|
||||
err := t.resetter.Reset()
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("Error resetting database: %s", err.Error())
|
||||
|
|
@ -194,6 +201,8 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
r := t.repository
|
||||
|
||||
for i, fi := range files {
|
||||
index := i + 1
|
||||
performerJSON, err := jsonschema.LoadPerformerFile(filepath.Join(path, fi.Name()))
|
||||
|
|
@ -204,11 +213,9 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) {
|
|||
|
||||
logger.Progressf("[performers] %d of %d", index, len(files))
|
||||
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
r := t.txnManager
|
||||
readerWriter := r.Performer
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
importer := &performer.Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: r.Performer,
|
||||
TagWriter: r.Tag,
|
||||
Input: *performerJSON,
|
||||
}
|
||||
|
|
@ -237,6 +244,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
r := t.repository
|
||||
|
||||
for i, fi := range files {
|
||||
index := i + 1
|
||||
studioJSON, err := jsonschema.LoadStudioFile(filepath.Join(path, fi.Name()))
|
||||
|
|
@ -247,8 +256,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
|
|||
|
||||
logger.Progressf("[studios] %d of %d", index, len(files))
|
||||
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
return t.ImportStudio(ctx, studioJSON, pendingParent, t.txnManager.Studio)
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
return t.importStudio(ctx, studioJSON, pendingParent)
|
||||
}); err != nil {
|
||||
if errors.Is(err, studio.ErrParentStudioNotExist) {
|
||||
// add to the pending parent list so that it is created after the parent
|
||||
|
|
@ -269,8 +278,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
|
|||
|
||||
for _, s := range pendingParent {
|
||||
for _, orphanStudioJSON := range s {
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
return t.ImportStudio(ctx, orphanStudioJSON, nil, t.txnManager.Studio)
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
return t.importStudio(ctx, orphanStudioJSON, nil)
|
||||
}); err != nil {
|
||||
logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error())
|
||||
continue
|
||||
|
|
@ -282,9 +291,9 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
|
|||
logger.Info("[studios] import complete")
|
||||
}
|
||||
|
||||
func (t *ImportTask) ImportStudio(ctx context.Context, studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter studio.ImporterReaderWriter) error {
|
||||
func (t *ImportTask) importStudio(ctx context.Context, studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio) error {
|
||||
importer := &studio.Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: t.repository.Studio,
|
||||
Input: *studioJSON,
|
||||
MissingRefBehaviour: t.MissingRefBehaviour,
|
||||
}
|
||||
|
|
@ -302,7 +311,7 @@ func (t *ImportTask) ImportStudio(ctx context.Context, studioJSON *jsonschema.St
|
|||
s := pendingParent[studioJSON.Name]
|
||||
for _, childStudioJSON := range s {
|
||||
// map is nil since we're not checking parent studios at this point
|
||||
if err := t.ImportStudio(ctx, childStudioJSON, nil, readerWriter); err != nil {
|
||||
if err := t.importStudio(ctx, childStudioJSON, nil); err != nil {
|
||||
return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error())
|
||||
}
|
||||
}
|
||||
|
|
@ -326,6 +335,8 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
r := t.repository
|
||||
|
||||
for i, fi := range files {
|
||||
index := i + 1
|
||||
movieJSON, err := jsonschema.LoadMovieFile(filepath.Join(path, fi.Name()))
|
||||
|
|
@ -336,14 +347,10 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
|
|||
|
||||
logger.Progressf("[movies] %d of %d", index, len(files))
|
||||
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
r := t.txnManager
|
||||
readerWriter := r.Movie
|
||||
studioReaderWriter := r.Studio
|
||||
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
movieImporter := &movie.Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
StudioWriter: studioReaderWriter,
|
||||
ReaderWriter: r.Movie,
|
||||
StudioWriter: r.Studio,
|
||||
Input: *movieJSON,
|
||||
MissingRefBehaviour: t.MissingRefBehaviour,
|
||||
}
|
||||
|
|
@ -371,6 +378,8 @@ func (t *ImportTask) ImportFiles(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
r := t.repository
|
||||
|
||||
pendingParent := make(map[string][]jsonschema.DirEntry)
|
||||
|
||||
for i, fi := range files {
|
||||
|
|
@ -383,8 +392,8 @@ func (t *ImportTask) ImportFiles(ctx context.Context) {
|
|||
|
||||
logger.Progressf("[files] %d of %d", index, len(files))
|
||||
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
return t.ImportFile(ctx, fileJSON, pendingParent)
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
return t.importFile(ctx, fileJSON, pendingParent)
|
||||
}); err != nil {
|
||||
if errors.Is(err, file.ErrZipFileNotExist) {
|
||||
// add to the pending parent list so that it is created after the parent
|
||||
|
|
@ -405,8 +414,8 @@ func (t *ImportTask) ImportFiles(ctx context.Context) {
|
|||
|
||||
for _, s := range pendingParent {
|
||||
for _, orphanFileJSON := range s {
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
return t.ImportFile(ctx, orphanFileJSON, nil)
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
return t.importFile(ctx, orphanFileJSON, nil)
|
||||
}); err != nil {
|
||||
logger.Errorf("[files] <%s> failed to create: %s", orphanFileJSON.DirEntry().Path, err.Error())
|
||||
continue
|
||||
|
|
@ -418,12 +427,11 @@ func (t *ImportTask) ImportFiles(ctx context.Context) {
|
|||
logger.Info("[files] import complete")
|
||||
}
|
||||
|
||||
func (t *ImportTask) ImportFile(ctx context.Context, fileJSON jsonschema.DirEntry, pendingParent map[string][]jsonschema.DirEntry) error {
|
||||
r := t.txnManager
|
||||
readerWriter := r.File
|
||||
func (t *ImportTask) importFile(ctx context.Context, fileJSON jsonschema.DirEntry, pendingParent map[string][]jsonschema.DirEntry) error {
|
||||
r := t.repository
|
||||
|
||||
fileImporter := &file.Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: r.File,
|
||||
FolderStore: r.Folder,
|
||||
Input: fileJSON,
|
||||
}
|
||||
|
|
@ -437,7 +445,7 @@ func (t *ImportTask) ImportFile(ctx context.Context, fileJSON jsonschema.DirEntr
|
|||
s := pendingParent[fileJSON.DirEntry().Path]
|
||||
for _, childFileJSON := range s {
|
||||
// map is nil since we're not checking parent studios at this point
|
||||
if err := t.ImportFile(ctx, childFileJSON, nil); err != nil {
|
||||
if err := t.importFile(ctx, childFileJSON, nil); err != nil {
|
||||
return fmt.Errorf("failed to create child file <%s>: %s", childFileJSON.DirEntry().Path, err.Error())
|
||||
}
|
||||
}
|
||||
|
|
@ -461,6 +469,8 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
r := t.repository
|
||||
|
||||
for i, fi := range files {
|
||||
index := i + 1
|
||||
galleryJSON, err := jsonschema.LoadGalleryFile(filepath.Join(path, fi.Name()))
|
||||
|
|
@ -471,21 +481,14 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
|
|||
|
||||
logger.Progressf("[galleries] %d of %d", index, len(files))
|
||||
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
r := t.txnManager
|
||||
readerWriter := r.Gallery
|
||||
tagWriter := r.Tag
|
||||
performerWriter := r.Performer
|
||||
studioWriter := r.Studio
|
||||
chapterWriter := r.GalleryChapter
|
||||
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
galleryImporter := &gallery.Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: r.Gallery,
|
||||
FolderFinder: r.Folder,
|
||||
FileFinder: r.File,
|
||||
PerformerWriter: performerWriter,
|
||||
StudioWriter: studioWriter,
|
||||
TagWriter: tagWriter,
|
||||
PerformerWriter: r.Performer,
|
||||
StudioWriter: r.Studio,
|
||||
TagWriter: r.Tag,
|
||||
Input: *galleryJSON,
|
||||
MissingRefBehaviour: t.MissingRefBehaviour,
|
||||
}
|
||||
|
|
@ -500,7 +503,7 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
|
|||
GalleryID: galleryImporter.ID,
|
||||
Input: m,
|
||||
MissingRefBehaviour: t.MissingRefBehaviour,
|
||||
ReaderWriter: chapterWriter,
|
||||
ReaderWriter: r.GalleryChapter,
|
||||
}
|
||||
|
||||
if err := performImport(ctx, chapterImporter, t.DuplicateBehaviour); err != nil {
|
||||
|
|
@ -532,6 +535,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
r := t.repository
|
||||
|
||||
for i, fi := range files {
|
||||
index := i + 1
|
||||
tagJSON, err := jsonschema.LoadTagFile(filepath.Join(path, fi.Name()))
|
||||
|
|
@ -542,8 +547,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
|
|||
|
||||
logger.Progressf("[tags] %d of %d", index, len(files))
|
||||
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
return t.ImportTag(ctx, tagJSON, pendingParent, false, t.txnManager.Tag)
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
return t.importTag(ctx, tagJSON, pendingParent, false)
|
||||
}); err != nil {
|
||||
var parentError tag.ParentTagNotExistError
|
||||
if errors.As(err, &parentError) {
|
||||
|
|
@ -558,8 +563,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
|
|||
|
||||
for _, s := range pendingParent {
|
||||
for _, orphanTagJSON := range s {
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
return t.ImportTag(ctx, orphanTagJSON, nil, true, t.txnManager.Tag)
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
return t.importTag(ctx, orphanTagJSON, nil, true)
|
||||
}); err != nil {
|
||||
logger.Errorf("[tags] <%s> failed to create: %s", orphanTagJSON.Name, err.Error())
|
||||
continue
|
||||
|
|
@ -570,9 +575,9 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
|
|||
logger.Info("[tags] import complete")
|
||||
}
|
||||
|
||||
func (t *ImportTask) ImportTag(ctx context.Context, tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter tag.ImporterReaderWriter) error {
|
||||
func (t *ImportTask) importTag(ctx context.Context, tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool) error {
|
||||
importer := &tag.Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: t.repository.Tag,
|
||||
Input: *tagJSON,
|
||||
MissingRefBehaviour: t.MissingRefBehaviour,
|
||||
}
|
||||
|
|
@ -587,7 +592,7 @@ func (t *ImportTask) ImportTag(ctx context.Context, tagJSON *jsonschema.Tag, pen
|
|||
}
|
||||
|
||||
for _, childTagJSON := range pendingParent[tagJSON.Name] {
|
||||
if err := t.ImportTag(ctx, childTagJSON, pendingParent, fail, readerWriter); err != nil {
|
||||
if err := t.importTag(ctx, childTagJSON, pendingParent, fail); err != nil {
|
||||
var parentError tag.ParentTagNotExistError
|
||||
if errors.As(err, &parentError) {
|
||||
pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], childTagJSON)
|
||||
|
|
@ -616,6 +621,8 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
r := t.repository
|
||||
|
||||
for i, fi := range files {
|
||||
index := i + 1
|
||||
|
||||
|
|
@ -627,29 +634,20 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
|
|||
continue
|
||||
}
|
||||
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
r := t.txnManager
|
||||
readerWriter := r.Scene
|
||||
tagWriter := r.Tag
|
||||
galleryWriter := r.Gallery
|
||||
movieWriter := r.Movie
|
||||
performerWriter := r.Performer
|
||||
studioWriter := r.Studio
|
||||
markerWriter := r.SceneMarker
|
||||
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
sceneImporter := &scene.Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: r.Scene,
|
||||
Input: *sceneJSON,
|
||||
FileFinder: r.File,
|
||||
|
||||
FileNamingAlgorithm: t.fileNamingAlgorithm,
|
||||
MissingRefBehaviour: t.MissingRefBehaviour,
|
||||
|
||||
GalleryFinder: galleryWriter,
|
||||
MovieWriter: movieWriter,
|
||||
PerformerWriter: performerWriter,
|
||||
StudioWriter: studioWriter,
|
||||
TagWriter: tagWriter,
|
||||
GalleryFinder: r.Gallery,
|
||||
MovieWriter: r.Movie,
|
||||
PerformerWriter: r.Performer,
|
||||
StudioWriter: r.Studio,
|
||||
TagWriter: r.Tag,
|
||||
}
|
||||
|
||||
if err := performImport(ctx, sceneImporter, t.DuplicateBehaviour); err != nil {
|
||||
|
|
@ -662,8 +660,8 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
|
|||
SceneID: sceneImporter.ID,
|
||||
Input: m,
|
||||
MissingRefBehaviour: t.MissingRefBehaviour,
|
||||
ReaderWriter: markerWriter,
|
||||
TagWriter: tagWriter,
|
||||
ReaderWriter: r.SceneMarker,
|
||||
TagWriter: r.Tag,
|
||||
}
|
||||
|
||||
if err := performImport(ctx, markerImporter, t.DuplicateBehaviour); err != nil {
|
||||
|
|
@ -693,6 +691,8 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
r := t.repository
|
||||
|
||||
for i, fi := range files {
|
||||
index := i + 1
|
||||
|
||||
|
|
@ -704,25 +704,18 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
|
|||
continue
|
||||
}
|
||||
|
||||
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
|
||||
r := t.txnManager
|
||||
readerWriter := r.Image
|
||||
tagWriter := r.Tag
|
||||
galleryWriter := r.Gallery
|
||||
performerWriter := r.Performer
|
||||
studioWriter := r.Studio
|
||||
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
imageImporter := &image.Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: r.Image,
|
||||
FileFinder: r.File,
|
||||
Input: *imageJSON,
|
||||
|
||||
MissingRefBehaviour: t.MissingRefBehaviour,
|
||||
|
||||
GalleryFinder: galleryWriter,
|
||||
PerformerWriter: performerWriter,
|
||||
StudioWriter: studioWriter,
|
||||
TagWriter: tagWriter,
|
||||
GalleryFinder: r.Gallery,
|
||||
PerformerWriter: r.Performer,
|
||||
StudioWriter: r.Studio,
|
||||
TagWriter: r.Tag,
|
||||
}
|
||||
|
||||
return performImport(ctx, imageImporter, t.DuplicateBehaviour)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/stashapp/stash/pkg/job"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/models/paths"
|
||||
"github.com/stashapp/stash/pkg/scene"
|
||||
"github.com/stashapp/stash/pkg/scene/generate"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
|
|
@ -48,10 +49,14 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) {
|
|||
paths[i] = p.Path
|
||||
}
|
||||
|
||||
mgr := GetInstance()
|
||||
c := mgr.Config
|
||||
repo := mgr.Repository
|
||||
|
||||
start := time.Now()
|
||||
|
||||
const taskQueueSize = 200000
|
||||
taskQueue := job.NewTaskQueue(ctx, progress, taskQueueSize, instance.Config.GetParallelTasksWithAutoDetection())
|
||||
taskQueue := job.NewTaskQueue(ctx, progress, taskQueueSize, c.GetParallelTasksWithAutoDetection())
|
||||
|
||||
var minModTime time.Time
|
||||
if j.input.Filter != nil && j.input.Filter.MinModTime != nil {
|
||||
|
|
@ -59,13 +64,11 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) {
|
|||
}
|
||||
|
||||
j.scanner.Scan(ctx, getScanHandlers(j.input, taskQueue, progress), file.ScanOptions{
|
||||
Paths: paths,
|
||||
ScanFilters: []file.PathFilter{newScanFilter(instance.Config, minModTime)},
|
||||
ZipFileExtensions: instance.Config.GetGalleryExtensions(),
|
||||
ParallelTasks: instance.Config.GetParallelTasksWithAutoDetection(),
|
||||
HandlerRequiredFilters: []file.Filter{
|
||||
newHandlerRequiredFilter(instance.Config),
|
||||
},
|
||||
Paths: paths,
|
||||
ScanFilters: []file.PathFilter{newScanFilter(c, repo, minModTime)},
|
||||
ZipFileExtensions: c.GetGalleryExtensions(),
|
||||
ParallelTasks: c.GetParallelTasksWithAutoDetection(),
|
||||
HandlerRequiredFilters: []file.Filter{newHandlerRequiredFilter(c, repo)},
|
||||
}, progress)
|
||||
|
||||
taskQueue.Close()
|
||||
|
|
@ -123,17 +126,16 @@ type handlerRequiredFilter struct {
|
|||
videoFileNamingAlgorithm models.HashAlgorithm
|
||||
}
|
||||
|
||||
func newHandlerRequiredFilter(c *config.Instance) *handlerRequiredFilter {
|
||||
db := instance.Database
|
||||
func newHandlerRequiredFilter(c *config.Instance, repo models.Repository) *handlerRequiredFilter {
|
||||
processes := c.GetParallelTasksWithAutoDetection()
|
||||
|
||||
return &handlerRequiredFilter{
|
||||
extensionConfig: newExtensionConfig(c),
|
||||
txnManager: db,
|
||||
SceneFinder: db.Scene,
|
||||
ImageFinder: db.Image,
|
||||
GalleryFinder: db.Gallery,
|
||||
CaptionUpdater: db.File,
|
||||
txnManager: repo.TxnManager,
|
||||
SceneFinder: repo.Scene,
|
||||
ImageFinder: repo.Image,
|
||||
GalleryFinder: repo.Gallery,
|
||||
CaptionUpdater: repo.File,
|
||||
FolderCache: lru.New(processes * 2),
|
||||
videoFileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(),
|
||||
}
|
||||
|
|
@ -226,6 +228,10 @@ func (f *handlerRequiredFilter) Accept(ctx context.Context, ff models.File) bool
|
|||
|
||||
type scanFilter struct {
|
||||
extensionConfig
|
||||
txnManager txn.Manager
|
||||
FileFinder models.FileFinder
|
||||
CaptionUpdater video.CaptionUpdater
|
||||
|
||||
stashPaths config.StashConfigs
|
||||
generatedPath string
|
||||
videoExcludeRegex []*regexp.Regexp
|
||||
|
|
@ -233,9 +239,12 @@ type scanFilter struct {
|
|||
minModTime time.Time
|
||||
}
|
||||
|
||||
func newScanFilter(c *config.Instance, minModTime time.Time) *scanFilter {
|
||||
func newScanFilter(c *config.Instance, repo models.Repository, minModTime time.Time) *scanFilter {
|
||||
return &scanFilter{
|
||||
extensionConfig: newExtensionConfig(c),
|
||||
txnManager: repo.TxnManager,
|
||||
FileFinder: repo.File,
|
||||
CaptionUpdater: repo.File,
|
||||
stashPaths: c.GetStashPaths(),
|
||||
generatedPath: c.GetGeneratedPath(),
|
||||
videoExcludeRegex: generateRegexps(c.GetExcludes()),
|
||||
|
|
@ -263,7 +272,7 @@ func (f *scanFilter) Accept(ctx context.Context, path string, info fs.FileInfo)
|
|||
if fsutil.MatchExtension(path, video.CaptionExts) {
|
||||
// we don't include caption files in the file scan, but we do need
|
||||
// to handle them
|
||||
video.AssociateCaptions(ctx, path, instance.Repository, instance.Database.File, instance.Database.File)
|
||||
video.AssociateCaptions(ctx, path, f.txnManager, f.FileFinder, f.CaptionUpdater)
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
@ -308,30 +317,37 @@ func (f *scanFilter) Accept(ctx context.Context, path string, info fs.FileInfo)
|
|||
type scanConfig struct {
|
||||
isGenerateThumbnails bool
|
||||
isGenerateClipPreviews bool
|
||||
|
||||
createGalleriesFromFolders bool
|
||||
}
|
||||
|
||||
func (c *scanConfig) GetCreateGalleriesFromFolders() bool {
|
||||
return instance.Config.GetCreateGalleriesFromFolders()
|
||||
return c.createGalleriesFromFolders
|
||||
}
|
||||
|
||||
func getScanHandlers(options ScanMetadataInput, taskQueue *job.TaskQueue, progress *job.Progress) []file.Handler {
|
||||
db := instance.Database
|
||||
pluginCache := instance.PluginCache
|
||||
mgr := GetInstance()
|
||||
c := mgr.Config
|
||||
r := mgr.Repository
|
||||
pluginCache := mgr.PluginCache
|
||||
|
||||
return []file.Handler{
|
||||
&file.FilteredHandler{
|
||||
Filter: file.FilterFunc(imageFileFilter),
|
||||
Handler: &image.ScanHandler{
|
||||
CreatorUpdater: db.Image,
|
||||
GalleryFinder: db.Gallery,
|
||||
CreatorUpdater: r.Image,
|
||||
GalleryFinder: r.Gallery,
|
||||
ScanGenerator: &imageGenerators{
|
||||
input: options,
|
||||
taskQueue: taskQueue,
|
||||
progress: progress,
|
||||
input: options,
|
||||
taskQueue: taskQueue,
|
||||
progress: progress,
|
||||
paths: mgr.Paths,
|
||||
sequentialScanning: c.GetSequentialScanning(),
|
||||
},
|
||||
ScanConfig: &scanConfig{
|
||||
isGenerateThumbnails: options.ScanGenerateThumbnails,
|
||||
isGenerateClipPreviews: options.ScanGenerateClipPreviews,
|
||||
isGenerateThumbnails: options.ScanGenerateThumbnails,
|
||||
isGenerateClipPreviews: options.ScanGenerateClipPreviews,
|
||||
createGalleriesFromFolders: c.GetCreateGalleriesFromFolders(),
|
||||
},
|
||||
PluginCache: pluginCache,
|
||||
Paths: instance.Paths,
|
||||
|
|
@ -340,25 +356,28 @@ func getScanHandlers(options ScanMetadataInput, taskQueue *job.TaskQueue, progre
|
|||
&file.FilteredHandler{
|
||||
Filter: file.FilterFunc(galleryFileFilter),
|
||||
Handler: &gallery.ScanHandler{
|
||||
CreatorUpdater: db.Gallery,
|
||||
SceneFinderUpdater: db.Scene,
|
||||
ImageFinderUpdater: db.Image,
|
||||
CreatorUpdater: r.Gallery,
|
||||
SceneFinderUpdater: r.Scene,
|
||||
ImageFinderUpdater: r.Image,
|
||||
PluginCache: pluginCache,
|
||||
},
|
||||
},
|
||||
&file.FilteredHandler{
|
||||
Filter: file.FilterFunc(videoFileFilter),
|
||||
Handler: &scene.ScanHandler{
|
||||
CreatorUpdater: db.Scene,
|
||||
CreatorUpdater: r.Scene,
|
||||
CaptionUpdater: r.File,
|
||||
PluginCache: pluginCache,
|
||||
CaptionUpdater: db.File,
|
||||
ScanGenerator: &sceneGenerators{
|
||||
input: options,
|
||||
taskQueue: taskQueue,
|
||||
progress: progress,
|
||||
input: options,
|
||||
taskQueue: taskQueue,
|
||||
progress: progress,
|
||||
paths: mgr.Paths,
|
||||
fileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(),
|
||||
sequentialScanning: c.GetSequentialScanning(),
|
||||
},
|
||||
FileNamingAlgorithm: instance.Config.GetVideoFileNamingAlgorithm(),
|
||||
Paths: instance.Paths,
|
||||
FileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(),
|
||||
Paths: mgr.Paths,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -368,6 +387,9 @@ type imageGenerators struct {
|
|||
input ScanMetadataInput
|
||||
taskQueue *job.TaskQueue
|
||||
progress *job.Progress
|
||||
|
||||
paths *paths.Paths
|
||||
sequentialScanning bool
|
||||
}
|
||||
|
||||
func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f models.File) error {
|
||||
|
|
@ -376,8 +398,6 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model
|
|||
progress := g.progress
|
||||
t := g.input
|
||||
path := f.Base().Path
|
||||
config := instance.Config
|
||||
sequentialScanning := config.GetSequentialScanning()
|
||||
|
||||
if t.ScanGenerateThumbnails {
|
||||
// this should be quick, so always generate sequentially
|
||||
|
|
@ -405,7 +425,7 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model
|
|||
progress.Increment()
|
||||
}
|
||||
|
||||
if sequentialScanning {
|
||||
if g.sequentialScanning {
|
||||
previewsFn(ctx)
|
||||
} else {
|
||||
g.taskQueue.Add(fmt.Sprintf("Generating preview for %s", path), previewsFn)
|
||||
|
|
@ -416,7 +436,7 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model
|
|||
}
|
||||
|
||||
func (g *imageGenerators) generateThumbnail(ctx context.Context, i *models.Image, f models.File) error {
|
||||
thumbPath := GetInstance().Paths.Generated.GetThumbnailPath(i.Checksum, models.DefaultGthumbWidth)
|
||||
thumbPath := g.paths.Generated.GetThumbnailPath(i.Checksum, models.DefaultGthumbWidth)
|
||||
exists, _ := fsutil.FileExists(thumbPath)
|
||||
if exists {
|
||||
return nil
|
||||
|
|
@ -435,13 +455,16 @@ func (g *imageGenerators) generateThumbnail(ctx context.Context, i *models.Image
|
|||
|
||||
logger.Debugf("Generating thumbnail for %s", path)
|
||||
|
||||
mgr := GetInstance()
|
||||
c := mgr.Config
|
||||
|
||||
clipPreviewOptions := image.ClipPreviewOptions{
|
||||
InputArgs: instance.Config.GetTranscodeInputArgs(),
|
||||
OutputArgs: instance.Config.GetTranscodeOutputArgs(),
|
||||
Preset: instance.Config.GetPreviewPreset().String(),
|
||||
InputArgs: c.GetTranscodeInputArgs(),
|
||||
OutputArgs: c.GetTranscodeOutputArgs(),
|
||||
Preset: c.GetPreviewPreset().String(),
|
||||
}
|
||||
|
||||
encoder := image.NewThumbnailEncoder(instance.FFMPEG, instance.FFProbe, clipPreviewOptions)
|
||||
encoder := image.NewThumbnailEncoder(mgr.FFMPEG, mgr.FFProbe, clipPreviewOptions)
|
||||
data, err := encoder.GetThumbnail(f, models.DefaultGthumbWidth)
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -464,6 +487,10 @@ type sceneGenerators struct {
|
|||
input ScanMetadataInput
|
||||
taskQueue *job.TaskQueue
|
||||
progress *job.Progress
|
||||
|
||||
paths *paths.Paths
|
||||
fileNamingAlgorithm models.HashAlgorithm
|
||||
sequentialScanning bool
|
||||
}
|
||||
|
||||
func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *models.VideoFile) error {
|
||||
|
|
@ -472,9 +499,8 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
|
|||
progress := g.progress
|
||||
t := g.input
|
||||
path := f.Path
|
||||
config := instance.Config
|
||||
fileNamingAlgorithm := config.GetVideoFileNamingAlgorithm()
|
||||
sequentialScanning := config.GetSequentialScanning()
|
||||
|
||||
mgr := GetInstance()
|
||||
|
||||
if t.ScanGenerateSprites {
|
||||
progress.AddTotal(1)
|
||||
|
|
@ -482,13 +508,13 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
|
|||
taskSprite := GenerateSpriteTask{
|
||||
Scene: *s,
|
||||
Overwrite: overwrite,
|
||||
fileNamingAlgorithm: fileNamingAlgorithm,
|
||||
fileNamingAlgorithm: g.fileNamingAlgorithm,
|
||||
}
|
||||
taskSprite.Start(ctx)
|
||||
progress.Increment()
|
||||
}
|
||||
|
||||
if sequentialScanning {
|
||||
if g.sequentialScanning {
|
||||
spriteFn(ctx)
|
||||
} else {
|
||||
g.taskQueue.Add(fmt.Sprintf("Generating sprites for %s", path), spriteFn)
|
||||
|
|
@ -499,17 +525,16 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
|
|||
progress.AddTotal(1)
|
||||
phashFn := func(ctx context.Context) {
|
||||
taskPhash := GeneratePhashTask{
|
||||
repository: mgr.Repository,
|
||||
File: f,
|
||||
fileNamingAlgorithm: fileNamingAlgorithm,
|
||||
txnManager: instance.Database,
|
||||
fileUpdater: instance.Database.File,
|
||||
Overwrite: overwrite,
|
||||
fileNamingAlgorithm: g.fileNamingAlgorithm,
|
||||
}
|
||||
taskPhash.Start(ctx)
|
||||
progress.Increment()
|
||||
}
|
||||
|
||||
if sequentialScanning {
|
||||
if g.sequentialScanning {
|
||||
phashFn(ctx)
|
||||
} else {
|
||||
g.taskQueue.Add(fmt.Sprintf("Generating phash for %s", path), phashFn)
|
||||
|
|
@ -521,12 +546,12 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
|
|||
previewsFn := func(ctx context.Context) {
|
||||
options := getGeneratePreviewOptions(GeneratePreviewOptionsInput{})
|
||||
|
||||
g := &generate.Generator{
|
||||
Encoder: instance.FFMPEG,
|
||||
FFMpegConfig: instance.Config,
|
||||
LockManager: instance.ReadLockManager,
|
||||
MarkerPaths: instance.Paths.SceneMarkers,
|
||||
ScenePaths: instance.Paths.Scene,
|
||||
generator := &generate.Generator{
|
||||
Encoder: mgr.FFMPEG,
|
||||
FFMpegConfig: mgr.Config,
|
||||
LockManager: mgr.ReadLockManager,
|
||||
MarkerPaths: g.paths.SceneMarkers,
|
||||
ScenePaths: g.paths.Scene,
|
||||
Overwrite: overwrite,
|
||||
}
|
||||
|
||||
|
|
@ -535,14 +560,14 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
|
|||
ImagePreview: t.ScanGenerateImagePreviews,
|
||||
Options: options,
|
||||
Overwrite: overwrite,
|
||||
fileNamingAlgorithm: fileNamingAlgorithm,
|
||||
generator: g,
|
||||
fileNamingAlgorithm: g.fileNamingAlgorithm,
|
||||
generator: generator,
|
||||
}
|
||||
taskPreview.Start(ctx)
|
||||
progress.Increment()
|
||||
}
|
||||
|
||||
if sequentialScanning {
|
||||
if g.sequentialScanning {
|
||||
previewsFn(ctx)
|
||||
} else {
|
||||
g.taskQueue.Add(fmt.Sprintf("Generating preview for %s", path), previewsFn)
|
||||
|
|
@ -553,8 +578,8 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
|
|||
progress.AddTotal(1)
|
||||
g.taskQueue.Add(fmt.Sprintf("Generating cover for %s", path), func(ctx context.Context) {
|
||||
taskCover := GenerateCoverTask{
|
||||
repository: mgr.Repository,
|
||||
Scene: *s,
|
||||
txnManager: instance.Repository,
|
||||
Overwrite: overwrite,
|
||||
}
|
||||
taskCover.Start(ctx)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import (
|
|||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/scraper/stashbox"
|
||||
"github.com/stashapp/stash/pkg/studio"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
type StashBoxTagTaskType int
|
||||
|
|
@ -92,18 +91,18 @@ func (t *StashBoxBatchTagTask) findStashBoxPerformer(ctx context.Context) (*mode
|
|||
var performer *models.ScrapedPerformer
|
||||
var err error
|
||||
|
||||
client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{
|
||||
Scene: instance.Repository.Scene,
|
||||
Performer: instance.Repository.Performer,
|
||||
Tag: instance.Repository.Tag,
|
||||
Studio: instance.Repository.Studio,
|
||||
})
|
||||
r := instance.Repository
|
||||
|
||||
stashboxRepository := stashbox.NewRepository(r)
|
||||
client := stashbox.NewClient(*t.box, stashboxRepository)
|
||||
|
||||
if t.refresh {
|
||||
var remoteID string
|
||||
if err := txn.WithReadTxn(ctx, instance.Repository, func(ctx context.Context) error {
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Performer
|
||||
|
||||
if !t.performer.StashIDs.Loaded() {
|
||||
err = t.performer.LoadStashIDs(ctx, instance.Repository.Performer)
|
||||
err = t.performer.LoadStashIDs(ctx, qb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -145,8 +144,9 @@ func (t *StashBoxBatchTagTask) processMatchedPerformer(ctx context.Context, p *m
|
|||
}
|
||||
|
||||
// Start the transaction and update the performer
|
||||
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
|
||||
qb := instance.Repository.Performer
|
||||
r := instance.Repository
|
||||
err = r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Performer
|
||||
|
||||
existingStashIDs, err := qb.GetStashIDs(ctx, storedID)
|
||||
if err != nil {
|
||||
|
|
@ -181,8 +181,10 @@ func (t *StashBoxBatchTagTask) processMatchedPerformer(ctx context.Context, p *m
|
|||
return
|
||||
}
|
||||
|
||||
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
|
||||
qb := instance.Repository.Performer
|
||||
r := instance.Repository
|
||||
err = r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Performer
|
||||
|
||||
if err := qb.Create(ctx, newPerformer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -233,18 +235,16 @@ func (t *StashBoxBatchTagTask) findStashBoxStudio(ctx context.Context) (*models.
|
|||
var studio *models.ScrapedStudio
|
||||
var err error
|
||||
|
||||
client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{
|
||||
Scene: instance.Repository.Scene,
|
||||
Performer: instance.Repository.Performer,
|
||||
Tag: instance.Repository.Tag,
|
||||
Studio: instance.Repository.Studio,
|
||||
})
|
||||
r := instance.Repository
|
||||
|
||||
stashboxRepository := stashbox.NewRepository(r)
|
||||
client := stashbox.NewClient(*t.box, stashboxRepository)
|
||||
|
||||
if t.refresh {
|
||||
var remoteID string
|
||||
if err := txn.WithReadTxn(ctx, instance.Repository, func(ctx context.Context) error {
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
if !t.studio.StashIDs.Loaded() {
|
||||
err = t.studio.LoadStashIDs(ctx, instance.Repository.Studio)
|
||||
err = t.studio.LoadStashIDs(ctx, r.Studio)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -293,8 +293,9 @@ func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *mode
|
|||
}
|
||||
|
||||
// Start the transaction and update the studio
|
||||
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
|
||||
qb := instance.Repository.Studio
|
||||
r := instance.Repository
|
||||
err = r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Studio
|
||||
|
||||
existingStashIDs, err := qb.GetStashIDs(ctx, storedID)
|
||||
if err != nil {
|
||||
|
|
@ -341,8 +342,10 @@ func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *mode
|
|||
}
|
||||
|
||||
// Start the transaction and save the studio
|
||||
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
|
||||
qb := instance.Repository.Studio
|
||||
r := instance.Repository
|
||||
err = r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Studio
|
||||
|
||||
if err := qb.Create(ctx, newStudio); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -375,8 +378,10 @@ func (t *StashBoxBatchTagTask) processParentStudio(ctx context.Context, parent *
|
|||
}
|
||||
|
||||
// Start the transaction and save the studio
|
||||
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
|
||||
qb := instance.Repository.Studio
|
||||
r := instance.Repository
|
||||
err = r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Studio
|
||||
|
||||
if err := qb.Create(ctx, newParentStudio); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -408,8 +413,9 @@ func (t *StashBoxBatchTagTask) processParentStudio(ctx context.Context, parent *
|
|||
}
|
||||
|
||||
// Start the transaction and update the studio
|
||||
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
|
||||
qb := instance.Repository.Studio
|
||||
r := instance.Repository
|
||||
err = r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Studio
|
||||
|
||||
existingStashIDs, err := qb.GetStashIDs(ctx, storedID)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -1,21 +1,62 @@
|
|||
package static
|
||||
|
||||
import "embed"
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
)
|
||||
|
||||
//go:embed performer
|
||||
var Performer embed.FS
|
||||
//go:embed performer performer_male scene image tag studio movie
|
||||
var data embed.FS
|
||||
|
||||
//go:embed performer_male
|
||||
var PerformerMale embed.FS
|
||||
const (
|
||||
Performer = "performer"
|
||||
PerformerMale = "performer_male"
|
||||
|
||||
//go:embed scene
|
||||
var Scene embed.FS
|
||||
Scene = "scene"
|
||||
DefaultSceneImage = "scene/scene.svg"
|
||||
|
||||
//go:embed image
|
||||
var Image embed.FS
|
||||
Image = "image"
|
||||
DefaultImageImage = "image/image.svg"
|
||||
|
||||
//go:embed tag
|
||||
var Tag embed.FS
|
||||
Tag = "tag"
|
||||
DefaultTagImage = "tag/tag.svg"
|
||||
|
||||
//go:embed studio
|
||||
var Studio embed.FS
|
||||
Studio = "studio"
|
||||
DefaultStudioImage = "studio/studio.svg"
|
||||
|
||||
Movie = "movie"
|
||||
DefaultMovieImage = "movie/movie.png"
|
||||
)
|
||||
|
||||
// Sub returns an FS rooted at path, using fs.Sub.
|
||||
// It will panic if an error occurs.
|
||||
func Sub(path string) fs.FS {
|
||||
ret, err := fs.Sub(data, path)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("creating static SubFS: %v", err))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Open opens the file at path for reading.
|
||||
// It will panic if an error occurs.
|
||||
func Open(path string) fs.File {
|
||||
f, err := data.Open(path)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("opening static file: %v", err))
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// ReadAll returns the contents of the file at path.
|
||||
// It will panic if an error occurs.
|
||||
func ReadAll(path string) []byte {
|
||||
f := Open(path)
|
||||
ret, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("reading static file: %v", err))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
|
|
|||
BIN
internal/static/movie/movie.png
Normal file
BIN
internal/static/movie/movie.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 405 B |
|
|
@ -11,7 +11,6 @@ import (
|
|||
"github.com/stashapp/stash/pkg/job"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
// Cleaner scans through stored file and folder instances and removes those that are no longer present on disk.
|
||||
|
|
@ -112,14 +111,15 @@ func (j *cleanJob) execute(ctx context.Context) error {
|
|||
folderCount int
|
||||
)
|
||||
|
||||
if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error {
|
||||
r := j.Repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
fileCount, err = j.Repository.FileStore.CountAllInPaths(ctx, j.options.Paths)
|
||||
fileCount, err = r.File.CountAllInPaths(ctx, j.options.Paths)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
folderCount, err = j.Repository.FolderStore.CountAllInPaths(ctx, j.options.Paths)
|
||||
folderCount, err = r.Folder.CountAllInPaths(ctx, j.options.Paths)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -172,13 +172,14 @@ func (j *cleanJob) assessFiles(ctx context.Context, toDelete *deleteSet) error {
|
|||
progress := j.progress
|
||||
|
||||
more := true
|
||||
if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error {
|
||||
r := j.Repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
for more {
|
||||
if job.IsCancelled(ctx) {
|
||||
return nil
|
||||
}
|
||||
|
||||
files, err := j.Repository.FileStore.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
|
||||
files, err := r.File.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error querying for files: %w", err)
|
||||
}
|
||||
|
|
@ -223,8 +224,9 @@ func (j *cleanJob) assessFiles(ctx context.Context, toDelete *deleteSet) error {
|
|||
|
||||
// flagFolderForDelete adds folders to the toDelete set, with the leaf folders added first
|
||||
func (j *cleanJob) flagFileForDelete(ctx context.Context, toDelete *deleteSet, f models.File) error {
|
||||
r := j.Repository
|
||||
// add contained files first
|
||||
containedFiles, err := j.Repository.FileStore.FindByZipFileID(ctx, f.Base().ID)
|
||||
containedFiles, err := r.File.FindByZipFileID(ctx, f.Base().ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error finding contained files for %q: %w", f.Base().Path, err)
|
||||
}
|
||||
|
|
@ -235,7 +237,7 @@ func (j *cleanJob) flagFileForDelete(ctx context.Context, toDelete *deleteSet, f
|
|||
}
|
||||
|
||||
// add contained folders as well
|
||||
containedFolders, err := j.Repository.FolderStore.FindByZipFileID(ctx, f.Base().ID)
|
||||
containedFolders, err := r.Folder.FindByZipFileID(ctx, f.Base().ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error finding contained folders for %q: %w", f.Base().Path, err)
|
||||
}
|
||||
|
|
@ -256,13 +258,14 @@ func (j *cleanJob) assessFolders(ctx context.Context, toDelete *deleteSet) error
|
|||
progress := j.progress
|
||||
|
||||
more := true
|
||||
if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error {
|
||||
r := j.Repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
for more {
|
||||
if job.IsCancelled(ctx) {
|
||||
return nil
|
||||
}
|
||||
|
||||
folders, err := j.Repository.FolderStore.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
|
||||
folders, err := r.Folder.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error querying for folders: %w", err)
|
||||
}
|
||||
|
|
@ -380,14 +383,15 @@ func (j *cleanJob) shouldCleanFolder(ctx context.Context, f *models.Folder) bool
|
|||
func (j *cleanJob) deleteFile(ctx context.Context, fileID models.FileID, fn string) {
|
||||
// delete associated objects
|
||||
fileDeleter := NewDeleter()
|
||||
if err := txn.WithTxn(ctx, j.Repository, func(ctx context.Context) error {
|
||||
r := j.Repository
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
fileDeleter.RegisterHooks(ctx)
|
||||
|
||||
if err := j.fireHandlers(ctx, fileDeleter, fileID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return j.Repository.FileStore.Destroy(ctx, fileID)
|
||||
return r.File.Destroy(ctx, fileID)
|
||||
}); err != nil {
|
||||
logger.Errorf("Error deleting file %q from database: %s", fn, err.Error())
|
||||
return
|
||||
|
|
@ -397,14 +401,15 @@ func (j *cleanJob) deleteFile(ctx context.Context, fileID models.FileID, fn stri
|
|||
func (j *cleanJob) deleteFolder(ctx context.Context, folderID models.FolderID, fn string) {
|
||||
// delete associated objects
|
||||
fileDeleter := NewDeleter()
|
||||
if err := txn.WithTxn(ctx, j.Repository, func(ctx context.Context) error {
|
||||
r := j.Repository
|
||||
if err := r.WithTxn(ctx, func(ctx context.Context) error {
|
||||
fileDeleter.RegisterHooks(ctx)
|
||||
|
||||
if err := j.fireFolderHandlers(ctx, fileDeleter, folderID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return j.Repository.FolderStore.Destroy(ctx, folderID)
|
||||
return r.Folder.Destroy(ctx, folderID)
|
||||
}); err != nil {
|
||||
logger.Errorf("Error deleting folder %q from database: %s", fn, err.Error())
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,15 +1,36 @@
|
|||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
// Repository provides access to storage methods for files and folders.
|
||||
type Repository struct {
|
||||
txn.Manager
|
||||
txn.DatabaseProvider
|
||||
TxnManager models.TxnManager
|
||||
|
||||
FileStore models.FileReaderWriter
|
||||
FolderStore models.FolderReaderWriter
|
||||
File models.FileReaderWriter
|
||||
Folder models.FolderReaderWriter
|
||||
}
|
||||
|
||||
func NewRepository(repo models.Repository) Repository {
|
||||
return Repository{
|
||||
TxnManager: repo.TxnManager,
|
||||
File: repo.File,
|
||||
Folder: repo.Folder,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error {
|
||||
return txn.WithTxn(ctx, r.TxnManager, fn)
|
||||
}
|
||||
|
||||
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
|
||||
return txn.WithReadTxn(ctx, r.TxnManager, fn)
|
||||
}
|
||||
|
||||
func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error {
|
||||
return txn.WithDatabase(ctx, r.TxnManager, fn)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,6 +86,8 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
|
|||
}
|
||||
// rejects is a set of folder ids which were found to still exist
|
||||
|
||||
r := s.Repository
|
||||
|
||||
if err := symWalk(file.fs, file.Path, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
// don't let errors prevent scanning
|
||||
|
|
@ -118,7 +120,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
|
|||
}
|
||||
|
||||
// check if the file exists in the database based on basename, size and mod time
|
||||
existing, err := s.Repository.FileStore.FindByFileInfo(ctx, info, size)
|
||||
existing, err := r.File.FindByFileInfo(ctx, info, size)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking for existing file %q: %w", path, err)
|
||||
}
|
||||
|
|
@ -140,7 +142,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
|
|||
|
||||
if c == nil {
|
||||
// need to check if the folder exists in the filesystem
|
||||
pf, err := s.Repository.FolderStore.Find(ctx, e.Base().ParentFolderID)
|
||||
pf, err := r.Folder.Find(ctx, e.Base().ParentFolderID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting parent folder %d: %w", e.Base().ParentFolderID, err)
|
||||
}
|
||||
|
|
@ -164,7 +166,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
|
|||
|
||||
// parent folder is missing, possible candidate
|
||||
// count the total number of files in the existing folder
|
||||
count, err := s.Repository.FileStore.CountByFolderID(ctx, parentFolderID)
|
||||
count, err := r.File.CountByFolderID(ctx, parentFolderID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("counting files in folder %d: %w", parentFolderID, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ func (m *Mover) moveFile(oldPath, newPath string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *Mover) RegisterHooks(ctx context.Context, mgr txn.Manager) {
|
||||
func (m *Mover) RegisterHooks(ctx context.Context) {
|
||||
txn.AddPostCommitHook(ctx, func(ctx context.Context) {
|
||||
m.commit()
|
||||
})
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ func (s *Scanner) Scan(ctx context.Context, handlers []Handler, options ScanOpti
|
|||
ProgressReports: progressReporter,
|
||||
options: options,
|
||||
txnRetryer: txn.Retryer{
|
||||
Manager: s.Repository,
|
||||
Manager: s.Repository.TxnManager,
|
||||
Retries: maxRetries,
|
||||
},
|
||||
}
|
||||
|
|
@ -163,7 +163,7 @@ func (s *scanJob) withTxn(ctx context.Context, fn func(ctx context.Context) erro
|
|||
}
|
||||
|
||||
func (s *scanJob) withDB(ctx context.Context, fn func(ctx context.Context) error) error {
|
||||
return txn.WithDatabase(ctx, s.Repository, fn)
|
||||
return s.Repository.WithDB(ctx, fn)
|
||||
}
|
||||
|
||||
func (s *scanJob) execute(ctx context.Context) {
|
||||
|
|
@ -439,7 +439,7 @@ func (s *scanJob) getFolderID(ctx context.Context, path string) (*models.FolderI
|
|||
return &v, nil
|
||||
}
|
||||
|
||||
ret, err := s.Repository.FolderStore.FindByPath(ctx, path)
|
||||
ret, err := s.Repository.Folder.FindByPath(ctx, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -469,7 +469,7 @@ func (s *scanJob) getZipFileID(ctx context.Context, zipFile *scanFile) (*models.
|
|||
return &v, nil
|
||||
}
|
||||
|
||||
ret, err := s.Repository.FileStore.FindByPath(ctx, path)
|
||||
ret, err := s.Repository.File.FindByPath(ctx, path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting zip file ID for %q: %w", path, err)
|
||||
}
|
||||
|
|
@ -489,7 +489,7 @@ func (s *scanJob) handleFolder(ctx context.Context, file scanFile) error {
|
|||
defer s.incrementProgress(file)
|
||||
|
||||
// determine if folder already exists in data store (by path)
|
||||
f, err := s.Repository.FolderStore.FindByPath(ctx, path)
|
||||
f, err := s.Repository.Folder.FindByPath(ctx, path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking for existing folder %q: %w", path, err)
|
||||
}
|
||||
|
|
@ -553,7 +553,7 @@ func (s *scanJob) onNewFolder(ctx context.Context, file scanFile) (*models.Folde
|
|||
logger.Infof("%s doesn't exist. Creating new folder entry...", file.Path)
|
||||
})
|
||||
|
||||
if err := s.Repository.FolderStore.Create(ctx, toCreate); err != nil {
|
||||
if err := s.Repository.Folder.Create(ctx, toCreate); err != nil {
|
||||
return nil, fmt.Errorf("creating folder %q: %w", file.Path, err)
|
||||
}
|
||||
|
||||
|
|
@ -589,12 +589,12 @@ func (s *scanJob) handleFolderRename(ctx context.Context, file scanFile) (*model
|
|||
|
||||
renamedFrom.ParentFolderID = parentFolderID
|
||||
|
||||
if err := s.Repository.FolderStore.Update(ctx, renamedFrom); err != nil {
|
||||
if err := s.Repository.Folder.Update(ctx, renamedFrom); err != nil {
|
||||
return nil, fmt.Errorf("updating folder for rename %q: %w", renamedFrom.Path, err)
|
||||
}
|
||||
|
||||
// #4146 - correct sub-folders to have the correct path
|
||||
if err := correctSubFolderHierarchy(ctx, s.Repository.FolderStore, renamedFrom); err != nil {
|
||||
if err := correctSubFolderHierarchy(ctx, s.Repository.Folder, renamedFrom); err != nil {
|
||||
return nil, fmt.Errorf("correcting sub folder hierarchy for %q: %w", renamedFrom.Path, err)
|
||||
}
|
||||
|
||||
|
|
@ -626,7 +626,7 @@ func (s *scanJob) onExistingFolder(ctx context.Context, f scanFile, existing *mo
|
|||
|
||||
if update {
|
||||
var err error
|
||||
if err = s.Repository.FolderStore.Update(ctx, existing); err != nil {
|
||||
if err = s.Repository.Folder.Update(ctx, existing); err != nil {
|
||||
return nil, fmt.Errorf("updating folder %q: %w", f.Path, err)
|
||||
}
|
||||
}
|
||||
|
|
@ -647,7 +647,7 @@ func (s *scanJob) handleFile(ctx context.Context, f scanFile) error {
|
|||
if err := s.withDB(ctx, func(ctx context.Context) error {
|
||||
// determine if file already exists in data store
|
||||
var err error
|
||||
ff, err = s.Repository.FileStore.FindByPath(ctx, f.Path)
|
||||
ff, err = s.Repository.File.FindByPath(ctx, f.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking for existing file %q: %w", f.Path, err)
|
||||
}
|
||||
|
|
@ -745,7 +745,7 @@ func (s *scanJob) onNewFile(ctx context.Context, f scanFile) (models.File, error
|
|||
|
||||
// if not renamed, queue file for creation
|
||||
if err := s.withTxn(ctx, func(ctx context.Context) error {
|
||||
if err := s.Repository.FileStore.Create(ctx, file); err != nil {
|
||||
if err := s.Repository.File.Create(ctx, file); err != nil {
|
||||
return fmt.Errorf("creating file %q: %w", path, err)
|
||||
}
|
||||
|
||||
|
|
@ -838,7 +838,7 @@ func (s *scanJob) handleRename(ctx context.Context, f models.File, fp []models.F
|
|||
var others []models.File
|
||||
|
||||
for _, tfp := range fp {
|
||||
thisOthers, err := s.Repository.FileStore.FindByFingerprint(ctx, tfp)
|
||||
thisOthers, err := s.Repository.File.FindByFingerprint(ctx, tfp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting files by fingerprint %v: %w", tfp, err)
|
||||
}
|
||||
|
|
@ -896,12 +896,12 @@ func (s *scanJob) handleRename(ctx context.Context, f models.File, fp []models.F
|
|||
fBase.Fingerprints = otherBase.Fingerprints
|
||||
|
||||
if err := s.withTxn(ctx, func(ctx context.Context) error {
|
||||
if err := s.Repository.FileStore.Update(ctx, f); err != nil {
|
||||
if err := s.Repository.File.Update(ctx, f); err != nil {
|
||||
return fmt.Errorf("updating file for rename %q: %w", fBase.Path, err)
|
||||
}
|
||||
|
||||
if s.isZipFile(fBase.Basename) {
|
||||
if err := TransferZipFolderHierarchy(ctx, s.Repository.FolderStore, fBase.ID, otherBase.Path, fBase.Path); err != nil {
|
||||
if err := TransferZipFolderHierarchy(ctx, s.Repository.Folder, fBase.ID, otherBase.Path, fBase.Path); err != nil {
|
||||
return fmt.Errorf("moving folder hierarchy for renamed zip file %q: %w", fBase.Path, err)
|
||||
}
|
||||
}
|
||||
|
|
@ -963,7 +963,7 @@ func (s *scanJob) setMissingMetadata(ctx context.Context, f scanFile, existing m
|
|||
|
||||
// queue file for update
|
||||
if err := s.withTxn(ctx, func(ctx context.Context) error {
|
||||
if err := s.Repository.FileStore.Update(ctx, existing); err != nil {
|
||||
if err := s.Repository.File.Update(ctx, existing); err != nil {
|
||||
return fmt.Errorf("updating file %q: %w", path, err)
|
||||
}
|
||||
|
||||
|
|
@ -986,7 +986,7 @@ func (s *scanJob) setMissingFingerprints(ctx context.Context, f scanFile, existi
|
|||
existing.SetFingerprints(fp)
|
||||
|
||||
if err := s.withTxn(ctx, func(ctx context.Context) error {
|
||||
if err := s.Repository.FileStore.Update(ctx, existing); err != nil {
|
||||
if err := s.Repository.File.Update(ctx, existing); err != nil {
|
||||
return fmt.Errorf("updating file %q: %w", f.Path, err)
|
||||
}
|
||||
|
||||
|
|
@ -1035,7 +1035,7 @@ func (s *scanJob) onExistingFile(ctx context.Context, f scanFile, existing model
|
|||
|
||||
// queue file for update
|
||||
if err := s.withTxn(ctx, func(ctx context.Context) error {
|
||||
if err := s.Repository.FileStore.Update(ctx, existing); err != nil {
|
||||
if err := s.Repository.File.Update(ctx, existing); err != nil {
|
||||
return fmt.Errorf("updating file %q: %w", path, err)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -157,19 +157,19 @@ var getStudioScenarios = []stringTestScenario{
|
|||
}
|
||||
|
||||
func TestGetStudioName(t *testing.T) {
|
||||
mockStudioReader := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
studioErr := errors.New("error getting image")
|
||||
|
||||
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
|
||||
db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{
|
||||
Name: studioName,
|
||||
}, nil).Once()
|
||||
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
|
||||
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
|
||||
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
|
||||
db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
|
||||
|
||||
for i, s := range getStudioScenarios {
|
||||
gallery := s.input
|
||||
json, err := GetStudioName(testCtx, mockStudioReader, &gallery)
|
||||
json, err := GetStudioName(testCtx, db.Studio, &gallery)
|
||||
|
||||
switch {
|
||||
case !s.err && err != nil:
|
||||
|
|
@ -181,7 +181,7 @@ func TestGetStudioName(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
mockStudioReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
const (
|
||||
|
|
@ -258,17 +258,17 @@ var validChapters = []*models.GalleryChapter{
|
|||
}
|
||||
|
||||
func TestGetGalleryChaptersJSON(t *testing.T) {
|
||||
mockChapterReader := &mocks.GalleryChapterReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
chaptersErr := errors.New("error getting gallery chapters")
|
||||
|
||||
mockChapterReader.On("FindByGalleryID", testCtx, galleryID).Return(validChapters, nil).Once()
|
||||
mockChapterReader.On("FindByGalleryID", testCtx, noChaptersID).Return(nil, nil).Once()
|
||||
mockChapterReader.On("FindByGalleryID", testCtx, errChaptersID).Return(nil, chaptersErr).Once()
|
||||
db.GalleryChapter.On("FindByGalleryID", testCtx, galleryID).Return(validChapters, nil).Once()
|
||||
db.GalleryChapter.On("FindByGalleryID", testCtx, noChaptersID).Return(nil, nil).Once()
|
||||
db.GalleryChapter.On("FindByGalleryID", testCtx, errChaptersID).Return(nil, chaptersErr).Once()
|
||||
|
||||
for i, s := range getGalleryChaptersJSONScenarios {
|
||||
gallery := s.input
|
||||
json, err := GetGalleryChaptersJSON(testCtx, mockChapterReader, &gallery)
|
||||
json, err := GetGalleryChaptersJSON(testCtx, db.GalleryChapter, &gallery)
|
||||
|
||||
switch {
|
||||
case !s.err && err != nil:
|
||||
|
|
@ -280,4 +280,5 @@ func TestGetGalleryChaptersJSON(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,19 +78,19 @@ func TestImporterPreImport(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestImporterPreImportWithStudio(t *testing.T) {
|
||||
studioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
StudioWriter: studioReaderWriter,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Gallery{
|
||||
Studio: existingStudioName,
|
||||
},
|
||||
}
|
||||
|
||||
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
|
||||
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
|
||||
ID: existingStudioID,
|
||||
}, nil).Once()
|
||||
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
|
||||
db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -100,22 +100,22 @@ func TestImporterPreImportWithStudio(t *testing.T) {
|
|||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
studioReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
||||
studioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
StudioWriter: studioReaderWriter,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Gallery{
|
||||
Studio: missingStudioName,
|
||||
},
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
|
||||
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = existingStudioID
|
||||
}).Return(nil)
|
||||
|
|
@ -132,32 +132,34 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, existingStudioID, *i.gallery.StudioID)
|
||||
|
||||
studioReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
|
||||
studioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
StudioWriter: studioReaderWriter,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Gallery{
|
||||
Studio: missingStudioName,
|
||||
},
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
|
||||
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithPerformer(t *testing.T) {
|
||||
performerReaderWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
PerformerWriter: performerReaderWriter,
|
||||
PerformerWriter: db.Performer,
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
Input: jsonschema.Gallery{
|
||||
Performers: []string{
|
||||
|
|
@ -166,13 +168,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
|
||||
db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
|
||||
{
|
||||
ID: existingPerformerID,
|
||||
Name: existingPerformerName,
|
||||
},
|
||||
}, nil).Once()
|
||||
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -182,14 +184,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
|
|||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
performerReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingPerformer(t *testing.T) {
|
||||
performerReaderWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
PerformerWriter: performerReaderWriter,
|
||||
PerformerWriter: db.Performer,
|
||||
Input: jsonschema.Gallery{
|
||||
Performers: []string{
|
||||
missingPerformerName,
|
||||
|
|
@ -198,8 +200,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
|
||||
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
|
||||
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
|
||||
performer := args.Get(1).(*models.Performer)
|
||||
performer.ID = existingPerformerID
|
||||
}).Return(nil)
|
||||
|
|
@ -216,14 +218,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, []int{existingPerformerID}, i.gallery.PerformerIDs.List())
|
||||
|
||||
performerReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
|
||||
performerReaderWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
PerformerWriter: performerReaderWriter,
|
||||
PerformerWriter: db.Performer,
|
||||
Input: jsonschema.Gallery{
|
||||
Performers: []string{
|
||||
missingPerformerName,
|
||||
|
|
@ -232,18 +234,20 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
|
||||
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
|
||||
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithTag(t *testing.T) {
|
||||
tagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
TagWriter: tagReaderWriter,
|
||||
TagWriter: db.Tag,
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
Input: jsonschema.Gallery{
|
||||
Tags: []string{
|
||||
|
|
@ -252,13 +256,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
|
||||
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
|
||||
{
|
||||
ID: existingTagID,
|
||||
Name: existingTagName,
|
||||
},
|
||||
}, nil).Once()
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -268,14 +272,14 @@ func TestImporterPreImportWithTag(t *testing.T) {
|
|||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
tagReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingTag(t *testing.T) {
|
||||
tagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
TagWriter: tagReaderWriter,
|
||||
TagWriter: db.Tag,
|
||||
Input: jsonschema.Gallery{
|
||||
Tags: []string{
|
||||
missingTagName,
|
||||
|
|
@ -284,8 +288,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = existingTagID
|
||||
}).Return(nil)
|
||||
|
|
@ -302,14 +306,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, []int{existingTagID}, i.gallery.TagIDs.List())
|
||||
|
||||
tagReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
||||
tagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
TagWriter: tagReaderWriter,
|
||||
TagWriter: db.Tag,
|
||||
Input: jsonschema.Gallery{
|
||||
Tags: []string{
|
||||
missingTagName,
|
||||
|
|
@ -318,9 +322,11 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -130,19 +130,19 @@ var getStudioScenarios = []stringTestScenario{
|
|||
}
|
||||
|
||||
func TestGetStudioName(t *testing.T) {
|
||||
mockStudioReader := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
studioErr := errors.New("error getting image")
|
||||
|
||||
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
|
||||
db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{
|
||||
Name: studioName,
|
||||
}, nil).Once()
|
||||
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
|
||||
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
|
||||
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
|
||||
db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
|
||||
|
||||
for i, s := range getStudioScenarios {
|
||||
image := s.input
|
||||
json, err := GetStudioName(testCtx, mockStudioReader, &image)
|
||||
json, err := GetStudioName(testCtx, db.Studio, &image)
|
||||
|
||||
switch {
|
||||
case !s.err && err != nil:
|
||||
|
|
@ -154,5 +154,5 @@ func TestGetStudioName(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
mockStudioReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,19 +40,19 @@ func TestImporterPreImport(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestImporterPreImportWithStudio(t *testing.T) {
|
||||
studioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
StudioWriter: studioReaderWriter,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Image{
|
||||
Studio: existingStudioName,
|
||||
},
|
||||
}
|
||||
|
||||
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
|
||||
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
|
||||
ID: existingStudioID,
|
||||
}, nil).Once()
|
||||
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
|
||||
db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -62,22 +62,22 @@ func TestImporterPreImportWithStudio(t *testing.T) {
|
|||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
studioReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
||||
studioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
StudioWriter: studioReaderWriter,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Image{
|
||||
Studio: missingStudioName,
|
||||
},
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
|
||||
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = existingStudioID
|
||||
}).Return(nil)
|
||||
|
|
@ -94,32 +94,34 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, existingStudioID, *i.image.StudioID)
|
||||
|
||||
studioReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
|
||||
studioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
StudioWriter: studioReaderWriter,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Image{
|
||||
Studio: missingStudioName,
|
||||
},
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
|
||||
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithPerformer(t *testing.T) {
|
||||
performerReaderWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
PerformerWriter: performerReaderWriter,
|
||||
PerformerWriter: db.Performer,
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
Input: jsonschema.Image{
|
||||
Performers: []string{
|
||||
|
|
@ -128,13 +130,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
|
||||
db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
|
||||
{
|
||||
ID: existingPerformerID,
|
||||
Name: existingPerformerName,
|
||||
},
|
||||
}, nil).Once()
|
||||
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -144,14 +146,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
|
|||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
performerReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingPerformer(t *testing.T) {
|
||||
performerReaderWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
PerformerWriter: performerReaderWriter,
|
||||
PerformerWriter: db.Performer,
|
||||
Input: jsonschema.Image{
|
||||
Performers: []string{
|
||||
missingPerformerName,
|
||||
|
|
@ -160,8 +162,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
|
||||
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
|
||||
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
|
||||
performer := args.Get(1).(*models.Performer)
|
||||
performer.ID = existingPerformerID
|
||||
}).Return(nil)
|
||||
|
|
@ -178,14 +180,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, []int{existingPerformerID}, i.image.PerformerIDs.List())
|
||||
|
||||
performerReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
|
||||
performerReaderWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
PerformerWriter: performerReaderWriter,
|
||||
PerformerWriter: db.Performer,
|
||||
Input: jsonschema.Image{
|
||||
Performers: []string{
|
||||
missingPerformerName,
|
||||
|
|
@ -194,18 +196,20 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
|
||||
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
|
||||
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithTag(t *testing.T) {
|
||||
tagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
TagWriter: tagReaderWriter,
|
||||
TagWriter: db.Tag,
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
Input: jsonschema.Image{
|
||||
Tags: []string{
|
||||
|
|
@ -214,13 +218,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
|
||||
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
|
||||
{
|
||||
ID: existingTagID,
|
||||
Name: existingTagName,
|
||||
},
|
||||
}, nil).Once()
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -230,14 +234,14 @@ func TestImporterPreImportWithTag(t *testing.T) {
|
|||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
tagReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingTag(t *testing.T) {
|
||||
tagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
TagWriter: tagReaderWriter,
|
||||
TagWriter: db.Tag,
|
||||
Input: jsonschema.Image{
|
||||
Tags: []string{
|
||||
missingTagName,
|
||||
|
|
@ -246,8 +250,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = existingTagID
|
||||
}).Return(nil)
|
||||
|
|
@ -264,14 +268,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, []int{existingTagID}, i.image.TagIDs.List())
|
||||
|
||||
tagReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
||||
tagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
TagWriter: tagReaderWriter,
|
||||
TagWriter: db.Tag,
|
||||
Input: jsonschema.Image{
|
||||
Tags: []string{
|
||||
missingTagName,
|
||||
|
|
@ -280,9 +284,11 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
107
pkg/models/mocks/database.go
Normal file
107
pkg/models/mocks/database.go
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
File *FileReaderWriter
|
||||
Folder *FolderReaderWriter
|
||||
Gallery *GalleryReaderWriter
|
||||
GalleryChapter *GalleryChapterReaderWriter
|
||||
Image *ImageReaderWriter
|
||||
Movie *MovieReaderWriter
|
||||
Performer *PerformerReaderWriter
|
||||
Scene *SceneReaderWriter
|
||||
SceneMarker *SceneMarkerReaderWriter
|
||||
Studio *StudioReaderWriter
|
||||
Tag *TagReaderWriter
|
||||
SavedFilter *SavedFilterReaderWriter
|
||||
}
|
||||
|
||||
func (*Database) Begin(ctx context.Context, exclusive bool) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (*Database) WithDatabase(ctx context.Context) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (*Database) Commit(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*Database) Rollback(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*Database) Complete(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (*Database) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) {
|
||||
}
|
||||
|
||||
func (*Database) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) {
|
||||
}
|
||||
|
||||
func (*Database) IsLocked(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (*Database) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewDatabase() *Database {
|
||||
return &Database{
|
||||
File: &FileReaderWriter{},
|
||||
Folder: &FolderReaderWriter{},
|
||||
Gallery: &GalleryReaderWriter{},
|
||||
GalleryChapter: &GalleryChapterReaderWriter{},
|
||||
Image: &ImageReaderWriter{},
|
||||
Movie: &MovieReaderWriter{},
|
||||
Performer: &PerformerReaderWriter{},
|
||||
Scene: &SceneReaderWriter{},
|
||||
SceneMarker: &SceneMarkerReaderWriter{},
|
||||
Studio: &StudioReaderWriter{},
|
||||
Tag: &TagReaderWriter{},
|
||||
SavedFilter: &SavedFilterReaderWriter{},
|
||||
}
|
||||
}
|
||||
|
||||
func (db *Database) AssertExpectations(t mock.TestingT) {
|
||||
db.File.AssertExpectations(t)
|
||||
db.Folder.AssertExpectations(t)
|
||||
db.Gallery.AssertExpectations(t)
|
||||
db.GalleryChapter.AssertExpectations(t)
|
||||
db.Image.AssertExpectations(t)
|
||||
db.Movie.AssertExpectations(t)
|
||||
db.Performer.AssertExpectations(t)
|
||||
db.Scene.AssertExpectations(t)
|
||||
db.SceneMarker.AssertExpectations(t)
|
||||
db.Studio.AssertExpectations(t)
|
||||
db.Tag.AssertExpectations(t)
|
||||
db.SavedFilter.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func (db *Database) Repository() models.Repository {
|
||||
return models.Repository{
|
||||
TxnManager: db,
|
||||
File: db.File,
|
||||
Folder: db.Folder,
|
||||
Gallery: db.Gallery,
|
||||
GalleryChapter: db.GalleryChapter,
|
||||
Image: db.Image,
|
||||
Movie: db.Movie,
|
||||
Performer: db.Performer,
|
||||
Scene: db.Scene,
|
||||
SceneMarker: db.SceneMarker,
|
||||
Studio: db.Studio,
|
||||
Tag: db.Tag,
|
||||
SavedFilter: db.SavedFilter,
|
||||
}
|
||||
}
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
type TxnManager struct{}
|
||||
|
||||
func (*TxnManager) Begin(ctx context.Context, exclusive bool) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (*TxnManager) WithDatabase(ctx context.Context) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (*TxnManager) Commit(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*TxnManager) Rollback(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*TxnManager) Complete(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (*TxnManager) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) {
|
||||
}
|
||||
|
||||
func (*TxnManager) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) {
|
||||
}
|
||||
|
||||
func (*TxnManager) IsLocked(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (*TxnManager) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewTxnRepository() models.Repository {
|
||||
return models.Repository{
|
||||
TxnManager: &TxnManager{},
|
||||
Gallery: &GalleryReaderWriter{},
|
||||
GalleryChapter: &GalleryChapterReaderWriter{},
|
||||
Image: &ImageReaderWriter{},
|
||||
Movie: &MovieReaderWriter{},
|
||||
Performer: &PerformerReaderWriter{},
|
||||
Scene: &SceneReaderWriter{},
|
||||
SceneMarker: &SceneMarkerReaderWriter{},
|
||||
Studio: &StudioReaderWriter{},
|
||||
Tag: &TagReaderWriter{},
|
||||
SavedFilter: &SavedFilterReaderWriter{},
|
||||
}
|
||||
}
|
||||
|
|
@ -49,5 +49,3 @@ func NewMoviePartial() MoviePartial {
|
|||
UpdatedAt: NewOptionalTime(currentTime),
|
||||
}
|
||||
}
|
||||
|
||||
var DefaultMovieImage = ""
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
type TxnManager interface {
|
||||
txn.Manager
|
||||
txn.DatabaseProvider
|
||||
Reset() error
|
||||
}
|
||||
|
||||
type Repository struct {
|
||||
TxnManager
|
||||
TxnManager TxnManager
|
||||
|
||||
File FileReaderWriter
|
||||
Folder FolderReaderWriter
|
||||
|
|
@ -26,3 +27,15 @@ type Repository struct {
|
|||
Tag TagReaderWriter
|
||||
SavedFilter SavedFilterReaderWriter
|
||||
}
|
||||
|
||||
func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error {
|
||||
return txn.WithTxn(ctx, r.TxnManager, fn)
|
||||
}
|
||||
|
||||
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
|
||||
return txn.WithReadTxn(ctx, r.TxnManager, fn)
|
||||
}
|
||||
|
||||
func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error {
|
||||
return txn.WithDatabase(ctx, r.TxnManager, fn)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -168,34 +168,32 @@ func initTestTable() {
|
|||
func TestToJSON(t *testing.T) {
|
||||
initTestTable()
|
||||
|
||||
mockMovieReader := &mocks.MovieReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
imageErr := errors.New("error getting image")
|
||||
|
||||
mockMovieReader.On("GetFrontImage", testCtx, movieID).Return(frontImageBytes, nil).Once()
|
||||
mockMovieReader.On("GetFrontImage", testCtx, missingStudioMovieID).Return(frontImageBytes, nil).Once()
|
||||
mockMovieReader.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe()
|
||||
mockMovieReader.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once()
|
||||
mockMovieReader.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once()
|
||||
db.Movie.On("GetFrontImage", testCtx, movieID).Return(frontImageBytes, nil).Once()
|
||||
db.Movie.On("GetFrontImage", testCtx, missingStudioMovieID).Return(frontImageBytes, nil).Once()
|
||||
db.Movie.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe()
|
||||
db.Movie.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once()
|
||||
db.Movie.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once()
|
||||
|
||||
mockMovieReader.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once()
|
||||
mockMovieReader.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once()
|
||||
mockMovieReader.On("GetBackImage", testCtx, emptyID).Return(nil, nil).Once()
|
||||
mockMovieReader.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once()
|
||||
mockMovieReader.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe()
|
||||
mockMovieReader.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe()
|
||||
|
||||
mockStudioReader := &mocks.StudioReaderWriter{}
|
||||
db.Movie.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once()
|
||||
db.Movie.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once()
|
||||
db.Movie.On("GetBackImage", testCtx, emptyID).Return(nil, nil).Once()
|
||||
db.Movie.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once()
|
||||
db.Movie.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe()
|
||||
db.Movie.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe()
|
||||
|
||||
studioErr := errors.New("error getting studio")
|
||||
|
||||
mockStudioReader.On("Find", testCtx, studioID).Return(&movieStudio, nil)
|
||||
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil)
|
||||
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr)
|
||||
db.Studio.On("Find", testCtx, studioID).Return(&movieStudio, nil)
|
||||
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil)
|
||||
db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr)
|
||||
|
||||
for i, s := range scenarios {
|
||||
movie := s.movie
|
||||
json, err := ToJSON(testCtx, mockMovieReader, mockStudioReader, &movie)
|
||||
json, err := ToJSON(testCtx, db.Movie, db.Studio, &movie)
|
||||
|
||||
switch {
|
||||
case !s.err && err != nil:
|
||||
|
|
@ -207,6 +205,5 @@ func TestToJSON(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
mockMovieReader.AssertExpectations(t)
|
||||
mockStudioReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,10 +69,11 @@ func TestImporterPreImport(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestImporterPreImportWithStudio(t *testing.T) {
|
||||
studioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
StudioWriter: studioReaderWriter,
|
||||
ReaderWriter: db.Movie,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Movie{
|
||||
Name: movieName,
|
||||
FrontImage: frontImage,
|
||||
|
|
@ -82,10 +83,10 @@ func TestImporterPreImportWithStudio(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
|
||||
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
|
||||
ID: existingStudioID,
|
||||
}, nil).Once()
|
||||
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
|
||||
db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -95,14 +96,15 @@ func TestImporterPreImportWithStudio(t *testing.T) {
|
|||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
studioReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
||||
studioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
StudioWriter: studioReaderWriter,
|
||||
ReaderWriter: db.Movie,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Movie{
|
||||
Name: movieName,
|
||||
FrontImage: frontImage,
|
||||
|
|
@ -111,8 +113,8 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
|
||||
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = existingStudioID
|
||||
}).Return(nil)
|
||||
|
|
@ -129,14 +131,15 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, existingStudioID, *i.movie.StudioID)
|
||||
|
||||
studioReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
|
||||
studioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
StudioWriter: studioReaderWriter,
|
||||
ReaderWriter: db.Movie,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Movie{
|
||||
Name: movieName,
|
||||
FrontImage: frontImage,
|
||||
|
|
@ -145,27 +148,30 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
|
||||
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPostImport(t *testing.T) {
|
||||
readerWriter := &mocks.MovieReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Movie,
|
||||
StudioWriter: db.Studio,
|
||||
frontImageData: frontImageBytes,
|
||||
backImageData: backImageBytes,
|
||||
}
|
||||
|
||||
updateMovieImageErr := errors.New("UpdateImages error")
|
||||
|
||||
readerWriter.On("UpdateFrontImage", testCtx, movieID, frontImageBytes).Return(nil).Once()
|
||||
readerWriter.On("UpdateBackImage", testCtx, movieID, backImageBytes).Return(nil).Once()
|
||||
readerWriter.On("UpdateFrontImage", testCtx, errImageID, frontImageBytes).Return(updateMovieImageErr).Once()
|
||||
db.Movie.On("UpdateFrontImage", testCtx, movieID, frontImageBytes).Return(nil).Once()
|
||||
db.Movie.On("UpdateBackImage", testCtx, movieID, backImageBytes).Return(nil).Once()
|
||||
db.Movie.On("UpdateFrontImage", testCtx, errImageID, frontImageBytes).Return(updateMovieImageErr).Once()
|
||||
|
||||
err := i.PostImport(testCtx, movieID)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -173,25 +179,26 @@ func TestImporterPostImport(t *testing.T) {
|
|||
err = i.PostImport(testCtx, errImageID)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterFindExistingID(t *testing.T) {
|
||||
readerWriter := &mocks.MovieReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Movie,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Movie{
|
||||
Name: movieName,
|
||||
},
|
||||
}
|
||||
|
||||
errFindByName := errors.New("FindByName error")
|
||||
readerWriter.On("FindByName", testCtx, movieName, false).Return(nil, nil).Once()
|
||||
readerWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
|
||||
db.Movie.On("FindByName", testCtx, movieName, false).Return(nil, nil).Once()
|
||||
db.Movie.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
|
||||
ID: existingMovieID,
|
||||
}, nil).Once()
|
||||
readerWriter.On("FindByName", testCtx, movieNameErr, false).Return(nil, errFindByName).Once()
|
||||
db.Movie.On("FindByName", testCtx, movieNameErr, false).Return(nil, errFindByName).Once()
|
||||
|
||||
id, err := i.FindExistingID(testCtx)
|
||||
assert.Nil(t, id)
|
||||
|
|
@ -207,11 +214,11 @@ func TestImporterFindExistingID(t *testing.T) {
|
|||
assert.Nil(t, id)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
readerWriter := &mocks.MovieReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
movie := models.Movie{
|
||||
Name: movieName,
|
||||
|
|
@ -222,16 +229,17 @@ func TestCreate(t *testing.T) {
|
|||
}
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Movie,
|
||||
StudioWriter: db.Studio,
|
||||
movie: movie,
|
||||
}
|
||||
|
||||
errCreate := errors.New("Create error")
|
||||
readerWriter.On("Create", testCtx, &movie).Run(func(args mock.Arguments) {
|
||||
db.Movie.On("Create", testCtx, &movie).Run(func(args mock.Arguments) {
|
||||
m := args.Get(1).(*models.Movie)
|
||||
m.ID = movieID
|
||||
}).Return(nil).Once()
|
||||
readerWriter.On("Create", testCtx, &movieErr).Return(errCreate).Once()
|
||||
db.Movie.On("Create", testCtx, &movieErr).Return(errCreate).Once()
|
||||
|
||||
id, err := i.Create(testCtx)
|
||||
assert.Equal(t, movieID, *id)
|
||||
|
|
@ -242,11 +250,11 @@ func TestCreate(t *testing.T) {
|
|||
assert.Nil(t, id)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
readerWriter := &mocks.MovieReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
movie := models.Movie{
|
||||
Name: movieName,
|
||||
|
|
@ -257,7 +265,8 @@ func TestUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Movie,
|
||||
StudioWriter: db.Studio,
|
||||
movie: movie,
|
||||
}
|
||||
|
||||
|
|
@ -265,7 +274,7 @@ func TestUpdate(t *testing.T) {
|
|||
|
||||
// id needs to be set for the mock input
|
||||
movie.ID = movieID
|
||||
readerWriter.On("Update", testCtx, &movie).Return(nil).Once()
|
||||
db.Movie.On("Update", testCtx, &movie).Return(nil).Once()
|
||||
|
||||
err := i.Update(testCtx, movieID)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -274,10 +283,10 @@ func TestUpdate(t *testing.T) {
|
|||
|
||||
// need to set id separately
|
||||
movieErr.ID = errImageID
|
||||
readerWriter.On("Update", testCtx, &movieErr).Return(errUpdate).Once()
|
||||
db.Movie.On("Update", testCtx, &movieErr).Return(errUpdate).Once()
|
||||
|
||||
err = i.Update(testCtx, errImageID)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -203,17 +203,17 @@ func initTestTable() {
|
|||
func TestToJSON(t *testing.T) {
|
||||
initTestTable()
|
||||
|
||||
mockPerformerReader := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
imageErr := errors.New("error getting image")
|
||||
|
||||
mockPerformerReader.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once()
|
||||
mockPerformerReader.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
|
||||
mockPerformerReader.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
|
||||
db.Performer.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once()
|
||||
db.Performer.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
|
||||
db.Performer.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
|
||||
|
||||
for i, s := range scenarios {
|
||||
tag := s.input
|
||||
json, err := ToJSON(testCtx, mockPerformerReader, &tag)
|
||||
json, err := ToJSON(testCtx, db.Performer, &tag)
|
||||
|
||||
switch {
|
||||
case !s.err && err != nil:
|
||||
|
|
@ -225,5 +225,5 @@ func TestToJSON(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
mockPerformerReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,10 +63,11 @@ func TestImporterPreImport(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestImporterPreImportWithTag(t *testing.T) {
|
||||
tagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
TagWriter: tagReaderWriter,
|
||||
ReaderWriter: db.Performer,
|
||||
TagWriter: db.Tag,
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
Input: jsonschema.Performer{
|
||||
Tags: []string{
|
||||
|
|
@ -75,13 +76,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
|
||||
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
|
||||
{
|
||||
ID: existingTagID,
|
||||
Name: existingTagName,
|
||||
},
|
||||
}, nil).Once()
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -91,14 +92,15 @@ func TestImporterPreImportWithTag(t *testing.T) {
|
|||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
tagReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingTag(t *testing.T) {
|
||||
tagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
TagWriter: tagReaderWriter,
|
||||
ReaderWriter: db.Performer,
|
||||
TagWriter: db.Tag,
|
||||
Input: jsonschema.Performer{
|
||||
Tags: []string{
|
||||
missingTagName,
|
||||
|
|
@ -107,8 +109,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = existingTagID
|
||||
}).Return(nil)
|
||||
|
|
@ -125,14 +127,15 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, existingTagID, i.performer.TagIDs.List()[0])
|
||||
|
||||
tagReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
||||
tagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
TagWriter: tagReaderWriter,
|
||||
ReaderWriter: db.Performer,
|
||||
TagWriter: db.Tag,
|
||||
Input: jsonschema.Performer{
|
||||
Tags: []string{
|
||||
missingTagName,
|
||||
|
|
@ -141,25 +144,28 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPostImport(t *testing.T) {
|
||||
readerWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Performer,
|
||||
TagWriter: db.Tag,
|
||||
imageData: imageBytes,
|
||||
}
|
||||
|
||||
updatePerformerImageErr := errors.New("UpdateImage error")
|
||||
|
||||
readerWriter.On("UpdateImage", testCtx, performerID, imageBytes).Return(nil).Once()
|
||||
readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updatePerformerImageErr).Once()
|
||||
db.Performer.On("UpdateImage", testCtx, performerID, imageBytes).Return(nil).Once()
|
||||
db.Performer.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updatePerformerImageErr).Once()
|
||||
|
||||
err := i.PostImport(testCtx, performerID)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -167,14 +173,15 @@ func TestImporterPostImport(t *testing.T) {
|
|||
err = i.PostImport(testCtx, errImageID)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterFindExistingID(t *testing.T) {
|
||||
readerWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Performer,
|
||||
TagWriter: db.Tag,
|
||||
Input: jsonschema.Performer{
|
||||
Name: performerName,
|
||||
},
|
||||
|
|
@ -195,13 +202,13 @@ func TestImporterFindExistingID(t *testing.T) {
|
|||
}
|
||||
|
||||
errFindByNames := errors.New("FindByNames error")
|
||||
readerWriter.On("Query", testCtx, performerFilter(performerName), findFilter).Return(nil, 0, nil).Once()
|
||||
readerWriter.On("Query", testCtx, performerFilter(existingPerformerName), findFilter).Return([]*models.Performer{
|
||||
db.Performer.On("Query", testCtx, performerFilter(performerName), findFilter).Return(nil, 0, nil).Once()
|
||||
db.Performer.On("Query", testCtx, performerFilter(existingPerformerName), findFilter).Return([]*models.Performer{
|
||||
{
|
||||
ID: existingPerformerID,
|
||||
},
|
||||
}, 1, nil).Once()
|
||||
readerWriter.On("Query", testCtx, performerFilter(performerNameErr), findFilter).Return(nil, 0, errFindByNames).Once()
|
||||
db.Performer.On("Query", testCtx, performerFilter(performerNameErr), findFilter).Return(nil, 0, errFindByNames).Once()
|
||||
|
||||
id, err := i.FindExistingID(testCtx)
|
||||
assert.Nil(t, id)
|
||||
|
|
@ -217,11 +224,11 @@ func TestImporterFindExistingID(t *testing.T) {
|
|||
assert.Nil(t, id)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
readerWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
performer := models.Performer{
|
||||
Name: performerName,
|
||||
|
|
@ -232,16 +239,17 @@ func TestCreate(t *testing.T) {
|
|||
}
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Performer,
|
||||
TagWriter: db.Tag,
|
||||
performer: performer,
|
||||
}
|
||||
|
||||
errCreate := errors.New("Create error")
|
||||
readerWriter.On("Create", testCtx, &performer).Run(func(args mock.Arguments) {
|
||||
db.Performer.On("Create", testCtx, &performer).Run(func(args mock.Arguments) {
|
||||
arg := args.Get(1).(*models.Performer)
|
||||
arg.ID = performerID
|
||||
}).Return(nil).Once()
|
||||
readerWriter.On("Create", testCtx, &performerErr).Return(errCreate).Once()
|
||||
db.Performer.On("Create", testCtx, &performerErr).Return(errCreate).Once()
|
||||
|
||||
id, err := i.Create(testCtx)
|
||||
assert.Equal(t, performerID, *id)
|
||||
|
|
@ -252,11 +260,11 @@ func TestCreate(t *testing.T) {
|
|||
assert.Nil(t, id)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
readerWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
performer := models.Performer{
|
||||
Name: performerName,
|
||||
|
|
@ -267,7 +275,8 @@ func TestUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Performer,
|
||||
TagWriter: db.Tag,
|
||||
performer: performer,
|
||||
}
|
||||
|
||||
|
|
@ -275,7 +284,7 @@ func TestUpdate(t *testing.T) {
|
|||
|
||||
// id needs to be set for the mock input
|
||||
performer.ID = performerID
|
||||
readerWriter.On("Update", testCtx, &performer).Return(nil).Once()
|
||||
db.Performer.On("Update", testCtx, &performer).Return(nil).Once()
|
||||
|
||||
err := i.Update(testCtx, performerID)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -284,10 +293,10 @@ func TestUpdate(t *testing.T) {
|
|||
|
||||
// need to set id separately
|
||||
performerErr.ID = errImageID
|
||||
readerWriter.On("Update", testCtx, &performerErr).Return(errUpdate).Once()
|
||||
db.Performer.On("Update", testCtx, &performerErr).Return(errUpdate).Once()
|
||||
|
||||
err = i.Update(testCtx, errImageID)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -186,17 +186,17 @@ var scenarios = []basicTestScenario{
|
|||
}
|
||||
|
||||
func TestToJSON(t *testing.T) {
|
||||
mockSceneReader := &mocks.SceneReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
imageErr := errors.New("error getting image")
|
||||
|
||||
mockSceneReader.On("GetCover", testCtx, sceneID).Return(imageBytes, nil).Once()
|
||||
mockSceneReader.On("GetCover", testCtx, noImageID).Return(nil, nil).Once()
|
||||
mockSceneReader.On("GetCover", testCtx, errImageID).Return(nil, imageErr).Once()
|
||||
db.Scene.On("GetCover", testCtx, sceneID).Return(imageBytes, nil).Once()
|
||||
db.Scene.On("GetCover", testCtx, noImageID).Return(nil, nil).Once()
|
||||
db.Scene.On("GetCover", testCtx, errImageID).Return(nil, imageErr).Once()
|
||||
|
||||
for i, s := range scenarios {
|
||||
scene := s.input
|
||||
json, err := ToBasicJSON(testCtx, mockSceneReader, &scene)
|
||||
json, err := ToBasicJSON(testCtx, db.Scene, &scene)
|
||||
|
||||
switch {
|
||||
case !s.err && err != nil:
|
||||
|
|
@ -208,7 +208,7 @@ func TestToJSON(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
mockSceneReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func createStudioScene(studioID int) models.Scene {
|
||||
|
|
@ -242,19 +242,19 @@ var getStudioScenarios = []stringTestScenario{
|
|||
}
|
||||
|
||||
func TestGetStudioName(t *testing.T) {
|
||||
mockStudioReader := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
studioErr := errors.New("error getting image")
|
||||
|
||||
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
|
||||
db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{
|
||||
Name: studioName,
|
||||
}, nil).Once()
|
||||
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
|
||||
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
|
||||
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
|
||||
db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
|
||||
|
||||
for i, s := range getStudioScenarios {
|
||||
scene := s.input
|
||||
json, err := GetStudioName(testCtx, mockStudioReader, &scene)
|
||||
json, err := GetStudioName(testCtx, db.Studio, &scene)
|
||||
|
||||
switch {
|
||||
case !s.err && err != nil:
|
||||
|
|
@ -266,7 +266,7 @@ func TestGetStudioName(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
mockStudioReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
type stringSliceTestScenario struct {
|
||||
|
|
@ -305,17 +305,17 @@ func getTags(names []string) []*models.Tag {
|
|||
}
|
||||
|
||||
func TestGetTagNames(t *testing.T) {
|
||||
mockTagReader := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
tagErr := errors.New("error getting tag")
|
||||
|
||||
mockTagReader.On("FindBySceneID", testCtx, sceneID).Return(getTags(names), nil).Once()
|
||||
mockTagReader.On("FindBySceneID", testCtx, noTagsID).Return(nil, nil).Once()
|
||||
mockTagReader.On("FindBySceneID", testCtx, errTagsID).Return(nil, tagErr).Once()
|
||||
db.Tag.On("FindBySceneID", testCtx, sceneID).Return(getTags(names), nil).Once()
|
||||
db.Tag.On("FindBySceneID", testCtx, noTagsID).Return(nil, nil).Once()
|
||||
db.Tag.On("FindBySceneID", testCtx, errTagsID).Return(nil, tagErr).Once()
|
||||
|
||||
for i, s := range getTagNamesScenarios {
|
||||
scene := s.input
|
||||
json, err := GetTagNames(testCtx, mockTagReader, &scene)
|
||||
json, err := GetTagNames(testCtx, db.Tag, &scene)
|
||||
|
||||
switch {
|
||||
case !s.err && err != nil:
|
||||
|
|
@ -327,7 +327,7 @@ func TestGetTagNames(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
mockTagReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
type sceneMoviesTestScenario struct {
|
||||
|
|
@ -391,20 +391,21 @@ var getSceneMoviesJSONScenarios = []sceneMoviesTestScenario{
|
|||
}
|
||||
|
||||
func TestGetSceneMoviesJSON(t *testing.T) {
|
||||
mockMovieReader := &mocks.MovieReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
movieErr := errors.New("error getting movie")
|
||||
|
||||
mockMovieReader.On("Find", testCtx, validMovie1).Return(&models.Movie{
|
||||
db.Movie.On("Find", testCtx, validMovie1).Return(&models.Movie{
|
||||
Name: movie1Name,
|
||||
}, nil).Once()
|
||||
mockMovieReader.On("Find", testCtx, validMovie2).Return(&models.Movie{
|
||||
db.Movie.On("Find", testCtx, validMovie2).Return(&models.Movie{
|
||||
Name: movie2Name,
|
||||
}, nil).Once()
|
||||
mockMovieReader.On("Find", testCtx, invalidMovie).Return(nil, movieErr).Once()
|
||||
db.Movie.On("Find", testCtx, invalidMovie).Return(nil, movieErr).Once()
|
||||
|
||||
for i, s := range getSceneMoviesJSONScenarios {
|
||||
scene := s.input
|
||||
json, err := GetSceneMoviesJSON(testCtx, mockMovieReader, &scene)
|
||||
json, err := GetSceneMoviesJSON(testCtx, db.Movie, &scene)
|
||||
|
||||
switch {
|
||||
case !s.err && err != nil:
|
||||
|
|
@ -416,7 +417,7 @@ func TestGetSceneMoviesJSON(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
mockMovieReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
const (
|
||||
|
|
@ -542,27 +543,26 @@ var invalidMarkers2 = []*models.SceneMarker{
|
|||
}
|
||||
|
||||
func TestGetSceneMarkersJSON(t *testing.T) {
|
||||
mockTagReader := &mocks.TagReaderWriter{}
|
||||
mockMarkerReader := &mocks.SceneMarkerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
markersErr := errors.New("error getting scene markers")
|
||||
tagErr := errors.New("error getting tags")
|
||||
|
||||
mockMarkerReader.On("FindBySceneID", testCtx, sceneID).Return(validMarkers, nil).Once()
|
||||
mockMarkerReader.On("FindBySceneID", testCtx, noMarkersID).Return(nil, nil).Once()
|
||||
mockMarkerReader.On("FindBySceneID", testCtx, errMarkersID).Return(nil, markersErr).Once()
|
||||
mockMarkerReader.On("FindBySceneID", testCtx, errFindPrimaryTagID).Return(invalidMarkers1, nil).Once()
|
||||
mockMarkerReader.On("FindBySceneID", testCtx, errFindByMarkerID).Return(invalidMarkers2, nil).Once()
|
||||
db.SceneMarker.On("FindBySceneID", testCtx, sceneID).Return(validMarkers, nil).Once()
|
||||
db.SceneMarker.On("FindBySceneID", testCtx, noMarkersID).Return(nil, nil).Once()
|
||||
db.SceneMarker.On("FindBySceneID", testCtx, errMarkersID).Return(nil, markersErr).Once()
|
||||
db.SceneMarker.On("FindBySceneID", testCtx, errFindPrimaryTagID).Return(invalidMarkers1, nil).Once()
|
||||
db.SceneMarker.On("FindBySceneID", testCtx, errFindByMarkerID).Return(invalidMarkers2, nil).Once()
|
||||
|
||||
mockTagReader.On("Find", testCtx, validTagID1).Return(&models.Tag{
|
||||
db.Tag.On("Find", testCtx, validTagID1).Return(&models.Tag{
|
||||
Name: validTagName1,
|
||||
}, nil)
|
||||
mockTagReader.On("Find", testCtx, validTagID2).Return(&models.Tag{
|
||||
db.Tag.On("Find", testCtx, validTagID2).Return(&models.Tag{
|
||||
Name: validTagName2,
|
||||
}, nil)
|
||||
mockTagReader.On("Find", testCtx, invalidTagID).Return(nil, tagErr)
|
||||
db.Tag.On("Find", testCtx, invalidTagID).Return(nil, tagErr)
|
||||
|
||||
mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID1).Return([]*models.Tag{
|
||||
db.Tag.On("FindBySceneMarkerID", testCtx, validMarkerID1).Return([]*models.Tag{
|
||||
{
|
||||
Name: validTagName1,
|
||||
},
|
||||
|
|
@ -570,16 +570,16 @@ func TestGetSceneMarkersJSON(t *testing.T) {
|
|||
Name: validTagName2,
|
||||
},
|
||||
}, nil)
|
||||
mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID2).Return([]*models.Tag{
|
||||
db.Tag.On("FindBySceneMarkerID", testCtx, validMarkerID2).Return([]*models.Tag{
|
||||
{
|
||||
Name: validTagName2,
|
||||
},
|
||||
}, nil)
|
||||
mockTagReader.On("FindBySceneMarkerID", testCtx, invalidMarkerID2).Return(nil, tagErr).Once()
|
||||
db.Tag.On("FindBySceneMarkerID", testCtx, invalidMarkerID2).Return(nil, tagErr).Once()
|
||||
|
||||
for i, s := range getSceneMarkersJSONScenarios {
|
||||
scene := s.input
|
||||
json, err := GetSceneMarkersJSON(testCtx, mockMarkerReader, mockTagReader, &scene)
|
||||
json, err := GetSceneMarkersJSON(testCtx, db.SceneMarker, db.Tag, &scene)
|
||||
|
||||
switch {
|
||||
case !s.err && err != nil:
|
||||
|
|
@ -591,5 +591,5 @@ func TestGetSceneMarkersJSON(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
mockTagReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -410,17 +410,19 @@ type FilenameParser struct {
|
|||
ParserInput models.SceneParserInput
|
||||
Filter *models.FindFilterType
|
||||
whitespaceRE *regexp.Regexp
|
||||
repository FilenameParserRepository
|
||||
performerCache map[string]*models.Performer
|
||||
studioCache map[string]*models.Studio
|
||||
movieCache map[string]*models.Movie
|
||||
tagCache map[string]*models.Tag
|
||||
}
|
||||
|
||||
func NewFilenameParser(filter *models.FindFilterType, config models.SceneParserInput) *FilenameParser {
|
||||
func NewFilenameParser(filter *models.FindFilterType, config models.SceneParserInput, repo FilenameParserRepository) *FilenameParser {
|
||||
p := &FilenameParser{
|
||||
Pattern: *filter.Q,
|
||||
ParserInput: config,
|
||||
Filter: filter,
|
||||
repository: repo,
|
||||
}
|
||||
|
||||
p.performerCache = make(map[string]*models.Performer)
|
||||
|
|
@ -457,7 +459,17 @@ type FilenameParserRepository struct {
|
|||
Tag models.TagQueryer
|
||||
}
|
||||
|
||||
func (p *FilenameParser) Parse(ctx context.Context, repo FilenameParserRepository) ([]*models.SceneParserResult, int, error) {
|
||||
func NewFilenameParserRepository(repo models.Repository) FilenameParserRepository {
|
||||
return FilenameParserRepository{
|
||||
Scene: repo.Scene,
|
||||
Performer: repo.Performer,
|
||||
Studio: repo.Studio,
|
||||
Movie: repo.Movie,
|
||||
Tag: repo.Tag,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *FilenameParser) Parse(ctx context.Context) ([]*models.SceneParserResult, int, error) {
|
||||
// perform the query to find the scenes
|
||||
mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords)
|
||||
|
||||
|
|
@ -479,17 +491,17 @@ func (p *FilenameParser) Parse(ctx context.Context, repo FilenameParserRepositor
|
|||
|
||||
p.Filter.Q = nil
|
||||
|
||||
scenes, total, err := QueryWithCount(ctx, repo.Scene, sceneFilter, p.Filter)
|
||||
scenes, total, err := QueryWithCount(ctx, p.repository.Scene, sceneFilter, p.Filter)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
ret := p.parseScenes(ctx, repo, scenes, mapper)
|
||||
ret := p.parseScenes(ctx, scenes, mapper)
|
||||
|
||||
return ret, total, nil
|
||||
}
|
||||
|
||||
func (p *FilenameParser) parseScenes(ctx context.Context, repo FilenameParserRepository, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult {
|
||||
func (p *FilenameParser) parseScenes(ctx context.Context, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult {
|
||||
var ret []*models.SceneParserResult
|
||||
for _, scene := range scenes {
|
||||
sceneHolder := mapper.parse(scene)
|
||||
|
|
@ -498,7 +510,7 @@ func (p *FilenameParser) parseScenes(ctx context.Context, repo FilenameParserRep
|
|||
r := &models.SceneParserResult{
|
||||
Scene: scene,
|
||||
}
|
||||
p.setParserResult(ctx, repo, *sceneHolder, r)
|
||||
p.setParserResult(ctx, *sceneHolder, r)
|
||||
|
||||
ret = append(ret, r)
|
||||
}
|
||||
|
|
@ -671,7 +683,7 @@ func (p *FilenameParser) setMovies(ctx context.Context, qb MovieNameFinder, h sc
|
|||
}
|
||||
}
|
||||
|
||||
func (p *FilenameParser) setParserResult(ctx context.Context, repo FilenameParserRepository, h sceneHolder, result *models.SceneParserResult) {
|
||||
func (p *FilenameParser) setParserResult(ctx context.Context, h sceneHolder, result *models.SceneParserResult) {
|
||||
if h.result.Title != "" {
|
||||
title := h.result.Title
|
||||
title = p.replaceWhitespaceCharacters(title)
|
||||
|
|
@ -692,15 +704,17 @@ func (p *FilenameParser) setParserResult(ctx context.Context, repo FilenameParse
|
|||
result.Rating = h.result.Rating
|
||||
}
|
||||
|
||||
r := p.repository
|
||||
|
||||
if len(h.performers) > 0 {
|
||||
p.setPerformers(ctx, repo.Performer, h, result)
|
||||
p.setPerformers(ctx, r.Performer, h, result)
|
||||
}
|
||||
if len(h.tags) > 0 {
|
||||
p.setTags(ctx, repo.Tag, h, result)
|
||||
p.setTags(ctx, r.Tag, h, result)
|
||||
}
|
||||
p.setStudio(ctx, repo.Studio, h, result)
|
||||
p.setStudio(ctx, r.Studio, h, result)
|
||||
|
||||
if len(h.movies) > 0 {
|
||||
p.setMovies(ctx, repo.Movie, h, result)
|
||||
p.setMovies(ctx, r.Movie, h, result)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,20 +56,19 @@ func TestImporterPreImport(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestImporterPreImportWithStudio(t *testing.T) {
|
||||
studioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
testCtx := context.Background()
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
StudioWriter: studioReaderWriter,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Scene{
|
||||
Studio: existingStudioName,
|
||||
},
|
||||
}
|
||||
|
||||
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
|
||||
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
|
||||
ID: existingStudioID,
|
||||
}, nil).Once()
|
||||
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
|
||||
db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -79,22 +78,22 @@ func TestImporterPreImportWithStudio(t *testing.T) {
|
|||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
studioReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
||||
studioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
StudioWriter: studioReaderWriter,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Scene{
|
||||
Studio: missingStudioName,
|
||||
},
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
|
||||
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = existingStudioID
|
||||
}).Return(nil)
|
||||
|
|
@ -111,32 +110,34 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, existingStudioID, *i.scene.StudioID)
|
||||
|
||||
studioReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
|
||||
studioReaderWriter := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
StudioWriter: studioReaderWriter,
|
||||
StudioWriter: db.Studio,
|
||||
Input: jsonschema.Scene{
|
||||
Studio: missingStudioName,
|
||||
},
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
|
||||
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithPerformer(t *testing.T) {
|
||||
performerReaderWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
PerformerWriter: performerReaderWriter,
|
||||
PerformerWriter: db.Performer,
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
Input: jsonschema.Scene{
|
||||
Performers: []string{
|
||||
|
|
@ -145,13 +146,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
|
||||
db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
|
||||
{
|
||||
ID: existingPerformerID,
|
||||
Name: existingPerformerName,
|
||||
},
|
||||
}, nil).Once()
|
||||
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -161,14 +162,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
|
|||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
performerReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingPerformer(t *testing.T) {
|
||||
performerReaderWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
PerformerWriter: performerReaderWriter,
|
||||
PerformerWriter: db.Performer,
|
||||
Input: jsonschema.Scene{
|
||||
Performers: []string{
|
||||
missingPerformerName,
|
||||
|
|
@ -177,8 +178,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
|
||||
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
|
||||
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
|
||||
p := args.Get(1).(*models.Performer)
|
||||
p.ID = existingPerformerID
|
||||
}).Return(nil)
|
||||
|
|
@ -195,14 +196,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, []int{existingPerformerID}, i.scene.PerformerIDs.List())
|
||||
|
||||
performerReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
|
||||
performerReaderWriter := &mocks.PerformerReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
PerformerWriter: performerReaderWriter,
|
||||
PerformerWriter: db.Performer,
|
||||
Input: jsonschema.Scene{
|
||||
Performers: []string{
|
||||
missingPerformerName,
|
||||
|
|
@ -211,19 +212,20 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
|
||||
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
|
||||
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMovie(t *testing.T) {
|
||||
movieReaderWriter := &mocks.MovieReaderWriter{}
|
||||
testCtx := context.Background()
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
MovieWriter: movieReaderWriter,
|
||||
MovieWriter: db.Movie,
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
Input: jsonschema.Scene{
|
||||
Movies: []jsonschema.SceneMovie{
|
||||
|
|
@ -235,11 +237,11 @@ func TestImporterPreImportWithMovie(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
movieReaderWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
|
||||
db.Movie.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
|
||||
ID: existingMovieID,
|
||||
Name: existingMovieName,
|
||||
}, nil).Once()
|
||||
movieReaderWriter.On("FindByName", testCtx, existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once()
|
||||
db.Movie.On("FindByName", testCtx, existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once()
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -249,15 +251,14 @@ func TestImporterPreImportWithMovie(t *testing.T) {
|
|||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
movieReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingMovie(t *testing.T) {
|
||||
movieReaderWriter := &mocks.MovieReaderWriter{}
|
||||
testCtx := context.Background()
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
MovieWriter: movieReaderWriter,
|
||||
MovieWriter: db.Movie,
|
||||
Input: jsonschema.Scene{
|
||||
Movies: []jsonschema.SceneMovie{
|
||||
{
|
||||
|
|
@ -268,8 +269,8 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Times(3)
|
||||
movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Run(func(args mock.Arguments) {
|
||||
db.Movie.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Times(3)
|
||||
db.Movie.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Run(func(args mock.Arguments) {
|
||||
m := args.Get(1).(*models.Movie)
|
||||
m.ID = existingMovieID
|
||||
}).Return(nil)
|
||||
|
|
@ -286,14 +287,14 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, existingMovieID, i.scene.Movies.List()[0].MovieID)
|
||||
|
||||
movieReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) {
|
||||
movieReaderWriter := &mocks.MovieReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
MovieWriter: movieReaderWriter,
|
||||
MovieWriter: db.Movie,
|
||||
Input: jsonschema.Scene{
|
||||
Movies: []jsonschema.SceneMovie{
|
||||
{
|
||||
|
|
@ -304,18 +305,20 @@ func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Once()
|
||||
movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Return(errors.New("Create error"))
|
||||
db.Movie.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Once()
|
||||
db.Movie.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithTag(t *testing.T) {
|
||||
tagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
TagWriter: tagReaderWriter,
|
||||
TagWriter: db.Tag,
|
||||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
Input: jsonschema.Scene{
|
||||
Tags: []string{
|
||||
|
|
@ -324,13 +327,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
|
||||
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
|
||||
{
|
||||
ID: existingTagID,
|
||||
Name: existingTagName,
|
||||
},
|
||||
}, nil).Once()
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -340,14 +343,14 @@ func TestImporterPreImportWithTag(t *testing.T) {
|
|||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
tagReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingTag(t *testing.T) {
|
||||
tagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
TagWriter: tagReaderWriter,
|
||||
TagWriter: db.Tag,
|
||||
Input: jsonschema.Scene{
|
||||
Tags: []string{
|
||||
missingTagName,
|
||||
|
|
@ -356,8 +359,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = existingTagID
|
||||
}).Return(nil)
|
||||
|
|
@ -374,14 +377,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, []int{existingTagID}, i.scene.TagIDs.List())
|
||||
|
||||
tagReaderWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
||||
tagReaderWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
TagWriter: tagReaderWriter,
|
||||
TagWriter: db.Tag,
|
||||
Input: jsonschema.Scene{
|
||||
Tags: []string{
|
||||
missingTagName,
|
||||
|
|
@ -390,9 +393,11 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package scene
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
|
@ -105,8 +104,6 @@ func TestUpdater_Update(t *testing.T) {
|
|||
tagID
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
performerIDs := []int{performerID}
|
||||
tagIDs := []int{tagID}
|
||||
stashID := "stashID"
|
||||
|
|
@ -119,14 +116,15 @@ func TestUpdater_Update(t *testing.T) {
|
|||
|
||||
updateErr := errors.New("error updating")
|
||||
|
||||
qb := mocks.SceneReaderWriter{}
|
||||
qb.On("UpdatePartial", ctx, mock.MatchedBy(func(id int) bool {
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
db.Scene.On("UpdatePartial", testCtx, mock.MatchedBy(func(id int) bool {
|
||||
return id != badUpdateID
|
||||
}), mock.Anything).Return(validScene, nil)
|
||||
qb.On("UpdatePartial", ctx, badUpdateID, mock.Anything).Return(nil, updateErr)
|
||||
db.Scene.On("UpdatePartial", testCtx, badUpdateID, mock.Anything).Return(nil, updateErr)
|
||||
|
||||
qb.On("UpdateCover", ctx, sceneID, cover).Return(nil).Once()
|
||||
qb.On("UpdateCover", ctx, badCoverID, cover).Return(updateErr).Once()
|
||||
db.Scene.On("UpdateCover", testCtx, sceneID, cover).Return(nil).Once()
|
||||
db.Scene.On("UpdateCover", testCtx, badCoverID, cover).Return(updateErr).Once()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
@ -204,7 +202,7 @@ func TestUpdater_Update(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.u.Update(ctx, &qb)
|
||||
got, err := tt.u.Update(testCtx, db.Scene)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Updater.Update() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
|
@ -215,7 +213,7 @@ func TestUpdater_Update(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
qb.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUpdateSet_UpdateInput(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ const (
|
|||
)
|
||||
|
||||
type autotagScraper struct {
|
||||
// repository models.Repository
|
||||
txnManager txn.Manager
|
||||
performerReader models.PerformerAutoTagQueryer
|
||||
studioReader models.StudioAutoTagQueryer
|
||||
|
|
@ -208,9 +207,9 @@ func (s autotagScraper) spec() Scraper {
|
|||
}
|
||||
}
|
||||
|
||||
func getAutoTagScraper(txnManager txn.Manager, repo Repository, globalConfig GlobalConfig) scraper {
|
||||
func getAutoTagScraper(repo Repository, globalConfig GlobalConfig) scraper {
|
||||
base := autotagScraper{
|
||||
txnManager: txnManager,
|
||||
txnManager: repo.TxnManager,
|
||||
performerReader: repo.PerformerFinder,
|
||||
studioReader: repo.StudioFinder,
|
||||
tagReader: repo.TagFinder,
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ type GalleryFinder interface {
|
|||
}
|
||||
|
||||
type Repository struct {
|
||||
TxnManager models.TxnManager
|
||||
|
||||
SceneFinder SceneFinder
|
||||
GalleryFinder GalleryFinder
|
||||
TagFinder TagFinder
|
||||
|
|
@ -85,12 +87,27 @@ type Repository struct {
|
|||
StudioFinder StudioFinder
|
||||
}
|
||||
|
||||
func NewRepository(repo models.Repository) Repository {
|
||||
return Repository{
|
||||
TxnManager: repo.TxnManager,
|
||||
SceneFinder: repo.Scene,
|
||||
GalleryFinder: repo.Gallery,
|
||||
TagFinder: repo.Tag,
|
||||
PerformerFinder: repo.Performer,
|
||||
MovieFinder: repo.Movie,
|
||||
StudioFinder: repo.Studio,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
|
||||
return txn.WithReadTxn(ctx, r.TxnManager, fn)
|
||||
}
|
||||
|
||||
// Cache stores the database of scrapers
|
||||
type Cache struct {
|
||||
client *http.Client
|
||||
scrapers map[string]scraper // Scraper ID -> Scraper
|
||||
globalConfig GlobalConfig
|
||||
txnManager txn.Manager
|
||||
|
||||
repository Repository
|
||||
}
|
||||
|
|
@ -122,14 +139,13 @@ func newClient(gc GlobalConfig) *http.Client {
|
|||
//
|
||||
// Scraper configurations are loaded from yml files in the provided scrapers
|
||||
// directory and any subdirectories.
|
||||
func NewCache(globalConfig GlobalConfig, txnManager txn.Manager, repo Repository) (*Cache, error) {
|
||||
func NewCache(globalConfig GlobalConfig, repo Repository) (*Cache, error) {
|
||||
// HTTP Client setup
|
||||
client := newClient(globalConfig)
|
||||
|
||||
ret := &Cache{
|
||||
client: client,
|
||||
globalConfig: globalConfig,
|
||||
txnManager: txnManager,
|
||||
repository: repo,
|
||||
}
|
||||
|
||||
|
|
@ -148,7 +164,7 @@ func (c *Cache) loadScrapers() (map[string]scraper, error) {
|
|||
|
||||
// Add built-in scrapers
|
||||
freeOnes := getFreeonesScraper(c.globalConfig)
|
||||
autoTag := getAutoTagScraper(c.txnManager, c.repository, c.globalConfig)
|
||||
autoTag := getAutoTagScraper(c.repository, c.globalConfig)
|
||||
scrapers[freeOnes.spec().ID] = freeOnes
|
||||
scrapers[autoTag.spec().ID] = autoTag
|
||||
|
||||
|
|
@ -369,9 +385,12 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty Scrape
|
|||
|
||||
func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error) {
|
||||
var ret *models.Scene
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.SceneFinder
|
||||
|
||||
var err error
|
||||
ret, err = c.repository.SceneFinder.Find(ctx, sceneID)
|
||||
ret, err = qb.Find(ctx, sceneID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -380,7 +399,7 @@ func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error)
|
|||
return fmt.Errorf("scene with id %d not found", sceneID)
|
||||
}
|
||||
|
||||
return ret.LoadURLs(ctx, c.repository.SceneFinder)
|
||||
return ret.LoadURLs(ctx, qb)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -389,9 +408,12 @@ func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error)
|
|||
|
||||
func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery, error) {
|
||||
var ret *models.Gallery
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.GalleryFinder
|
||||
|
||||
var err error
|
||||
ret, err = c.repository.GalleryFinder.Find(ctx, galleryID)
|
||||
ret, err = qb.Find(ctx, galleryID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -400,12 +422,12 @@ func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery,
|
|||
return fmt.Errorf("gallery with id %d not found", galleryID)
|
||||
}
|
||||
|
||||
err = ret.LoadFiles(ctx, c.repository.GalleryFinder)
|
||||
err = ret.LoadFiles(ctx, qb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ret.LoadURLs(ctx, c.repository.GalleryFinder)
|
||||
return ret.LoadURLs(ctx, qb)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import (
|
|||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/match"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
)
|
||||
|
||||
// postScrape handles post-processing of scraped content. If the content
|
||||
|
|
@ -46,8 +45,9 @@ func (c Cache) postScrape(ctx context.Context, content ScrapedContent) (ScrapedC
|
|||
}
|
||||
|
||||
func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerformer) (ScrapedContent, error) {
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
tqb := c.repository.TagFinder
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
tqb := r.TagFinder
|
||||
|
||||
tags, err := postProcessTags(ctx, tqb, p.Tags)
|
||||
if err != nil {
|
||||
|
|
@ -72,8 +72,9 @@ func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerforme
|
|||
|
||||
func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (ScrapedContent, error) {
|
||||
if m.Studio != nil {
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
return match.ScrapedStudio(ctx, c.repository.StudioFinder, m.Studio, nil)
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
return match.ScrapedStudio(ctx, r.StudioFinder, m.Studio, nil)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -113,11 +114,12 @@ func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (Scraped
|
|||
scene.URLs = []string{*scene.URL}
|
||||
}
|
||||
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
pqb := c.repository.PerformerFinder
|
||||
mqb := c.repository.MovieFinder
|
||||
tqb := c.repository.TagFinder
|
||||
sqb := c.repository.StudioFinder
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
pqb := r.PerformerFinder
|
||||
mqb := r.MovieFinder
|
||||
tqb := r.TagFinder
|
||||
sqb := r.StudioFinder
|
||||
|
||||
for _, p := range scene.Performers {
|
||||
if p == nil {
|
||||
|
|
@ -175,10 +177,11 @@ func (c Cache) postScrapeGallery(ctx context.Context, g ScrapedGallery) (Scraped
|
|||
g.URLs = []string{*g.URL}
|
||||
}
|
||||
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
pqb := c.repository.PerformerFinder
|
||||
tqb := c.repository.TagFinder
|
||||
sqb := c.repository.StudioFinder
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
pqb := r.PerformerFinder
|
||||
tqb := r.TagFinder
|
||||
sqb := r.StudioFinder
|
||||
|
||||
for _, p := range g.Performers {
|
||||
err := match.ScrapedPerformer(ctx, pqb, p, nil)
|
||||
|
|
|
|||
|
|
@ -56,22 +56,37 @@ type TagFinder interface {
|
|||
}
|
||||
|
||||
type Repository struct {
|
||||
TxnManager models.TxnManager
|
||||
|
||||
Scene SceneReader
|
||||
Performer PerformerReader
|
||||
Tag TagFinder
|
||||
Studio StudioReader
|
||||
}
|
||||
|
||||
func NewRepository(repo models.Repository) Repository {
|
||||
return Repository{
|
||||
TxnManager: repo.TxnManager,
|
||||
Scene: repo.Scene,
|
||||
Performer: repo.Performer,
|
||||
Tag: repo.Tag,
|
||||
Studio: repo.Studio,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
|
||||
return txn.WithReadTxn(ctx, r.TxnManager, fn)
|
||||
}
|
||||
|
||||
// Client represents the client interface to a stash-box server instance.
|
||||
type Client struct {
|
||||
client *graphql.Client
|
||||
txnManager txn.Manager
|
||||
repository Repository
|
||||
box models.StashBox
|
||||
}
|
||||
|
||||
// NewClient returns a new instance of a stash-box client.
|
||||
func NewClient(box models.StashBox, txnManager txn.Manager, repo Repository) *Client {
|
||||
func NewClient(box models.StashBox, repo Repository) *Client {
|
||||
authHeader := func(req *http.Request) {
|
||||
req.Header.Set("ApiKey", box.APIKey)
|
||||
}
|
||||
|
|
@ -82,7 +97,6 @@ func NewClient(box models.StashBox, txnManager txn.Manager, repo Repository) *Cl
|
|||
|
||||
return &Client{
|
||||
client: client,
|
||||
txnManager: txnManager,
|
||||
repository: repo,
|
||||
box: box,
|
||||
}
|
||||
|
|
@ -129,8 +143,9 @@ func (c Client) FindStashBoxSceneByFingerprints(ctx context.Context, sceneID int
|
|||
func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int) ([][]*scraper.ScrapedScene, error) {
|
||||
var fingerprints [][]*graphql.FingerprintQueryInput
|
||||
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
qb := c.repository.Scene
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Scene
|
||||
|
||||
for _, sceneID := range ids {
|
||||
scene, err := qb.Find(ctx, sceneID)
|
||||
|
|
@ -142,7 +157,7 @@ func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int)
|
|||
return fmt.Errorf("scene with id %d not found", sceneID)
|
||||
}
|
||||
|
||||
if err := scene.LoadFiles(ctx, c.repository.Scene); err != nil {
|
||||
if err := scene.LoadFiles(ctx, r.Scene); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -243,8 +258,9 @@ func (c Client) SubmitStashBoxFingerprints(ctx context.Context, sceneIDs []strin
|
|||
|
||||
var fingerprints []graphql.FingerprintSubmission
|
||||
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
qb := c.repository.Scene
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Scene
|
||||
|
||||
for _, sceneID := range ids {
|
||||
scene, err := qb.Find(ctx, sceneID)
|
||||
|
|
@ -382,9 +398,9 @@ func (c Client) FindStashBoxPerformersByNames(ctx context.Context, performerIDs
|
|||
}
|
||||
|
||||
var performers []*models.Performer
|
||||
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
qb := c.repository.Performer
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Performer
|
||||
|
||||
for _, performerID := range ids {
|
||||
performer, err := qb.Find(ctx, performerID)
|
||||
|
|
@ -417,8 +433,9 @@ func (c Client) FindStashBoxPerformersByPerformerNames(ctx context.Context, perf
|
|||
|
||||
var performers []*models.Performer
|
||||
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
qb := c.repository.Performer
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.Performer
|
||||
|
||||
for _, performerID := range ids {
|
||||
performer, err := qb.Find(ctx, performerID)
|
||||
|
|
@ -739,14 +756,15 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
|
|||
ss.URL = &s.Urls[0].URL
|
||||
}
|
||||
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
pqb := c.repository.Performer
|
||||
tqb := c.repository.Tag
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
pqb := r.Performer
|
||||
tqb := r.Tag
|
||||
|
||||
if s.Studio != nil {
|
||||
ss.Studio = studioFragmentToScrapedStudio(*s.Studio)
|
||||
|
||||
err := match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio, &c.box.Endpoint)
|
||||
err := match.ScrapedStudio(ctx, r.Studio, ss.Studio, &c.box.Endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -761,7 +779,7 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
|
|||
if parentStudio.FindStudio != nil {
|
||||
ss.Studio.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
|
||||
|
||||
err = match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio.Parent, &c.box.Endpoint)
|
||||
err = match.ScrapedStudio(ctx, r.Studio, ss.Studio.Parent, &c.box.Endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -809,8 +827,9 @@ func (c Client) FindStashBoxPerformerByID(ctx context.Context, id string) (*mode
|
|||
|
||||
ret := performerFragmentToScrapedPerformer(*performer.FindPerformer)
|
||||
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
err := match.ScrapedPerformer(ctx, c.repository.Performer, ret, &c.box.Endpoint)
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
err := match.ScrapedPerformer(ctx, r.Performer, ret, &c.box.Endpoint)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -836,8 +855,9 @@ func (c Client) FindStashBoxPerformerByName(ctx context.Context, name string) (*
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
err := match.ScrapedPerformer(ctx, c.repository.Performer, ret, &c.box.Endpoint)
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
err := match.ScrapedPerformer(ctx, r.Performer, ret, &c.box.Endpoint)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -864,10 +884,11 @@ func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.S
|
|||
|
||||
var ret *models.ScrapedStudio
|
||||
if studio.FindStudio != nil {
|
||||
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
|
||||
r := c.repository
|
||||
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
|
||||
ret = studioFragmentToScrapedStudio(*studio.FindStudio)
|
||||
|
||||
err = match.ScrapedStudio(ctx, c.repository.Studio, ret, &c.box.Endpoint)
|
||||
err = match.ScrapedStudio(ctx, r.Studio, ret, &c.box.Endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -881,7 +902,7 @@ func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.S
|
|||
if parentStudio.FindStudio != nil {
|
||||
ret.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
|
||||
|
||||
err = match.ScrapedStudio(ctx, c.repository.Studio, ret.Parent, &c.box.Endpoint)
|
||||
err = match.ScrapedStudio(ctx, r.Studio, ret.Parent, &c.box.Endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1176,7 +1176,7 @@ func makeImage(i int) *models.Image {
|
|||
}
|
||||
|
||||
func createImages(ctx context.Context, n int) error {
|
||||
qb := db.TxnRepository().Image
|
||||
qb := db.Image
|
||||
fqb := db.File
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
|
|
@ -1273,7 +1273,7 @@ func makeGallery(i int, includeScenes bool) *models.Gallery {
|
|||
}
|
||||
|
||||
func createGalleries(ctx context.Context, n int) error {
|
||||
gqb := db.TxnRepository().Gallery
|
||||
gqb := db.Gallery
|
||||
fqb := db.File
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ func (db *Database) IsLocked(err error) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (db *Database) TxnRepository() models.Repository {
|
||||
func (db *Database) Repository() models.Repository {
|
||||
return models.Repository{
|
||||
TxnManager: db,
|
||||
File: db.File,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package studio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
|
|
@ -162,27 +161,26 @@ func initTestTable() {
|
|||
|
||||
func TestToJSON(t *testing.T) {
|
||||
initTestTable()
|
||||
ctx := context.Background()
|
||||
|
||||
mockStudioReader := &mocks.StudioReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
imageErr := errors.New("error getting image")
|
||||
|
||||
mockStudioReader.On("GetImage", ctx, studioID).Return(imageBytes, nil).Once()
|
||||
mockStudioReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once()
|
||||
mockStudioReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once()
|
||||
mockStudioReader.On("GetImage", ctx, missingParentStudioID).Return(imageBytes, nil).Maybe()
|
||||
mockStudioReader.On("GetImage", ctx, errStudioID).Return(imageBytes, nil).Maybe()
|
||||
db.Studio.On("GetImage", testCtx, studioID).Return(imageBytes, nil).Once()
|
||||
db.Studio.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
|
||||
db.Studio.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
|
||||
db.Studio.On("GetImage", testCtx, missingParentStudioID).Return(imageBytes, nil).Maybe()
|
||||
db.Studio.On("GetImage", testCtx, errStudioID).Return(imageBytes, nil).Maybe()
|
||||
|
||||
parentStudioErr := errors.New("error getting parent studio")
|
||||
|
||||
mockStudioReader.On("Find", ctx, parentStudioID).Return(&parentStudio, nil)
|
||||
mockStudioReader.On("Find", ctx, missingStudioID).Return(nil, nil)
|
||||
mockStudioReader.On("Find", ctx, errParentStudioID).Return(nil, parentStudioErr)
|
||||
db.Studio.On("Find", testCtx, parentStudioID).Return(&parentStudio, nil)
|
||||
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil)
|
||||
db.Studio.On("Find", testCtx, errParentStudioID).Return(nil, parentStudioErr)
|
||||
|
||||
for i, s := range scenarios {
|
||||
studio := s.input
|
||||
json, err := ToJSON(ctx, mockStudioReader, &studio)
|
||||
json, err := ToJSON(testCtx, db.Studio, &studio)
|
||||
|
||||
switch {
|
||||
case !s.err && err != nil:
|
||||
|
|
@ -194,5 +192,5 @@ func TestToJSON(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
mockStudioReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ const (
|
|||
missingParentStudioName = "existingParentStudioName"
|
||||
)
|
||||
|
||||
var testCtx = context.Background()
|
||||
|
||||
func TestImporterName(t *testing.T) {
|
||||
i := Importer{
|
||||
Input: jsonschema.Studio{
|
||||
|
|
@ -43,22 +45,21 @@ func TestImporterPreImport(t *testing.T) {
|
|||
IgnoreAutoTag: autoTagIgnored,
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
err := i.PreImport(ctx)
|
||||
err := i.PreImport(testCtx)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
|
||||
i.Input.Image = image
|
||||
|
||||
err = i.PreImport(ctx)
|
||||
err = i.PreImport(testCtx)
|
||||
|
||||
assert.Nil(t, err)
|
||||
|
||||
i.Input = *createFullJSONStudio(studioName, image, []string{"alias"})
|
||||
i.Input.ParentStudio = ""
|
||||
|
||||
err = i.PreImport(ctx)
|
||||
err = i.PreImport(testCtx)
|
||||
|
||||
assert.Nil(t, err)
|
||||
expectedStudio := createFullStudio(0, 0)
|
||||
|
|
@ -67,11 +68,10 @@ func TestImporterPreImport(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestImporterPreImportWithParent(t *testing.T) {
|
||||
readerWriter := &mocks.StudioReaderWriter{}
|
||||
ctx := context.Background()
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Studio,
|
||||
Input: jsonschema.Studio{
|
||||
Name: studioName,
|
||||
Image: image,
|
||||
|
|
@ -79,28 +79,27 @@ func TestImporterPreImportWithParent(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
readerWriter.On("FindByName", ctx, existingParentStudioName, false).Return(&models.Studio{
|
||||
db.Studio.On("FindByName", testCtx, existingParentStudioName, false).Return(&models.Studio{
|
||||
ID: existingStudioID,
|
||||
}, nil).Once()
|
||||
readerWriter.On("FindByName", ctx, existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
|
||||
db.Studio.On("FindByName", testCtx, existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
|
||||
|
||||
err := i.PreImport(ctx)
|
||||
err := i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, existingStudioID, *i.studio.ParentID)
|
||||
|
||||
i.Input.ParentStudio = existingParentStudioErr
|
||||
err = i.PreImport(ctx)
|
||||
err = i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingParent(t *testing.T) {
|
||||
readerWriter := &mocks.StudioReaderWriter{}
|
||||
ctx := context.Background()
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Studio,
|
||||
Input: jsonschema.Studio{
|
||||
Name: studioName,
|
||||
Image: image,
|
||||
|
|
@ -109,33 +108,32 @@ func TestImporterPreImportWithMissingParent(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumFail,
|
||||
}
|
||||
|
||||
readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Times(3)
|
||||
readerWriter.On("Create", ctx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Times(3)
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = existingStudioID
|
||||
}).Return(nil)
|
||||
|
||||
err := i.PreImport(ctx)
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
|
||||
err = i.PreImport(ctx)
|
||||
err = i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
|
||||
err = i.PreImport(ctx)
|
||||
err = i.PreImport(testCtx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, existingStudioID, *i.studio.ParentID)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) {
|
||||
readerWriter := &mocks.StudioReaderWriter{}
|
||||
ctx := context.Background()
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Studio,
|
||||
Input: jsonschema.Studio{
|
||||
Name: studioName,
|
||||
Image: image,
|
||||
|
|
@ -144,19 +142,20 @@ func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) {
|
|||
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
|
||||
}
|
||||
|
||||
readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Once()
|
||||
readerWriter.On("Create", ctx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Once()
|
||||
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(ctx)
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPostImport(t *testing.T) {
|
||||
readerWriter := &mocks.StudioReaderWriter{}
|
||||
ctx := context.Background()
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Studio,
|
||||
Input: jsonschema.Studio{
|
||||
Aliases: []string{"alias"},
|
||||
},
|
||||
|
|
@ -165,56 +164,54 @@ func TestImporterPostImport(t *testing.T) {
|
|||
|
||||
updateStudioImageErr := errors.New("UpdateImage error")
|
||||
|
||||
readerWriter.On("UpdateImage", ctx, studioID, imageBytes).Return(nil).Once()
|
||||
readerWriter.On("UpdateImage", ctx, errImageID, imageBytes).Return(updateStudioImageErr).Once()
|
||||
db.Studio.On("UpdateImage", testCtx, studioID, imageBytes).Return(nil).Once()
|
||||
db.Studio.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateStudioImageErr).Once()
|
||||
|
||||
err := i.PostImport(ctx, studioID)
|
||||
err := i.PostImport(testCtx, studioID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = i.PostImport(ctx, errImageID)
|
||||
err = i.PostImport(testCtx, errImageID)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterFindExistingID(t *testing.T) {
|
||||
readerWriter := &mocks.StudioReaderWriter{}
|
||||
ctx := context.Background()
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Studio,
|
||||
Input: jsonschema.Studio{
|
||||
Name: studioName,
|
||||
},
|
||||
}
|
||||
|
||||
errFindByName := errors.New("FindByName error")
|
||||
readerWriter.On("FindByName", ctx, studioName, false).Return(nil, nil).Once()
|
||||
readerWriter.On("FindByName", ctx, existingStudioName, false).Return(&models.Studio{
|
||||
db.Studio.On("FindByName", testCtx, studioName, false).Return(nil, nil).Once()
|
||||
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
|
||||
ID: existingStudioID,
|
||||
}, nil).Once()
|
||||
readerWriter.On("FindByName", ctx, studioNameErr, false).Return(nil, errFindByName).Once()
|
||||
db.Studio.On("FindByName", testCtx, studioNameErr, false).Return(nil, errFindByName).Once()
|
||||
|
||||
id, err := i.FindExistingID(ctx)
|
||||
id, err := i.FindExistingID(testCtx)
|
||||
assert.Nil(t, id)
|
||||
assert.Nil(t, err)
|
||||
|
||||
i.Input.Name = existingStudioName
|
||||
id, err = i.FindExistingID(ctx)
|
||||
id, err = i.FindExistingID(testCtx)
|
||||
assert.Equal(t, existingStudioID, *id)
|
||||
assert.Nil(t, err)
|
||||
|
||||
i.Input.Name = studioNameErr
|
||||
id, err = i.FindExistingID(ctx)
|
||||
id, err = i.FindExistingID(testCtx)
|
||||
assert.Nil(t, id)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
readerWriter := &mocks.StudioReaderWriter{}
|
||||
ctx := context.Background()
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
studio := models.Studio{
|
||||
Name: studioName,
|
||||
|
|
@ -225,32 +222,31 @@ func TestCreate(t *testing.T) {
|
|||
}
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Studio,
|
||||
studio: studio,
|
||||
}
|
||||
|
||||
errCreate := errors.New("Create error")
|
||||
readerWriter.On("Create", ctx, &studio).Run(func(args mock.Arguments) {
|
||||
db.Studio.On("Create", testCtx, &studio).Run(func(args mock.Arguments) {
|
||||
s := args.Get(1).(*models.Studio)
|
||||
s.ID = studioID
|
||||
}).Return(nil).Once()
|
||||
readerWriter.On("Create", ctx, &studioErr).Return(errCreate).Once()
|
||||
db.Studio.On("Create", testCtx, &studioErr).Return(errCreate).Once()
|
||||
|
||||
id, err := i.Create(ctx)
|
||||
id, err := i.Create(testCtx)
|
||||
assert.Equal(t, studioID, *id)
|
||||
assert.Nil(t, err)
|
||||
|
||||
i.studio = studioErr
|
||||
id, err = i.Create(ctx)
|
||||
id, err = i.Create(testCtx)
|
||||
assert.Nil(t, id)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
readerWriter := &mocks.StudioReaderWriter{}
|
||||
ctx := context.Background()
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
studio := models.Studio{
|
||||
Name: studioName,
|
||||
|
|
@ -261,7 +257,7 @@ func TestUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Studio,
|
||||
studio: studio,
|
||||
}
|
||||
|
||||
|
|
@ -269,19 +265,19 @@ func TestUpdate(t *testing.T) {
|
|||
|
||||
// id needs to be set for the mock input
|
||||
studio.ID = studioID
|
||||
readerWriter.On("Update", ctx, &studio).Return(nil).Once()
|
||||
db.Studio.On("Update", testCtx, &studio).Return(nil).Once()
|
||||
|
||||
err := i.Update(ctx, studioID)
|
||||
err := i.Update(testCtx, studioID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
i.studio = studioErr
|
||||
|
||||
// need to set id separately
|
||||
studioErr.ID = errImageID
|
||||
readerWriter.On("Update", ctx, &studioErr).Return(errUpdate).Once()
|
||||
db.Studio.On("Update", testCtx, &studioErr).Return(errUpdate).Once()
|
||||
|
||||
err = i.Update(ctx, errImageID)
|
||||
err = i.Update(testCtx, errImageID)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package tag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
|
|
@ -109,35 +108,34 @@ func initTestTable() {
|
|||
func TestToJSON(t *testing.T) {
|
||||
initTestTable()
|
||||
|
||||
ctx := context.Background()
|
||||
mockTagReader := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
imageErr := errors.New("error getting image")
|
||||
aliasErr := errors.New("error getting aliases")
|
||||
parentsErr := errors.New("error getting parents")
|
||||
|
||||
mockTagReader.On("GetAliases", ctx, tagID).Return([]string{"alias"}, nil).Once()
|
||||
mockTagReader.On("GetAliases", ctx, noImageID).Return(nil, nil).Once()
|
||||
mockTagReader.On("GetAliases", ctx, errImageID).Return(nil, nil).Once()
|
||||
mockTagReader.On("GetAliases", ctx, errAliasID).Return(nil, aliasErr).Once()
|
||||
mockTagReader.On("GetAliases", ctx, withParentsID).Return(nil, nil).Once()
|
||||
mockTagReader.On("GetAliases", ctx, errParentsID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, tagID).Return([]string{"alias"}, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, noImageID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, errImageID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, errAliasID).Return(nil, aliasErr).Once()
|
||||
db.Tag.On("GetAliases", testCtx, withParentsID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, errParentsID).Return(nil, nil).Once()
|
||||
|
||||
mockTagReader.On("GetImage", ctx, tagID).Return(imageBytes, nil).Once()
|
||||
mockTagReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once()
|
||||
mockTagReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once()
|
||||
mockTagReader.On("GetImage", ctx, withParentsID).Return(imageBytes, nil).Once()
|
||||
mockTagReader.On("GetImage", ctx, errParentsID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetImage", testCtx, tagID).Return(imageBytes, nil).Once()
|
||||
db.Tag.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
|
||||
db.Tag.On("GetImage", testCtx, withParentsID).Return(imageBytes, nil).Once()
|
||||
db.Tag.On("GetImage", testCtx, errParentsID).Return(nil, nil).Once()
|
||||
|
||||
mockTagReader.On("FindByChildTagID", ctx, tagID).Return(nil, nil).Once()
|
||||
mockTagReader.On("FindByChildTagID", ctx, noImageID).Return(nil, nil).Once()
|
||||
mockTagReader.On("FindByChildTagID", ctx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once()
|
||||
mockTagReader.On("FindByChildTagID", ctx, errParentsID).Return(nil, parentsErr).Once()
|
||||
mockTagReader.On("FindByChildTagID", ctx, errImageID).Return(nil, nil).Once()
|
||||
db.Tag.On("FindByChildTagID", testCtx, tagID).Return(nil, nil).Once()
|
||||
db.Tag.On("FindByChildTagID", testCtx, noImageID).Return(nil, nil).Once()
|
||||
db.Tag.On("FindByChildTagID", testCtx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once()
|
||||
db.Tag.On("FindByChildTagID", testCtx, errParentsID).Return(nil, parentsErr).Once()
|
||||
db.Tag.On("FindByChildTagID", testCtx, errImageID).Return(nil, nil).Once()
|
||||
|
||||
for i, s := range scenarios {
|
||||
tag := s.tag
|
||||
json, err := ToJSON(ctx, mockTagReader, &tag)
|
||||
json, err := ToJSON(testCtx, db.Tag, &tag)
|
||||
|
||||
switch {
|
||||
case !s.err && err != nil:
|
||||
|
|
@ -149,5 +147,5 @@ func TestToJSON(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
mockTagReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,10 +58,10 @@ func TestImporterPreImport(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestImporterPostImport(t *testing.T) {
|
||||
readerWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Tag,
|
||||
Input: jsonschema.Tag{
|
||||
Aliases: []string{"alias"},
|
||||
},
|
||||
|
|
@ -72,23 +72,23 @@ func TestImporterPostImport(t *testing.T) {
|
|||
updateTagAliasErr := errors.New("UpdateAlias error")
|
||||
updateTagParentsErr := errors.New("UpdateParentTags error")
|
||||
|
||||
readerWriter.On("UpdateAliases", testCtx, tagID, i.Input.Aliases).Return(nil).Once()
|
||||
readerWriter.On("UpdateAliases", testCtx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
|
||||
readerWriter.On("UpdateAliases", testCtx, withParentsID, i.Input.Aliases).Return(nil).Once()
|
||||
readerWriter.On("UpdateAliases", testCtx, errParentsID, i.Input.Aliases).Return(nil).Once()
|
||||
db.Tag.On("UpdateAliases", testCtx, tagID, i.Input.Aliases).Return(nil).Once()
|
||||
db.Tag.On("UpdateAliases", testCtx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
|
||||
db.Tag.On("UpdateAliases", testCtx, withParentsID, i.Input.Aliases).Return(nil).Once()
|
||||
db.Tag.On("UpdateAliases", testCtx, errParentsID, i.Input.Aliases).Return(nil).Once()
|
||||
|
||||
readerWriter.On("UpdateImage", testCtx, tagID, imageBytes).Return(nil).Once()
|
||||
readerWriter.On("UpdateImage", testCtx, errAliasID, imageBytes).Return(nil).Once()
|
||||
readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateTagImageErr).Once()
|
||||
readerWriter.On("UpdateImage", testCtx, withParentsID, imageBytes).Return(nil).Once()
|
||||
readerWriter.On("UpdateImage", testCtx, errParentsID, imageBytes).Return(nil).Once()
|
||||
db.Tag.On("UpdateImage", testCtx, tagID, imageBytes).Return(nil).Once()
|
||||
db.Tag.On("UpdateImage", testCtx, errAliasID, imageBytes).Return(nil).Once()
|
||||
db.Tag.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateTagImageErr).Once()
|
||||
db.Tag.On("UpdateImage", testCtx, withParentsID, imageBytes).Return(nil).Once()
|
||||
db.Tag.On("UpdateImage", testCtx, errParentsID, imageBytes).Return(nil).Once()
|
||||
|
||||
var parentTags []int
|
||||
readerWriter.On("UpdateParentTags", testCtx, tagID, parentTags).Return(nil).Once()
|
||||
readerWriter.On("UpdateParentTags", testCtx, withParentsID, []int{100}).Return(nil).Once()
|
||||
readerWriter.On("UpdateParentTags", testCtx, errParentsID, []int{100}).Return(updateTagParentsErr).Once()
|
||||
db.Tag.On("UpdateParentTags", testCtx, tagID, parentTags).Return(nil).Once()
|
||||
db.Tag.On("UpdateParentTags", testCtx, withParentsID, []int{100}).Return(nil).Once()
|
||||
db.Tag.On("UpdateParentTags", testCtx, errParentsID, []int{100}).Return(updateTagParentsErr).Once()
|
||||
|
||||
readerWriter.On("FindByName", testCtx, "Parent", false).Return(&models.Tag{ID: 100}, nil)
|
||||
db.Tag.On("FindByName", testCtx, "Parent", false).Return(&models.Tag{ID: 100}, nil)
|
||||
|
||||
err := i.PostImport(testCtx, tagID)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -106,14 +106,14 @@ func TestImporterPostImport(t *testing.T) {
|
|||
err = i.PostImport(testCtx, errParentsID)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterPostImportParentMissing(t *testing.T) {
|
||||
readerWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Tag,
|
||||
Input: jsonschema.Tag{},
|
||||
imageData: imageBytes,
|
||||
}
|
||||
|
|
@ -133,33 +133,33 @@ func TestImporterPostImportParentMissing(t *testing.T) {
|
|||
|
||||
var emptyParents []int
|
||||
|
||||
readerWriter.On("UpdateImage", testCtx, mock.Anything, mock.Anything).Return(nil)
|
||||
readerWriter.On("UpdateAliases", testCtx, mock.Anything, mock.Anything).Return(nil)
|
||||
db.Tag.On("UpdateImage", testCtx, mock.Anything, mock.Anything).Return(nil)
|
||||
db.Tag.On("UpdateAliases", testCtx, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
readerWriter.On("FindByName", testCtx, "Create", false).Return(nil, nil).Once()
|
||||
readerWriter.On("FindByName", testCtx, "CreateError", false).Return(nil, nil).Once()
|
||||
readerWriter.On("FindByName", testCtx, "CreateFindError", false).Return(nil, findError).Once()
|
||||
readerWriter.On("FindByName", testCtx, "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once()
|
||||
readerWriter.On("FindByName", testCtx, "Fail", false).Return(nil, nil).Once()
|
||||
readerWriter.On("FindByName", testCtx, "FailFindError", false).Return(nil, findError)
|
||||
readerWriter.On("FindByName", testCtx, "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once()
|
||||
readerWriter.On("FindByName", testCtx, "Ignore", false).Return(nil, nil).Once()
|
||||
readerWriter.On("FindByName", testCtx, "IgnoreFindError", false).Return(nil, findError)
|
||||
readerWriter.On("FindByName", testCtx, "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once()
|
||||
db.Tag.On("FindByName", testCtx, "Create", false).Return(nil, nil).Once()
|
||||
db.Tag.On("FindByName", testCtx, "CreateError", false).Return(nil, nil).Once()
|
||||
db.Tag.On("FindByName", testCtx, "CreateFindError", false).Return(nil, findError).Once()
|
||||
db.Tag.On("FindByName", testCtx, "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once()
|
||||
db.Tag.On("FindByName", testCtx, "Fail", false).Return(nil, nil).Once()
|
||||
db.Tag.On("FindByName", testCtx, "FailFindError", false).Return(nil, findError)
|
||||
db.Tag.On("FindByName", testCtx, "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once()
|
||||
db.Tag.On("FindByName", testCtx, "Ignore", false).Return(nil, nil).Once()
|
||||
db.Tag.On("FindByName", testCtx, "IgnoreFindError", false).Return(nil, findError)
|
||||
db.Tag.On("FindByName", testCtx, "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once()
|
||||
|
||||
readerWriter.On("UpdateParentTags", testCtx, createID, []int{100}).Return(nil).Once()
|
||||
readerWriter.On("UpdateParentTags", testCtx, createFoundID, []int{101}).Return(nil).Once()
|
||||
readerWriter.On("UpdateParentTags", testCtx, failFoundID, []int{102}).Return(nil).Once()
|
||||
readerWriter.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once()
|
||||
readerWriter.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once()
|
||||
db.Tag.On("UpdateParentTags", testCtx, createID, []int{100}).Return(nil).Once()
|
||||
db.Tag.On("UpdateParentTags", testCtx, createFoundID, []int{101}).Return(nil).Once()
|
||||
db.Tag.On("UpdateParentTags", testCtx, failFoundID, []int{102}).Return(nil).Once()
|
||||
db.Tag.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once()
|
||||
db.Tag.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once()
|
||||
|
||||
readerWriter.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
|
||||
return t.Name == "Create"
|
||||
})).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = 100
|
||||
}).Return(nil).Once()
|
||||
readerWriter.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
|
||||
return t.Name == "CreateError"
|
||||
})).Return(errors.New("failed creating parent")).Once()
|
||||
|
||||
|
|
@ -206,25 +206,25 @@ func TestImporterPostImportParentMissing(t *testing.T) {
|
|||
err = i.PostImport(testCtx, ignoreFoundID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestImporterFindExistingID(t *testing.T) {
|
||||
readerWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Tag,
|
||||
Input: jsonschema.Tag{
|
||||
Name: tagName,
|
||||
},
|
||||
}
|
||||
|
||||
errFindByName := errors.New("FindByName error")
|
||||
readerWriter.On("FindByName", testCtx, tagName, false).Return(nil, nil).Once()
|
||||
readerWriter.On("FindByName", testCtx, existingTagName, false).Return(&models.Tag{
|
||||
db.Tag.On("FindByName", testCtx, tagName, false).Return(nil, nil).Once()
|
||||
db.Tag.On("FindByName", testCtx, existingTagName, false).Return(&models.Tag{
|
||||
ID: existingTagID,
|
||||
}, nil).Once()
|
||||
readerWriter.On("FindByName", testCtx, tagNameErr, false).Return(nil, errFindByName).Once()
|
||||
db.Tag.On("FindByName", testCtx, tagNameErr, false).Return(nil, errFindByName).Once()
|
||||
|
||||
id, err := i.FindExistingID(testCtx)
|
||||
assert.Nil(t, id)
|
||||
|
|
@ -240,11 +240,11 @@ func TestImporterFindExistingID(t *testing.T) {
|
|||
assert.Nil(t, id)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
readerWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
tag := models.Tag{
|
||||
Name: tagName,
|
||||
|
|
@ -255,16 +255,16 @@ func TestCreate(t *testing.T) {
|
|||
}
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Tag,
|
||||
tag: tag,
|
||||
}
|
||||
|
||||
errCreate := errors.New("Create error")
|
||||
readerWriter.On("Create", testCtx, &tag).Run(func(args mock.Arguments) {
|
||||
db.Tag.On("Create", testCtx, &tag).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = tagID
|
||||
}).Return(nil).Once()
|
||||
readerWriter.On("Create", testCtx, &tagErr).Return(errCreate).Once()
|
||||
db.Tag.On("Create", testCtx, &tagErr).Return(errCreate).Once()
|
||||
|
||||
id, err := i.Create(testCtx)
|
||||
assert.Equal(t, tagID, *id)
|
||||
|
|
@ -275,11 +275,11 @@ func TestCreate(t *testing.T) {
|
|||
assert.Nil(t, id)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
readerWriter := &mocks.TagReaderWriter{}
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
tag := models.Tag{
|
||||
Name: tagName,
|
||||
|
|
@ -290,7 +290,7 @@ func TestUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: readerWriter,
|
||||
ReaderWriter: db.Tag,
|
||||
tag: tag,
|
||||
}
|
||||
|
||||
|
|
@ -298,7 +298,7 @@ func TestUpdate(t *testing.T) {
|
|||
|
||||
// id needs to be set for the mock input
|
||||
tag.ID = tagID
|
||||
readerWriter.On("Update", testCtx, &tag).Return(nil).Once()
|
||||
db.Tag.On("Update", testCtx, &tag).Return(nil).Once()
|
||||
|
||||
err := i.Update(testCtx, tagID)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -307,10 +307,10 @@ func TestUpdate(t *testing.T) {
|
|||
|
||||
// need to set id separately
|
||||
tagErr.ID = errImageID
|
||||
readerWriter.On("Update", testCtx, &tagErr).Return(errUpdate).Once()
|
||||
db.Tag.On("Update", testCtx, &tagErr).Return(errUpdate).Once()
|
||||
|
||||
err = i.Update(testCtx, errImageID)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
readerWriter.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -219,8 +219,7 @@ func TestEnsureHierarchy(t *testing.T) {
|
|||
}
|
||||
|
||||
func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, queryChildren bool) {
|
||||
mockTagReader := &mocks.TagReaderWriter{}
|
||||
ctx := context.Background()
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
var parentIDs, childIDs []int
|
||||
find := make(map[int]*models.Tag)
|
||||
|
|
@ -247,15 +246,15 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
|
|||
|
||||
if queryParents {
|
||||
parentIDs = nil
|
||||
mockTagReader.On("FindByChildTagID", ctx, tc.id).Return(tc.parents, nil).Once()
|
||||
db.Tag.On("FindByChildTagID", testCtx, tc.id).Return(tc.parents, nil).Once()
|
||||
}
|
||||
|
||||
if queryChildren {
|
||||
childIDs = nil
|
||||
mockTagReader.On("FindByParentTagID", ctx, tc.id).Return(tc.children, nil).Once()
|
||||
db.Tag.On("FindByParentTagID", testCtx, tc.id).Return(tc.children, nil).Once()
|
||||
}
|
||||
|
||||
mockTagReader.On("FindAllAncestors", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
|
||||
db.Tag.On("FindAllAncestors", testCtx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
|
||||
return tc.onFindAllAncestors
|
||||
}, func(ctx context.Context, tagID int, excludeIDs []int) error {
|
||||
if tc.onFindAllAncestors != nil {
|
||||
|
|
@ -264,7 +263,7 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
|
|||
return fmt.Errorf("undefined ancestors for: %d", tagID)
|
||||
}).Maybe()
|
||||
|
||||
mockTagReader.On("FindAllDescendants", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
|
||||
db.Tag.On("FindAllDescendants", testCtx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
|
||||
return tc.onFindAllDescendants
|
||||
}, func(ctx context.Context, tagID int, excludeIDs []int) error {
|
||||
if tc.onFindAllDescendants != nil {
|
||||
|
|
@ -273,7 +272,7 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
|
|||
return fmt.Errorf("undefined descendants for: %d", tagID)
|
||||
}).Maybe()
|
||||
|
||||
res := ValidateHierarchy(ctx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, mockTagReader)
|
||||
res := ValidateHierarchy(testCtx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, db.Tag)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
|
|
@ -285,5 +284,5 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
|
|||
assert.Nil(res)
|
||||
}
|
||||
|
||||
mockTagReader.AssertExpectations(t)
|
||||
db.AssertExpectations(t)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue