Model refactor, part 3 (#4152)

* Remove manager.Repository
* Refactor other repositories
* Fix tests and add database mock
* Add AssertExpectations method
* Refactor routes
* Move default movie image to internal/static and add convenience methods
* Refactor default performer image boxes
This commit is contained in:
DingDongSoLong4 2023-10-16 05:26:34 +02:00 committed by GitHub
parent 40bcb4baa5
commit 33f2ebf2a3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
87 changed files with 1843 additions and 1651 deletions

View file

@ -1,12 +1,13 @@
package api
import (
"errors"
"fmt"
"io"
"io/fs"
"os"
"strings"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/internal/static"
"github.com/stashapp/stash/pkg/hash"
"github.com/stashapp/stash/pkg/logger"
@ -18,7 +19,7 @@ type imageBox struct {
files []string
}
var imageExtensions = []string{
var imageBoxExts = []string{
".jpg",
".jpeg",
".png",
@ -42,7 +43,7 @@ func newImageBox(box fs.FS) (*imageBox, error) {
}
baseName := strings.ToLower(d.Name())
for _, ext := range imageExtensions {
for _, ext := range imageBoxExts {
if strings.HasSuffix(baseName, ext) {
ret.files = append(ret.files, path)
break
@ -55,44 +56,59 @@ func newImageBox(box fs.FS) (*imageBox, error) {
return ret, err
}
func (box *imageBox) GetRandomImageByName(name string) ([]byte, error) {
files := box.files
if len(files) == 0 {
return nil, errors.New("box is empty")
}
index := hash.IntFromString(name) % uint64(len(files))
img, err := box.box.Open(files[index])
if err != nil {
return nil, err
}
defer img.Close()
return io.ReadAll(img)
}
var performerBox *imageBox
var performerBoxMale *imageBox
var performerBoxCustom *imageBox
func initialiseImages() {
func init() {
var err error
performerBox, err = newImageBox(&static.Performer)
performerBox, err = newImageBox(static.Sub(static.Performer))
if err != nil {
logger.Warnf("error loading performer images: %v", err)
panic(fmt.Sprintf("loading performer images: %v", err))
}
performerBoxMale, err = newImageBox(&static.PerformerMale)
performerBoxMale, err = newImageBox(static.Sub(static.PerformerMale))
if err != nil {
logger.Warnf("error loading male performer images: %v", err)
panic(fmt.Sprintf("loading male performer images: %v", err))
}
initialiseCustomImages()
}
func initialiseCustomImages() {
customPath := config.GetInstance().GetCustomPerformerImageLocation()
func initCustomPerformerImages(customPath string) {
if customPath != "" {
logger.Debugf("Loading custom performer images from %s", customPath)
// We need to set performerBoxCustom at runtime, as this is a custom path, and store it in a pointer.
var err error
performerBoxCustom, err = newImageBox(os.DirFS(customPath))
if err != nil {
logger.Warnf("error loading custom performer from %s: %v", customPath, err)
logger.Warnf("error loading custom performer images from %s: %v", customPath, err)
}
} else {
performerBoxCustom = nil
}
}
func getRandomPerformerImageUsingName(name string, gender *models.GenderEnum, customPath string) ([]byte, error) {
var box *imageBox
// If we have a custom path, we should return a new box in the given path.
if performerBoxCustom != nil && len(performerBoxCustom.files) > 0 {
box = performerBoxCustom
func getDefaultPerformerImage(name string, gender *models.GenderEnum) []byte {
// try the custom box first if we have one
if performerBoxCustom != nil {
ret, err := performerBoxCustom.GetRandomImageByName(name)
if err == nil {
return ret
}
logger.Warnf("error loading custom default performer image: %v", err)
}
var g models.GenderEnum
@ -100,7 +116,7 @@ func getRandomPerformerImageUsingName(name string, gender *models.GenderEnum, cu
g = *gender
}
if box == nil {
var box *imageBox
switch g {
case models.GenderEnumFemale, models.GenderEnumTransgenderFemale:
box = performerBox
@ -109,15 +125,10 @@ func getRandomPerformerImageUsingName(name string, gender *models.GenderEnum, cu
default:
box = performerBox
}
}
imageFiles := box.files
index := hash.IntFromString(name) % uint64(len(imageFiles))
img, err := box.box.Open(imageFiles[index])
ret, err := box.GetRandomImageByName(name)
if err != nil {
return nil, err
logger.Warnf("error loading default performer image: %v", err)
}
defer img.Close()
return io.ReadAll(img)
return ret
}

View file

@ -17,9 +17,7 @@ import (
"net/http"
"time"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
type contextKey struct{ name string }
@ -49,8 +47,7 @@ type Loaders struct {
}
type Middleware struct {
DatabaseProvider txn.DatabaseProvider
Repository manager.Repository
Repository models.Repository
}
func (m Middleware) Middleware(next http.Handler) http.Handler {
@ -131,13 +128,9 @@ func toErrorSlice(err error) []error {
return nil
}
func (m Middleware) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
return txn.WithDatabase(ctx, m.DatabaseProvider, fn)
}
func (m Middleware) fetchScenes(ctx context.Context) func(keys []int) ([]*models.Scene, []error) {
return func(keys []int) (ret []*models.Scene, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error
ret, err = m.Repository.Scene.FindMany(ctx, keys)
return err
@ -148,7 +141,7 @@ func (m Middleware) fetchScenes(ctx context.Context) func(keys []int) ([]*models
func (m Middleware) fetchImages(ctx context.Context) func(keys []int) ([]*models.Image, []error) {
return func(keys []int) (ret []*models.Image, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error
ret, err = m.Repository.Image.FindMany(ctx, keys)
return err
@ -160,7 +153,7 @@ func (m Middleware) fetchImages(ctx context.Context) func(keys []int) ([]*models
func (m Middleware) fetchGalleries(ctx context.Context) func(keys []int) ([]*models.Gallery, []error) {
return func(keys []int) (ret []*models.Gallery, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error
ret, err = m.Repository.Gallery.FindMany(ctx, keys)
return err
@ -172,7 +165,7 @@ func (m Middleware) fetchGalleries(ctx context.Context) func(keys []int) ([]*mod
func (m Middleware) fetchPerformers(ctx context.Context) func(keys []int) ([]*models.Performer, []error) {
return func(keys []int) (ret []*models.Performer, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error
ret, err = m.Repository.Performer.FindMany(ctx, keys)
return err
@ -184,7 +177,7 @@ func (m Middleware) fetchPerformers(ctx context.Context) func(keys []int) ([]*mo
func (m Middleware) fetchStudios(ctx context.Context) func(keys []int) ([]*models.Studio, []error) {
return func(keys []int) (ret []*models.Studio, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error
ret, err = m.Repository.Studio.FindMany(ctx, keys)
return err
@ -195,7 +188,7 @@ func (m Middleware) fetchStudios(ctx context.Context) func(keys []int) ([]*model
func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.Tag, []error) {
return func(keys []int) (ret []*models.Tag, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error
ret, err = m.Repository.Tag.FindMany(ctx, keys)
return err
@ -206,7 +199,7 @@ func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.T
func (m Middleware) fetchMovies(ctx context.Context) func(keys []int) ([]*models.Movie, []error) {
return func(keys []int) (ret []*models.Movie, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error
ret, err = m.Repository.Movie.FindMany(ctx, keys)
return err
@ -217,7 +210,7 @@ func (m Middleware) fetchMovies(ctx context.Context) func(keys []int) ([]*models
func (m Middleware) fetchFiles(ctx context.Context) func(keys []models.FileID) ([]models.File, []error) {
return func(keys []models.FileID) (ret []models.File, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error
ret, err = m.Repository.File.Find(ctx, keys...)
return err
@ -228,7 +221,7 @@ func (m Middleware) fetchFiles(ctx context.Context) func(keys []models.FileID) (
func (m Middleware) fetchScenesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) {
return func(keys []int) (ret [][]models.FileID, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error
ret, err = m.Repository.Scene.GetManyFileIDs(ctx, keys)
return err
@ -239,7 +232,7 @@ func (m Middleware) fetchScenesFileIDs(ctx context.Context) func(keys []int) ([]
func (m Middleware) fetchImagesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) {
return func(keys []int) (ret [][]models.FileID, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error
ret, err = m.Repository.Image.GetManyFileIDs(ctx, keys)
return err
@ -250,7 +243,7 @@ func (m Middleware) fetchImagesFileIDs(ctx context.Context) func(keys []int) ([]
func (m Middleware) fetchGalleriesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) {
return func(keys []int) (ret [][]models.FileID, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error {
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error
ret, err = m.Repository.Gallery.GetManyFileIDs(ctx, keys)
return err

View file

@ -13,7 +13,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/scraper/stashbox"
)
var (
@ -33,8 +33,7 @@ type hookExecutor interface {
}
type Resolver struct {
txnManager txn.Manager
repository manager.Repository
repository models.Repository
sceneService manager.SceneService
imageService manager.ImageService
galleryService manager.GalleryService
@ -102,11 +101,15 @@ type tagResolver struct{ *Resolver }
type savedFilterResolver struct{ *Resolver }
func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
return txn.WithTxn(ctx, r.txnManager, fn)
return r.repository.WithTxn(ctx, fn)
}
func (r *Resolver) withReadTxn(ctx context.Context, fn func(ctx context.Context) error) error {
return txn.WithReadTxn(ctx, r.txnManager, fn)
return r.repository.WithReadTxn(ctx, fn)
}
func (r *Resolver) stashboxRepository() stashbox.Repository {
return stashbox.NewRepository(r.repository)
}
func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) {

View file

@ -316,7 +316,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen
if input.CustomPerformerImageLocation != nil {
c.Set(config.CustomPerformerImageLocation, *input.CustomPerformerImageLocation)
initialiseCustomImages()
initCustomPerformerImages(*input.CustomPerformerImageLocation)
}
if input.ScraperUserAgent != nil {

View file

@ -17,7 +17,7 @@ func (r *mutationResolver) MoveFiles(ctx context.Context, input MoveFilesInput)
fileStore := r.repository.File
folderStore := r.repository.Folder
mover := file.NewMover(fileStore, folderStore)
mover.RegisterHooks(ctx, r.txnManager)
mover.RegisterHooks(ctx)
var (
folder *models.Folder

View file

@ -11,30 +11,30 @@ import (
)
func (r *mutationResolver) MigrateSceneScreenshots(ctx context.Context, input MigrateSceneScreenshotsInput) (string, error) {
db := manager.GetInstance().Database
mgr := manager.GetInstance()
t := &task.MigrateSceneScreenshotsJob{
ScreenshotsPath: manager.GetInstance().Paths.Generated.Screenshots,
Input: scene.MigrateSceneScreenshotsInput{
DeleteFiles: utils.IsTrue(input.DeleteFiles),
OverwriteExisting: utils.IsTrue(input.OverwriteExisting),
},
SceneRepo: db.Scene,
TxnManager: db,
SceneRepo: mgr.Repository.Scene,
TxnManager: mgr.Repository.TxnManager,
}
jobID := manager.GetInstance().JobManager.Add(ctx, "Migrating scene screenshots to blobs...", t)
jobID := mgr.JobManager.Add(ctx, "Migrating scene screenshots to blobs...", t)
return strconv.Itoa(jobID), nil
}
func (r *mutationResolver) MigrateBlobs(ctx context.Context, input MigrateBlobsInput) (string, error) {
db := manager.GetInstance().Database
mgr := manager.GetInstance()
t := &task.MigrateBlobsJob{
TxnManager: db,
BlobStore: db.Blobs,
Vacuumer: db,
TxnManager: mgr.Database,
BlobStore: mgr.Database.Blobs,
Vacuumer: mgr.Database,
DeleteOld: utils.IsTrue(input.DeleteOld),
}
jobID := manager.GetInstance().JobManager.Add(ctx, "Migrating blobs...", t)
jobID := mgr.JobManager.Add(ctx, "Migrating blobs...", t)
return strconv.Itoa(jobID), nil
}

View file

@ -5,6 +5,7 @@ import (
"fmt"
"strconv"
"github.com/stashapp/stash/internal/static"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
@ -50,12 +51,6 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp
return nil, fmt.Errorf("converting studio id: %w", err)
}
// HACK: if back image is being set, set the front image to the default.
// This is because we can't have a null front image with a non-null back image.
if input.FrontImage == nil && input.BackImage != nil {
input.FrontImage = &models.DefaultMovieImage
}
// Process the base 64 encoded image string
var frontimageData []byte
if input.FrontImage != nil {
@ -74,6 +69,12 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp
}
}
// HACK: if back image is being set, set the front image to the default.
// This is because we can't have a null front image with a non-null back image.
if len(frontimageData) == 0 && len(backimageData) != 0 {
frontimageData = static.ReadAll(static.DefaultMovieImage)
}
// Start the transaction and save the movie
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Movie

View file

@ -11,15 +11,6 @@ import (
"github.com/stashapp/stash/pkg/scraper/stashbox"
)
func (r *Resolver) stashboxRepository() stashbox.Repository {
return stashbox.Repository{
Scene: r.repository.Scene,
Performer: r.repository.Performer,
Tag: r.repository.Tag,
Studio: r.repository.Studio,
}
}
func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input StashBoxFingerprintSubmissionInput) (bool, error) {
boxes := config.GetInstance().GetStashBoxes()
@ -27,7 +18,7 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
}
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.stashboxRepository())
return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint)
}
@ -49,7 +40,7 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
}
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.stashboxRepository())
id, err := strconv.Atoi(input.ID)
if err != nil {
@ -91,7 +82,7 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
}
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.stashboxRepository())
id, err := strconv.Atoi(input.ID)
if err != nil {

View file

@ -5,7 +5,6 @@ import (
"errors"
"testing"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
"github.com/stashapp/stash/pkg/plugin"
@ -15,14 +14,9 @@ import (
)
// TODO - move this into a common area
func newResolver() *Resolver {
txnMgr := &mocks.TxnManager{}
func newResolver(db *mocks.Database) *Resolver {
return &Resolver{
txnManager: txnMgr,
repository: manager.Repository{
TxnManager: txnMgr,
Tag: &mocks.TagReaderWriter{},
},
repository: db.Repository(),
hookExecutor: &mockHookExecutor{},
}
}
@ -45,9 +39,8 @@ func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType
}
func TestTagCreate(t *testing.T) {
r := newResolver()
tagRW := r.repository.Tag.(*mocks.TagReaderWriter)
db := mocks.NewDatabase()
r := newResolver(db)
pp := 1
findFilter := &models.FindFilterType{
@ -72,17 +65,17 @@ func TestTagCreate(t *testing.T) {
}
}
tagRW.On("Query", mock.Anything, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
db.Tag.On("Query", mock.Anything, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, 1, nil).Once()
tagRW.On("Query", mock.Anything, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", mock.Anything, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
db.Tag.On("Query", mock.Anything, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
db.Tag.On("Query", mock.Anything, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
expectedErr := errors.New("TagCreate error")
tagRW.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Return(expectedErr)
db.Tag.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Return(expectedErr)
// fails here because testCtx is empty
// TODO: Fix this
@ -101,22 +94,22 @@ func TestTagCreate(t *testing.T) {
})
assert.Equal(t, expectedErr, err)
tagRW.AssertExpectations(t)
db.AssertExpectations(t)
r = newResolver()
tagRW = r.repository.Tag.(*mocks.TagReaderWriter)
db = mocks.NewDatabase()
r = newResolver(db)
tagRW.On("Query", mock.Anything, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", mock.Anything, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
db.Tag.On("Query", mock.Anything, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
db.Tag.On("Query", mock.Anything, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
newTag := &models.Tag{
ID: newTagID,
Name: tagName,
}
tagRW.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
db.Tag.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
arg := args.Get(1).(*models.Tag)
arg.ID = newTagID
}).Return(nil)
tagRW.On("Find", mock.Anything, newTagID).Return(newTag, nil)
db.Tag.On("Find", mock.Anything, newTagID).Return(newTag, nil)
tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: tagName,
@ -124,4 +117,5 @@ func TestTagCreate(t *testing.T) {
assert.Nil(t, err)
assert.NotNil(t, tag)
db.AssertExpectations(t)
}

View file

@ -243,7 +243,9 @@ func makeConfigUIResult() map[string]interface{} {
}
func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input config.StashBoxInput) (*StashBoxValidationResult, error) {
client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager, r.stashboxRepository())
box := models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}
client := stashbox.NewClient(box, r.stashboxRepository())
user, err := client.GetUser(ctx)
valid := user != nil && user.Me != nil

View file

@ -191,16 +191,11 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
}
func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config models.SceneParserInput) (ret *SceneParserResultType, err error) {
parser := scene.NewFilenameParser(filter, config)
repo := scene.NewFilenameParserRepository(r.repository)
parser := scene.NewFilenameParser(filter, config, repo)
if err := r.withReadTxn(ctx, func(ctx context.Context) error {
result, count, err := parser.Parse(ctx, scene.FilenameParserRepository{
Scene: r.repository.Scene,
Performer: r.repository.Performer,
Studio: r.repository.Studio,
Movie: r.repository.Movie,
Tag: r.repository.Tag,
})
result, count, err := parser.Parse(ctx)
if err != nil {
return err

View file

@ -238,7 +238,7 @@ func (r *queryResolver) getStashBoxClient(index int) (*stashbox.Client, error) {
return nil, fmt.Errorf("%w: invalid stash_box_index %d", ErrInput, index)
}
return stashbox.NewClient(*boxes[index], r.txnManager, r.stashboxRepository()), nil
return stashbox.NewClient(*boxes[index], r.stashboxRepository()), nil
}
func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) {

15
internal/api/routes.go Normal file
View file

@ -0,0 +1,15 @@
package api
import (
"net/http"
"github.com/stashapp/stash/pkg/txn"
)
type routes struct {
txnManager txn.Manager
}
func (rs routes) withReadTxn(r *http.Request, fn txn.TxnFunc) error {
return txn.WithReadTxn(r.Context(), rs.txnManager, fn)
}

View file

@ -12,6 +12,10 @@ type customRoutes struct {
servedFolders config.URLMap
}
func getCustomRoutes(servedFolders config.URLMap) chi.Router {
return customRoutes{servedFolders: servedFolders}.Routes()
}
func (rs customRoutes) Routes() chi.Router {
r := chi.NewRouter()

View file

@ -10,6 +10,10 @@ import (
type downloadsRoutes struct{}
func getDownloadsRoutes() chi.Router {
return downloadsRoutes{}.Routes()
}
func (rs downloadsRoutes) Routes() chi.Router {
r := chi.NewRouter()

View file

@ -3,7 +3,6 @@ package api
import (
"context"
"errors"
"io"
"io/fs"
"net/http"
"os/exec"
@ -17,7 +16,6 @@ import (
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
@ -27,11 +25,19 @@ type ImageFinder interface {
}
type imageRoutes struct {
txnManager txn.Manager
routes
imageFinder ImageFinder
fileGetter models.FileGetter
}
func getImageRoutes(repo models.Repository) chi.Router {
return imageRoutes{
routes: routes{txnManager: repo.TxnManager},
imageFinder: repo.Image,
fileGetter: repo.File,
}.Routes()
}
func (rs imageRoutes) Routes() chi.Router {
r := chi.NewRouter()
@ -46,8 +52,6 @@ func (rs imageRoutes) Routes() chi.Router {
return r
}
// region Handlers
func (rs imageRoutes) Thumbnail(w http.ResponseWriter, r *http.Request) {
img := r.Context().Value(imageKey).(*models.Image)
filepath := manager.GetInstance().Paths.Generated.GetThumbnailPath(img.Checksum, models.DefaultGthumbWidth)
@ -119,8 +123,6 @@ func (rs imageRoutes) Image(w http.ResponseWriter, r *http.Request) {
}
func (rs imageRoutes) serveImage(w http.ResponseWriter, r *http.Request, i *models.Image, useDefault bool) {
const defaultImageImage = "image/image.svg"
if i.Files.Primary() != nil {
err := i.Files.Primary().Base().Serve(&file.OsFS{}, w, r)
if err == nil {
@ -141,22 +143,18 @@ func (rs imageRoutes) serveImage(w http.ResponseWriter, r *http.Request, i *mode
return
}
// fall back to static image
f, _ := static.Image.Open(defaultImageImage)
defer f.Close()
image, _ := io.ReadAll(f)
// fallback to default image
image := static.ReadAll(static.DefaultImageImage)
utils.ServeImage(w, r, image)
}
// endregion
func (rs imageRoutes) ImageCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
imageIdentifierQueryParam := chi.URLParam(r, "imageId")
imageID, _ := strconv.Atoi(imageIdentifierQueryParam)
var image *models.Image
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
_ = rs.withReadTxn(r, func(ctx context.Context) error {
qb := rs.imageFinder
if imageID == 0 {
images, _ := qb.FindByChecksum(ctx, imageIdentifierQueryParam)

View file

@ -7,9 +7,10 @@ import (
"strconv"
"github.com/go-chi/chi"
"github.com/stashapp/stash/internal/static"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
@ -20,10 +21,17 @@ type MovieFinder interface {
}
type movieRoutes struct {
txnManager txn.Manager
routes
movieFinder MovieFinder
}
func getMovieRoutes(repo models.Repository) chi.Router {
return movieRoutes{
routes: routes{txnManager: repo.TxnManager},
movieFinder: repo.Movie,
}.Routes()
}
func (rs movieRoutes) Routes() chi.Router {
r := chi.NewRouter()
@ -41,7 +49,7 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) {
defaultParam := r.URL.Query().Get("default")
var image []byte
if defaultParam != "true" {
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error
image, err = rs.movieFinder.GetFrontImage(ctx, movie.ID)
return err
@ -54,8 +62,9 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) {
}
}
// fallback to default image
if len(image) == 0 {
image, _ = utils.ProcessBase64Image(models.DefaultMovieImage)
image = static.ReadAll(static.DefaultMovieImage)
}
utils.ServeImage(w, r, image)
@ -66,7 +75,7 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
defaultParam := r.URL.Query().Get("default")
var image []byte
if defaultParam != "true" {
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error
image, err = rs.movieFinder.GetBackImage(ctx, movie.ID)
return err
@ -79,8 +88,9 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
}
}
// fallback to default image
if len(image) == 0 {
image, _ = utils.ProcessBase64Image(models.DefaultMovieImage)
image = static.ReadAll(static.DefaultMovieImage)
}
utils.ServeImage(w, r, image)
@ -95,7 +105,7 @@ func (rs movieRoutes) MovieCtx(next http.Handler) http.Handler {
}
var movie *models.Movie
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
_ = rs.withReadTxn(r, func(ctx context.Context) error {
movie, _ = rs.movieFinder.Find(ctx, movieID)
return nil
})

View file

@ -7,10 +7,8 @@ import (
"strconv"
"github.com/go-chi/chi"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
@ -20,10 +18,17 @@ type PerformerFinder interface {
}
type performerRoutes struct {
txnManager txn.Manager
routes
performerFinder PerformerFinder
}
func getPerformerRoutes(repo models.Repository) chi.Router {
return performerRoutes{
routes: routes{txnManager: repo.TxnManager},
performerFinder: repo.Performer,
}.Routes()
}
func (rs performerRoutes) Routes() chi.Router {
r := chi.NewRouter()
@ -41,7 +46,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte
if defaultParam != "true" {
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error
image, err = rs.performerFinder.GetImage(ctx, performer.ID)
return err
@ -55,7 +60,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
}
if len(image) == 0 {
image, _ = getRandomPerformerImageUsingName(performer.Name, performer.Gender, config.GetInstance().GetCustomPerformerImageLocation())
image = getDefaultPerformerImage(performer.Name, performer.Gender)
}
utils.ServeImage(w, r, image)
@ -70,7 +75,7 @@ func (rs performerRoutes) PerformerCtx(next http.Handler) http.Handler {
}
var performer *models.Performer
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
_ = rs.withReadTxn(r, func(ctx context.Context) error {
var err error
performer, err = rs.performerFinder.Find(ctx, performerID)
return err

View file

@ -16,7 +16,6 @@ import (
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
@ -43,7 +42,7 @@ type CaptionFinder interface {
}
type sceneRoutes struct {
txnManager txn.Manager
routes
sceneFinder SceneFinder
fileGetter models.FileGetter
captionFinder CaptionFinder
@ -51,6 +50,17 @@ type sceneRoutes struct {
tagFinder SceneMarkerTagFinder
}
func getSceneRoutes(repo models.Repository) chi.Router {
return sceneRoutes{
routes: routes{txnManager: repo.TxnManager},
sceneFinder: repo.Scene,
fileGetter: repo.File,
captionFinder: repo.File,
sceneMarkerFinder: repo.SceneMarker,
tagFinder: repo.Tag,
}.Routes()
}
func (rs sceneRoutes) Routes() chi.Router {
r := chi.NewRouter()
@ -89,8 +99,6 @@ func (rs sceneRoutes) Routes() chi.Router {
return r
}
// region Handlers
func (rs sceneRoutes) StreamDirect(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene)
ss := manager.SceneServer{
@ -270,13 +278,13 @@ func (rs sceneRoutes) Webp(w http.ResponseWriter, r *http.Request) {
utils.ServeStaticFile(w, r, filepath)
}
func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.SceneMarker) (*string, error) {
func (rs sceneRoutes) getChapterVttTitle(r *http.Request, marker *models.SceneMarker) (*string, error) {
if marker.Title != "" {
return &marker.Title, nil
}
var title string
if err := txn.WithReadTxn(ctx, rs.txnManager, func(ctx context.Context) error {
if err := rs.withReadTxn(r, func(ctx context.Context) error {
qb := rs.tagFinder
primaryTag, err := qb.Find(ctx, marker.PrimaryTagID)
if err != nil {
@ -305,7 +313,7 @@ func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.Sce
func (rs sceneRoutes) VttChapter(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene)
var sceneMarkers []*models.SceneMarker
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error
sceneMarkers, err = rs.sceneMarkerFinder.FindBySceneID(ctx, scene.ID)
return err
@ -325,7 +333,7 @@ func (rs sceneRoutes) VttChapter(w http.ResponseWriter, r *http.Request) {
time := utils.GetVTTTime(marker.Seconds)
vttLines = append(vttLines, time+" --> "+time)
vttTitle, err := rs.getChapterVttTitle(r.Context(), marker)
vttTitle, err := rs.getChapterVttTitle(r, marker)
if errors.Is(err, context.Canceled) {
return
}
@ -404,7 +412,7 @@ func (rs sceneRoutes) Caption(w http.ResponseWriter, r *http.Request, lang strin
s := r.Context().Value(sceneKey).(*models.Scene)
var captions []*models.VideoCaption
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error
primaryFile := s.Files.Primary()
if primaryFile == nil {
@ -466,7 +474,7 @@ func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request)
sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm())
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error
sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err
@ -494,7 +502,7 @@ func (rs sceneRoutes) SceneMarkerPreview(w http.ResponseWriter, r *http.Request)
sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm())
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error
sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err
@ -530,7 +538,7 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque
sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm())
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error
sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err
@ -561,8 +569,6 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque
}
}
// endregion
func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sceneID, err := strconv.Atoi(chi.URLParam(r, "sceneId"))
@ -572,7 +578,7 @@ func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler {
}
var scene *models.Scene
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
_ = rs.withReadTxn(r, func(ctx context.Context) error {
qb := rs.sceneFinder
scene, _ = qb.Find(ctx, sceneID)

View file

@ -3,7 +3,6 @@ package api
import (
"context"
"errors"
"io"
"net/http"
"strconv"
@ -11,7 +10,6 @@ import (
"github.com/stashapp/stash/internal/static"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
@ -21,10 +19,17 @@ type StudioFinder interface {
}
type studioRoutes struct {
txnManager txn.Manager
routes
studioFinder StudioFinder
}
func getStudioRoutes(repo models.Repository) chi.Router {
return studioRoutes{
routes: routes{txnManager: repo.TxnManager},
studioFinder: repo.Studio,
}.Routes()
}
func (rs studioRoutes) Routes() chi.Router {
r := chi.NewRouter()
@ -42,7 +47,7 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte
if defaultParam != "true" {
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error
image, err = rs.studioFinder.GetImage(ctx, studio.ID)
return err
@ -55,15 +60,9 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
}
}
// fallback to default image
if len(image) == 0 {
const defaultStudioImage = "studio/studio.svg"
// fall back to static image
f, _ := static.Studio.Open(defaultStudioImage)
defer f.Close()
stat, _ := f.Stat()
http.ServeContent(w, r, "studio.svg", stat.ModTime(), f.(io.ReadSeeker))
return
image = static.ReadAll(static.DefaultStudioImage)
}
utils.ServeImage(w, r, image)
@ -78,7 +77,7 @@ func (rs studioRoutes) StudioCtx(next http.Handler) http.Handler {
}
var studio *models.Studio
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
_ = rs.withReadTxn(r, func(ctx context.Context) error {
var err error
studio, err = rs.studioFinder.Find(ctx, studioID)
return err

View file

@ -3,7 +3,6 @@ package api
import (
"context"
"errors"
"io"
"net/http"
"strconv"
@ -11,7 +10,6 @@ import (
"github.com/stashapp/stash/internal/static"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
@ -21,10 +19,17 @@ type TagFinder interface {
}
type tagRoutes struct {
txnManager txn.Manager
routes
tagFinder TagFinder
}
func getTagRoutes(repo models.Repository) chi.Router {
return tagRoutes{
routes: routes{txnManager: repo.TxnManager},
tagFinder: repo.Tag,
}.Routes()
}
func (rs tagRoutes) Routes() chi.Router {
r := chi.NewRouter()
@ -42,7 +47,7 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte
if defaultParam != "true" {
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error
image, err = rs.tagFinder.GetImage(ctx, tag.ID)
return err
@ -55,15 +60,9 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
}
}
// fallback to default image
if len(image) == 0 {
const defaultTagImage = "tag/tag.svg"
// fall back to static image
f, _ := static.Tag.Open(defaultTagImage)
defer f.Close()
stat, _ := f.Stat()
http.ServeContent(w, r, "tag.svg", stat.ModTime(), f.(io.ReadSeeker))
return
image = static.ReadAll(static.DefaultTagImage)
}
utils.ServeImage(w, r, image)
@ -78,7 +77,7 @@ func (rs tagRoutes) TagCtx(next http.Handler) http.Handler {
}
var tag *models.Tag
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
_ = rs.withReadTxn(r, func(ctx context.Context) error {
var err error
tag, err = rs.tagFinder.Find(ctx, tagID)
return err

View file

@ -50,7 +50,9 @@ var uiBox = ui.UIBox
var loginUIBox = ui.LoginUIBox
func Start() error {
initialiseImages()
c := config.GetInstance()
initCustomPerformerImages(c.GetCustomPerformerImageLocation())
r := chi.NewRouter()
@ -62,7 +64,6 @@ func Start() error {
r.Use(middleware.Recoverer)
c := config.GetInstance()
if c.GetLogAccess() {
httpLogger := httplog.NewLogger("Stash", httplog.Options{
Concise: true,
@ -82,11 +83,10 @@ func Start() error {
return errors.New(message)
}
txnManager := manager.GetInstance().Repository
repo := manager.GetInstance().Repository
dataloaders := loaders.Middleware{
DatabaseProvider: txnManager,
Repository: txnManager,
Repository: repo,
}
r.Use(dataloaders.Middleware)
@ -96,8 +96,7 @@ func Start() error {
imageService := manager.GetInstance().ImageService
galleryService := manager.GetInstance().GalleryService
resolver := &Resolver{
txnManager: txnManager,
repository: txnManager,
repository: repo,
sceneService: sceneService,
imageService: imageService,
galleryService: galleryService,
@ -144,36 +143,13 @@ func Start() error {
gqlPlayground.Handler("GraphQL playground", endpoint)(w, r)
})
r.Mount("/performer", performerRoutes{
txnManager: txnManager,
performerFinder: txnManager.Performer,
}.Routes())
r.Mount("/scene", sceneRoutes{
txnManager: txnManager,
sceneFinder: txnManager.Scene,
fileGetter: txnManager.File,
captionFinder: txnManager.File,
sceneMarkerFinder: txnManager.SceneMarker,
tagFinder: txnManager.Tag,
}.Routes())
r.Mount("/image", imageRoutes{
txnManager: txnManager,
imageFinder: txnManager.Image,
fileGetter: txnManager.File,
}.Routes())
r.Mount("/studio", studioRoutes{
txnManager: txnManager,
studioFinder: txnManager.Studio,
}.Routes())
r.Mount("/movie", movieRoutes{
txnManager: txnManager,
movieFinder: txnManager.Movie,
}.Routes())
r.Mount("/tag", tagRoutes{
txnManager: txnManager,
tagFinder: txnManager.Tag,
}.Routes())
r.Mount("/downloads", downloadsRoutes{}.Routes())
r.Mount("/performer", getPerformerRoutes(repo))
r.Mount("/scene", getSceneRoutes(repo))
r.Mount("/image", getImageRoutes(repo))
r.Mount("/studio", getStudioRoutes(repo))
r.Mount("/movie", getMovieRoutes(repo))
r.Mount("/tag", getTagRoutes(repo))
r.Mount("/downloads", getDownloadsRoutes())
r.HandleFunc("/css", cssHandler(c, pluginCache))
r.HandleFunc("/javascript", javascriptHandler(c, pluginCache))
@ -193,9 +169,7 @@ func Start() error {
// Serve static folders
customServedFolders := c.GetCustomServedFolders()
if customServedFolders != nil {
r.Mount("/custom", customRoutes{
servedFolders: customServedFolders,
}.Routes())
r.Mount("/custom", getCustomRoutes(customServedFolders))
}
customUILocation := c.GetCustomUILocation()

View file

@ -52,11 +52,10 @@ func TestGalleryPerformers(t *testing.T) {
assert := assert.New(t)
for _, test := range testTables {
mockPerformerReader := &mocks.PerformerReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
db := mocks.NewDatabase()
mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
db.Performer.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Performer.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches {
matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool {
@ -69,7 +68,7 @@ func TestGalleryPerformers(t *testing.T) {
return galleryPartialsEqual(got, expected)
})
mockGalleryReader.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
}
gallery := models.Gallery{
@ -77,11 +76,10 @@ func TestGalleryPerformers(t *testing.T) {
Path: test.Path,
PerformerIDs: models.NewRelatedIDs([]int{}),
}
err := GalleryPerformers(testCtx, &gallery, mockGalleryReader, mockPerformerReader, nil)
err := GalleryPerformers(testCtx, &gallery, db.Gallery, db.Performer, nil)
assert.Nil(err)
mockPerformerReader.AssertExpectations(t)
mockGalleryReader.AssertExpectations(t)
db.AssertExpectations(t)
}
}
@ -107,7 +105,7 @@ func TestGalleryStudios(t *testing.T) {
assert := assert.New(t)
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) {
doTest := func(db *mocks.Database, test pathTestTable) {
if test.Matches {
matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool {
expected := models.GalleryPartial{
@ -116,29 +114,27 @@ func TestGalleryStudios(t *testing.T) {
return galleryPartialsEqual(got, expected)
})
mockGalleryReader.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
}
gallery := models.Gallery{
ID: galleryID,
Path: test.Path,
}
err := GalleryStudios(testCtx, &gallery, mockGalleryReader, mockStudioReader, nil)
err := GalleryStudios(testCtx, &gallery, db.Gallery, db.Studio, nil)
assert.Nil(err)
mockStudioReader.AssertExpectations(t)
mockGalleryReader.AssertExpectations(t)
db.AssertExpectations(t)
}
for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
db := mocks.NewDatabase()
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
db.Studio.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockGalleryReader, test)
doTest(db, test)
}
// test against aliases
@ -146,17 +142,16 @@ func TestGalleryStudios(t *testing.T) {
studio.Name = unmatchedName
for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
db := mocks.NewDatabase()
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
db.Studio.On("GetAliases", testCtx, studioID).Return([]string{
studioName,
}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
db.Studio.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockGalleryReader, test)
doTest(db, test)
}
}
@ -182,7 +177,7 @@ func TestGalleryTags(t *testing.T) {
assert := assert.New(t)
doTest := func(mockTagReader *mocks.TagReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) {
doTest := func(db *mocks.Database, test pathTestTable) {
if test.Matches {
matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool {
expected := models.GalleryPartial{
@ -194,7 +189,7 @@ func TestGalleryTags(t *testing.T) {
return galleryPartialsEqual(got, expected)
})
mockGalleryReader.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
}
gallery := models.Gallery{
@ -202,38 +197,35 @@ func TestGalleryTags(t *testing.T) {
Path: test.Path,
TagIDs: models.NewRelatedIDs([]int{}),
}
err := GalleryTags(testCtx, &gallery, mockGalleryReader, mockTagReader, nil)
err := GalleryTags(testCtx, &gallery, db.Gallery, db.Tag, nil)
assert.Nil(err)
mockTagReader.AssertExpectations(t)
mockGalleryReader.AssertExpectations(t)
db.AssertExpectations(t)
}
for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
db := mocks.NewDatabase()
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
db.Tag.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockGalleryReader, test)
doTest(db, test)
}
const unmatchedName = "unmatched"
tag.Name = unmatchedName
for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
db := mocks.NewDatabase()
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
db.Tag.On("GetAliases", testCtx, tagID).Return([]string{
tagName,
}, nil).Once()
mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
db.Tag.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockGalleryReader, test)
doTest(db, test)
}
}

View file

@ -49,11 +49,10 @@ func TestImagePerformers(t *testing.T) {
assert := assert.New(t)
for _, test := range testTables {
mockPerformerReader := &mocks.PerformerReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
db := mocks.NewDatabase()
mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
db.Performer.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Performer.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches {
matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool {
@ -66,7 +65,7 @@ func TestImagePerformers(t *testing.T) {
return imagePartialsEqual(got, expected)
})
mockImageReader.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
}
image := models.Image{
@ -74,11 +73,10 @@ func TestImagePerformers(t *testing.T) {
Path: test.Path,
PerformerIDs: models.NewRelatedIDs([]int{}),
}
err := ImagePerformers(testCtx, &image, mockImageReader, mockPerformerReader, nil)
err := ImagePerformers(testCtx, &image, db.Image, db.Performer, nil)
assert.Nil(err)
mockPerformerReader.AssertExpectations(t)
mockImageReader.AssertExpectations(t)
db.AssertExpectations(t)
}
}
@ -104,7 +102,7 @@ func TestImageStudios(t *testing.T) {
assert := assert.New(t)
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) {
doTest := func(db *mocks.Database, test pathTestTable) {
if test.Matches {
matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool {
expected := models.ImagePartial{
@ -113,29 +111,27 @@ func TestImageStudios(t *testing.T) {
return imagePartialsEqual(got, expected)
})
mockImageReader.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
}
image := models.Image{
ID: imageID,
Path: test.Path,
}
err := ImageStudios(testCtx, &image, mockImageReader, mockStudioReader, nil)
err := ImageStudios(testCtx, &image, db.Image, db.Studio, nil)
assert.Nil(err)
mockStudioReader.AssertExpectations(t)
mockImageReader.AssertExpectations(t)
db.AssertExpectations(t)
}
for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
db := mocks.NewDatabase()
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
db.Studio.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockImageReader, test)
doTest(db, test)
}
// test against aliases
@ -143,17 +139,16 @@ func TestImageStudios(t *testing.T) {
studio.Name = unmatchedName
for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
db := mocks.NewDatabase()
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
db.Studio.On("GetAliases", testCtx, studioID).Return([]string{
studioName,
}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
db.Studio.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockImageReader, test)
doTest(db, test)
}
}
@ -179,7 +174,7 @@ func TestImageTags(t *testing.T) {
assert := assert.New(t)
doTest := func(mockTagReader *mocks.TagReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) {
doTest := func(db *mocks.Database, test pathTestTable) {
if test.Matches {
matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool {
expected := models.ImagePartial{
@ -191,7 +186,7 @@ func TestImageTags(t *testing.T) {
return imagePartialsEqual(got, expected)
})
mockImageReader.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
}
image := models.Image{
@ -199,22 +194,20 @@ func TestImageTags(t *testing.T) {
Path: test.Path,
TagIDs: models.NewRelatedIDs([]int{}),
}
err := ImageTags(testCtx, &image, mockImageReader, mockTagReader, nil)
err := ImageTags(testCtx, &image, db.Image, db.Tag, nil)
assert.Nil(err)
mockTagReader.AssertExpectations(t)
mockImageReader.AssertExpectations(t)
db.AssertExpectations(t)
}
for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
db := mocks.NewDatabase()
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
db.Tag.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockImageReader, test)
doTest(db, test)
}
// test against aliases
@ -222,16 +215,15 @@ func TestImageTags(t *testing.T) {
tag.Name = unmatchedName
for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
db := mocks.NewDatabase()
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
db.Tag.On("GetAliases", testCtx, tagID).Return([]string{
tagName,
}, nil).Once()
mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
db.Tag.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockImageReader, test)
doTest(db, test)
}
}

View file

@ -62,7 +62,7 @@ func runTests(m *testing.M) int {
panic(fmt.Sprintf("Could not initialize database: %s", err.Error()))
}
r = db.TxnRepository()
r = db.Repository()
// defer close and delete the database
defer testTeardown(databaseFile)
@ -474,11 +474,11 @@ func createGallery(ctx context.Context, w models.GalleryWriter, o *models.Galler
}
func withTxn(f func(ctx context.Context) error) error {
return txn.WithTxn(context.TODO(), db, f)
return txn.WithTxn(testCtx, db, f)
}
func withDB(f func(ctx context.Context) error) error {
return txn.WithDatabase(context.TODO(), db, f)
return txn.WithDatabase(testCtx, db, f)
}
func populateDB() error {

View file

@ -45,7 +45,7 @@ func TestPerformerScenes(t *testing.T) {
}
func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
mockSceneReader := &mocks.SceneReaderWriter{}
db := mocks.NewDatabase()
const performerID = 2
@ -84,7 +84,7 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
Direction: &direction,
}
mockSceneReader.On("Query", mock.Anything, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)).
db.Scene.On("Query", mock.Anything, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
for i := range matchingPaths {
@ -100,19 +100,19 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
return scenePartialsEqual(got, expected)
})
mockSceneReader.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
}
tagger := Tagger{
TxnManager: &mocks.TxnManager{},
TxnManager: db,
}
err := tagger.PerformerScenes(testCtx, &performer, nil, mockSceneReader)
err := tagger.PerformerScenes(testCtx, &performer, nil, db.Scene)
assert := assert.New(t)
assert.Nil(err)
mockSceneReader.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestPerformerImages(t *testing.T) {
@ -140,7 +140,7 @@ func TestPerformerImages(t *testing.T) {
}
func testPerformerImages(t *testing.T, performerName, expectedRegex string) {
mockImageReader := &mocks.ImageReaderWriter{}
db := mocks.NewDatabase()
const performerID = 2
@ -179,7 +179,7 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) {
Direction: &direction,
}
mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)).
db.Image.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
for i := range matchingPaths {
@ -195,19 +195,19 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) {
return imagePartialsEqual(got, expected)
})
mockImageReader.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
}
tagger := Tagger{
TxnManager: &mocks.TxnManager{},
TxnManager: db,
}
err := tagger.PerformerImages(testCtx, &performer, nil, mockImageReader)
err := tagger.PerformerImages(testCtx, &performer, nil, db.Image)
assert := assert.New(t)
assert.Nil(err)
mockImageReader.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestPerformerGalleries(t *testing.T) {
@ -235,7 +235,7 @@ func TestPerformerGalleries(t *testing.T) {
}
func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) {
mockGalleryReader := &mocks.GalleryReaderWriter{}
db := mocks.NewDatabase()
const performerID = 2
@ -275,7 +275,7 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) {
Direction: &direction,
}
mockGalleryReader.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
db.Gallery.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
for i := range matchingPaths {
galleryID := i + 1
@ -290,17 +290,17 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) {
return galleryPartialsEqual(got, expected)
})
mockGalleryReader.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
}
tagger := Tagger{
TxnManager: &mocks.TxnManager{},
TxnManager: db,
}
err := tagger.PerformerGalleries(testCtx, &performer, nil, mockGalleryReader)
err := tagger.PerformerGalleries(testCtx, &performer, nil, db.Gallery)
assert := assert.New(t)
assert.Nil(err)
mockGalleryReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -182,11 +182,10 @@ func TestScenePerformers(t *testing.T) {
assert := assert.New(t)
for _, test := range testTables {
mockPerformerReader := &mocks.PerformerReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
db := mocks.NewDatabase()
mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
db.Performer.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Performer.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
scene := models.Scene{
ID: sceneID,
@ -205,14 +204,13 @@ func TestScenePerformers(t *testing.T) {
return scenePartialsEqual(got, expected)
})
mockSceneReader.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
}
err := ScenePerformers(testCtx, &scene, mockSceneReader, mockPerformerReader, nil)
err := ScenePerformers(testCtx, &scene, db.Scene, db.Performer, nil)
assert.Nil(err)
mockPerformerReader.AssertExpectations(t)
mockSceneReader.AssertExpectations(t)
db.AssertExpectations(t)
}
}
@ -240,7 +238,7 @@ func TestSceneStudios(t *testing.T) {
assert := assert.New(t)
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) {
doTest := func(db *mocks.Database, test pathTestTable) {
if test.Matches {
matchPartial := mock.MatchedBy(func(got models.ScenePartial) bool {
expected := models.ScenePartial{
@ -249,29 +247,27 @@ func TestSceneStudios(t *testing.T) {
return scenePartialsEqual(got, expected)
})
mockSceneReader.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
}
scene := models.Scene{
ID: sceneID,
Path: test.Path,
}
err := SceneStudios(testCtx, &scene, mockSceneReader, mockStudioReader, nil)
err := SceneStudios(testCtx, &scene, db.Scene, db.Studio, nil)
assert.Nil(err)
mockStudioReader.AssertExpectations(t)
mockSceneReader.AssertExpectations(t)
db.AssertExpectations(t)
}
for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
db := mocks.NewDatabase()
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
db.Studio.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockSceneReader, test)
doTest(db, test)
}
const unmatchedName = "unmatched"
@ -279,17 +275,16 @@ func TestSceneStudios(t *testing.T) {
// test against aliases
for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
db := mocks.NewDatabase()
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
db.Studio.On("GetAliases", testCtx, studioID).Return([]string{
studioName,
}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
db.Studio.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockSceneReader, test)
doTest(db, test)
}
}
@ -315,7 +310,7 @@ func TestSceneTags(t *testing.T) {
assert := assert.New(t)
doTest := func(mockTagReader *mocks.TagReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) {
doTest := func(db *mocks.Database, test pathTestTable) {
if test.Matches {
matchPartial := mock.MatchedBy(func(got models.ScenePartial) bool {
expected := models.ScenePartial{
@ -327,7 +322,7 @@ func TestSceneTags(t *testing.T) {
return scenePartialsEqual(got, expected)
})
mockSceneReader.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
}
scene := models.Scene{
@ -335,22 +330,20 @@ func TestSceneTags(t *testing.T) {
Path: test.Path,
TagIDs: models.NewRelatedIDs([]int{}),
}
err := SceneTags(testCtx, &scene, mockSceneReader, mockTagReader, nil)
err := SceneTags(testCtx, &scene, db.Scene, db.Tag, nil)
assert.Nil(err)
mockTagReader.AssertExpectations(t)
mockSceneReader.AssertExpectations(t)
db.AssertExpectations(t)
}
for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
db := mocks.NewDatabase()
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
db.Tag.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockSceneReader, test)
doTest(db, test)
}
const unmatchedName = "unmatched"
@ -358,16 +351,15 @@ func TestSceneTags(t *testing.T) {
// test against aliases
for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
db := mocks.NewDatabase()
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
db.Tag.On("GetAliases", testCtx, tagID).Return([]string{
tagName,
}, nil).Once()
mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
db.Tag.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockSceneReader, test)
doTest(db, test)
}
}

View file

@ -83,7 +83,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
aliasName := tc.aliasName
aliasRegex := tc.aliasRegex
mockSceneReader := &mocks.SceneReaderWriter{}
db := mocks.NewDatabase()
var studioID = 2
@ -130,7 +130,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
}
// if alias provided, then don't find by name
onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
onNameQuery := db.Scene.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
if aliasName == "" {
onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
@ -145,7 +145,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
},
}
mockSceneReader.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
db.Scene.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
}
@ -159,19 +159,19 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
return scenePartialsEqual(got, expected)
})
mockSceneReader.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
}
tagger := Tagger{
TxnManager: &mocks.TxnManager{},
TxnManager: db,
}
err := tagger.StudioScenes(testCtx, &studio, nil, aliases, mockSceneReader)
err := tagger.StudioScenes(testCtx, &studio, nil, aliases, db.Scene)
assert := assert.New(t)
assert.Nil(err)
mockSceneReader.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestStudioImages(t *testing.T) {
@ -188,7 +188,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
aliasName := tc.aliasName
aliasRegex := tc.aliasRegex
mockImageReader := &mocks.ImageReaderWriter{}
db := mocks.NewDatabase()
var studioID = 2
@ -234,7 +234,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
}
// if alias provided, then don't find by name
onNameQuery := mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
onNameQuery := db.Image.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
if aliasName == "" {
onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} else {
@ -248,7 +248,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
},
}
mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
db.Image.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
}
@ -262,19 +262,19 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
return imagePartialsEqual(got, expected)
})
mockImageReader.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
}
tagger := Tagger{
TxnManager: &mocks.TxnManager{},
TxnManager: db,
}
err := tagger.StudioImages(testCtx, &studio, nil, aliases, mockImageReader)
err := tagger.StudioImages(testCtx, &studio, nil, aliases, db.Image)
assert := assert.New(t)
assert.Nil(err)
mockImageReader.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestStudioGalleries(t *testing.T) {
@ -290,7 +290,8 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
expectedRegex := tc.expectedRegex
aliasName := tc.aliasName
aliasRegex := tc.aliasRegex
mockGalleryReader := &mocks.GalleryReaderWriter{}
db := mocks.NewDatabase()
var studioID = 2
@ -337,7 +338,7 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
}
// if alias provided, then don't find by name
onNameQuery := mockGalleryReader.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter)
onNameQuery := db.Gallery.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter)
if aliasName == "" {
onNameQuery.Return(galleries, len(galleries), nil).Once()
} else {
@ -351,7 +352,7 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
},
}
mockGalleryReader.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
db.Gallery.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
}
for i := range matchingPaths {
@ -364,17 +365,17 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
return galleryPartialsEqual(got, expected)
})
mockGalleryReader.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
}
tagger := Tagger{
TxnManager: &mocks.TxnManager{},
TxnManager: db,
}
err := tagger.StudioGalleries(testCtx, &studio, nil, aliases, mockGalleryReader)
err := tagger.StudioGalleries(testCtx, &studio, nil, aliases, db.Gallery)
assert := assert.New(t)
assert.Nil(err)
mockGalleryReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -83,7 +83,7 @@ func testTagScenes(t *testing.T, tc testTagCase) {
aliasName := tc.aliasName
aliasRegex := tc.aliasRegex
mockSceneReader := &mocks.SceneReaderWriter{}
db := mocks.NewDatabase()
const tagID = 2
@ -131,7 +131,7 @@ func testTagScenes(t *testing.T, tc testTagCase) {
}
// if alias provided, then don't find by name
onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
onNameQuery := db.Scene.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
if aliasName == "" {
onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
} else {
@ -145,7 +145,7 @@ func testTagScenes(t *testing.T, tc testTagCase) {
},
}
mockSceneReader.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
db.Scene.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
}
@ -162,19 +162,19 @@ func testTagScenes(t *testing.T, tc testTagCase) {
return scenePartialsEqual(got, expected)
})
mockSceneReader.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
}
tagger := Tagger{
TxnManager: &mocks.TxnManager{},
TxnManager: db,
}
err := tagger.TagScenes(testCtx, &tag, nil, aliases, mockSceneReader)
err := tagger.TagScenes(testCtx, &tag, nil, aliases, db.Scene)
assert := assert.New(t)
assert.Nil(err)
mockSceneReader.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestTagImages(t *testing.T) {
@ -191,7 +191,7 @@ func testTagImages(t *testing.T, tc testTagCase) {
aliasName := tc.aliasName
aliasRegex := tc.aliasRegex
mockImageReader := &mocks.ImageReaderWriter{}
db := mocks.NewDatabase()
const tagID = 2
@ -238,7 +238,7 @@ func testTagImages(t *testing.T, tc testTagCase) {
}
// if alias provided, then don't find by name
onNameQuery := mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
onNameQuery := db.Image.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
if aliasName == "" {
onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} else {
@ -252,7 +252,7 @@ func testTagImages(t *testing.T, tc testTagCase) {
},
}
mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
db.Image.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
}
@ -269,19 +269,19 @@ func testTagImages(t *testing.T, tc testTagCase) {
return imagePartialsEqual(got, expected)
})
mockImageReader.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
}
tagger := Tagger{
TxnManager: &mocks.TxnManager{},
TxnManager: db,
}
err := tagger.TagImages(testCtx, &tag, nil, aliases, mockImageReader)
err := tagger.TagImages(testCtx, &tag, nil, aliases, db.Image)
assert := assert.New(t)
assert.Nil(err)
mockImageReader.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestTagGalleries(t *testing.T) {
@ -298,7 +298,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
aliasName := tc.aliasName
aliasRegex := tc.aliasRegex
mockGalleryReader := &mocks.GalleryReaderWriter{}
db := mocks.NewDatabase()
const tagID = 2
@ -346,7 +346,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
}
// if alias provided, then don't find by name
onNameQuery := mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter)
onNameQuery := db.Gallery.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter)
if aliasName == "" {
onNameQuery.Return(galleries, len(galleries), nil).Once()
} else {
@ -360,7 +360,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
},
}
mockGalleryReader.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
db.Gallery.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
}
for i := range matchingPaths {
@ -376,18 +376,18 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
return galleryPartialsEqual(got, expected)
})
mockGalleryReader.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
}
tagger := Tagger{
TxnManager: &mocks.TxnManager{},
TxnManager: db,
}
err := tagger.TagGalleries(testCtx, &tag, nil, aliases, mockGalleryReader)
err := tagger.TagGalleries(testCtx, &tag, nil, aliases, db.Gallery)
assert := assert.New(t)
assert.Nil(err)
mockGalleryReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -41,7 +41,6 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/txn"
)
var pageSize = 100
@ -360,10 +359,11 @@ func (me *contentDirectoryService) handleBrowseMetadata(obj object, host string)
} else {
var scene *models.Scene
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
scene, err = me.repository.SceneFinder.Find(ctx, sceneID)
r := me.repository
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
scene, err = r.SceneFinder.Find(ctx, sceneID)
if scene != nil {
err = scene.LoadPrimaryFile(ctx, me.repository.FileGetter)
err = scene.LoadPrimaryFile(ctx, r.FileGetter)
}
if err != nil {
@ -452,7 +452,8 @@ func getSortDirection(sceneFilter *models.SceneFilterType, sort string) models.S
func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType, parentID string, host string) []interface{} {
var objs []interface{}
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
r := me.repository
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
sort := me.VideoSortOrder
direction := getSortDirection(sceneFilter, sort)
findFilter := &models.FindFilterType{
@ -461,7 +462,7 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
Direction: &direction,
}
scenes, total, err := scene.QueryWithCount(ctx, me.repository.SceneFinder, sceneFilter, findFilter)
scenes, total, err := scene.QueryWithCount(ctx, r.SceneFinder, sceneFilter, findFilter)
if err != nil {
return err
}
@ -472,13 +473,13 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
parentID: parentID,
}
objs, err = pager.getPages(ctx, me.repository.SceneFinder, total)
objs, err = pager.getPages(ctx, r.SceneFinder, total)
if err != nil {
return err
}
} else {
for _, s := range scenes {
if err := s.LoadPrimaryFile(ctx, me.repository.FileGetter); err != nil {
if err := s.LoadPrimaryFile(ctx, r.FileGetter); err != nil {
return err
}
@ -497,7 +498,8 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilterType, parentID string, page int, host string) []interface{} {
var objs []interface{}
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
r := me.repository
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
pager := scenePager{
sceneFilter: sceneFilter,
parentID: parentID,
@ -506,7 +508,7 @@ func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilter
sort := me.VideoSortOrder
direction := getSortDirection(sceneFilter, sort)
var err error
objs, err = pager.getPageVideos(ctx, me.repository.SceneFinder, me.repository.FileGetter, page, host, sort, direction)
objs, err = pager.getPageVideos(ctx, r.SceneFinder, r.FileGetter, page, host, sort, direction)
if err != nil {
return err
}
@ -540,8 +542,9 @@ func (me *contentDirectoryService) getAllScenes(host string) []interface{} {
func (me *contentDirectoryService) getStudios() []interface{} {
var objs []interface{}
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
studios, err := me.repository.StudioFinder.All(ctx)
r := me.repository
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
studios, err := r.StudioFinder.All(ctx)
if err != nil {
return err
}
@ -579,8 +582,9 @@ func (me *contentDirectoryService) getStudioScenes(paths []string, host string)
func (me *contentDirectoryService) getTags() []interface{} {
var objs []interface{}
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
tags, err := me.repository.TagFinder.All(ctx)
r := me.repository
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
tags, err := r.TagFinder.All(ctx)
if err != nil {
return err
}
@ -618,8 +622,9 @@ func (me *contentDirectoryService) getTagScenes(paths []string, host string) []i
func (me *contentDirectoryService) getPerformers() []interface{} {
var objs []interface{}
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
performers, err := me.repository.PerformerFinder.All(ctx)
r := me.repository
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
performers, err := r.PerformerFinder.All(ctx)
if err != nil {
return err
}
@ -657,8 +662,9 @@ func (me *contentDirectoryService) getPerformerScenes(paths []string, host strin
func (me *contentDirectoryService) getMovies() []interface{} {
var objs []interface{}
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
movies, err := me.repository.MovieFinder.All(ctx)
r := me.repository
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
movies, err := r.MovieFinder.All(ctx)
if err != nil {
return err
}

View file

@ -48,7 +48,6 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
type SceneFinder interface {
@ -271,7 +270,6 @@ type Server struct {
// Time interval between SSPD announces
NotifyInterval time.Duration
txnManager txn.Manager
repository Repository
sceneServer sceneServer
ipWhitelistManager *ipWhitelistManager
@ -439,12 +437,13 @@ func (me *Server) serveIcon(w http.ResponseWriter, r *http.Request) {
}
var scene *models.Scene
err := txn.WithReadTxn(r.Context(), me.txnManager, func(ctx context.Context) error {
repo := me.repository
err := repo.WithReadTxn(r.Context(), func(ctx context.Context) error {
idInt, err := strconv.Atoi(sceneId)
if err != nil {
return nil
}
scene, _ = me.repository.SceneFinder.Find(ctx, idInt)
scene, _ = repo.SceneFinder.Find(ctx, idInt)
return nil
})
if err != nil {
@ -579,12 +578,13 @@ func (me *Server) initMux(mux *http.ServeMux) {
mux.HandleFunc(resPath, func(w http.ResponseWriter, r *http.Request) {
sceneId := r.URL.Query().Get("scene")
var scene *models.Scene
err := txn.WithReadTxn(r.Context(), me.txnManager, func(ctx context.Context) error {
repo := me.repository
err := repo.WithReadTxn(r.Context(), func(ctx context.Context) error {
sceneIdInt, err := strconv.Atoi(sceneId)
if err != nil {
return nil
}
scene, _ = me.repository.SceneFinder.Find(ctx, sceneIdInt)
scene, _ = repo.SceneFinder.Find(ctx, sceneIdInt)
return nil
})
if err != nil {

View file

@ -1,6 +1,7 @@
package dlna
import (
"context"
"fmt"
"net"
"net/http"
@ -14,6 +15,8 @@ import (
)
type Repository struct {
TxnManager models.TxnManager
SceneFinder SceneFinder
FileGetter models.FileGetter
StudioFinder StudioFinder
@ -22,6 +25,22 @@ type Repository struct {
MovieFinder MovieFinder
}
func NewRepository(repo models.Repository) Repository {
return Repository{
TxnManager: repo.TxnManager,
FileGetter: repo.File,
SceneFinder: repo.Scene,
StudioFinder: repo.Studio,
TagFinder: repo.Tag,
PerformerFinder: repo.Performer,
MovieFinder: repo.Movie,
}
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
type Status struct {
Running bool `json:"running"`
// If not currently running, time until it will be started. If running, time until it will be stopped
@ -60,7 +79,6 @@ type Config interface {
}
type Service struct {
txnManager txn.Manager
repository Repository
config Config
sceneServer sceneServer
@ -133,9 +151,8 @@ func (s *Service) init() error {
}
s.server = &Server{
txnManager: s.txnManager,
sceneServer: s.sceneServer,
repository: s.repository,
sceneServer: s.sceneServer,
ipWhitelistManager: s.ipWhitelistMgr,
Interfaces: interfaces,
HTTPConn: func() net.Listener {
@ -197,9 +214,8 @@ func (s *Service) init() error {
// }
// NewService initialises and returns a new DLNA service.
func NewService(txnManager txn.Manager, repo Repository, cfg Config, sceneServer sceneServer) *Service {
func NewService(repo Repository, cfg Config, sceneServer sceneServer) *Service {
ret := &Service{
txnManager: txnManager,
repository: repo,
sceneServer: sceneServer,
config: cfg,

View file

@ -43,6 +43,7 @@ type ScraperSource struct {
}
type SceneIdentifier struct {
TxnManager txn.Manager
SceneReaderUpdater SceneReaderUpdater
StudioReaderWriter models.StudioReaderWriter
PerformerCreator PerformerCreator
@ -53,8 +54,8 @@ type SceneIdentifier struct {
SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor
}
func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager, scene *models.Scene) error {
result, err := t.scrapeScene(ctx, txnManager, scene)
func (t *SceneIdentifier) Identify(ctx context.Context, scene *models.Scene) error {
result, err := t.scrapeScene(ctx, scene)
var multipleMatchErr *MultipleMatchesFoundError
if err != nil {
if !errors.As(err, &multipleMatchErr) {
@ -70,7 +71,7 @@ func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager,
options := t.getOptions(multipleMatchErr.Source)
if options.SkipMultipleMatchTag != nil && len(*options.SkipMultipleMatchTag) > 0 {
// Tag it with the multiple results tag
err := t.addTagToScene(ctx, txnManager, scene, *options.SkipMultipleMatchTag)
err := t.addTagToScene(ctx, scene, *options.SkipMultipleMatchTag)
if err != nil {
return err
}
@ -83,7 +84,7 @@ func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager,
}
// results were found, modify the scene
if err := t.modifyScene(ctx, txnManager, scene, result); err != nil {
if err := t.modifyScene(ctx, scene, result); err != nil {
return fmt.Errorf("error modifying scene: %v", err)
}
@ -95,7 +96,7 @@ type scrapeResult struct {
source ScraperSource
}
func (t *SceneIdentifier) scrapeScene(ctx context.Context, txnManager txn.Manager, scene *models.Scene) (*scrapeResult, error) {
func (t *SceneIdentifier) scrapeScene(ctx context.Context, scene *models.Scene) (*scrapeResult, error) {
// iterate through the input sources
for _, source := range t.Sources {
// scrape using the source
@ -261,9 +262,9 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
return ret, nil
}
func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager txn.Manager, s *models.Scene, result *scrapeResult) error {
func (t *SceneIdentifier) modifyScene(ctx context.Context, s *models.Scene, result *scrapeResult) error {
var updater *scene.UpdateSet
if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error {
if err := txn.WithTxn(ctx, t.TxnManager, func(ctx context.Context) error {
// load scene relationships
if err := s.LoadURLs(ctx, t.SceneReaderUpdater); err != nil {
return err
@ -316,8 +317,8 @@ func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager txn.Manage
return nil
}
func (t *SceneIdentifier) addTagToScene(ctx context.Context, txnManager txn.Manager, s *models.Scene, tagToAdd string) error {
if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error {
func (t *SceneIdentifier) addTagToScene(ctx context.Context, s *models.Scene, tagToAdd string) error {
if err := txn.WithTxn(ctx, t.TxnManager, func(ctx context.Context) error {
tagID, err := strconv.Atoi(tagToAdd)
if err != nil {
return fmt.Errorf("error converting tag ID %s: %w", tagToAdd, err)

View file

@ -108,17 +108,17 @@ func TestSceneIdentifier_Identify(t *testing.T) {
},
}
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
mockSceneReaderWriter.On("GetURLs", mock.Anything, mock.Anything).Return(nil, nil)
mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
db := mocks.NewDatabase()
db.Scene.On("GetURLs", mock.Anything, mock.Anything).Return(nil, nil)
db.Scene.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
return id == errUpdateID
}), mock.Anything).Return(nil, errors.New("update error"))
mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
db.Scene.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
return id != errUpdateID
}), mock.Anything).Return(nil, nil)
mockTagFinderCreator := &mocks.TagReaderWriter{}
mockTagFinderCreator.On("Find", mock.Anything, skipMultipleTagID).Return(&models.Tag{
db.Tag.On("Find", mock.Anything, skipMultipleTagID).Return(&models.Tag{
ID: skipMultipleTagID,
Name: skipMultipleTagIDStr,
}, nil)
@ -185,8 +185,11 @@ func TestSceneIdentifier_Identify(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
identifier := SceneIdentifier{
SceneReaderUpdater: mockSceneReaderWriter,
TagFinderCreator: mockTagFinderCreator,
TxnManager: db,
SceneReaderUpdater: db.Scene,
StudioReaderWriter: db.Studio,
PerformerCreator: db.Performer,
TagFinderCreator: db.Tag,
DefaultOptions: defaultOptions,
Sources: sources,
SceneUpdatePostHookExecutor: mockHookExecutor{},
@ -202,7 +205,7 @@ func TestSceneIdentifier_Identify(t *testing.T) {
TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
}
if err := identifier.Identify(testCtx, &mocks.TxnManager{}, scene); (err != nil) != tt.wantErr {
if err := identifier.Identify(testCtx, scene); (err != nil) != tt.wantErr {
t.Errorf("SceneIdentifier.Identify() error = %v, wantErr %v", err, tt.wantErr)
}
})
@ -210,9 +213,8 @@ func TestSceneIdentifier_Identify(t *testing.T) {
}
func TestSceneIdentifier_modifyScene(t *testing.T) {
repo := models.Repository{
TxnManager: &mocks.TxnManager{},
}
db := mocks.NewDatabase()
boolFalse := false
defaultOptions := &MetadataOptions{
SetOrganized: &boolFalse,
@ -221,6 +223,11 @@ func TestSceneIdentifier_modifyScene(t *testing.T) {
SkipSingleNamePerformers: &boolFalse,
}
tr := &SceneIdentifier{
TxnManager: db,
SceneReaderUpdater: db.Scene,
StudioReaderWriter: db.Studio,
PerformerCreator: db.Performer,
TagFinderCreator: db.Tag,
DefaultOptions: defaultOptions,
}
@ -254,7 +261,7 @@ func TestSceneIdentifier_modifyScene(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tr.modifyScene(testCtx, repo, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr {
if err := tr.modifyScene(testCtx, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr {
t.Errorf("SceneIdentifier.modifyScene() error = %v, wantErr %v", err, tt.wantErr)
}
})

View file

@ -22,8 +22,9 @@ func Test_getPerformerID(t *testing.T) {
remoteSiteID := "2"
name := "name"
mockPerformerReaderWriter := mocks.PerformerReaderWriter{}
mockPerformerReaderWriter.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
db := mocks.NewDatabase()
db.Performer.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(1).(*models.Performer)
p.ID = validStoredID
}).Return(nil)
@ -131,7 +132,7 @@ func Test_getPerformerID(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getPerformerID(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p, tt.args.createMissing, tt.args.skipSingleName)
got, err := getPerformerID(testCtx, tt.args.endpoint, db.Performer, tt.args.p, tt.args.createMissing, tt.args.skipSingleName)
if (err != nil) != tt.wantErr {
t.Errorf("getPerformerID() error = %v, wantErr %v", err, tt.wantErr)
return
@ -151,15 +152,16 @@ func Test_createMissingPerformer(t *testing.T) {
invalidName := "invalidName"
performerID := 1
mockPerformerReaderWriter := mocks.PerformerReaderWriter{}
mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool {
db := mocks.NewDatabase()
db.Performer.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool {
return p.Name == validName
})).Run(func(args mock.Arguments) {
p := args.Get(1).(*models.Performer)
p.ID = performerID
}).Return(nil)
mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool {
db.Performer.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool {
return p.Name == invalidName
})).Return(errors.New("error creating performer"))
@ -212,7 +214,7 @@ func Test_createMissingPerformer(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := createMissingPerformer(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p)
got, err := createMissingPerformer(testCtx, tt.args.endpoint, db.Performer, tt.args.p)
if (err != nil) != tt.wantErr {
t.Errorf("createMissingPerformer() error = %v, wantErr %v", err, tt.wantErr)
return

View file

@ -24,14 +24,15 @@ func Test_sceneRelationships_studio(t *testing.T) {
Strategy: FieldStrategyMerge,
}
mockStudioReaderWriter := &mocks.StudioReaderWriter{}
mockStudioReaderWriter.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
db := mocks.NewDatabase()
db.Studio.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = validStoredIDInt
}).Return(nil)
tr := sceneRelationships{
studioReaderWriter: mockStudioReaderWriter,
studioReaderWriter: db.Studio,
fieldOptions: make(map[string]*FieldOptions),
}
@ -174,8 +175,10 @@ func Test_sceneRelationships_performers(t *testing.T) {
}),
}
db := mocks.NewDatabase()
tr := sceneRelationships{
sceneReader: &mocks.SceneReaderWriter{},
sceneReader: db.Scene,
fieldOptions: make(map[string]*FieldOptions),
}
@ -363,22 +366,21 @@ func Test_sceneRelationships_tags(t *testing.T) {
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
}
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
mockTagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool {
db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool {
return p.Name == validName
})).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = validStoredIDInt
}).Return(nil)
mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool {
db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool {
return p.Name == invalidName
})).Return(errors.New("error creating tag"))
tr := sceneRelationships{
sceneReader: mockSceneReaderWriter,
tagCreator: mockTagReaderWriter,
sceneReader: db.Scene,
tagCreator: db.Tag,
fieldOptions: make(map[string]*FieldOptions),
}
@ -552,10 +554,10 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
}),
}
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
db := mocks.NewDatabase()
tr := sceneRelationships{
sceneReader: mockSceneReaderWriter,
sceneReader: db.Scene,
fieldOptions: make(map[string]*FieldOptions),
}
@ -706,12 +708,13 @@ func Test_sceneRelationships_cover(t *testing.T) {
newDataEncoded := base64Prefix + utils.GetBase64StringFromData(newData)
invalidData := newDataEncoded + "!!!"
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
mockSceneReaderWriter.On("GetCover", testCtx, sceneID).Return(existingData, nil)
mockSceneReaderWriter.On("GetCover", testCtx, errSceneID).Return(nil, errors.New("error getting cover"))
db := mocks.NewDatabase()
db.Scene.On("GetCover", testCtx, sceneID).Return(existingData, nil)
db.Scene.On("GetCover", testCtx, errSceneID).Return(nil, errors.New("error getting cover"))
tr := sceneRelationships{
sceneReader: mockSceneReaderWriter,
sceneReader: db.Scene,
fieldOptions: make(map[string]*FieldOptions),
}

View file

@ -19,18 +19,19 @@ func Test_createMissingStudio(t *testing.T) {
invalidName := "invalidName"
createdID := 1
mockStudioReaderWriter := &mocks.StudioReaderWriter{}
mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
db := mocks.NewDatabase()
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
return p.Name == validName
})).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = createdID
}).Return(nil)
mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
return p.Name == invalidName
})).Return(errors.New("error creating studio"))
mockStudioReaderWriter.On("UpdatePartial", testCtx, models.StudioPartial{
db.Studio.On("UpdatePartial", testCtx, models.StudioPartial{
ID: createdID,
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{
@ -42,7 +43,7 @@ func Test_createMissingStudio(t *testing.T) {
Mode: models.RelationshipUpdateModeSet,
},
}).Return(nil, errors.New("error updating stash ids"))
mockStudioReaderWriter.On("UpdatePartial", testCtx, models.StudioPartial{
db.Studio.On("UpdatePartial", testCtx, models.StudioPartial{
ID: createdID,
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{
@ -106,7 +107,7 @@ func Test_createMissingStudio(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := createMissingStudio(testCtx, tt.args.endpoint, mockStudioReaderWriter, tt.args.studio)
got, err := createMissingStudio(testCtx, tt.args.endpoint, db.Studio, tt.args.studio)
if (err != nil) != tt.wantErr {
t.Errorf("createMissingStudio() error = %v, wantErr %v", err, tt.wantErr)
return

View file

@ -131,7 +131,7 @@ type Manager struct {
DLNAService *dlna.Service
Database *sqlite.Database
Repository Repository
Repository models.Repository
SceneService SceneService
ImageService ImageService
@ -174,6 +174,7 @@ func initialize() error {
initProfiling(cfg.GetCPUProfilePath())
db := sqlite.NewDatabase()
repo := db.Repository()
// start with empty paths
emptyPaths := paths.Paths{}
@ -186,49 +187,43 @@ func initialize() error {
PluginCache: plugin.NewCache(cfg),
Database: db,
Repository: sqliteRepository(db),
Repository: repo,
Paths: &emptyPaths,
scanSubs: &subscriptionManager{},
}
instance.SceneService = &scene.Service{
File: db.File,
Repository: db.Scene,
MarkerRepository: db.SceneMarker,
File: repo.File,
Repository: repo.Scene,
MarkerRepository: repo.SceneMarker,
PluginCache: instance.PluginCache,
Paths: instance.Paths,
Config: cfg,
}
instance.ImageService = &image.Service{
File: db.File,
Repository: db.Image,
File: repo.File,
Repository: repo.Image,
}
instance.GalleryService = &gallery.Service{
Repository: db.Gallery,
ImageFinder: db.Image,
Repository: repo.Gallery,
ImageFinder: repo.Image,
ImageService: instance.ImageService,
File: db.File,
Folder: db.Folder,
File: repo.File,
Folder: repo.Folder,
}
instance.JobManager = initJobManager()
sceneServer := SceneServer{
TxnManager: instance.Repository,
SceneCoverGetter: instance.Repository.Scene,
TxnManager: repo.TxnManager,
SceneCoverGetter: repo.Scene,
}
instance.DLNAService = dlna.NewService(instance.Repository, dlna.Repository{
SceneFinder: instance.Repository.Scene,
FileGetter: instance.Repository.File,
StudioFinder: instance.Repository.Studio,
TagFinder: instance.Repository.Tag,
PerformerFinder: instance.Repository.Performer,
MovieFinder: instance.Repository.Movie,
}, instance.Config, &sceneServer)
dlnaRepository := dlna.NewRepository(repo)
instance.DLNAService = dlna.NewService(dlnaRepository, cfg, &sceneServer)
if !cfg.IsNewSystem() {
logger.Infof("using config file: %s", cfg.GetConfigFile())
@ -268,8 +263,8 @@ func initialize() error {
logger.Warnf("could not initialize FFMPEG subsystem: %v", err)
}
instance.Scanner = makeScanner(db, instance.PluginCache)
instance.Cleaner = makeCleaner(db, instance.PluginCache)
instance.Scanner = makeScanner(repo, instance.PluginCache)
instance.Cleaner = makeCleaner(repo, instance.PluginCache)
// if DLNA is enabled, start it now
if instance.Config.GetDLNADefaultEnabled() {
@ -293,14 +288,9 @@ func galleryFileFilter(ctx context.Context, f models.File) bool {
return isZip(f.Base().Basename)
}
func makeScanner(db *sqlite.Database, pluginCache *plugin.Cache) *file.Scanner {
func makeScanner(repo models.Repository, pluginCache *plugin.Cache) *file.Scanner {
return &file.Scanner{
Repository: file.Repository{
Manager: db,
DatabaseProvider: db,
FileStore: db.File,
FolderStore: db.Folder,
},
Repository: file.NewRepository(repo),
FileDecorators: []file.Decorator{
&file.FilteredDecorator{
Decorator: &video.Decorator{
@ -320,15 +310,10 @@ func makeScanner(db *sqlite.Database, pluginCache *plugin.Cache) *file.Scanner {
}
}
func makeCleaner(db *sqlite.Database, pluginCache *plugin.Cache) *file.Cleaner {
func makeCleaner(repo models.Repository, pluginCache *plugin.Cache) *file.Cleaner {
return &file.Cleaner{
FS: &file.OsFS{},
Repository: file.Repository{
Manager: db,
DatabaseProvider: db,
FileStore: db.File,
FolderStore: db.Folder,
},
Repository: file.NewRepository(repo),
Handlers: []file.CleanHandler{
&cleanHandler{},
},
@ -523,14 +508,8 @@ func writeStashIcon() {
// initScraperCache initializes a new scraper cache and returns it.
func (s *Manager) initScraperCache() *scraper.Cache {
ret, err := scraper.NewCache(config.GetInstance(), s.Repository, scraper.Repository{
SceneFinder: s.Repository.Scene,
GalleryFinder: s.Repository.Gallery,
TagFinder: s.Repository.Tag,
PerformerFinder: s.Repository.Performer,
MovieFinder: s.Repository.Movie,
StudioFinder: s.Repository.Studio,
})
scraperRepository := scraper.NewRepository(s.Repository)
ret, err := scraper.NewCache(s.Config, scraperRepository)
if err != nil {
logger.Errorf("Error reading scraper configs: %s", err.Error())
@ -697,7 +676,7 @@ func (s *Manager) Setup(ctx context.Context, input SetupInput) error {
return fmt.Errorf("error initializing FFMPEG subsystem: %v", err)
}
instance.Scanner = makeScanner(instance.Database, instance.PluginCache)
instance.Scanner = makeScanner(instance.Repository, instance.PluginCache)
return nil
}

View file

@ -112,7 +112,8 @@ func (s *Manager) Import(ctx context.Context) (int, error) {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
task := ImportTask{
txnManager: s.Repository,
repository: s.Repository,
resetter: s.Database,
BaseDir: metadataPath,
Reset: true,
DuplicateBehaviour: ImportDuplicateEnumFail,
@ -136,7 +137,7 @@ func (s *Manager) Export(ctx context.Context) (int, error) {
var wg sync.WaitGroup
wg.Add(1)
task := ExportTask{
txnManager: s.Repository,
repository: s.Repository,
full: true,
fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(),
}
@ -167,7 +168,7 @@ func (s *Manager) Generate(ctx context.Context, input GenerateMetadataInput) (in
}
j := &GenerateJob{
txnManager: s.Repository,
repository: s.Repository,
input: input,
}
@ -212,7 +213,7 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl
}
task := GenerateCoverTask{
txnManager: s.Repository,
repository: s.Repository,
Scene: *scene,
ScreenshotAt: at,
Overwrite: true,
@ -239,7 +240,7 @@ type AutoTagMetadataInput struct {
func (s *Manager) AutoTag(ctx context.Context, input AutoTagMetadataInput) int {
j := autoTagJob{
txnManager: s.Repository,
repository: s.Repository,
input: input,
}
@ -255,7 +256,7 @@ type CleanMetadataInput struct {
func (s *Manager) Clean(ctx context.Context, input CleanMetadataInput) int {
j := cleanJob{
cleaner: s.Cleaner,
txnManager: s.Repository,
repository: s.Repository,
sceneService: s.SceneService,
imageService: s.ImageService,
input: input,

View file

@ -6,59 +6,8 @@ import (
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sqlite"
"github.com/stashapp/stash/pkg/txn"
)
type Repository struct {
models.TxnManager
File models.FileReaderWriter
Folder models.FolderReaderWriter
Gallery models.GalleryReaderWriter
GalleryChapter models.GalleryChapterReaderWriter
Image models.ImageReaderWriter
Movie models.MovieReaderWriter
Performer models.PerformerReaderWriter
Scene models.SceneReaderWriter
SceneMarker models.SceneMarkerReaderWriter
Studio models.StudioReaderWriter
Tag models.TagReaderWriter
SavedFilter models.SavedFilterReaderWriter
}
func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithTxn(ctx, r, fn)
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r, fn)
}
func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithDatabase(ctx, r, fn)
}
func sqliteRepository(d *sqlite.Database) Repository {
txnRepo := d.TxnRepository()
return Repository{
TxnManager: txnRepo,
File: d.File,
Folder: d.Folder,
Gallery: d.Gallery,
GalleryChapter: txnRepo.GalleryChapter,
Image: d.Image,
Movie: txnRepo.Movie,
Performer: txnRepo.Performer,
Scene: d.Scene,
SceneMarker: txnRepo.SceneMarker,
Studio: txnRepo.Studio,
Tag: txnRepo.Tag,
SavedFilter: txnRepo.SavedFilter,
}
}
type SceneService interface {
Create(ctx context.Context, input *models.Scene, fileIDs []models.FileID, coverImage []byte) (*models.Scene, error)
AssignFile(ctx context.Context, sceneID int, fileID models.FileID) error

View file

@ -3,7 +3,6 @@ package manager
import (
"context"
"errors"
"io"
"net/http"
"github.com/stashapp/stash/internal/manager/config"
@ -58,8 +57,6 @@ func (s *SceneServer) StreamSceneDirect(scene *models.Scene, w http.ResponseWrit
}
func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter, r *http.Request) {
const defaultSceneImage = "scene/scene.svg"
var cover []byte
readTxnErr := txn.WithReadTxn(r.Context(), s.TxnManager, func(ctx context.Context) error {
cover, _ = s.SceneCoverGetter.GetCover(ctx, scene.ID)
@ -92,10 +89,7 @@ func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter
}
// fallback to default cover if none found
// should always be there
f, _ := static.Scene.Open(defaultSceneImage)
defer f.Close()
cover, _ = io.ReadAll(f)
cover = static.ReadAll(static.DefaultSceneImage)
}
utils.ServeImage(w, r, cover)

View file

@ -19,7 +19,7 @@ import (
)
type autoTagJob struct {
txnManager Repository
repository models.Repository
input AutoTagMetadataInput
cache match.Cache
@ -56,7 +56,7 @@ func (j *autoTagJob) autoTagFiles(ctx context.Context, progress *job.Progress, p
studios: studios,
tags: tags,
progress: progress,
txnManager: j.txnManager,
repository: j.repository,
cache: &j.cache,
}
@ -73,8 +73,8 @@ func (j *autoTagJob) autoTagSpecific(ctx context.Context, progress *job.Progress
studioCount := len(studioIds)
tagCount := len(tagIds)
if err := j.txnManager.WithReadTxn(ctx, func(ctx context.Context) error {
r := j.txnManager
r := j.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
performerQuery := r.Performer
studioQuery := r.Studio
tagQuery := r.Tag
@ -123,16 +123,17 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
return
}
r := j.repository
tagger := autotag.Tagger{
TxnManager: j.txnManager,
TxnManager: r.TxnManager,
Cache: &j.cache,
}
for _, performerId := range performerIds {
var performers []*models.Performer
if err := j.txnManager.WithDB(ctx, func(ctx context.Context) error {
performerQuery := j.txnManager.Performer
if err := r.WithDB(ctx, func(ctx context.Context) error {
performerQuery := r.Performer
ignoreAutoTag := false
perPage := -1
@ -161,7 +162,7 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
return fmt.Errorf("performer with id %s not found", performerId)
}
if err := performer.LoadAliases(ctx, j.txnManager.Performer); err != nil {
if err := performer.LoadAliases(ctx, r.Performer); err != nil {
return fmt.Errorf("loading aliases for performer %d: %w", performer.ID, err)
}
performers = append(performers, performer)
@ -173,7 +174,6 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
}
err := func() error {
r := j.txnManager
if err := tagger.PerformerScenes(ctx, performer, paths, r.Scene); err != nil {
return fmt.Errorf("processing scenes: %w", err)
}
@ -215,9 +215,9 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress,
return
}
r := j.txnManager
r := j.repository
tagger := autotag.Tagger{
TxnManager: j.txnManager,
TxnManager: r.TxnManager,
Cache: &j.cache,
}
@ -308,15 +308,15 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa
return
}
r := j.txnManager
r := j.repository
tagger := autotag.Tagger{
TxnManager: j.txnManager,
TxnManager: r.TxnManager,
Cache: &j.cache,
}
for _, tagId := range tagIds {
var tags []*models.Tag
if err := j.txnManager.WithDB(ctx, func(ctx context.Context) error {
if err := r.WithDB(ctx, func(ctx context.Context) error {
tagQuery := r.Tag
ignoreAutoTag := false
perPage := -1
@ -402,7 +402,7 @@ type autoTagFilesTask struct {
tags bool
progress *job.Progress
txnManager Repository
repository models.Repository
cache *match.Cache
}
@ -482,7 +482,9 @@ func (t *autoTagFilesTask) makeGalleryFilter() *models.GalleryFilterType {
return ret
}
func (t *autoTagFilesTask) getCount(ctx context.Context, r Repository) (int, error) {
func (t *autoTagFilesTask) getCount(ctx context.Context) (int, error) {
r := t.repository
pp := 0
findFilter := &models.FindFilterType{
PerPage: &pp,
@ -522,7 +524,7 @@ func (t *autoTagFilesTask) getCount(ctx context.Context, r Repository) (int, err
return sceneCount + imageCount + galleryCount, nil
}
func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) {
func (t *autoTagFilesTask) processScenes(ctx context.Context) {
if job.IsCancelled(ctx) {
return
}
@ -534,10 +536,12 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) {
findFilter := models.BatchFindFilter(batchSize)
sceneFilter := t.makeSceneFilter()
r := t.repository
more := true
for more {
var scenes []*models.Scene
if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error {
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error
scenes, err = scene.Query(ctx, r.Scene, sceneFilter, findFilter)
return err
@ -555,7 +559,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) {
}
tt := autoTagSceneTask{
txnManager: t.txnManager,
repository: r,
scene: ss,
performers: t.performers,
studios: t.studios,
@ -583,7 +587,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) {
}
}
func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) {
func (t *autoTagFilesTask) processImages(ctx context.Context) {
if job.IsCancelled(ctx) {
return
}
@ -595,10 +599,12 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) {
findFilter := models.BatchFindFilter(batchSize)
imageFilter := t.makeImageFilter()
r := t.repository
more := true
for more {
var images []*models.Image
if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error {
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error
images, err = image.Query(ctx, r.Image, imageFilter, findFilter)
return err
@ -616,7 +622,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) {
}
tt := autoTagImageTask{
txnManager: t.txnManager,
repository: t.repository,
image: ss,
performers: t.performers,
studios: t.studios,
@ -644,7 +650,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) {
}
}
func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) {
func (t *autoTagFilesTask) processGalleries(ctx context.Context) {
if job.IsCancelled(ctx) {
return
}
@ -656,10 +662,12 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) {
findFilter := models.BatchFindFilter(batchSize)
galleryFilter := t.makeGalleryFilter()
r := t.repository
more := true
for more {
var galleries []*models.Gallery
if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error {
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error
galleries, _, err = r.Gallery.Query(ctx, galleryFilter, findFilter)
return err
@ -677,7 +685,7 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) {
}
tt := autoTagGalleryTask{
txnManager: t.txnManager,
repository: t.repository,
gallery: ss,
performers: t.performers,
studios: t.studios,
@ -706,9 +714,8 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) {
}
func (t *autoTagFilesTask) process(ctx context.Context) {
r := t.txnManager
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
total, err := t.getCount(ctx, t.txnManager)
if err := t.repository.WithReadTxn(ctx, func(ctx context.Context) error {
total, err := t.getCount(ctx)
if err != nil {
return err
}
@ -724,13 +731,13 @@ func (t *autoTagFilesTask) process(ctx context.Context) {
return
}
t.processScenes(ctx, r)
t.processImages(ctx, r)
t.processGalleries(ctx, r)
t.processScenes(ctx)
t.processImages(ctx)
t.processGalleries(ctx)
}
type autoTagSceneTask struct {
txnManager Repository
repository models.Repository
scene *models.Scene
performers bool
@ -742,8 +749,8 @@ type autoTagSceneTask struct {
func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
r := t.txnManager
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
if t.scene.Path == "" {
// nothing to do
return nil
@ -774,7 +781,7 @@ func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
}
type autoTagImageTask struct {
txnManager Repository
repository models.Repository
image *models.Image
performers bool
@ -786,8 +793,8 @@ type autoTagImageTask struct {
func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
r := t.txnManager
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
if t.performers {
if err := autotag.ImagePerformers(ctx, t.image, r.Image, r.Performer, t.cache); err != nil {
return fmt.Errorf("tagging image performers for %s: %v", t.image.DisplayName(), err)
@ -813,7 +820,7 @@ func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
}
type autoTagGalleryTask struct {
txnManager Repository
repository models.Repository
gallery *models.Gallery
performers bool
@ -825,8 +832,8 @@ type autoTagGalleryTask struct {
func (t *autoTagGalleryTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
r := t.txnManager
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
if t.performers {
if err := autotag.GalleryPerformers(ctx, t.gallery, r.Gallery, r.Performer, t.cache); err != nil {
return fmt.Errorf("tagging gallery performers for %s: %v", t.gallery.DisplayName(), err)

View file

@ -16,7 +16,6 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/txn"
)
type cleaner interface {
@ -25,7 +24,7 @@ type cleaner interface {
type cleanJob struct {
cleaner cleaner
txnManager Repository
repository models.Repository
input CleanMetadataInput
sceneService SceneService
imageService ImageService
@ -61,10 +60,11 @@ func (j *cleanJob) cleanEmptyGalleries(ctx context.Context) {
const batchSize = 1000
var toClean []int
findFilter := models.BatchFindFilter(batchSize)
if err := txn.WithTxn(ctx, j.txnManager, func(ctx context.Context) error {
r := j.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
found := true
for found {
emptyGalleries, _, err := j.txnManager.Gallery.Query(ctx, &models.GalleryFilterType{
emptyGalleries, _, err := r.Gallery.Query(ctx, &models.GalleryFilterType{
ImageCount: &models.IntCriterionInput{
Value: 0,
Modifier: models.CriterionModifierEquals,
@ -108,9 +108,10 @@ func (j *cleanJob) cleanEmptyGalleries(ctx context.Context) {
func (j *cleanJob) deleteGallery(ctx context.Context, id int) {
pluginCache := GetInstance().PluginCache
qb := j.txnManager.Gallery
if err := txn.WithTxn(ctx, j.txnManager, func(ctx context.Context) error {
r := j.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Gallery
g, err := qb.Find(ctx, id)
if err != nil {
return err
@ -120,7 +121,7 @@ func (j *cleanJob) deleteGallery(ctx context.Context, id int) {
return fmt.Errorf("gallery with id %d not found", id)
}
if err := g.LoadPrimaryFile(ctx, j.txnManager.File); err != nil {
if err := g.LoadPrimaryFile(ctx, r.File); err != nil {
return err
}
@ -253,9 +254,7 @@ func (f *cleanFilter) shouldCleanImage(path string, stash *config.StashConfig) b
return false
}
type cleanHandler struct {
PluginCache *plugin.Cache
}
type cleanHandler struct{}
func (h *cleanHandler) HandleFile(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error {
if err := h.handleRelatedScenes(ctx, fileDeleter, fileID); err != nil {
@ -277,7 +276,7 @@ func (h *cleanHandler) HandleFolder(ctx context.Context, fileDeleter *file.Delet
func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error {
mgr := GetInstance()
sceneQB := mgr.Database.Scene
sceneQB := mgr.Repository.Scene
scenes, err := sceneQB.FindByFileID(ctx, fileID)
if err != nil {
return err
@ -303,12 +302,9 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil
return err
}
checksum := scene.Checksum
oshash := scene.OSHash
mgr.PluginCache.RegisterPostHooks(ctx, scene.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{
Checksum: checksum,
OSHash: oshash,
Checksum: scene.Checksum,
OSHash: scene.OSHash,
Path: scene.Path,
}, nil)
} else {
@ -335,7 +331,7 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil
func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID models.FileID) error {
mgr := GetInstance()
qb := mgr.Database.Gallery
qb := mgr.Repository.Gallery
galleries, err := qb.FindByFileID(ctx, fileID)
if err != nil {
return err
@ -381,7 +377,7 @@ func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID models
func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderID models.FolderID) error {
mgr := GetInstance()
qb := mgr.Database.Gallery
qb := mgr.Repository.Gallery
galleries, err := qb.FindByFolderID(ctx, folderID)
if err != nil {
return err
@ -405,7 +401,7 @@ func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderI
func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error {
mgr := GetInstance()
imageQB := mgr.Database.Image
imageQB := mgr.Repository.Image
images, err := imageQB.FindByFileID(ctx, fileID)
if err != nil {
return err
@ -413,7 +409,7 @@ func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *fil
imageFileDeleter := &image.FileDeleter{
Deleter: fileDeleter,
Paths: GetInstance().Paths,
Paths: mgr.Paths,
}
for _, i := range images {

View file

@ -31,7 +31,7 @@ import (
)
type ExportTask struct {
txnManager Repository
repository models.Repository
full bool
baseDir string
@ -98,7 +98,7 @@ func CreateExportTask(a models.HashAlgorithm, input ExportObjectsInput) *ExportT
}
return &ExportTask{
txnManager: GetInstance().Repository,
repository: GetInstance().Repository,
fileNamingAlgorithm: a,
scenes: newExportSpec(input.Scenes),
images: newExportSpec(input.Images),
@ -148,29 +148,27 @@ func (t *ExportTask) Start(ctx context.Context, wg *sync.WaitGroup) {
paths.EmptyJSONDirs(t.baseDir)
paths.EnsureJSONDirs(t.baseDir)
txnErr := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
txnErr := t.repository.WithTxn(ctx, func(ctx context.Context) error {
// include movie scenes and gallery images
if !t.full {
// only include movie scenes if includeDependencies is also set
if !t.scenes.all && t.includeDependencies {
t.populateMovieScenes(ctx, r)
t.populateMovieScenes(ctx)
}
// always export gallery images
if !t.images.all {
t.populateGalleryImages(ctx, r)
t.populateGalleryImages(ctx)
}
}
t.ExportScenes(ctx, workerCount, r)
t.ExportImages(ctx, workerCount, r)
t.ExportGalleries(ctx, workerCount, r)
t.ExportMovies(ctx, workerCount, r)
t.ExportPerformers(ctx, workerCount, r)
t.ExportStudios(ctx, workerCount, r)
t.ExportTags(ctx, workerCount, r)
t.ExportScenes(ctx, workerCount)
t.ExportImages(ctx, workerCount)
t.ExportGalleries(ctx, workerCount)
t.ExportMovies(ctx, workerCount)
t.ExportPerformers(ctx, workerCount)
t.ExportStudios(ctx, workerCount)
t.ExportTags(ctx, workerCount)
return nil
})
@ -277,9 +275,10 @@ func (t *ExportTask) zipFile(fn, outDir string, z *zip.Writer) error {
return nil
}
func (t *ExportTask) populateMovieScenes(ctx context.Context, repo Repository) {
reader := repo.Movie
sceneReader := repo.Scene
func (t *ExportTask) populateMovieScenes(ctx context.Context) {
r := t.repository
reader := r.Movie
sceneReader := r.Scene
var movies []*models.Movie
var err error
@ -307,9 +306,10 @@ func (t *ExportTask) populateMovieScenes(ctx context.Context, repo Repository) {
}
}
func (t *ExportTask) populateGalleryImages(ctx context.Context, repo Repository) {
reader := repo.Gallery
imageReader := repo.Image
func (t *ExportTask) populateGalleryImages(ctx context.Context) {
r := t.repository
reader := r.Gallery
imageReader := r.Image
var galleries []*models.Gallery
var err error
@ -342,10 +342,10 @@ func (t *ExportTask) populateGalleryImages(ctx context.Context, repo Repository)
}
}
func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo Repository) {
func (t *ExportTask) ExportScenes(ctx context.Context, workers int) {
var scenesWg sync.WaitGroup
sceneReader := repo.Scene
sceneReader := t.repository.Scene
var scenes []*models.Scene
var err error
@ -367,7 +367,7 @@ func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo Reposit
for w := 0; w < workers; w++ { // create export Scene workers
scenesWg.Add(1)
go exportScene(ctx, &scenesWg, jobCh, repo, t)
go t.exportScene(ctx, &scenesWg, jobCh)
}
for i, scene := range scenes {
@ -385,7 +385,7 @@ func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo Reposit
logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func exportFile(f models.File, t *ExportTask) {
func (t *ExportTask) exportFile(f models.File) {
newFileJSON := fileToJSON(f)
fn := newFileJSON.Filename()
@ -449,7 +449,7 @@ func fileToJSON(f models.File) jsonschema.DirEntry {
return &base
}
func exportFolder(f models.Folder, t *ExportTask) {
func (t *ExportTask) exportFolder(f models.Folder) {
newFileJSON := folderToJSON(f)
fn := newFileJSON.Filename()
@ -475,15 +475,17 @@ func folderToJSON(f models.Folder) jsonschema.DirEntry {
return &base
}
func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo Repository, t *ExportTask) {
func (t *ExportTask) exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Scene) {
defer wg.Done()
sceneReader := repo.Scene
studioReader := repo.Studio
movieReader := repo.Movie
galleryReader := repo.Gallery
performerReader := repo.Performer
tagReader := repo.Tag
sceneMarkerReader := repo.SceneMarker
r := t.repository
sceneReader := r.Scene
studioReader := r.Studio
movieReader := r.Movie
galleryReader := r.Gallery
performerReader := r.Performer
tagReader := r.Tag
sceneMarkerReader := r.SceneMarker
for s := range jobChan {
sceneHash := s.GetHash(t.fileNamingAlgorithm)
@ -500,7 +502,7 @@ func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models
// export files
for _, f := range s.Files.List() {
exportFile(f, t)
t.exportFile(f)
}
newSceneJSON.Studio, err = scene.GetStudioName(ctx, studioReader, s)
@ -589,10 +591,11 @@ func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models
}
}
func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo Repository) {
func (t *ExportTask) ExportImages(ctx context.Context, workers int) {
var imagesWg sync.WaitGroup
imageReader := repo.Image
r := t.repository
imageReader := r.Image
var images []*models.Image
var err error
@ -614,7 +617,7 @@ func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo Reposit
for w := 0; w < workers; w++ { // create export Image workers
imagesWg.Add(1)
go exportImage(ctx, &imagesWg, jobCh, repo, t)
go t.exportImage(ctx, &imagesWg, jobCh)
}
for i, image := range images {
@ -632,22 +635,24 @@ func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo Reposit
logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Image, repo Repository, t *ExportTask) {
func (t *ExportTask) exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Image) {
defer wg.Done()
studioReader := repo.Studio
galleryReader := repo.Gallery
performerReader := repo.Performer
tagReader := repo.Tag
r := t.repository
studioReader := r.Studio
galleryReader := r.Gallery
performerReader := r.Performer
tagReader := r.Tag
for s := range jobChan {
imageHash := s.Checksum
if err := s.LoadFiles(ctx, repo.Image); err != nil {
if err := s.LoadFiles(ctx, r.Image); err != nil {
logger.Errorf("[images] <%s> error getting image files: %s", imageHash, err.Error())
continue
}
if err := s.LoadURLs(ctx, repo.Image); err != nil {
if err := s.LoadURLs(ctx, r.Image); err != nil {
logger.Errorf("[images] <%s> error getting image urls: %s", imageHash, err.Error())
continue
}
@ -656,7 +661,7 @@ func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models
// export files
for _, f := range s.Files.List() {
exportFile(f, t)
t.exportFile(f)
}
var err error
@ -715,10 +720,10 @@ func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models
}
}
func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo Repository) {
func (t *ExportTask) ExportGalleries(ctx context.Context, workers int) {
var galleriesWg sync.WaitGroup
reader := repo.Gallery
reader := t.repository.Gallery
var galleries []*models.Gallery
var err error
@ -740,7 +745,7 @@ func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo Repo
for w := 0; w < workers; w++ { // create export Scene workers
galleriesWg.Add(1)
go exportGallery(ctx, &galleriesWg, jobCh, repo, t)
go t.exportGallery(ctx, &galleriesWg, jobCh)
}
for i, gallery := range galleries {
@ -759,15 +764,17 @@ func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo Repo
logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo Repository, t *ExportTask) {
func (t *ExportTask) exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Gallery) {
defer wg.Done()
studioReader := repo.Studio
performerReader := repo.Performer
tagReader := repo.Tag
galleryChapterReader := repo.GalleryChapter
r := t.repository
studioReader := r.Studio
performerReader := r.Performer
tagReader := r.Tag
galleryChapterReader := r.GalleryChapter
for g := range jobChan {
if err := g.LoadFiles(ctx, repo.Gallery); err != nil {
if err := g.LoadFiles(ctx, r.Gallery); err != nil {
logger.Errorf("[galleries] <%s> failed to fetch files for gallery: %s", g.DisplayName(), err.Error())
continue
}
@ -782,12 +789,12 @@ func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *mode
// export files
for _, f := range g.Files.List() {
exportFile(f, t)
t.exportFile(f)
}
// export folder if necessary
if g.FolderID != nil {
folder, err := repo.Folder.Find(ctx, *g.FolderID)
folder, err := r.Folder.Find(ctx, *g.FolderID)
if err != nil {
logger.Errorf("[galleries] <%s> error getting gallery folder: %v", galleryHash, err)
continue
@ -798,7 +805,7 @@ func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *mode
continue
}
exportFolder(*folder, t)
t.exportFolder(*folder)
}
newGalleryJSON.Studio, err = gallery.GetStudioName(ctx, studioReader, g)
@ -857,10 +864,10 @@ func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *mode
}
}
func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo Repository) {
func (t *ExportTask) ExportPerformers(ctx context.Context, workers int) {
var performersWg sync.WaitGroup
reader := repo.Performer
reader := t.repository.Performer
var performers []*models.Performer
var err error
all := t.full || (t.performers != nil && t.performers.all)
@ -880,7 +887,7 @@ func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo Rep
for w := 0; w < workers; w++ { // create export Performer workers
performersWg.Add(1)
go t.exportPerformer(ctx, &performersWg, jobCh, repo)
go t.exportPerformer(ctx, &performersWg, jobCh)
}
for i, performer := range performers {
@ -896,10 +903,11 @@ func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo Rep
logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo Repository) {
func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Performer) {
defer wg.Done()
performerReader := repo.Performer
r := t.repository
performerReader := r.Performer
for p := range jobChan {
newPerformerJSON, err := performer.ToJSON(ctx, performerReader, p)
@ -909,7 +917,7 @@ func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jo
continue
}
tags, err := repo.Tag.FindByPerformerID(ctx, p.ID)
tags, err := r.Tag.FindByPerformerID(ctx, p.ID)
if err != nil {
logger.Errorf("[performers] <%s> error getting performer tags: %s", p.Name, err.Error())
continue
@ -929,10 +937,10 @@ func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jo
}
}
func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo Repository) {
func (t *ExportTask) ExportStudios(ctx context.Context, workers int) {
var studiosWg sync.WaitGroup
reader := repo.Studio
reader := t.repository.Studio
var studios []*models.Studio
var err error
all := t.full || (t.studios != nil && t.studios.all)
@ -953,7 +961,7 @@ func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo Reposi
for w := 0; w < workers; w++ { // create export Studio workers
studiosWg.Add(1)
go t.exportStudio(ctx, &studiosWg, jobCh, repo)
go t.exportStudio(ctx, &studiosWg, jobCh)
}
for i, studio := range studios {
@ -969,10 +977,10 @@ func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo Reposi
logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo Repository) {
func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Studio) {
defer wg.Done()
studioReader := repo.Studio
studioReader := t.repository.Studio
for s := range jobChan {
newStudioJSON, err := studio.ToJSON(ctx, studioReader, s)
@ -990,10 +998,10 @@ func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobCh
}
}
func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo Repository) {
func (t *ExportTask) ExportTags(ctx context.Context, workers int) {
var tagsWg sync.WaitGroup
reader := repo.Tag
reader := t.repository.Tag
var tags []*models.Tag
var err error
all := t.full || (t.tags != nil && t.tags.all)
@ -1014,7 +1022,7 @@ func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo Repositor
for w := 0; w < workers; w++ { // create export Tag workers
tagsWg.Add(1)
go t.exportTag(ctx, &tagsWg, jobCh, repo)
go t.exportTag(ctx, &tagsWg, jobCh)
}
for i, tag := range tags {
@ -1030,10 +1038,10 @@ func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo Repositor
logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo Repository) {
func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Tag) {
defer wg.Done()
tagReader := repo.Tag
tagReader := t.repository.Tag
for thisTag := range jobChan {
newTagJSON, err := tag.ToJSON(ctx, tagReader, thisTag)
@ -1051,10 +1059,10 @@ func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan
}
}
func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo Repository) {
func (t *ExportTask) ExportMovies(ctx context.Context, workers int) {
var moviesWg sync.WaitGroup
reader := repo.Movie
reader := t.repository.Movie
var movies []*models.Movie
var err error
all := t.full || (t.movies != nil && t.movies.all)
@ -1075,7 +1083,7 @@ func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo Reposit
for w := 0; w < workers; w++ { // create export Studio workers
moviesWg.Add(1)
go t.exportMovie(ctx, &moviesWg, jobCh, repo)
go t.exportMovie(ctx, &moviesWg, jobCh)
}
for i, movie := range movies {
@ -1091,11 +1099,12 @@ func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo Reposit
logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func (t *ExportTask) exportMovie(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo Repository) {
func (t *ExportTask) exportMovie(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Movie) {
defer wg.Done()
movieReader := repo.Movie
studioReader := repo.Studio
r := t.repository
movieReader := r.Movie
studioReader := r.Studio
for m := range jobChan {
newMovieJSON, err := movie.ToJSON(ctx, movieReader, studioReader, m)

View file

@ -55,7 +55,7 @@ type GeneratePreviewOptionsInput struct {
const generateQueueSize = 200000
type GenerateJob struct {
txnManager Repository
repository models.Repository
input GenerateMetadataInput
overwrite bool
@ -112,8 +112,9 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) {
Overwrite: j.overwrite,
}
if err := j.txnManager.WithReadTxn(ctx, func(ctx context.Context) error {
qb := j.txnManager.Scene
r := j.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Scene
if len(j.input.SceneIDs) == 0 && len(j.input.MarkerIDs) == 0 {
totals = j.queueTasks(ctx, g, queue)
} else {
@ -129,7 +130,7 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) {
}
if len(j.input.MarkerIDs) > 0 {
markers, err = j.txnManager.SceneMarker.FindMany(ctx, markerIDs)
markers, err = r.SceneMarker.FindMany(ctx, markerIDs)
if err != nil {
return err
}
@ -229,12 +230,14 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
findFilter := models.BatchFindFilter(batchSize)
r := j.repository
for more := true; more; {
if job.IsCancelled(ctx) {
return totals
}
scenes, err := scene.Query(ctx, j.txnManager.Scene, nil, findFilter)
scenes, err := scene.Query(ctx, r.Scene, nil, findFilter)
if err != nil {
logger.Errorf("Error encountered queuing files to scan: %s", err.Error())
return totals
@ -245,7 +248,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
return totals
}
if err := ss.LoadFiles(ctx, j.txnManager.Scene); err != nil {
if err := ss.LoadFiles(ctx, r.Scene); err != nil {
logger.Errorf("Error encountered queuing files to scan: %s", err.Error())
return totals
}
@ -266,7 +269,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
return totals
}
images, err := image.Query(ctx, j.txnManager.Image, nil, findFilter)
images, err := image.Query(ctx, r.Image, nil, findFilter)
if err != nil {
logger.Errorf("Error encountered queuing files to scan: %s", err.Error())
return totals
@ -277,7 +280,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
return totals
}
if err := ss.LoadFiles(ctx, j.txnManager.Image); err != nil {
if err := ss.LoadFiles(ctx, r.Image); err != nil {
logger.Errorf("Error encountered queuing files to scan: %s", err.Error())
return totals
}
@ -331,9 +334,11 @@ func getGeneratePreviewOptions(optionsInput GeneratePreviewOptionsInput) generat
}
func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator, scene *models.Scene, queue chan<- Task, totals *totalsGenerate) {
r := j.repository
if j.input.Covers {
task := &GenerateCoverTask{
txnManager: j.txnManager,
repository: r,
Scene: *scene,
Overwrite: j.overwrite,
}
@ -390,7 +395,7 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
if j.input.Markers {
task := &GenerateMarkersTask{
TxnManager: j.txnManager,
repository: r,
Scene: scene,
Overwrite: j.overwrite,
fileNamingAlgorithm: j.fileNamingAlgo,
@ -429,10 +434,9 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
// generate for all files in scene
for _, f := range scene.Files.List() {
task := &GeneratePhashTask{
repository: r,
File: f,
fileNamingAlgorithm: j.fileNamingAlgo,
txnManager: j.txnManager,
fileUpdater: j.txnManager.File,
Overwrite: j.overwrite,
}
@ -446,10 +450,10 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
if j.input.InteractiveHeatmapsSpeeds {
task := &GenerateInteractiveHeatmapSpeedTask{
repository: r,
Scene: *scene,
Overwrite: j.overwrite,
fileNamingAlgorithm: j.fileNamingAlgo,
TxnManager: j.txnManager,
}
if task.required() {
@ -462,7 +466,7 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
func (j *GenerateJob) queueMarkerJob(g *generate.Generator, marker *models.SceneMarker, queue chan<- Task, totals *totalsGenerate) {
task := &GenerateMarkersTask{
TxnManager: j.txnManager,
repository: j.repository,
Marker: marker,
Overwrite: j.overwrite,
fileNamingAlgorithm: j.fileNamingAlgo,

View file

@ -11,10 +11,10 @@ import (
)
type GenerateInteractiveHeatmapSpeedTask struct {
repository models.Repository
Scene models.Scene
Overwrite bool
fileNamingAlgorithm models.HashAlgorithm
TxnManager Repository
}
func (t *GenerateInteractiveHeatmapSpeedTask) GetDescription() string {
@ -42,10 +42,11 @@ func (t *GenerateInteractiveHeatmapSpeedTask) Start(ctx context.Context) {
median := generator.InteractiveSpeed
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
primaryFile := t.Scene.Files.Primary()
primaryFile.InteractiveSpeed = &median
qb := t.TxnManager.File
qb := r.File
return qb.Update(ctx, primaryFile)
}); err != nil && ctx.Err() == nil {
logger.Error(err.Error())

View file

@ -12,7 +12,7 @@ import (
)
type GenerateMarkersTask struct {
TxnManager Repository
repository models.Repository
Scene *models.Scene
Marker *models.SceneMarker
Overwrite bool
@ -41,9 +41,10 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
if t.Marker != nil {
var scene *models.Scene
if err := t.TxnManager.WithReadTxn(ctx, func(ctx context.Context) error {
r := t.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error
scene, err = t.TxnManager.Scene.Find(ctx, t.Marker.SceneID)
scene, err = r.Scene.Find(ctx, t.Marker.SceneID)
if err != nil {
return err
}
@ -51,7 +52,7 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
return fmt.Errorf("scene with id %d not found", t.Marker.SceneID)
}
return scene.LoadPrimaryFile(ctx, t.TxnManager.File)
return scene.LoadPrimaryFile(ctx, r.File)
}); err != nil {
logger.Errorf("error finding scene for marker generation: %v", err)
return
@ -70,9 +71,10 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
func (t *GenerateMarkersTask) generateSceneMarkers(ctx context.Context) {
var sceneMarkers []*models.SceneMarker
if err := t.TxnManager.WithReadTxn(ctx, func(ctx context.Context) error {
r := t.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error
sceneMarkers, err = t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
sceneMarkers, err = r.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
return err
}); err != nil {
logger.Errorf("error getting scene markers: %s", err.Error())
@ -129,7 +131,7 @@ func (t *GenerateMarkersTask) generateMarker(videoFile *models.VideoFile, scene
func (t *GenerateMarkersTask) markersNeeded(ctx context.Context) int {
markers := 0
sceneMarkers, err := t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
sceneMarkers, err := t.repository.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
if err != nil {
logger.Errorf("error finding scene markers: %s", err.Error())
return 0

View file

@ -7,15 +7,13 @@ import (
"github.com/stashapp/stash/pkg/hash/videophash"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
type GeneratePhashTask struct {
repository models.Repository
File *models.VideoFile
Overwrite bool
fileNamingAlgorithm models.HashAlgorithm
txnManager txn.Manager
fileUpdater models.FileUpdater
}
func (t *GeneratePhashTask) GetDescription() string {
@ -34,15 +32,15 @@ func (t *GeneratePhashTask) Start(ctx context.Context) {
return
}
if err := txn.WithTxn(ctx, t.txnManager, func(ctx context.Context) error {
qb := t.fileUpdater
r := t.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
hashValue := int64(*hash)
t.File.Fingerprints = t.File.Fingerprints.AppendUnique(models.Fingerprint{
Type: models.FingerprintTypePhash,
Fingerprint: hashValue,
})
return qb.Update(ctx, t.File)
return r.File.Update(ctx, t.File)
}); err != nil && ctx.Err() == nil {
logger.Errorf("Error setting phash: %v", err)
}

View file

@ -10,9 +10,9 @@ import (
)
type GenerateCoverTask struct {
repository models.Repository
Scene models.Scene
ScreenshotAt *float64
txnManager Repository
Overwrite bool
}
@ -23,11 +23,13 @@ func (t *GenerateCoverTask) GetDescription() string {
func (t *GenerateCoverTask) Start(ctx context.Context) {
scenePath := t.Scene.Path
r := t.repository
var required bool
if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error {
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
required = t.required(ctx)
return t.Scene.LoadPrimaryFile(ctx, t.txnManager.File)
return t.Scene.LoadPrimaryFile(ctx, r.File)
}); err != nil {
logger.Error(err)
}
@ -70,8 +72,8 @@ func (t *GenerateCoverTask) Start(ctx context.Context) {
return
}
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := t.txnManager.Scene
if err := r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Scene
scenePartial := models.NewScenePartial()
// update the scene cover table
@ -103,7 +105,7 @@ func (t *GenerateCoverTask) required(ctx context.Context) bool {
}
// if the scene has a cover, then we don't need to generate it
hasCover, err := t.txnManager.Scene.HasCover(ctx, t.Scene.ID)
hasCover, err := t.repository.Scene.HasCover(ctx, t.Scene.ID)
if err != nil {
logger.Errorf("Error getting cover: %v", err)
return false

View file

@ -14,7 +14,6 @@ import (
"github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/txn"
)
var ErrInput = errors.New("invalid request input")
@ -52,7 +51,8 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
// if scene ids provided, use those
// otherwise, batch query for all scenes - ordering by path
// don't use a transaction to query scenes
if err := txn.WithDatabase(ctx, instance.Repository, func(ctx context.Context) error {
r := instance.Repository
if err := r.WithDB(ctx, func(ctx context.Context) error {
if len(j.input.SceneIDs) == 0 {
return j.identifyAllScenes(ctx, sources)
}
@ -70,7 +70,7 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
// find the scene
var err error
scene, err := instance.Repository.Scene.Find(ctx, id)
scene, err := r.Scene.Find(ctx, id)
if err != nil {
return fmt.Errorf("finding scene id %d: %w", id, err)
}
@ -89,6 +89,8 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
}
func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.ScraperSource) error {
r := instance.Repository
// exclude organised
organised := false
sceneFilter := scene.FilterFromPaths(j.input.Paths)
@ -102,7 +104,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.
// get the count
pp := 0
findFilter.PerPage = &pp
countResult, err := instance.Repository.Scene.Query(ctx, models.SceneQueryOptions{
countResult, err := r.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: findFilter,
Count: true,
@ -115,7 +117,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.
j.progress.SetTotal(countResult.Count)
return scene.BatchProcess(ctx, instance.Repository.Scene, sceneFilter, findFilter, func(scene *models.Scene) error {
return scene.BatchProcess(ctx, r.Scene, sceneFilter, findFilter, func(scene *models.Scene) error {
if job.IsCancelled(ctx) {
return nil
}
@ -132,18 +134,20 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source
var taskError error
j.progress.ExecuteTask("Identifying "+s.Path, func() {
r := instance.Repository
task := identify.SceneIdentifier{
SceneReaderUpdater: instance.Repository.Scene,
StudioReaderWriter: instance.Repository.Studio,
PerformerCreator: instance.Repository.Performer,
TagFinderCreator: instance.Repository.Tag,
TxnManager: r.TxnManager,
SceneReaderUpdater: r.Scene,
StudioReaderWriter: r.Studio,
PerformerCreator: r.Performer,
TagFinderCreator: r.Tag,
DefaultOptions: j.input.Options,
Sources: sources,
SceneUpdatePostHookExecutor: j.postHookExecutor,
}
taskError = task.Identify(ctx, instance.Repository, s)
taskError = task.Identify(ctx, s)
})
if taskError != nil {
@ -164,15 +168,11 @@ func (j *IdentifyJob) getSources() ([]identify.ScraperSource, error) {
var src identify.ScraperSource
if stashBox != nil {
stashboxRepository := stashbox.NewRepository(instance.Repository)
src = identify.ScraperSource{
Name: "stash-box: " + stashBox.Endpoint,
Scraper: stashboxSource{
stashbox.NewClient(*stashBox, instance.Repository, stashbox.Repository{
Scene: instance.Repository.Scene,
Performer: instance.Repository.Performer,
Tag: instance.Repository.Tag,
Studio: instance.Repository.Studio,
}),
stashbox.NewClient(*stashBox, stashboxRepository),
stashBox.Endpoint,
},
RemoteSite: stashBox.Endpoint,

View file

@ -25,8 +25,13 @@ import (
"github.com/stashapp/stash/pkg/tag"
)
type Resetter interface {
Reset() error
}
type ImportTask struct {
txnManager Repository
repository models.Repository
resetter Resetter
json jsonUtils
BaseDir string
@ -66,8 +71,10 @@ func CreateImportTask(a models.HashAlgorithm, input ImportObjectsInput) (*Import
}
}
mgr := GetInstance()
return &ImportTask{
txnManager: GetInstance().Repository,
repository: mgr.Repository,
resetter: mgr.Database,
BaseDir: baseDir,
TmpZip: tmpZip,
Reset: false,
@ -109,7 +116,7 @@ func (t *ImportTask) Start(ctx context.Context) {
}
if t.Reset {
err := t.txnManager.Reset()
err := t.resetter.Reset()
if err != nil {
logger.Errorf("Error resetting database: %s", err.Error())
@ -194,6 +201,8 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) {
return
}
r := t.repository
for i, fi := range files {
index := i + 1
performerJSON, err := jsonschema.LoadPerformerFile(filepath.Join(path, fi.Name()))
@ -204,11 +213,9 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) {
logger.Progressf("[performers] %d of %d", index, len(files))
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Performer
if err := r.WithTxn(ctx, func(ctx context.Context) error {
importer := &performer.Importer{
ReaderWriter: readerWriter,
ReaderWriter: r.Performer,
TagWriter: r.Tag,
Input: *performerJSON,
}
@ -237,6 +244,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
return
}
r := t.repository
for i, fi := range files {
index := i + 1
studioJSON, err := jsonschema.LoadStudioFile(filepath.Join(path, fi.Name()))
@ -247,8 +256,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Progressf("[studios] %d of %d", index, len(files))
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportStudio(ctx, studioJSON, pendingParent, t.txnManager.Studio)
if err := r.WithTxn(ctx, func(ctx context.Context) error {
return t.importStudio(ctx, studioJSON, pendingParent)
}); err != nil {
if errors.Is(err, studio.ErrParentStudioNotExist) {
// add to the pending parent list so that it is created after the parent
@ -269,8 +278,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
for _, s := range pendingParent {
for _, orphanStudioJSON := range s {
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportStudio(ctx, orphanStudioJSON, nil, t.txnManager.Studio)
if err := r.WithTxn(ctx, func(ctx context.Context) error {
return t.importStudio(ctx, orphanStudioJSON, nil)
}); err != nil {
logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error())
continue
@ -282,9 +291,9 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Info("[studios] import complete")
}
func (t *ImportTask) ImportStudio(ctx context.Context, studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter studio.ImporterReaderWriter) error {
func (t *ImportTask) importStudio(ctx context.Context, studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio) error {
importer := &studio.Importer{
ReaderWriter: readerWriter,
ReaderWriter: t.repository.Studio,
Input: *studioJSON,
MissingRefBehaviour: t.MissingRefBehaviour,
}
@ -302,7 +311,7 @@ func (t *ImportTask) ImportStudio(ctx context.Context, studioJSON *jsonschema.St
s := pendingParent[studioJSON.Name]
for _, childStudioJSON := range s {
// map is nil since we're not checking parent studios at this point
if err := t.ImportStudio(ctx, childStudioJSON, nil, readerWriter); err != nil {
if err := t.importStudio(ctx, childStudioJSON, nil); err != nil {
return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error())
}
}
@ -326,6 +335,8 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
return
}
r := t.repository
for i, fi := range files {
index := i + 1
movieJSON, err := jsonschema.LoadMovieFile(filepath.Join(path, fi.Name()))
@ -336,14 +347,10 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
logger.Progressf("[movies] %d of %d", index, len(files))
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Movie
studioReaderWriter := r.Studio
if err := r.WithTxn(ctx, func(ctx context.Context) error {
movieImporter := &movie.Importer{
ReaderWriter: readerWriter,
StudioWriter: studioReaderWriter,
ReaderWriter: r.Movie,
StudioWriter: r.Studio,
Input: *movieJSON,
MissingRefBehaviour: t.MissingRefBehaviour,
}
@ -371,6 +378,8 @@ func (t *ImportTask) ImportFiles(ctx context.Context) {
return
}
r := t.repository
pendingParent := make(map[string][]jsonschema.DirEntry)
for i, fi := range files {
@ -383,8 +392,8 @@ func (t *ImportTask) ImportFiles(ctx context.Context) {
logger.Progressf("[files] %d of %d", index, len(files))
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportFile(ctx, fileJSON, pendingParent)
if err := r.WithTxn(ctx, func(ctx context.Context) error {
return t.importFile(ctx, fileJSON, pendingParent)
}); err != nil {
if errors.Is(err, file.ErrZipFileNotExist) {
// add to the pending parent list so that it is created after the parent
@ -405,8 +414,8 @@ func (t *ImportTask) ImportFiles(ctx context.Context) {
for _, s := range pendingParent {
for _, orphanFileJSON := range s {
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportFile(ctx, orphanFileJSON, nil)
if err := r.WithTxn(ctx, func(ctx context.Context) error {
return t.importFile(ctx, orphanFileJSON, nil)
}); err != nil {
logger.Errorf("[files] <%s> failed to create: %s", orphanFileJSON.DirEntry().Path, err.Error())
continue
@ -418,12 +427,11 @@ func (t *ImportTask) ImportFiles(ctx context.Context) {
logger.Info("[files] import complete")
}
func (t *ImportTask) ImportFile(ctx context.Context, fileJSON jsonschema.DirEntry, pendingParent map[string][]jsonschema.DirEntry) error {
r := t.txnManager
readerWriter := r.File
func (t *ImportTask) importFile(ctx context.Context, fileJSON jsonschema.DirEntry, pendingParent map[string][]jsonschema.DirEntry) error {
r := t.repository
fileImporter := &file.Importer{
ReaderWriter: readerWriter,
ReaderWriter: r.File,
FolderStore: r.Folder,
Input: fileJSON,
}
@ -437,7 +445,7 @@ func (t *ImportTask) ImportFile(ctx context.Context, fileJSON jsonschema.DirEntr
s := pendingParent[fileJSON.DirEntry().Path]
for _, childFileJSON := range s {
// map is nil since we're not checking parent studios at this point
if err := t.ImportFile(ctx, childFileJSON, nil); err != nil {
if err := t.importFile(ctx, childFileJSON, nil); err != nil {
return fmt.Errorf("failed to create child file <%s>: %s", childFileJSON.DirEntry().Path, err.Error())
}
}
@ -461,6 +469,8 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
return
}
r := t.repository
for i, fi := range files {
index := i + 1
galleryJSON, err := jsonschema.LoadGalleryFile(filepath.Join(path, fi.Name()))
@ -471,21 +481,14 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
logger.Progressf("[galleries] %d of %d", index, len(files))
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Gallery
tagWriter := r.Tag
performerWriter := r.Performer
studioWriter := r.Studio
chapterWriter := r.GalleryChapter
if err := r.WithTxn(ctx, func(ctx context.Context) error {
galleryImporter := &gallery.Importer{
ReaderWriter: readerWriter,
ReaderWriter: r.Gallery,
FolderFinder: r.Folder,
FileFinder: r.File,
PerformerWriter: performerWriter,
StudioWriter: studioWriter,
TagWriter: tagWriter,
PerformerWriter: r.Performer,
StudioWriter: r.Studio,
TagWriter: r.Tag,
Input: *galleryJSON,
MissingRefBehaviour: t.MissingRefBehaviour,
}
@ -500,7 +503,7 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
GalleryID: galleryImporter.ID,
Input: m,
MissingRefBehaviour: t.MissingRefBehaviour,
ReaderWriter: chapterWriter,
ReaderWriter: r.GalleryChapter,
}
if err := performImport(ctx, chapterImporter, t.DuplicateBehaviour); err != nil {
@ -532,6 +535,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
return
}
r := t.repository
for i, fi := range files {
index := i + 1
tagJSON, err := jsonschema.LoadTagFile(filepath.Join(path, fi.Name()))
@ -542,8 +547,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
logger.Progressf("[tags] %d of %d", index, len(files))
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportTag(ctx, tagJSON, pendingParent, false, t.txnManager.Tag)
if err := r.WithTxn(ctx, func(ctx context.Context) error {
return t.importTag(ctx, tagJSON, pendingParent, false)
}); err != nil {
var parentError tag.ParentTagNotExistError
if errors.As(err, &parentError) {
@ -558,8 +563,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
for _, s := range pendingParent {
for _, orphanTagJSON := range s {
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportTag(ctx, orphanTagJSON, nil, true, t.txnManager.Tag)
if err := r.WithTxn(ctx, func(ctx context.Context) error {
return t.importTag(ctx, orphanTagJSON, nil, true)
}); err != nil {
logger.Errorf("[tags] <%s> failed to create: %s", orphanTagJSON.Name, err.Error())
continue
@ -570,9 +575,9 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
logger.Info("[tags] import complete")
}
func (t *ImportTask) ImportTag(ctx context.Context, tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter tag.ImporterReaderWriter) error {
func (t *ImportTask) importTag(ctx context.Context, tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool) error {
importer := &tag.Importer{
ReaderWriter: readerWriter,
ReaderWriter: t.repository.Tag,
Input: *tagJSON,
MissingRefBehaviour: t.MissingRefBehaviour,
}
@ -587,7 +592,7 @@ func (t *ImportTask) ImportTag(ctx context.Context, tagJSON *jsonschema.Tag, pen
}
for _, childTagJSON := range pendingParent[tagJSON.Name] {
if err := t.ImportTag(ctx, childTagJSON, pendingParent, fail, readerWriter); err != nil {
if err := t.importTag(ctx, childTagJSON, pendingParent, fail); err != nil {
var parentError tag.ParentTagNotExistError
if errors.As(err, &parentError) {
pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], childTagJSON)
@ -616,6 +621,8 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
return
}
r := t.repository
for i, fi := range files {
index := i + 1
@ -627,29 +634,20 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
continue
}
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Scene
tagWriter := r.Tag
galleryWriter := r.Gallery
movieWriter := r.Movie
performerWriter := r.Performer
studioWriter := r.Studio
markerWriter := r.SceneMarker
if err := r.WithTxn(ctx, func(ctx context.Context) error {
sceneImporter := &scene.Importer{
ReaderWriter: readerWriter,
ReaderWriter: r.Scene,
Input: *sceneJSON,
FileFinder: r.File,
FileNamingAlgorithm: t.fileNamingAlgorithm,
MissingRefBehaviour: t.MissingRefBehaviour,
GalleryFinder: galleryWriter,
MovieWriter: movieWriter,
PerformerWriter: performerWriter,
StudioWriter: studioWriter,
TagWriter: tagWriter,
GalleryFinder: r.Gallery,
MovieWriter: r.Movie,
PerformerWriter: r.Performer,
StudioWriter: r.Studio,
TagWriter: r.Tag,
}
if err := performImport(ctx, sceneImporter, t.DuplicateBehaviour); err != nil {
@ -662,8 +660,8 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
SceneID: sceneImporter.ID,
Input: m,
MissingRefBehaviour: t.MissingRefBehaviour,
ReaderWriter: markerWriter,
TagWriter: tagWriter,
ReaderWriter: r.SceneMarker,
TagWriter: r.Tag,
}
if err := performImport(ctx, markerImporter, t.DuplicateBehaviour); err != nil {
@ -693,6 +691,8 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
return
}
r := t.repository
for i, fi := range files {
index := i + 1
@ -704,25 +704,18 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
continue
}
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Image
tagWriter := r.Tag
galleryWriter := r.Gallery
performerWriter := r.Performer
studioWriter := r.Studio
if err := r.WithTxn(ctx, func(ctx context.Context) error {
imageImporter := &image.Importer{
ReaderWriter: readerWriter,
ReaderWriter: r.Image,
FileFinder: r.File,
Input: *imageJSON,
MissingRefBehaviour: t.MissingRefBehaviour,
GalleryFinder: galleryWriter,
PerformerWriter: performerWriter,
StudioWriter: studioWriter,
TagWriter: tagWriter,
GalleryFinder: r.Gallery,
PerformerWriter: r.Performer,
StudioWriter: r.Studio,
TagWriter: r.Tag,
}
return performImport(ctx, imageImporter, t.DuplicateBehaviour)

View file

@ -19,6 +19,7 @@ import (
"github.com/stashapp/stash/pkg/job"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/paths"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/scene/generate"
"github.com/stashapp/stash/pkg/txn"
@ -48,10 +49,14 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) {
paths[i] = p.Path
}
mgr := GetInstance()
c := mgr.Config
repo := mgr.Repository
start := time.Now()
const taskQueueSize = 200000
taskQueue := job.NewTaskQueue(ctx, progress, taskQueueSize, instance.Config.GetParallelTasksWithAutoDetection())
taskQueue := job.NewTaskQueue(ctx, progress, taskQueueSize, c.GetParallelTasksWithAutoDetection())
var minModTime time.Time
if j.input.Filter != nil && j.input.Filter.MinModTime != nil {
@ -60,12 +65,10 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) {
j.scanner.Scan(ctx, getScanHandlers(j.input, taskQueue, progress), file.ScanOptions{
Paths: paths,
ScanFilters: []file.PathFilter{newScanFilter(instance.Config, minModTime)},
ZipFileExtensions: instance.Config.GetGalleryExtensions(),
ParallelTasks: instance.Config.GetParallelTasksWithAutoDetection(),
HandlerRequiredFilters: []file.Filter{
newHandlerRequiredFilter(instance.Config),
},
ScanFilters: []file.PathFilter{newScanFilter(c, repo, minModTime)},
ZipFileExtensions: c.GetGalleryExtensions(),
ParallelTasks: c.GetParallelTasksWithAutoDetection(),
HandlerRequiredFilters: []file.Filter{newHandlerRequiredFilter(c, repo)},
}, progress)
taskQueue.Close()
@ -123,17 +126,16 @@ type handlerRequiredFilter struct {
videoFileNamingAlgorithm models.HashAlgorithm
}
func newHandlerRequiredFilter(c *config.Instance) *handlerRequiredFilter {
db := instance.Database
func newHandlerRequiredFilter(c *config.Instance, repo models.Repository) *handlerRequiredFilter {
processes := c.GetParallelTasksWithAutoDetection()
return &handlerRequiredFilter{
extensionConfig: newExtensionConfig(c),
txnManager: db,
SceneFinder: db.Scene,
ImageFinder: db.Image,
GalleryFinder: db.Gallery,
CaptionUpdater: db.File,
txnManager: repo.TxnManager,
SceneFinder: repo.Scene,
ImageFinder: repo.Image,
GalleryFinder: repo.Gallery,
CaptionUpdater: repo.File,
FolderCache: lru.New(processes * 2),
videoFileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(),
}
@ -226,6 +228,10 @@ func (f *handlerRequiredFilter) Accept(ctx context.Context, ff models.File) bool
type scanFilter struct {
extensionConfig
txnManager txn.Manager
FileFinder models.FileFinder
CaptionUpdater video.CaptionUpdater
stashPaths config.StashConfigs
generatedPath string
videoExcludeRegex []*regexp.Regexp
@ -233,9 +239,12 @@ type scanFilter struct {
minModTime time.Time
}
func newScanFilter(c *config.Instance, minModTime time.Time) *scanFilter {
func newScanFilter(c *config.Instance, repo models.Repository, minModTime time.Time) *scanFilter {
return &scanFilter{
extensionConfig: newExtensionConfig(c),
txnManager: repo.TxnManager,
FileFinder: repo.File,
CaptionUpdater: repo.File,
stashPaths: c.GetStashPaths(),
generatedPath: c.GetGeneratedPath(),
videoExcludeRegex: generateRegexps(c.GetExcludes()),
@ -263,7 +272,7 @@ func (f *scanFilter) Accept(ctx context.Context, path string, info fs.FileInfo)
if fsutil.MatchExtension(path, video.CaptionExts) {
// we don't include caption files in the file scan, but we do need
// to handle them
video.AssociateCaptions(ctx, path, instance.Repository, instance.Database.File, instance.Database.File)
video.AssociateCaptions(ctx, path, f.txnManager, f.FileFinder, f.CaptionUpdater)
return false
}
@ -308,30 +317,37 @@ func (f *scanFilter) Accept(ctx context.Context, path string, info fs.FileInfo)
type scanConfig struct {
isGenerateThumbnails bool
isGenerateClipPreviews bool
createGalleriesFromFolders bool
}
func (c *scanConfig) GetCreateGalleriesFromFolders() bool {
return instance.Config.GetCreateGalleriesFromFolders()
return c.createGalleriesFromFolders
}
func getScanHandlers(options ScanMetadataInput, taskQueue *job.TaskQueue, progress *job.Progress) []file.Handler {
db := instance.Database
pluginCache := instance.PluginCache
mgr := GetInstance()
c := mgr.Config
r := mgr.Repository
pluginCache := mgr.PluginCache
return []file.Handler{
&file.FilteredHandler{
Filter: file.FilterFunc(imageFileFilter),
Handler: &image.ScanHandler{
CreatorUpdater: db.Image,
GalleryFinder: db.Gallery,
CreatorUpdater: r.Image,
GalleryFinder: r.Gallery,
ScanGenerator: &imageGenerators{
input: options,
taskQueue: taskQueue,
progress: progress,
paths: mgr.Paths,
sequentialScanning: c.GetSequentialScanning(),
},
ScanConfig: &scanConfig{
isGenerateThumbnails: options.ScanGenerateThumbnails,
isGenerateClipPreviews: options.ScanGenerateClipPreviews,
createGalleriesFromFolders: c.GetCreateGalleriesFromFolders(),
},
PluginCache: pluginCache,
Paths: instance.Paths,
@ -340,25 +356,28 @@ func getScanHandlers(options ScanMetadataInput, taskQueue *job.TaskQueue, progre
&file.FilteredHandler{
Filter: file.FilterFunc(galleryFileFilter),
Handler: &gallery.ScanHandler{
CreatorUpdater: db.Gallery,
SceneFinderUpdater: db.Scene,
ImageFinderUpdater: db.Image,
CreatorUpdater: r.Gallery,
SceneFinderUpdater: r.Scene,
ImageFinderUpdater: r.Image,
PluginCache: pluginCache,
},
},
&file.FilteredHandler{
Filter: file.FilterFunc(videoFileFilter),
Handler: &scene.ScanHandler{
CreatorUpdater: db.Scene,
CreatorUpdater: r.Scene,
CaptionUpdater: r.File,
PluginCache: pluginCache,
CaptionUpdater: db.File,
ScanGenerator: &sceneGenerators{
input: options,
taskQueue: taskQueue,
progress: progress,
paths: mgr.Paths,
fileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(),
sequentialScanning: c.GetSequentialScanning(),
},
FileNamingAlgorithm: instance.Config.GetVideoFileNamingAlgorithm(),
Paths: instance.Paths,
FileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(),
Paths: mgr.Paths,
},
},
}
@ -368,6 +387,9 @@ type imageGenerators struct {
input ScanMetadataInput
taskQueue *job.TaskQueue
progress *job.Progress
paths *paths.Paths
sequentialScanning bool
}
func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f models.File) error {
@ -376,8 +398,6 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model
progress := g.progress
t := g.input
path := f.Base().Path
config := instance.Config
sequentialScanning := config.GetSequentialScanning()
if t.ScanGenerateThumbnails {
// this should be quick, so always generate sequentially
@ -405,7 +425,7 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model
progress.Increment()
}
if sequentialScanning {
if g.sequentialScanning {
previewsFn(ctx)
} else {
g.taskQueue.Add(fmt.Sprintf("Generating preview for %s", path), previewsFn)
@ -416,7 +436,7 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model
}
func (g *imageGenerators) generateThumbnail(ctx context.Context, i *models.Image, f models.File) error {
thumbPath := GetInstance().Paths.Generated.GetThumbnailPath(i.Checksum, models.DefaultGthumbWidth)
thumbPath := g.paths.Generated.GetThumbnailPath(i.Checksum, models.DefaultGthumbWidth)
exists, _ := fsutil.FileExists(thumbPath)
if exists {
return nil
@ -435,13 +455,16 @@ func (g *imageGenerators) generateThumbnail(ctx context.Context, i *models.Image
logger.Debugf("Generating thumbnail for %s", path)
mgr := GetInstance()
c := mgr.Config
clipPreviewOptions := image.ClipPreviewOptions{
InputArgs: instance.Config.GetTranscodeInputArgs(),
OutputArgs: instance.Config.GetTranscodeOutputArgs(),
Preset: instance.Config.GetPreviewPreset().String(),
InputArgs: c.GetTranscodeInputArgs(),
OutputArgs: c.GetTranscodeOutputArgs(),
Preset: c.GetPreviewPreset().String(),
}
encoder := image.NewThumbnailEncoder(instance.FFMPEG, instance.FFProbe, clipPreviewOptions)
encoder := image.NewThumbnailEncoder(mgr.FFMPEG, mgr.FFProbe, clipPreviewOptions)
data, err := encoder.GetThumbnail(f, models.DefaultGthumbWidth)
if err != nil {
@ -464,6 +487,10 @@ type sceneGenerators struct {
input ScanMetadataInput
taskQueue *job.TaskQueue
progress *job.Progress
paths *paths.Paths
fileNamingAlgorithm models.HashAlgorithm
sequentialScanning bool
}
func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *models.VideoFile) error {
@ -472,9 +499,8 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
progress := g.progress
t := g.input
path := f.Path
config := instance.Config
fileNamingAlgorithm := config.GetVideoFileNamingAlgorithm()
sequentialScanning := config.GetSequentialScanning()
mgr := GetInstance()
if t.ScanGenerateSprites {
progress.AddTotal(1)
@ -482,13 +508,13 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
taskSprite := GenerateSpriteTask{
Scene: *s,
Overwrite: overwrite,
fileNamingAlgorithm: fileNamingAlgorithm,
fileNamingAlgorithm: g.fileNamingAlgorithm,
}
taskSprite.Start(ctx)
progress.Increment()
}
if sequentialScanning {
if g.sequentialScanning {
spriteFn(ctx)
} else {
g.taskQueue.Add(fmt.Sprintf("Generating sprites for %s", path), spriteFn)
@ -499,17 +525,16 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
progress.AddTotal(1)
phashFn := func(ctx context.Context) {
taskPhash := GeneratePhashTask{
repository: mgr.Repository,
File: f,
fileNamingAlgorithm: fileNamingAlgorithm,
txnManager: instance.Database,
fileUpdater: instance.Database.File,
Overwrite: overwrite,
fileNamingAlgorithm: g.fileNamingAlgorithm,
}
taskPhash.Start(ctx)
progress.Increment()
}
if sequentialScanning {
if g.sequentialScanning {
phashFn(ctx)
} else {
g.taskQueue.Add(fmt.Sprintf("Generating phash for %s", path), phashFn)
@ -521,12 +546,12 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
previewsFn := func(ctx context.Context) {
options := getGeneratePreviewOptions(GeneratePreviewOptionsInput{})
g := &generate.Generator{
Encoder: instance.FFMPEG,
FFMpegConfig: instance.Config,
LockManager: instance.ReadLockManager,
MarkerPaths: instance.Paths.SceneMarkers,
ScenePaths: instance.Paths.Scene,
generator := &generate.Generator{
Encoder: mgr.FFMPEG,
FFMpegConfig: mgr.Config,
LockManager: mgr.ReadLockManager,
MarkerPaths: g.paths.SceneMarkers,
ScenePaths: g.paths.Scene,
Overwrite: overwrite,
}
@ -535,14 +560,14 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
ImagePreview: t.ScanGenerateImagePreviews,
Options: options,
Overwrite: overwrite,
fileNamingAlgorithm: fileNamingAlgorithm,
generator: g,
fileNamingAlgorithm: g.fileNamingAlgorithm,
generator: generator,
}
taskPreview.Start(ctx)
progress.Increment()
}
if sequentialScanning {
if g.sequentialScanning {
previewsFn(ctx)
} else {
g.taskQueue.Add(fmt.Sprintf("Generating preview for %s", path), previewsFn)
@ -553,8 +578,8 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
progress.AddTotal(1)
g.taskQueue.Add(fmt.Sprintf("Generating cover for %s", path), func(ctx context.Context) {
taskCover := GenerateCoverTask{
repository: mgr.Repository,
Scene: *s,
txnManager: instance.Repository,
Overwrite: overwrite,
}
taskCover.Start(ctx)

View file

@ -9,7 +9,6 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/txn"
)
type StashBoxTagTaskType int
@ -92,18 +91,18 @@ func (t *StashBoxBatchTagTask) findStashBoxPerformer(ctx context.Context) (*mode
var performer *models.ScrapedPerformer
var err error
client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{
Scene: instance.Repository.Scene,
Performer: instance.Repository.Performer,
Tag: instance.Repository.Tag,
Studio: instance.Repository.Studio,
})
r := instance.Repository
stashboxRepository := stashbox.NewRepository(r)
client := stashbox.NewClient(*t.box, stashboxRepository)
if t.refresh {
var remoteID string
if err := txn.WithReadTxn(ctx, instance.Repository, func(ctx context.Context) error {
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Performer
if !t.performer.StashIDs.Loaded() {
err = t.performer.LoadStashIDs(ctx, instance.Repository.Performer)
err = t.performer.LoadStashIDs(ctx, qb)
if err != nil {
return err
}
@ -145,8 +144,9 @@ func (t *StashBoxBatchTagTask) processMatchedPerformer(ctx context.Context, p *m
}
// Start the transaction and update the performer
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Performer
r := instance.Repository
err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Performer
existingStashIDs, err := qb.GetStashIDs(ctx, storedID)
if err != nil {
@ -181,8 +181,10 @@ func (t *StashBoxBatchTagTask) processMatchedPerformer(ctx context.Context, p *m
return
}
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Performer
r := instance.Repository
err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Performer
if err := qb.Create(ctx, newPerformer); err != nil {
return err
}
@ -233,18 +235,16 @@ func (t *StashBoxBatchTagTask) findStashBoxStudio(ctx context.Context) (*models.
var studio *models.ScrapedStudio
var err error
client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{
Scene: instance.Repository.Scene,
Performer: instance.Repository.Performer,
Tag: instance.Repository.Tag,
Studio: instance.Repository.Studio,
})
r := instance.Repository
stashboxRepository := stashbox.NewRepository(r)
client := stashbox.NewClient(*t.box, stashboxRepository)
if t.refresh {
var remoteID string
if err := txn.WithReadTxn(ctx, instance.Repository, func(ctx context.Context) error {
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
if !t.studio.StashIDs.Loaded() {
err = t.studio.LoadStashIDs(ctx, instance.Repository.Studio)
err = t.studio.LoadStashIDs(ctx, r.Studio)
if err != nil {
return err
}
@ -293,8 +293,9 @@ func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *mode
}
// Start the transaction and update the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Studio
r := instance.Repository
err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Studio
existingStashIDs, err := qb.GetStashIDs(ctx, storedID)
if err != nil {
@ -341,8 +342,10 @@ func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *mode
}
// Start the transaction and save the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Studio
r := instance.Repository
err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Studio
if err := qb.Create(ctx, newStudio); err != nil {
return err
}
@ -375,8 +378,10 @@ func (t *StashBoxBatchTagTask) processParentStudio(ctx context.Context, parent *
}
// Start the transaction and save the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Studio
r := instance.Repository
err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Studio
if err := qb.Create(ctx, newParentStudio); err != nil {
return err
}
@ -408,8 +413,9 @@ func (t *StashBoxBatchTagTask) processParentStudio(ctx context.Context, parent *
}
// Start the transaction and update the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Studio
r := instance.Repository
err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Studio
existingStashIDs, err := qb.GetStashIDs(ctx, storedID)
if err != nil {

View file

@ -1,21 +1,62 @@
package static
import "embed"
import (
"embed"
"fmt"
"io"
"io/fs"
)
//go:embed performer
var Performer embed.FS
//go:embed performer performer_male scene image tag studio movie
var data embed.FS
//go:embed performer_male
var PerformerMale embed.FS
const (
Performer = "performer"
PerformerMale = "performer_male"
//go:embed scene
var Scene embed.FS
Scene = "scene"
DefaultSceneImage = "scene/scene.svg"
//go:embed image
var Image embed.FS
Image = "image"
DefaultImageImage = "image/image.svg"
//go:embed tag
var Tag embed.FS
Tag = "tag"
DefaultTagImage = "tag/tag.svg"
//go:embed studio
var Studio embed.FS
Studio = "studio"
DefaultStudioImage = "studio/studio.svg"
Movie = "movie"
DefaultMovieImage = "movie/movie.png"
)
// Sub returns an FS rooted at path, using fs.Sub.
// It will panic if an error occurs.
func Sub(path string) fs.FS {
ret, err := fs.Sub(data, path)
if err != nil {
panic(fmt.Sprintf("creating static SubFS: %v", err))
}
return ret
}
// Open opens the file at path for reading.
// It will panic if an error occurs.
func Open(path string) fs.File {
f, err := data.Open(path)
if err != nil {
panic(fmt.Sprintf("opening static file: %v", err))
}
return f
}
// ReadAll returns the contents of the file at path.
// It will panic if an error occurs.
func ReadAll(path string) []byte {
f := Open(path)
ret, err := io.ReadAll(f)
if err != nil {
panic(fmt.Sprintf("reading static file: %v", err))
}
return ret
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 405 B

View file

@ -11,7 +11,6 @@ import (
"github.com/stashapp/stash/pkg/job"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
// Cleaner scans through stored file and folder instances and removes those that are no longer present on disk.
@ -112,14 +111,15 @@ func (j *cleanJob) execute(ctx context.Context) error {
folderCount int
)
if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error {
r := j.Repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error
fileCount, err = j.Repository.FileStore.CountAllInPaths(ctx, j.options.Paths)
fileCount, err = r.File.CountAllInPaths(ctx, j.options.Paths)
if err != nil {
return err
}
folderCount, err = j.Repository.FolderStore.CountAllInPaths(ctx, j.options.Paths)
folderCount, err = r.Folder.CountAllInPaths(ctx, j.options.Paths)
if err != nil {
return err
}
@ -172,13 +172,14 @@ func (j *cleanJob) assessFiles(ctx context.Context, toDelete *deleteSet) error {
progress := j.progress
more := true
if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error {
r := j.Repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
for more {
if job.IsCancelled(ctx) {
return nil
}
files, err := j.Repository.FileStore.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
files, err := r.File.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
if err != nil {
return fmt.Errorf("error querying for files: %w", err)
}
@ -223,8 +224,9 @@ func (j *cleanJob) assessFiles(ctx context.Context, toDelete *deleteSet) error {
// flagFolderForDelete adds folders to the toDelete set, with the leaf folders added first
func (j *cleanJob) flagFileForDelete(ctx context.Context, toDelete *deleteSet, f models.File) error {
r := j.Repository
// add contained files first
containedFiles, err := j.Repository.FileStore.FindByZipFileID(ctx, f.Base().ID)
containedFiles, err := r.File.FindByZipFileID(ctx, f.Base().ID)
if err != nil {
return fmt.Errorf("error finding contained files for %q: %w", f.Base().Path, err)
}
@ -235,7 +237,7 @@ func (j *cleanJob) flagFileForDelete(ctx context.Context, toDelete *deleteSet, f
}
// add contained folders as well
containedFolders, err := j.Repository.FolderStore.FindByZipFileID(ctx, f.Base().ID)
containedFolders, err := r.Folder.FindByZipFileID(ctx, f.Base().ID)
if err != nil {
return fmt.Errorf("error finding contained folders for %q: %w", f.Base().Path, err)
}
@ -256,13 +258,14 @@ func (j *cleanJob) assessFolders(ctx context.Context, toDelete *deleteSet) error
progress := j.progress
more := true
if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error {
r := j.Repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
for more {
if job.IsCancelled(ctx) {
return nil
}
folders, err := j.Repository.FolderStore.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
folders, err := r.Folder.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
if err != nil {
return fmt.Errorf("error querying for folders: %w", err)
}
@ -380,14 +383,15 @@ func (j *cleanJob) shouldCleanFolder(ctx context.Context, f *models.Folder) bool
func (j *cleanJob) deleteFile(ctx context.Context, fileID models.FileID, fn string) {
// delete associated objects
fileDeleter := NewDeleter()
if err := txn.WithTxn(ctx, j.Repository, func(ctx context.Context) error {
r := j.Repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
fileDeleter.RegisterHooks(ctx)
if err := j.fireHandlers(ctx, fileDeleter, fileID); err != nil {
return err
}
return j.Repository.FileStore.Destroy(ctx, fileID)
return r.File.Destroy(ctx, fileID)
}); err != nil {
logger.Errorf("Error deleting file %q from database: %s", fn, err.Error())
return
@ -397,14 +401,15 @@ func (j *cleanJob) deleteFile(ctx context.Context, fileID models.FileID, fn stri
func (j *cleanJob) deleteFolder(ctx context.Context, folderID models.FolderID, fn string) {
// delete associated objects
fileDeleter := NewDeleter()
if err := txn.WithTxn(ctx, j.Repository, func(ctx context.Context) error {
r := j.Repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
fileDeleter.RegisterHooks(ctx)
if err := j.fireFolderHandlers(ctx, fileDeleter, folderID); err != nil {
return err
}
return j.Repository.FolderStore.Destroy(ctx, folderID)
return r.Folder.Destroy(ctx, folderID)
}); err != nil {
logger.Errorf("Error deleting folder %q from database: %s", fn, err.Error())
return

View file

@ -1,15 +1,36 @@
package file
import (
"context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
// Repository provides access to storage methods for files and folders.
type Repository struct {
txn.Manager
txn.DatabaseProvider
TxnManager models.TxnManager
FileStore models.FileReaderWriter
FolderStore models.FolderReaderWriter
File models.FileReaderWriter
Folder models.FolderReaderWriter
}
func NewRepository(repo models.Repository) Repository {
return Repository{
TxnManager: repo.TxnManager,
File: repo.File,
Folder: repo.Folder,
}
}
func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithTxn(ctx, r.TxnManager, fn)
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithDatabase(ctx, r.TxnManager, fn)
}

View file

@ -86,6 +86,8 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
}
// rejects is a set of folder ids which were found to still exist
r := s.Repository
if err := symWalk(file.fs, file.Path, func(path string, d fs.DirEntry, err error) error {
if err != nil {
// don't let errors prevent scanning
@ -118,7 +120,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
}
// check if the file exists in the database based on basename, size and mod time
existing, err := s.Repository.FileStore.FindByFileInfo(ctx, info, size)
existing, err := r.File.FindByFileInfo(ctx, info, size)
if err != nil {
return fmt.Errorf("checking for existing file %q: %w", path, err)
}
@ -140,7 +142,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
if c == nil {
// need to check if the folder exists in the filesystem
pf, err := s.Repository.FolderStore.Find(ctx, e.Base().ParentFolderID)
pf, err := r.Folder.Find(ctx, e.Base().ParentFolderID)
if err != nil {
return fmt.Errorf("getting parent folder %d: %w", e.Base().ParentFolderID, err)
}
@ -164,7 +166,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
// parent folder is missing, possible candidate
// count the total number of files in the existing folder
count, err := s.Repository.FileStore.CountByFolderID(ctx, parentFolderID)
count, err := r.File.CountByFolderID(ctx, parentFolderID)
if err != nil {
return fmt.Errorf("counting files in folder %d: %w", parentFolderID, err)
}

View file

@ -181,7 +181,7 @@ func (m *Mover) moveFile(oldPath, newPath string) error {
return nil
}
func (m *Mover) RegisterHooks(ctx context.Context, mgr txn.Manager) {
func (m *Mover) RegisterHooks(ctx context.Context) {
txn.AddPostCommitHook(ctx, func(ctx context.Context) {
m.commit()
})

View file

@ -144,7 +144,7 @@ func (s *Scanner) Scan(ctx context.Context, handlers []Handler, options ScanOpti
ProgressReports: progressReporter,
options: options,
txnRetryer: txn.Retryer{
Manager: s.Repository,
Manager: s.Repository.TxnManager,
Retries: maxRetries,
},
}
@ -163,7 +163,7 @@ func (s *scanJob) withTxn(ctx context.Context, fn func(ctx context.Context) erro
}
func (s *scanJob) withDB(ctx context.Context, fn func(ctx context.Context) error) error {
return txn.WithDatabase(ctx, s.Repository, fn)
return s.Repository.WithDB(ctx, fn)
}
func (s *scanJob) execute(ctx context.Context) {
@ -439,7 +439,7 @@ func (s *scanJob) getFolderID(ctx context.Context, path string) (*models.FolderI
return &v, nil
}
ret, err := s.Repository.FolderStore.FindByPath(ctx, path)
ret, err := s.Repository.Folder.FindByPath(ctx, path)
if err != nil {
return nil, err
}
@ -469,7 +469,7 @@ func (s *scanJob) getZipFileID(ctx context.Context, zipFile *scanFile) (*models.
return &v, nil
}
ret, err := s.Repository.FileStore.FindByPath(ctx, path)
ret, err := s.Repository.File.FindByPath(ctx, path)
if err != nil {
return nil, fmt.Errorf("getting zip file ID for %q: %w", path, err)
}
@ -489,7 +489,7 @@ func (s *scanJob) handleFolder(ctx context.Context, file scanFile) error {
defer s.incrementProgress(file)
// determine if folder already exists in data store (by path)
f, err := s.Repository.FolderStore.FindByPath(ctx, path)
f, err := s.Repository.Folder.FindByPath(ctx, path)
if err != nil {
return fmt.Errorf("checking for existing folder %q: %w", path, err)
}
@ -553,7 +553,7 @@ func (s *scanJob) onNewFolder(ctx context.Context, file scanFile) (*models.Folde
logger.Infof("%s doesn't exist. Creating new folder entry...", file.Path)
})
if err := s.Repository.FolderStore.Create(ctx, toCreate); err != nil {
if err := s.Repository.Folder.Create(ctx, toCreate); err != nil {
return nil, fmt.Errorf("creating folder %q: %w", file.Path, err)
}
@ -589,12 +589,12 @@ func (s *scanJob) handleFolderRename(ctx context.Context, file scanFile) (*model
renamedFrom.ParentFolderID = parentFolderID
if err := s.Repository.FolderStore.Update(ctx, renamedFrom); err != nil {
if err := s.Repository.Folder.Update(ctx, renamedFrom); err != nil {
return nil, fmt.Errorf("updating folder for rename %q: %w", renamedFrom.Path, err)
}
// #4146 - correct sub-folders to have the correct path
if err := correctSubFolderHierarchy(ctx, s.Repository.FolderStore, renamedFrom); err != nil {
if err := correctSubFolderHierarchy(ctx, s.Repository.Folder, renamedFrom); err != nil {
return nil, fmt.Errorf("correcting sub folder hierarchy for %q: %w", renamedFrom.Path, err)
}
@ -626,7 +626,7 @@ func (s *scanJob) onExistingFolder(ctx context.Context, f scanFile, existing *mo
if update {
var err error
if err = s.Repository.FolderStore.Update(ctx, existing); err != nil {
if err = s.Repository.Folder.Update(ctx, existing); err != nil {
return nil, fmt.Errorf("updating folder %q: %w", f.Path, err)
}
}
@ -647,7 +647,7 @@ func (s *scanJob) handleFile(ctx context.Context, f scanFile) error {
if err := s.withDB(ctx, func(ctx context.Context) error {
// determine if file already exists in data store
var err error
ff, err = s.Repository.FileStore.FindByPath(ctx, f.Path)
ff, err = s.Repository.File.FindByPath(ctx, f.Path)
if err != nil {
return fmt.Errorf("checking for existing file %q: %w", f.Path, err)
}
@ -745,7 +745,7 @@ func (s *scanJob) onNewFile(ctx context.Context, f scanFile) (models.File, error
// if not renamed, queue file for creation
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Create(ctx, file); err != nil {
if err := s.Repository.File.Create(ctx, file); err != nil {
return fmt.Errorf("creating file %q: %w", path, err)
}
@ -838,7 +838,7 @@ func (s *scanJob) handleRename(ctx context.Context, f models.File, fp []models.F
var others []models.File
for _, tfp := range fp {
thisOthers, err := s.Repository.FileStore.FindByFingerprint(ctx, tfp)
thisOthers, err := s.Repository.File.FindByFingerprint(ctx, tfp)
if err != nil {
return nil, fmt.Errorf("getting files by fingerprint %v: %w", tfp, err)
}
@ -896,12 +896,12 @@ func (s *scanJob) handleRename(ctx context.Context, f models.File, fp []models.F
fBase.Fingerprints = otherBase.Fingerprints
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Update(ctx, f); err != nil {
if err := s.Repository.File.Update(ctx, f); err != nil {
return fmt.Errorf("updating file for rename %q: %w", fBase.Path, err)
}
if s.isZipFile(fBase.Basename) {
if err := TransferZipFolderHierarchy(ctx, s.Repository.FolderStore, fBase.ID, otherBase.Path, fBase.Path); err != nil {
if err := TransferZipFolderHierarchy(ctx, s.Repository.Folder, fBase.ID, otherBase.Path, fBase.Path); err != nil {
return fmt.Errorf("moving folder hierarchy for renamed zip file %q: %w", fBase.Path, err)
}
}
@ -963,7 +963,7 @@ func (s *scanJob) setMissingMetadata(ctx context.Context, f scanFile, existing m
// queue file for update
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Update(ctx, existing); err != nil {
if err := s.Repository.File.Update(ctx, existing); err != nil {
return fmt.Errorf("updating file %q: %w", path, err)
}
@ -986,7 +986,7 @@ func (s *scanJob) setMissingFingerprints(ctx context.Context, f scanFile, existi
existing.SetFingerprints(fp)
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Update(ctx, existing); err != nil {
if err := s.Repository.File.Update(ctx, existing); err != nil {
return fmt.Errorf("updating file %q: %w", f.Path, err)
}
@ -1035,7 +1035,7 @@ func (s *scanJob) onExistingFile(ctx context.Context, f scanFile, existing model
// queue file for update
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Update(ctx, existing); err != nil {
if err := s.Repository.File.Update(ctx, existing); err != nil {
return fmt.Errorf("updating file %q: %w", path, err)
}

View file

@ -157,19 +157,19 @@ var getStudioScenarios = []stringTestScenario{
}
func TestGetStudioName(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
studioErr := errors.New("error getting image")
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{
Name: studioName,
}, nil).Once()
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios {
gallery := s.input
json, err := GetStudioName(testCtx, mockStudioReader, &gallery)
json, err := GetStudioName(testCtx, db.Studio, &gallery)
switch {
case !s.err && err != nil:
@ -181,7 +181,7 @@ func TestGetStudioName(t *testing.T) {
}
}
mockStudioReader.AssertExpectations(t)
db.AssertExpectations(t)
}
const (
@ -258,17 +258,17 @@ var validChapters = []*models.GalleryChapter{
}
func TestGetGalleryChaptersJSON(t *testing.T) {
mockChapterReader := &mocks.GalleryChapterReaderWriter{}
db := mocks.NewDatabase()
chaptersErr := errors.New("error getting gallery chapters")
mockChapterReader.On("FindByGalleryID", testCtx, galleryID).Return(validChapters, nil).Once()
mockChapterReader.On("FindByGalleryID", testCtx, noChaptersID).Return(nil, nil).Once()
mockChapterReader.On("FindByGalleryID", testCtx, errChaptersID).Return(nil, chaptersErr).Once()
db.GalleryChapter.On("FindByGalleryID", testCtx, galleryID).Return(validChapters, nil).Once()
db.GalleryChapter.On("FindByGalleryID", testCtx, noChaptersID).Return(nil, nil).Once()
db.GalleryChapter.On("FindByGalleryID", testCtx, errChaptersID).Return(nil, chaptersErr).Once()
for i, s := range getGalleryChaptersJSONScenarios {
gallery := s.input
json, err := GetGalleryChaptersJSON(testCtx, mockChapterReader, &gallery)
json, err := GetGalleryChaptersJSON(testCtx, db.GalleryChapter, &gallery)
switch {
case !s.err && err != nil:
@ -280,4 +280,5 @@ func TestGetGalleryChaptersJSON(t *testing.T) {
}
}
db.AssertExpectations(t)
}

View file

@ -78,19 +78,19 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Gallery{
Studio: existingStudioName,
},
}
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@ -100,22 +100,22 @@ func TestImporterPreImportWithStudio(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Gallery{
Studio: missingStudioName,
},
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
@ -132,32 +132,34 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.gallery.StudioID)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Gallery{
Studio: missingStudioName,
},
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Gallery{
Performers: []string{
@ -166,13 +168,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
},
}
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{
ID: existingPerformerID,
Name: existingPerformerName,
},
}, nil).Once()
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@ -182,14 +184,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
Input: jsonschema.Gallery{
Performers: []string{
missingPerformerName,
@ -198,8 +200,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
performer := args.Get(1).(*models.Performer)
performer.ID = existingPerformerID
}).Return(nil)
@ -216,14 +218,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []int{existingPerformerID}, i.gallery.PerformerIDs.List())
performerReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
Input: jsonschema.Gallery{
Performers: []string{
missingPerformerName,
@ -232,18 +234,20 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Gallery{
Tags: []string{
@ -252,13 +256,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@ -268,14 +272,14 @@ func TestImporterPreImportWithTag(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
Input: jsonschema.Gallery{
Tags: []string{
missingTagName,
@ -284,8 +288,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = existingTagID
}).Return(nil)
@ -302,14 +306,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []int{existingTagID}, i.gallery.TagIDs.List())
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
Input: jsonschema.Gallery{
Tags: []string{
missingTagName,
@ -318,9 +322,11 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}

View file

@ -130,19 +130,19 @@ var getStudioScenarios = []stringTestScenario{
}
func TestGetStudioName(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
studioErr := errors.New("error getting image")
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{
Name: studioName,
}, nil).Once()
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios {
image := s.input
json, err := GetStudioName(testCtx, mockStudioReader, &image)
json, err := GetStudioName(testCtx, db.Studio, &image)
switch {
case !s.err && err != nil:
@ -154,5 +154,5 @@ func TestGetStudioName(t *testing.T) {
}
}
mockStudioReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -40,19 +40,19 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Image{
Studio: existingStudioName,
},
}
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@ -62,22 +62,22 @@ func TestImporterPreImportWithStudio(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Image{
Studio: missingStudioName,
},
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
@ -94,32 +94,34 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.image.StudioID)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Image{
Studio: missingStudioName,
},
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Image{
Performers: []string{
@ -128,13 +130,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
},
}
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{
ID: existingPerformerID,
Name: existingPerformerName,
},
}, nil).Once()
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@ -144,14 +146,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
Input: jsonschema.Image{
Performers: []string{
missingPerformerName,
@ -160,8 +162,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
performer := args.Get(1).(*models.Performer)
performer.ID = existingPerformerID
}).Return(nil)
@ -178,14 +180,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []int{existingPerformerID}, i.image.PerformerIDs.List())
performerReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
Input: jsonschema.Image{
Performers: []string{
missingPerformerName,
@ -194,18 +196,20 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Image{
Tags: []string{
@ -214,13 +218,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@ -230,14 +234,14 @@ func TestImporterPreImportWithTag(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
Input: jsonschema.Image{
Tags: []string{
missingTagName,
@ -246,8 +250,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = existingTagID
}).Return(nil)
@ -264,14 +268,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []int{existingTagID}, i.image.TagIDs.List())
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
Input: jsonschema.Image{
Tags: []string{
missingTagName,
@ -280,9 +284,11 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}

View file

@ -0,0 +1,107 @@
package mocks
import (
"context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stretchr/testify/mock"
)
type Database struct {
File *FileReaderWriter
Folder *FolderReaderWriter
Gallery *GalleryReaderWriter
GalleryChapter *GalleryChapterReaderWriter
Image *ImageReaderWriter
Movie *MovieReaderWriter
Performer *PerformerReaderWriter
Scene *SceneReaderWriter
SceneMarker *SceneMarkerReaderWriter
Studio *StudioReaderWriter
Tag *TagReaderWriter
SavedFilter *SavedFilterReaderWriter
}
func (*Database) Begin(ctx context.Context, exclusive bool) (context.Context, error) {
return ctx, nil
}
func (*Database) WithDatabase(ctx context.Context) (context.Context, error) {
return ctx, nil
}
func (*Database) Commit(ctx context.Context) error {
return nil
}
func (*Database) Rollback(ctx context.Context) error {
return nil
}
func (*Database) Complete(ctx context.Context) {
}
func (*Database) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) {
}
func (*Database) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) {
}
func (*Database) IsLocked(err error) bool {
return false
}
func (*Database) Reset() error {
return nil
}
func NewDatabase() *Database {
return &Database{
File: &FileReaderWriter{},
Folder: &FolderReaderWriter{},
Gallery: &GalleryReaderWriter{},
GalleryChapter: &GalleryChapterReaderWriter{},
Image: &ImageReaderWriter{},
Movie: &MovieReaderWriter{},
Performer: &PerformerReaderWriter{},
Scene: &SceneReaderWriter{},
SceneMarker: &SceneMarkerReaderWriter{},
Studio: &StudioReaderWriter{},
Tag: &TagReaderWriter{},
SavedFilter: &SavedFilterReaderWriter{},
}
}
func (db *Database) AssertExpectations(t mock.TestingT) {
db.File.AssertExpectations(t)
db.Folder.AssertExpectations(t)
db.Gallery.AssertExpectations(t)
db.GalleryChapter.AssertExpectations(t)
db.Image.AssertExpectations(t)
db.Movie.AssertExpectations(t)
db.Performer.AssertExpectations(t)
db.Scene.AssertExpectations(t)
db.SceneMarker.AssertExpectations(t)
db.Studio.AssertExpectations(t)
db.Tag.AssertExpectations(t)
db.SavedFilter.AssertExpectations(t)
}
func (db *Database) Repository() models.Repository {
return models.Repository{
TxnManager: db,
File: db.File,
Folder: db.Folder,
Gallery: db.Gallery,
GalleryChapter: db.GalleryChapter,
Image: db.Image,
Movie: db.Movie,
Performer: db.Performer,
Scene: db.Scene,
SceneMarker: db.SceneMarker,
Studio: db.Studio,
Tag: db.Tag,
SavedFilter: db.SavedFilter,
}
}

View file

@ -1,59 +0,0 @@
package mocks
import (
context "context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
type TxnManager struct{}
func (*TxnManager) Begin(ctx context.Context, exclusive bool) (context.Context, error) {
return ctx, nil
}
func (*TxnManager) WithDatabase(ctx context.Context) (context.Context, error) {
return ctx, nil
}
func (*TxnManager) Commit(ctx context.Context) error {
return nil
}
func (*TxnManager) Rollback(ctx context.Context) error {
return nil
}
func (*TxnManager) Complete(ctx context.Context) {
}
func (*TxnManager) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) {
}
func (*TxnManager) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) {
}
func (*TxnManager) IsLocked(err error) bool {
return false
}
func (*TxnManager) Reset() error {
return nil
}
func NewTxnRepository() models.Repository {
return models.Repository{
TxnManager: &TxnManager{},
Gallery: &GalleryReaderWriter{},
GalleryChapter: &GalleryChapterReaderWriter{},
Image: &ImageReaderWriter{},
Movie: &MovieReaderWriter{},
Performer: &PerformerReaderWriter{},
Scene: &SceneReaderWriter{},
SceneMarker: &SceneMarkerReaderWriter{},
Studio: &StudioReaderWriter{},
Tag: &TagReaderWriter{},
SavedFilter: &SavedFilterReaderWriter{},
}
}

View file

@ -49,5 +49,3 @@ func NewMoviePartial() MoviePartial {
UpdatedAt: NewOptionalTime(currentTime),
}
}
var DefaultMovieImage = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAYAAABw4pVUAAAABmJLR0QA/wD/AP+gvaeTAAAACXBIWXMAAA3XAAAN1wFCKJt4AAAAB3RJTUUH4wgVBQsJl1CMZAAAASJJREFUeNrt3N0JwyAYhlEj3cj9R3Cm5rbkqtAP+qrnGaCYHPwJpLlaa++mmLpbAERAgAgIEAEBIiBABERAgAgIEAEBIiBABERAgAgIEAHZuVflj40x4i94zhk9vqsVvEq6AsQqMP1EjORx20OACAgQRRx7T+zzcFBxcjNDfoB4ntQqTm5Awo7MlqywZxcgYQ+RlqywJ3ozJAQCSBiEJSsQA0gYBpDAgAARECACAkRAgAgIEAERECACAmSjUv6eAOSB8m8YIGGzBUjYbAESBgMkbBkDEjZbgITBAClcxiqQvEoatreYIWEBASIgJ4Gkf11ntXH3nS9uxfGWfJ5J9hAgAgJEQAQEiIAAERAgAgJEQAQEiIAAERAgAgJEQAQEiL7qBuc6RKLHxr0CAAAAAElFTkSuQmCC"

View file

@ -1,17 +1,18 @@
package models
import (
"context"
"github.com/stashapp/stash/pkg/txn"
)
type TxnManager interface {
txn.Manager
txn.DatabaseProvider
Reset() error
}
type Repository struct {
TxnManager
TxnManager TxnManager
File FileReaderWriter
Folder FolderReaderWriter
@ -26,3 +27,15 @@ type Repository struct {
Tag TagReaderWriter
SavedFilter SavedFilterReaderWriter
}
func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithTxn(ctx, r.TxnManager, fn)
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithDatabase(ctx, r.TxnManager, fn)
}

View file

@ -168,34 +168,32 @@ func initTestTable() {
func TestToJSON(t *testing.T) {
initTestTable()
mockMovieReader := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
imageErr := errors.New("error getting image")
mockMovieReader.On("GetFrontImage", testCtx, movieID).Return(frontImageBytes, nil).Once()
mockMovieReader.On("GetFrontImage", testCtx, missingStudioMovieID).Return(frontImageBytes, nil).Once()
mockMovieReader.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe()
mockMovieReader.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once()
mockMovieReader.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once()
db.Movie.On("GetFrontImage", testCtx, movieID).Return(frontImageBytes, nil).Once()
db.Movie.On("GetFrontImage", testCtx, missingStudioMovieID).Return(frontImageBytes, nil).Once()
db.Movie.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe()
db.Movie.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once()
db.Movie.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once()
mockMovieReader.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once()
mockMovieReader.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once()
mockMovieReader.On("GetBackImage", testCtx, emptyID).Return(nil, nil).Once()
mockMovieReader.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once()
mockMovieReader.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe()
mockMovieReader.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe()
mockStudioReader := &mocks.StudioReaderWriter{}
db.Movie.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once()
db.Movie.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once()
db.Movie.On("GetBackImage", testCtx, emptyID).Return(nil, nil).Once()
db.Movie.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once()
db.Movie.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe()
db.Movie.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe()
studioErr := errors.New("error getting studio")
mockStudioReader.On("Find", testCtx, studioID).Return(&movieStudio, nil)
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil)
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr)
db.Studio.On("Find", testCtx, studioID).Return(&movieStudio, nil)
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil)
db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr)
for i, s := range scenarios {
movie := s.movie
json, err := ToJSON(testCtx, mockMovieReader, mockStudioReader, &movie)
json, err := ToJSON(testCtx, db.Movie, db.Studio, &movie)
switch {
case !s.err && err != nil:
@ -207,6 +205,5 @@ func TestToJSON(t *testing.T) {
}
}
mockMovieReader.AssertExpectations(t)
mockStudioReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -69,10 +69,11 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
Input: jsonschema.Movie{
Name: movieName,
FrontImage: frontImage,
@ -82,10 +83,10 @@ func TestImporterPreImportWithStudio(t *testing.T) {
},
}
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@ -95,14 +96,15 @@ func TestImporterPreImportWithStudio(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
Input: jsonschema.Movie{
Name: movieName,
FrontImage: frontImage,
@ -111,8 +113,8 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
@ -129,14 +131,15 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.movie.StudioID)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
Input: jsonschema.Movie{
Name: movieName,
FrontImage: frontImage,
@ -145,27 +148,30 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
frontImageData: frontImageBytes,
backImageData: backImageBytes,
}
updateMovieImageErr := errors.New("UpdateImages error")
readerWriter.On("UpdateFrontImage", testCtx, movieID, frontImageBytes).Return(nil).Once()
readerWriter.On("UpdateBackImage", testCtx, movieID, backImageBytes).Return(nil).Once()
readerWriter.On("UpdateFrontImage", testCtx, errImageID, frontImageBytes).Return(updateMovieImageErr).Once()
db.Movie.On("UpdateFrontImage", testCtx, movieID, frontImageBytes).Return(nil).Once()
db.Movie.On("UpdateBackImage", testCtx, movieID, backImageBytes).Return(nil).Once()
db.Movie.On("UpdateFrontImage", testCtx, errImageID, frontImageBytes).Return(updateMovieImageErr).Once()
err := i.PostImport(testCtx, movieID)
assert.Nil(t, err)
@ -173,25 +179,26 @@ func TestImporterPostImport(t *testing.T) {
err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
Input: jsonschema.Movie{
Name: movieName,
},
}
errFindByName := errors.New("FindByName error")
readerWriter.On("FindByName", testCtx, movieName, false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
db.Movie.On("FindByName", testCtx, movieName, false).Return(nil, nil).Once()
db.Movie.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
ID: existingMovieID,
}, nil).Once()
readerWriter.On("FindByName", testCtx, movieNameErr, false).Return(nil, errFindByName).Once()
db.Movie.On("FindByName", testCtx, movieNameErr, false).Return(nil, errFindByName).Once()
id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
@ -207,11 +214,11 @@ func TestImporterFindExistingID(t *testing.T) {
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestCreate(t *testing.T) {
readerWriter := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
movie := models.Movie{
Name: movieName,
@ -222,16 +229,17 @@ func TestCreate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
movie: movie,
}
errCreate := errors.New("Create error")
readerWriter.On("Create", testCtx, &movie).Run(func(args mock.Arguments) {
db.Movie.On("Create", testCtx, &movie).Run(func(args mock.Arguments) {
m := args.Get(1).(*models.Movie)
m.ID = movieID
}).Return(nil).Once()
readerWriter.On("Create", testCtx, &movieErr).Return(errCreate).Once()
db.Movie.On("Create", testCtx, &movieErr).Return(errCreate).Once()
id, err := i.Create(testCtx)
assert.Equal(t, movieID, *id)
@ -242,11 +250,11 @@ func TestCreate(t *testing.T) {
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestUpdate(t *testing.T) {
readerWriter := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
movie := models.Movie{
Name: movieName,
@ -257,7 +265,8 @@ func TestUpdate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
movie: movie,
}
@ -265,7 +274,7 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
movie.ID = movieID
readerWriter.On("Update", testCtx, &movie).Return(nil).Once()
db.Movie.On("Update", testCtx, &movie).Return(nil).Once()
err := i.Update(testCtx, movieID)
assert.Nil(t, err)
@ -274,10 +283,10 @@ func TestUpdate(t *testing.T) {
// need to set id separately
movieErr.ID = errImageID
readerWriter.On("Update", testCtx, &movieErr).Return(errUpdate).Once()
db.Movie.On("Update", testCtx, &movieErr).Return(errUpdate).Once()
err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -203,17 +203,17 @@ func initTestTable() {
func TestToJSON(t *testing.T) {
initTestTable()
mockPerformerReader := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
imageErr := errors.New("error getting image")
mockPerformerReader.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once()
mockPerformerReader.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
mockPerformerReader.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
db.Performer.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once()
db.Performer.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
db.Performer.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
for i, s := range scenarios {
tag := s.input
json, err := ToJSON(testCtx, mockPerformerReader, &tag)
json, err := ToJSON(testCtx, db.Performer, &tag)
switch {
case !s.err && err != nil:
@ -225,5 +225,5 @@ func TestToJSON(t *testing.T) {
}
}
mockPerformerReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -63,10 +63,11 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Performer{
Tags: []string{
@ -75,13 +76,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@ -91,14 +92,15 @@ func TestImporterPreImportWithTag(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
Input: jsonschema.Performer{
Tags: []string{
missingTagName,
@ -107,8 +109,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = existingTagID
}).Return(nil)
@ -125,14 +127,15 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.performer.TagIDs.List()[0])
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
Input: jsonschema.Performer{
Tags: []string{
missingTagName,
@ -141,25 +144,28 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
imageData: imageBytes,
}
updatePerformerImageErr := errors.New("UpdateImage error")
readerWriter.On("UpdateImage", testCtx, performerID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updatePerformerImageErr).Once()
db.Performer.On("UpdateImage", testCtx, performerID, imageBytes).Return(nil).Once()
db.Performer.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updatePerformerImageErr).Once()
err := i.PostImport(testCtx, performerID)
assert.Nil(t, err)
@ -167,14 +173,15 @@ func TestImporterPostImport(t *testing.T) {
err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
Input: jsonschema.Performer{
Name: performerName,
},
@ -195,13 +202,13 @@ func TestImporterFindExistingID(t *testing.T) {
}
errFindByNames := errors.New("FindByNames error")
readerWriter.On("Query", testCtx, performerFilter(performerName), findFilter).Return(nil, 0, nil).Once()
readerWriter.On("Query", testCtx, performerFilter(existingPerformerName), findFilter).Return([]*models.Performer{
db.Performer.On("Query", testCtx, performerFilter(performerName), findFilter).Return(nil, 0, nil).Once()
db.Performer.On("Query", testCtx, performerFilter(existingPerformerName), findFilter).Return([]*models.Performer{
{
ID: existingPerformerID,
},
}, 1, nil).Once()
readerWriter.On("Query", testCtx, performerFilter(performerNameErr), findFilter).Return(nil, 0, errFindByNames).Once()
db.Performer.On("Query", testCtx, performerFilter(performerNameErr), findFilter).Return(nil, 0, errFindByNames).Once()
id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
@ -217,11 +224,11 @@ func TestImporterFindExistingID(t *testing.T) {
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestCreate(t *testing.T) {
readerWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
performer := models.Performer{
Name: performerName,
@ -232,16 +239,17 @@ func TestCreate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
performer: performer,
}
errCreate := errors.New("Create error")
readerWriter.On("Create", testCtx, &performer).Run(func(args mock.Arguments) {
db.Performer.On("Create", testCtx, &performer).Run(func(args mock.Arguments) {
arg := args.Get(1).(*models.Performer)
arg.ID = performerID
}).Return(nil).Once()
readerWriter.On("Create", testCtx, &performerErr).Return(errCreate).Once()
db.Performer.On("Create", testCtx, &performerErr).Return(errCreate).Once()
id, err := i.Create(testCtx)
assert.Equal(t, performerID, *id)
@ -252,11 +260,11 @@ func TestCreate(t *testing.T) {
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestUpdate(t *testing.T) {
readerWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
performer := models.Performer{
Name: performerName,
@ -267,7 +275,8 @@ func TestUpdate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
performer: performer,
}
@ -275,7 +284,7 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
performer.ID = performerID
readerWriter.On("Update", testCtx, &performer).Return(nil).Once()
db.Performer.On("Update", testCtx, &performer).Return(nil).Once()
err := i.Update(testCtx, performerID)
assert.Nil(t, err)
@ -284,10 +293,10 @@ func TestUpdate(t *testing.T) {
// need to set id separately
performerErr.ID = errImageID
readerWriter.On("Update", testCtx, &performerErr).Return(errUpdate).Once()
db.Performer.On("Update", testCtx, &performerErr).Return(errUpdate).Once()
err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -186,17 +186,17 @@ var scenarios = []basicTestScenario{
}
func TestToJSON(t *testing.T) {
mockSceneReader := &mocks.SceneReaderWriter{}
db := mocks.NewDatabase()
imageErr := errors.New("error getting image")
mockSceneReader.On("GetCover", testCtx, sceneID).Return(imageBytes, nil).Once()
mockSceneReader.On("GetCover", testCtx, noImageID).Return(nil, nil).Once()
mockSceneReader.On("GetCover", testCtx, errImageID).Return(nil, imageErr).Once()
db.Scene.On("GetCover", testCtx, sceneID).Return(imageBytes, nil).Once()
db.Scene.On("GetCover", testCtx, noImageID).Return(nil, nil).Once()
db.Scene.On("GetCover", testCtx, errImageID).Return(nil, imageErr).Once()
for i, s := range scenarios {
scene := s.input
json, err := ToBasicJSON(testCtx, mockSceneReader, &scene)
json, err := ToBasicJSON(testCtx, db.Scene, &scene)
switch {
case !s.err && err != nil:
@ -208,7 +208,7 @@ func TestToJSON(t *testing.T) {
}
}
mockSceneReader.AssertExpectations(t)
db.AssertExpectations(t)
}
func createStudioScene(studioID int) models.Scene {
@ -242,19 +242,19 @@ var getStudioScenarios = []stringTestScenario{
}
func TestGetStudioName(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
studioErr := errors.New("error getting image")
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{
Name: studioName,
}, nil).Once()
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios {
scene := s.input
json, err := GetStudioName(testCtx, mockStudioReader, &scene)
json, err := GetStudioName(testCtx, db.Studio, &scene)
switch {
case !s.err && err != nil:
@ -266,7 +266,7 @@ func TestGetStudioName(t *testing.T) {
}
}
mockStudioReader.AssertExpectations(t)
db.AssertExpectations(t)
}
type stringSliceTestScenario struct {
@ -305,17 +305,17 @@ func getTags(names []string) []*models.Tag {
}
func TestGetTagNames(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
tagErr := errors.New("error getting tag")
mockTagReader.On("FindBySceneID", testCtx, sceneID).Return(getTags(names), nil).Once()
mockTagReader.On("FindBySceneID", testCtx, noTagsID).Return(nil, nil).Once()
mockTagReader.On("FindBySceneID", testCtx, errTagsID).Return(nil, tagErr).Once()
db.Tag.On("FindBySceneID", testCtx, sceneID).Return(getTags(names), nil).Once()
db.Tag.On("FindBySceneID", testCtx, noTagsID).Return(nil, nil).Once()
db.Tag.On("FindBySceneID", testCtx, errTagsID).Return(nil, tagErr).Once()
for i, s := range getTagNamesScenarios {
scene := s.input
json, err := GetTagNames(testCtx, mockTagReader, &scene)
json, err := GetTagNames(testCtx, db.Tag, &scene)
switch {
case !s.err && err != nil:
@ -327,7 +327,7 @@ func TestGetTagNames(t *testing.T) {
}
}
mockTagReader.AssertExpectations(t)
db.AssertExpectations(t)
}
type sceneMoviesTestScenario struct {
@ -391,20 +391,21 @@ var getSceneMoviesJSONScenarios = []sceneMoviesTestScenario{
}
func TestGetSceneMoviesJSON(t *testing.T) {
mockMovieReader := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
movieErr := errors.New("error getting movie")
mockMovieReader.On("Find", testCtx, validMovie1).Return(&models.Movie{
db.Movie.On("Find", testCtx, validMovie1).Return(&models.Movie{
Name: movie1Name,
}, nil).Once()
mockMovieReader.On("Find", testCtx, validMovie2).Return(&models.Movie{
db.Movie.On("Find", testCtx, validMovie2).Return(&models.Movie{
Name: movie2Name,
}, nil).Once()
mockMovieReader.On("Find", testCtx, invalidMovie).Return(nil, movieErr).Once()
db.Movie.On("Find", testCtx, invalidMovie).Return(nil, movieErr).Once()
for i, s := range getSceneMoviesJSONScenarios {
scene := s.input
json, err := GetSceneMoviesJSON(testCtx, mockMovieReader, &scene)
json, err := GetSceneMoviesJSON(testCtx, db.Movie, &scene)
switch {
case !s.err && err != nil:
@ -416,7 +417,7 @@ func TestGetSceneMoviesJSON(t *testing.T) {
}
}
mockMovieReader.AssertExpectations(t)
db.AssertExpectations(t)
}
const (
@ -542,27 +543,26 @@ var invalidMarkers2 = []*models.SceneMarker{
}
func TestGetSceneMarkersJSON(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockMarkerReader := &mocks.SceneMarkerReaderWriter{}
db := mocks.NewDatabase()
markersErr := errors.New("error getting scene markers")
tagErr := errors.New("error getting tags")
mockMarkerReader.On("FindBySceneID", testCtx, sceneID).Return(validMarkers, nil).Once()
mockMarkerReader.On("FindBySceneID", testCtx, noMarkersID).Return(nil, nil).Once()
mockMarkerReader.On("FindBySceneID", testCtx, errMarkersID).Return(nil, markersErr).Once()
mockMarkerReader.On("FindBySceneID", testCtx, errFindPrimaryTagID).Return(invalidMarkers1, nil).Once()
mockMarkerReader.On("FindBySceneID", testCtx, errFindByMarkerID).Return(invalidMarkers2, nil).Once()
db.SceneMarker.On("FindBySceneID", testCtx, sceneID).Return(validMarkers, nil).Once()
db.SceneMarker.On("FindBySceneID", testCtx, noMarkersID).Return(nil, nil).Once()
db.SceneMarker.On("FindBySceneID", testCtx, errMarkersID).Return(nil, markersErr).Once()
db.SceneMarker.On("FindBySceneID", testCtx, errFindPrimaryTagID).Return(invalidMarkers1, nil).Once()
db.SceneMarker.On("FindBySceneID", testCtx, errFindByMarkerID).Return(invalidMarkers2, nil).Once()
mockTagReader.On("Find", testCtx, validTagID1).Return(&models.Tag{
db.Tag.On("Find", testCtx, validTagID1).Return(&models.Tag{
Name: validTagName1,
}, nil)
mockTagReader.On("Find", testCtx, validTagID2).Return(&models.Tag{
db.Tag.On("Find", testCtx, validTagID2).Return(&models.Tag{
Name: validTagName2,
}, nil)
mockTagReader.On("Find", testCtx, invalidTagID).Return(nil, tagErr)
db.Tag.On("Find", testCtx, invalidTagID).Return(nil, tagErr)
mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID1).Return([]*models.Tag{
db.Tag.On("FindBySceneMarkerID", testCtx, validMarkerID1).Return([]*models.Tag{
{
Name: validTagName1,
},
@ -570,16 +570,16 @@ func TestGetSceneMarkersJSON(t *testing.T) {
Name: validTagName2,
},
}, nil)
mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID2).Return([]*models.Tag{
db.Tag.On("FindBySceneMarkerID", testCtx, validMarkerID2).Return([]*models.Tag{
{
Name: validTagName2,
},
}, nil)
mockTagReader.On("FindBySceneMarkerID", testCtx, invalidMarkerID2).Return(nil, tagErr).Once()
db.Tag.On("FindBySceneMarkerID", testCtx, invalidMarkerID2).Return(nil, tagErr).Once()
for i, s := range getSceneMarkersJSONScenarios {
scene := s.input
json, err := GetSceneMarkersJSON(testCtx, mockMarkerReader, mockTagReader, &scene)
json, err := GetSceneMarkersJSON(testCtx, db.SceneMarker, db.Tag, &scene)
switch {
case !s.err && err != nil:
@ -591,5 +591,5 @@ func TestGetSceneMarkersJSON(t *testing.T) {
}
}
mockTagReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -410,17 +410,19 @@ type FilenameParser struct {
ParserInput models.SceneParserInput
Filter *models.FindFilterType
whitespaceRE *regexp.Regexp
repository FilenameParserRepository
performerCache map[string]*models.Performer
studioCache map[string]*models.Studio
movieCache map[string]*models.Movie
tagCache map[string]*models.Tag
}
func NewFilenameParser(filter *models.FindFilterType, config models.SceneParserInput) *FilenameParser {
func NewFilenameParser(filter *models.FindFilterType, config models.SceneParserInput, repo FilenameParserRepository) *FilenameParser {
p := &FilenameParser{
Pattern: *filter.Q,
ParserInput: config,
Filter: filter,
repository: repo,
}
p.performerCache = make(map[string]*models.Performer)
@ -457,7 +459,17 @@ type FilenameParserRepository struct {
Tag models.TagQueryer
}
func (p *FilenameParser) Parse(ctx context.Context, repo FilenameParserRepository) ([]*models.SceneParserResult, int, error) {
func NewFilenameParserRepository(repo models.Repository) FilenameParserRepository {
return FilenameParserRepository{
Scene: repo.Scene,
Performer: repo.Performer,
Studio: repo.Studio,
Movie: repo.Movie,
Tag: repo.Tag,
}
}
func (p *FilenameParser) Parse(ctx context.Context) ([]*models.SceneParserResult, int, error) {
// perform the query to find the scenes
mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords)
@ -479,17 +491,17 @@ func (p *FilenameParser) Parse(ctx context.Context, repo FilenameParserRepositor
p.Filter.Q = nil
scenes, total, err := QueryWithCount(ctx, repo.Scene, sceneFilter, p.Filter)
scenes, total, err := QueryWithCount(ctx, p.repository.Scene, sceneFilter, p.Filter)
if err != nil {
return nil, 0, err
}
ret := p.parseScenes(ctx, repo, scenes, mapper)
ret := p.parseScenes(ctx, scenes, mapper)
return ret, total, nil
}
func (p *FilenameParser) parseScenes(ctx context.Context, repo FilenameParserRepository, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult {
func (p *FilenameParser) parseScenes(ctx context.Context, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult {
var ret []*models.SceneParserResult
for _, scene := range scenes {
sceneHolder := mapper.parse(scene)
@ -498,7 +510,7 @@ func (p *FilenameParser) parseScenes(ctx context.Context, repo FilenameParserRep
r := &models.SceneParserResult{
Scene: scene,
}
p.setParserResult(ctx, repo, *sceneHolder, r)
p.setParserResult(ctx, *sceneHolder, r)
ret = append(ret, r)
}
@ -671,7 +683,7 @@ func (p *FilenameParser) setMovies(ctx context.Context, qb MovieNameFinder, h sc
}
}
func (p *FilenameParser) setParserResult(ctx context.Context, repo FilenameParserRepository, h sceneHolder, result *models.SceneParserResult) {
func (p *FilenameParser) setParserResult(ctx context.Context, h sceneHolder, result *models.SceneParserResult) {
if h.result.Title != "" {
title := h.result.Title
title = p.replaceWhitespaceCharacters(title)
@ -692,15 +704,17 @@ func (p *FilenameParser) setParserResult(ctx context.Context, repo FilenameParse
result.Rating = h.result.Rating
}
r := p.repository
if len(h.performers) > 0 {
p.setPerformers(ctx, repo.Performer, h, result)
p.setPerformers(ctx, r.Performer, h, result)
}
if len(h.tags) > 0 {
p.setTags(ctx, repo.Tag, h, result)
p.setTags(ctx, r.Tag, h, result)
}
p.setStudio(ctx, repo.Studio, h, result)
p.setStudio(ctx, r.Studio, h, result)
if len(h.movies) > 0 {
p.setMovies(ctx, repo.Movie, h, result)
p.setMovies(ctx, r.Movie, h, result)
}
}

View file

@ -56,20 +56,19 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
testCtx := context.Background()
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Scene{
Studio: existingStudioName,
},
}
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@ -79,22 +78,22 @@ func TestImporterPreImportWithStudio(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Scene{
Studio: missingStudioName,
},
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
@ -111,32 +110,34 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.scene.StudioID)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Scene{
Studio: missingStudioName,
},
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Scene{
Performers: []string{
@ -145,13 +146,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
},
}
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{
ID: existingPerformerID,
Name: existingPerformerName,
},
}, nil).Once()
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@ -161,14 +162,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
Input: jsonschema.Scene{
Performers: []string{
missingPerformerName,
@ -177,8 +178,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
p := args.Get(1).(*models.Performer)
p.ID = existingPerformerID
}).Return(nil)
@ -195,14 +196,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []int{existingPerformerID}, i.scene.PerformerIDs.List())
performerReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
Input: jsonschema.Scene{
Performers: []string{
missingPerformerName,
@ -211,19 +212,20 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMovie(t *testing.T) {
movieReaderWriter := &mocks.MovieReaderWriter{}
testCtx := context.Background()
db := mocks.NewDatabase()
i := Importer{
MovieWriter: movieReaderWriter,
MovieWriter: db.Movie,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Scene{
Movies: []jsonschema.SceneMovie{
@ -235,11 +237,11 @@ func TestImporterPreImportWithMovie(t *testing.T) {
},
}
movieReaderWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
db.Movie.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
ID: existingMovieID,
Name: existingMovieName,
}, nil).Once()
movieReaderWriter.On("FindByName", testCtx, existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once()
db.Movie.On("FindByName", testCtx, existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@ -249,15 +251,14 @@ func TestImporterPreImportWithMovie(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
movieReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingMovie(t *testing.T) {
movieReaderWriter := &mocks.MovieReaderWriter{}
testCtx := context.Background()
db := mocks.NewDatabase()
i := Importer{
MovieWriter: movieReaderWriter,
MovieWriter: db.Movie,
Input: jsonschema.Scene{
Movies: []jsonschema.SceneMovie{
{
@ -268,8 +269,8 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Times(3)
movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Run(func(args mock.Arguments) {
db.Movie.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Times(3)
db.Movie.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Run(func(args mock.Arguments) {
m := args.Get(1).(*models.Movie)
m.ID = existingMovieID
}).Return(nil)
@ -286,14 +287,14 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, existingMovieID, i.scene.Movies.List()[0].MovieID)
movieReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) {
movieReaderWriter := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
MovieWriter: movieReaderWriter,
MovieWriter: db.Movie,
Input: jsonschema.Scene{
Movies: []jsonschema.SceneMovie{
{
@ -304,18 +305,20 @@ func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Once()
movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Return(errors.New("Create error"))
db.Movie.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Once()
db.Movie.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Scene{
Tags: []string{
@ -324,13 +327,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@ -340,14 +343,14 @@ func TestImporterPreImportWithTag(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
Input: jsonschema.Scene{
Tags: []string{
missingTagName,
@ -356,8 +359,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = existingTagID
}).Return(nil)
@ -374,14 +377,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []int{existingTagID}, i.scene.TagIDs.List())
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
Input: jsonschema.Scene{
Tags: []string{
missingTagName,
@ -390,9 +393,11 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}

View file

@ -1,7 +1,6 @@
package scene
import (
"context"
"errors"
"strconv"
"testing"
@ -105,8 +104,6 @@ func TestUpdater_Update(t *testing.T) {
tagID
)
ctx := context.Background()
performerIDs := []int{performerID}
tagIDs := []int{tagID}
stashID := "stashID"
@ -119,14 +116,15 @@ func TestUpdater_Update(t *testing.T) {
updateErr := errors.New("error updating")
qb := mocks.SceneReaderWriter{}
qb.On("UpdatePartial", ctx, mock.MatchedBy(func(id int) bool {
db := mocks.NewDatabase()
db.Scene.On("UpdatePartial", testCtx, mock.MatchedBy(func(id int) bool {
return id != badUpdateID
}), mock.Anything).Return(validScene, nil)
qb.On("UpdatePartial", ctx, badUpdateID, mock.Anything).Return(nil, updateErr)
db.Scene.On("UpdatePartial", testCtx, badUpdateID, mock.Anything).Return(nil, updateErr)
qb.On("UpdateCover", ctx, sceneID, cover).Return(nil).Once()
qb.On("UpdateCover", ctx, badCoverID, cover).Return(updateErr).Once()
db.Scene.On("UpdateCover", testCtx, sceneID, cover).Return(nil).Once()
db.Scene.On("UpdateCover", testCtx, badCoverID, cover).Return(updateErr).Once()
tests := []struct {
name string
@ -204,7 +202,7 @@ func TestUpdater_Update(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.u.Update(ctx, &qb)
got, err := tt.u.Update(testCtx, db.Scene)
if (err != nil) != tt.wantErr {
t.Errorf("Updater.Update() error = %v, wantErr %v", err, tt.wantErr)
return
@ -215,7 +213,7 @@ func TestUpdater_Update(t *testing.T) {
})
}
qb.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestUpdateSet_UpdateInput(t *testing.T) {

View file

@ -18,7 +18,6 @@ const (
)
type autotagScraper struct {
// repository models.Repository
txnManager txn.Manager
performerReader models.PerformerAutoTagQueryer
studioReader models.StudioAutoTagQueryer
@ -208,9 +207,9 @@ func (s autotagScraper) spec() Scraper {
}
}
func getAutoTagScraper(txnManager txn.Manager, repo Repository, globalConfig GlobalConfig) scraper {
func getAutoTagScraper(repo Repository, globalConfig GlobalConfig) scraper {
base := autotagScraper{
txnManager: txnManager,
txnManager: repo.TxnManager,
performerReader: repo.PerformerFinder,
studioReader: repo.StudioFinder,
tagReader: repo.TagFinder,

View file

@ -77,6 +77,8 @@ type GalleryFinder interface {
}
type Repository struct {
TxnManager models.TxnManager
SceneFinder SceneFinder
GalleryFinder GalleryFinder
TagFinder TagFinder
@ -85,12 +87,27 @@ type Repository struct {
StudioFinder StudioFinder
}
func NewRepository(repo models.Repository) Repository {
return Repository{
TxnManager: repo.TxnManager,
SceneFinder: repo.Scene,
GalleryFinder: repo.Gallery,
TagFinder: repo.Tag,
PerformerFinder: repo.Performer,
MovieFinder: repo.Movie,
StudioFinder: repo.Studio,
}
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
// Cache stores the database of scrapers
type Cache struct {
client *http.Client
scrapers map[string]scraper // Scraper ID -> Scraper
globalConfig GlobalConfig
txnManager txn.Manager
repository Repository
}
@ -122,14 +139,13 @@ func newClient(gc GlobalConfig) *http.Client {
//
// Scraper configurations are loaded from yml files in the provided scrapers
// directory and any subdirectories.
func NewCache(globalConfig GlobalConfig, txnManager txn.Manager, repo Repository) (*Cache, error) {
func NewCache(globalConfig GlobalConfig, repo Repository) (*Cache, error) {
// HTTP Client setup
client := newClient(globalConfig)
ret := &Cache{
client: client,
globalConfig: globalConfig,
txnManager: txnManager,
repository: repo,
}
@ -148,7 +164,7 @@ func (c *Cache) loadScrapers() (map[string]scraper, error) {
// Add built-in scrapers
freeOnes := getFreeonesScraper(c.globalConfig)
autoTag := getAutoTagScraper(c.txnManager, c.repository, c.globalConfig)
autoTag := getAutoTagScraper(c.repository, c.globalConfig)
scrapers[freeOnes.spec().ID] = freeOnes
scrapers[autoTag.spec().ID] = autoTag
@ -369,9 +385,12 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty Scrape
func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error) {
var ret *models.Scene
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.SceneFinder
var err error
ret, err = c.repository.SceneFinder.Find(ctx, sceneID)
ret, err = qb.Find(ctx, sceneID)
if err != nil {
return err
}
@ -380,7 +399,7 @@ func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error)
return fmt.Errorf("scene with id %d not found", sceneID)
}
return ret.LoadURLs(ctx, c.repository.SceneFinder)
return ret.LoadURLs(ctx, qb)
}); err != nil {
return nil, err
}
@ -389,9 +408,12 @@ func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error)
func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery, error) {
var ret *models.Gallery
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.GalleryFinder
var err error
ret, err = c.repository.GalleryFinder.Find(ctx, galleryID)
ret, err = qb.Find(ctx, galleryID)
if err != nil {
return err
}
@ -400,12 +422,12 @@ func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery,
return fmt.Errorf("gallery with id %d not found", galleryID)
}
err = ret.LoadFiles(ctx, c.repository.GalleryFinder)
err = ret.LoadFiles(ctx, qb)
if err != nil {
return err
}
return ret.LoadURLs(ctx, c.repository.GalleryFinder)
return ret.LoadURLs(ctx, qb)
}); err != nil {
return nil, err
}

View file

@ -6,7 +6,6 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
// postScrape handles post-processing of scraped content. If the content
@ -46,8 +45,9 @@ func (c Cache) postScrape(ctx context.Context, content ScrapedContent) (ScrapedC
}
func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerformer) (ScrapedContent, error) {
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
tqb := c.repository.TagFinder
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
tqb := r.TagFinder
tags, err := postProcessTags(ctx, tqb, p.Tags)
if err != nil {
@ -72,8 +72,9 @@ func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerforme
func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (ScrapedContent, error) {
if m.Studio != nil {
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
return match.ScrapedStudio(ctx, c.repository.StudioFinder, m.Studio, nil)
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
return match.ScrapedStudio(ctx, r.StudioFinder, m.Studio, nil)
}); err != nil {
return nil, err
}
@ -113,11 +114,12 @@ func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (Scraped
scene.URLs = []string{*scene.URL}
}
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
pqb := c.repository.PerformerFinder
mqb := c.repository.MovieFinder
tqb := c.repository.TagFinder
sqb := c.repository.StudioFinder
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
pqb := r.PerformerFinder
mqb := r.MovieFinder
tqb := r.TagFinder
sqb := r.StudioFinder
for _, p := range scene.Performers {
if p == nil {
@ -175,10 +177,11 @@ func (c Cache) postScrapeGallery(ctx context.Context, g ScrapedGallery) (Scraped
g.URLs = []string{*g.URL}
}
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
pqb := c.repository.PerformerFinder
tqb := c.repository.TagFinder
sqb := c.repository.StudioFinder
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
pqb := r.PerformerFinder
tqb := r.TagFinder
sqb := r.StudioFinder
for _, p := range g.Performers {
err := match.ScrapedPerformer(ctx, pqb, p, nil)

View file

@ -56,22 +56,37 @@ type TagFinder interface {
}
type Repository struct {
TxnManager models.TxnManager
Scene SceneReader
Performer PerformerReader
Tag TagFinder
Studio StudioReader
}
func NewRepository(repo models.Repository) Repository {
return Repository{
TxnManager: repo.TxnManager,
Scene: repo.Scene,
Performer: repo.Performer,
Tag: repo.Tag,
Studio: repo.Studio,
}
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
// Client represents the client interface to a stash-box server instance.
type Client struct {
client *graphql.Client
txnManager txn.Manager
repository Repository
box models.StashBox
}
// NewClient returns a new instance of a stash-box client.
func NewClient(box models.StashBox, txnManager txn.Manager, repo Repository) *Client {
func NewClient(box models.StashBox, repo Repository) *Client {
authHeader := func(req *http.Request) {
req.Header.Set("ApiKey", box.APIKey)
}
@ -82,7 +97,6 @@ func NewClient(box models.StashBox, txnManager txn.Manager, repo Repository) *Cl
return &Client{
client: client,
txnManager: txnManager,
repository: repo,
box: box,
}
@ -129,8 +143,9 @@ func (c Client) FindStashBoxSceneByFingerprints(ctx context.Context, sceneID int
func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int) ([][]*scraper.ScrapedScene, error) {
var fingerprints [][]*graphql.FingerprintQueryInput
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
qb := c.repository.Scene
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Scene
for _, sceneID := range ids {
scene, err := qb.Find(ctx, sceneID)
@ -142,7 +157,7 @@ func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int)
return fmt.Errorf("scene with id %d not found", sceneID)
}
if err := scene.LoadFiles(ctx, c.repository.Scene); err != nil {
if err := scene.LoadFiles(ctx, r.Scene); err != nil {
return err
}
@ -243,8 +258,9 @@ func (c Client) SubmitStashBoxFingerprints(ctx context.Context, sceneIDs []strin
var fingerprints []graphql.FingerprintSubmission
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
qb := c.repository.Scene
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Scene
for _, sceneID := range ids {
scene, err := qb.Find(ctx, sceneID)
@ -382,9 +398,9 @@ func (c Client) FindStashBoxPerformersByNames(ctx context.Context, performerIDs
}
var performers []*models.Performer
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
qb := c.repository.Performer
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Performer
for _, performerID := range ids {
performer, err := qb.Find(ctx, performerID)
@ -417,8 +433,9 @@ func (c Client) FindStashBoxPerformersByPerformerNames(ctx context.Context, perf
var performers []*models.Performer
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
qb := c.repository.Performer
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Performer
for _, performerID := range ids {
performer, err := qb.Find(ctx, performerID)
@ -739,14 +756,15 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
ss.URL = &s.Urls[0].URL
}
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
pqb := c.repository.Performer
tqb := c.repository.Tag
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
pqb := r.Performer
tqb := r.Tag
if s.Studio != nil {
ss.Studio = studioFragmentToScrapedStudio(*s.Studio)
err := match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio, &c.box.Endpoint)
err := match.ScrapedStudio(ctx, r.Studio, ss.Studio, &c.box.Endpoint)
if err != nil {
return err
}
@ -761,7 +779,7 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
if parentStudio.FindStudio != nil {
ss.Studio.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio.Parent, &c.box.Endpoint)
err = match.ScrapedStudio(ctx, r.Studio, ss.Studio.Parent, &c.box.Endpoint)
if err != nil {
return err
}
@ -809,8 +827,9 @@ func (c Client) FindStashBoxPerformerByID(ctx context.Context, id string) (*mode
ret := performerFragmentToScrapedPerformer(*performer.FindPerformer)
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
err := match.ScrapedPerformer(ctx, c.repository.Performer, ret, &c.box.Endpoint)
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
err := match.ScrapedPerformer(ctx, r.Performer, ret, &c.box.Endpoint)
return err
}); err != nil {
return nil, err
@ -836,8 +855,9 @@ func (c Client) FindStashBoxPerformerByName(ctx context.Context, name string) (*
return nil, nil
}
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
err := match.ScrapedPerformer(ctx, c.repository.Performer, ret, &c.box.Endpoint)
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
err := match.ScrapedPerformer(ctx, r.Performer, ret, &c.box.Endpoint)
return err
}); err != nil {
return nil, err
@ -864,10 +884,11 @@ func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.S
var ret *models.ScrapedStudio
if studio.FindStudio != nil {
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
ret = studioFragmentToScrapedStudio(*studio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ret, &c.box.Endpoint)
err = match.ScrapedStudio(ctx, r.Studio, ret, &c.box.Endpoint)
if err != nil {
return err
}
@ -881,7 +902,7 @@ func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.S
if parentStudio.FindStudio != nil {
ret.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ret.Parent, &c.box.Endpoint)
err = match.ScrapedStudio(ctx, r.Studio, ret.Parent, &c.box.Endpoint)
if err != nil {
return err
}

View file

@ -1176,7 +1176,7 @@ func makeImage(i int) *models.Image {
}
func createImages(ctx context.Context, n int) error {
qb := db.TxnRepository().Image
qb := db.Image
fqb := db.File
for i := 0; i < n; i++ {
@ -1273,7 +1273,7 @@ func makeGallery(i int, includeScenes bool) *models.Gallery {
}
func createGalleries(ctx context.Context, n int) error {
gqb := db.TxnRepository().Gallery
gqb := db.Gallery
fqb := db.File
for i := 0; i < n; i++ {

View file

@ -123,7 +123,7 @@ func (db *Database) IsLocked(err error) bool {
return false
}
func (db *Database) TxnRepository() models.Repository {
func (db *Database) Repository() models.Repository {
return models.Repository{
TxnManager: db,
File: db.File,

View file

@ -1,7 +1,6 @@
package studio
import (
"context"
"errors"
"github.com/stashapp/stash/pkg/models"
@ -162,27 +161,26 @@ func initTestTable() {
func TestToJSON(t *testing.T) {
initTestTable()
ctx := context.Background()
mockStudioReader := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
imageErr := errors.New("error getting image")
mockStudioReader.On("GetImage", ctx, studioID).Return(imageBytes, nil).Once()
mockStudioReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once()
mockStudioReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once()
mockStudioReader.On("GetImage", ctx, missingParentStudioID).Return(imageBytes, nil).Maybe()
mockStudioReader.On("GetImage", ctx, errStudioID).Return(imageBytes, nil).Maybe()
db.Studio.On("GetImage", testCtx, studioID).Return(imageBytes, nil).Once()
db.Studio.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
db.Studio.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
db.Studio.On("GetImage", testCtx, missingParentStudioID).Return(imageBytes, nil).Maybe()
db.Studio.On("GetImage", testCtx, errStudioID).Return(imageBytes, nil).Maybe()
parentStudioErr := errors.New("error getting parent studio")
mockStudioReader.On("Find", ctx, parentStudioID).Return(&parentStudio, nil)
mockStudioReader.On("Find", ctx, missingStudioID).Return(nil, nil)
mockStudioReader.On("Find", ctx, errParentStudioID).Return(nil, parentStudioErr)
db.Studio.On("Find", testCtx, parentStudioID).Return(&parentStudio, nil)
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil)
db.Studio.On("Find", testCtx, errParentStudioID).Return(nil, parentStudioErr)
for i, s := range scenarios {
studio := s.input
json, err := ToJSON(ctx, mockStudioReader, &studio)
json, err := ToJSON(testCtx, db.Studio, &studio)
switch {
case !s.err && err != nil:
@ -194,5 +192,5 @@ func TestToJSON(t *testing.T) {
}
}
mockStudioReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -25,6 +25,8 @@ const (
missingParentStudioName = "existingParentStudioName"
)
var testCtx = context.Background()
func TestImporterName(t *testing.T) {
i := Importer{
Input: jsonschema.Studio{
@ -43,22 +45,21 @@ func TestImporterPreImport(t *testing.T) {
IgnoreAutoTag: autoTagIgnored,
},
}
ctx := context.Background()
err := i.PreImport(ctx)
err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.Input.Image = image
err = i.PreImport(ctx)
err = i.PreImport(testCtx)
assert.Nil(t, err)
i.Input = *createFullJSONStudio(studioName, image, []string{"alias"})
i.Input.ParentStudio = ""
err = i.PreImport(ctx)
err = i.PreImport(testCtx)
assert.Nil(t, err)
expectedStudio := createFullStudio(0, 0)
@ -67,11 +68,10 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPreImportWithParent(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
Input: jsonschema.Studio{
Name: studioName,
Image: image,
@ -79,28 +79,27 @@ func TestImporterPreImportWithParent(t *testing.T) {
},
}
readerWriter.On("FindByName", ctx, existingParentStudioName, false).Return(&models.Studio{
db.Studio.On("FindByName", testCtx, existingParentStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
readerWriter.On("FindByName", ctx, existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
db.Studio.On("FindByName", testCtx, existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(ctx)
err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.studio.ParentID)
i.Input.ParentStudio = existingParentStudioErr
err = i.PreImport(ctx)
err = i.PreImport(testCtx)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingParent(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
Input: jsonschema.Studio{
Name: studioName,
Image: image,
@ -109,33 +108,32 @@ func TestImporterPreImportWithMissingParent(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Times(3)
readerWriter.On("Create", ctx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
err := i.PreImport(ctx)
err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
err = i.PreImport(ctx)
err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
err = i.PreImport(ctx)
err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.studio.ParentID)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
Input: jsonschema.Studio{
Name: studioName,
Image: image,
@ -144,19 +142,20 @@ func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Once()
readerWriter.On("Create", ctx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(ctx)
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
Input: jsonschema.Studio{
Aliases: []string{"alias"},
},
@ -165,56 +164,54 @@ func TestImporterPostImport(t *testing.T) {
updateStudioImageErr := errors.New("UpdateImage error")
readerWriter.On("UpdateImage", ctx, studioID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", ctx, errImageID, imageBytes).Return(updateStudioImageErr).Once()
db.Studio.On("UpdateImage", testCtx, studioID, imageBytes).Return(nil).Once()
db.Studio.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateStudioImageErr).Once()
err := i.PostImport(ctx, studioID)
err := i.PostImport(testCtx, studioID)
assert.Nil(t, err)
err = i.PostImport(ctx, errImageID)
err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
Input: jsonschema.Studio{
Name: studioName,
},
}
errFindByName := errors.New("FindByName error")
readerWriter.On("FindByName", ctx, studioName, false).Return(nil, nil).Once()
readerWriter.On("FindByName", ctx, existingStudioName, false).Return(&models.Studio{
db.Studio.On("FindByName", testCtx, studioName, false).Return(nil, nil).Once()
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
readerWriter.On("FindByName", ctx, studioNameErr, false).Return(nil, errFindByName).Once()
db.Studio.On("FindByName", testCtx, studioNameErr, false).Return(nil, errFindByName).Once()
id, err := i.FindExistingID(ctx)
id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.Nil(t, err)
i.Input.Name = existingStudioName
id, err = i.FindExistingID(ctx)
id, err = i.FindExistingID(testCtx)
assert.Equal(t, existingStudioID, *id)
assert.Nil(t, err)
i.Input.Name = studioNameErr
id, err = i.FindExistingID(ctx)
id, err = i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestCreate(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
studio := models.Studio{
Name: studioName,
@ -225,32 +222,31 @@ func TestCreate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
studio: studio,
}
errCreate := errors.New("Create error")
readerWriter.On("Create", ctx, &studio).Run(func(args mock.Arguments) {
db.Studio.On("Create", testCtx, &studio).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = studioID
}).Return(nil).Once()
readerWriter.On("Create", ctx, &studioErr).Return(errCreate).Once()
db.Studio.On("Create", testCtx, &studioErr).Return(errCreate).Once()
id, err := i.Create(ctx)
id, err := i.Create(testCtx)
assert.Equal(t, studioID, *id)
assert.Nil(t, err)
i.studio = studioErr
id, err = i.Create(ctx)
id, err = i.Create(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestUpdate(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
studio := models.Studio{
Name: studioName,
@ -261,7 +257,7 @@ func TestUpdate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
studio: studio,
}
@ -269,19 +265,19 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
studio.ID = studioID
readerWriter.On("Update", ctx, &studio).Return(nil).Once()
db.Studio.On("Update", testCtx, &studio).Return(nil).Once()
err := i.Update(ctx, studioID)
err := i.Update(testCtx, studioID)
assert.Nil(t, err)
i.studio = studioErr
// need to set id separately
studioErr.ID = errImageID
readerWriter.On("Update", ctx, &studioErr).Return(errUpdate).Once()
db.Studio.On("Update", testCtx, &studioErr).Return(errUpdate).Once()
err = i.Update(ctx, errImageID)
err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -1,7 +1,6 @@
package tag
import (
"context"
"errors"
"github.com/stashapp/stash/pkg/models"
@ -109,35 +108,34 @@ func initTestTable() {
func TestToJSON(t *testing.T) {
initTestTable()
ctx := context.Background()
mockTagReader := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
imageErr := errors.New("error getting image")
aliasErr := errors.New("error getting aliases")
parentsErr := errors.New("error getting parents")
mockTagReader.On("GetAliases", ctx, tagID).Return([]string{"alias"}, nil).Once()
mockTagReader.On("GetAliases", ctx, noImageID).Return(nil, nil).Once()
mockTagReader.On("GetAliases", ctx, errImageID).Return(nil, nil).Once()
mockTagReader.On("GetAliases", ctx, errAliasID).Return(nil, aliasErr).Once()
mockTagReader.On("GetAliases", ctx, withParentsID).Return(nil, nil).Once()
mockTagReader.On("GetAliases", ctx, errParentsID).Return(nil, nil).Once()
db.Tag.On("GetAliases", testCtx, tagID).Return([]string{"alias"}, nil).Once()
db.Tag.On("GetAliases", testCtx, noImageID).Return(nil, nil).Once()
db.Tag.On("GetAliases", testCtx, errImageID).Return(nil, nil).Once()
db.Tag.On("GetAliases", testCtx, errAliasID).Return(nil, aliasErr).Once()
db.Tag.On("GetAliases", testCtx, withParentsID).Return(nil, nil).Once()
db.Tag.On("GetAliases", testCtx, errParentsID).Return(nil, nil).Once()
mockTagReader.On("GetImage", ctx, tagID).Return(imageBytes, nil).Once()
mockTagReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once()
mockTagReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once()
mockTagReader.On("GetImage", ctx, withParentsID).Return(imageBytes, nil).Once()
mockTagReader.On("GetImage", ctx, errParentsID).Return(nil, nil).Once()
db.Tag.On("GetImage", testCtx, tagID).Return(imageBytes, nil).Once()
db.Tag.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
db.Tag.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
db.Tag.On("GetImage", testCtx, withParentsID).Return(imageBytes, nil).Once()
db.Tag.On("GetImage", testCtx, errParentsID).Return(nil, nil).Once()
mockTagReader.On("FindByChildTagID", ctx, tagID).Return(nil, nil).Once()
mockTagReader.On("FindByChildTagID", ctx, noImageID).Return(nil, nil).Once()
mockTagReader.On("FindByChildTagID", ctx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once()
mockTagReader.On("FindByChildTagID", ctx, errParentsID).Return(nil, parentsErr).Once()
mockTagReader.On("FindByChildTagID", ctx, errImageID).Return(nil, nil).Once()
db.Tag.On("FindByChildTagID", testCtx, tagID).Return(nil, nil).Once()
db.Tag.On("FindByChildTagID", testCtx, noImageID).Return(nil, nil).Once()
db.Tag.On("FindByChildTagID", testCtx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once()
db.Tag.On("FindByChildTagID", testCtx, errParentsID).Return(nil, parentsErr).Once()
db.Tag.On("FindByChildTagID", testCtx, errImageID).Return(nil, nil).Once()
for i, s := range scenarios {
tag := s.tag
json, err := ToJSON(ctx, mockTagReader, &tag)
json, err := ToJSON(testCtx, db.Tag, &tag)
switch {
case !s.err && err != nil:
@ -149,5 +147,5 @@ func TestToJSON(t *testing.T) {
}
}
mockTagReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -58,10 +58,10 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Tag,
Input: jsonschema.Tag{
Aliases: []string{"alias"},
},
@ -72,23 +72,23 @@ func TestImporterPostImport(t *testing.T) {
updateTagAliasErr := errors.New("UpdateAlias error")
updateTagParentsErr := errors.New("UpdateParentTags error")
readerWriter.On("UpdateAliases", testCtx, tagID, i.Input.Aliases).Return(nil).Once()
readerWriter.On("UpdateAliases", testCtx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
readerWriter.On("UpdateAliases", testCtx, withParentsID, i.Input.Aliases).Return(nil).Once()
readerWriter.On("UpdateAliases", testCtx, errParentsID, i.Input.Aliases).Return(nil).Once()
db.Tag.On("UpdateAliases", testCtx, tagID, i.Input.Aliases).Return(nil).Once()
db.Tag.On("UpdateAliases", testCtx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
db.Tag.On("UpdateAliases", testCtx, withParentsID, i.Input.Aliases).Return(nil).Once()
db.Tag.On("UpdateAliases", testCtx, errParentsID, i.Input.Aliases).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, tagID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, errAliasID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateTagImageErr).Once()
readerWriter.On("UpdateImage", testCtx, withParentsID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, errParentsID, imageBytes).Return(nil).Once()
db.Tag.On("UpdateImage", testCtx, tagID, imageBytes).Return(nil).Once()
db.Tag.On("UpdateImage", testCtx, errAliasID, imageBytes).Return(nil).Once()
db.Tag.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateTagImageErr).Once()
db.Tag.On("UpdateImage", testCtx, withParentsID, imageBytes).Return(nil).Once()
db.Tag.On("UpdateImage", testCtx, errParentsID, imageBytes).Return(nil).Once()
var parentTags []int
readerWriter.On("UpdateParentTags", testCtx, tagID, parentTags).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, withParentsID, []int{100}).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, errParentsID, []int{100}).Return(updateTagParentsErr).Once()
db.Tag.On("UpdateParentTags", testCtx, tagID, parentTags).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, withParentsID, []int{100}).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, errParentsID, []int{100}).Return(updateTagParentsErr).Once()
readerWriter.On("FindByName", testCtx, "Parent", false).Return(&models.Tag{ID: 100}, nil)
db.Tag.On("FindByName", testCtx, "Parent", false).Return(&models.Tag{ID: 100}, nil)
err := i.PostImport(testCtx, tagID)
assert.Nil(t, err)
@ -106,14 +106,14 @@ func TestImporterPostImport(t *testing.T) {
err = i.PostImport(testCtx, errParentsID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPostImportParentMissing(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Tag,
Input: jsonschema.Tag{},
imageData: imageBytes,
}
@ -133,33 +133,33 @@ func TestImporterPostImportParentMissing(t *testing.T) {
var emptyParents []int
readerWriter.On("UpdateImage", testCtx, mock.Anything, mock.Anything).Return(nil)
readerWriter.On("UpdateAliases", testCtx, mock.Anything, mock.Anything).Return(nil)
db.Tag.On("UpdateImage", testCtx, mock.Anything, mock.Anything).Return(nil)
db.Tag.On("UpdateAliases", testCtx, mock.Anything, mock.Anything).Return(nil)
readerWriter.On("FindByName", testCtx, "Create", false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, "CreateError", false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, "CreateFindError", false).Return(nil, findError).Once()
readerWriter.On("FindByName", testCtx, "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once()
readerWriter.On("FindByName", testCtx, "Fail", false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, "FailFindError", false).Return(nil, findError)
readerWriter.On("FindByName", testCtx, "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once()
readerWriter.On("FindByName", testCtx, "Ignore", false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, "IgnoreFindError", false).Return(nil, findError)
readerWriter.On("FindByName", testCtx, "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once()
db.Tag.On("FindByName", testCtx, "Create", false).Return(nil, nil).Once()
db.Tag.On("FindByName", testCtx, "CreateError", false).Return(nil, nil).Once()
db.Tag.On("FindByName", testCtx, "CreateFindError", false).Return(nil, findError).Once()
db.Tag.On("FindByName", testCtx, "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once()
db.Tag.On("FindByName", testCtx, "Fail", false).Return(nil, nil).Once()
db.Tag.On("FindByName", testCtx, "FailFindError", false).Return(nil, findError)
db.Tag.On("FindByName", testCtx, "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once()
db.Tag.On("FindByName", testCtx, "Ignore", false).Return(nil, nil).Once()
db.Tag.On("FindByName", testCtx, "IgnoreFindError", false).Return(nil, findError)
db.Tag.On("FindByName", testCtx, "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once()
readerWriter.On("UpdateParentTags", testCtx, createID, []int{100}).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, createFoundID, []int{101}).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, failFoundID, []int{102}).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, createID, []int{100}).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, createFoundID, []int{101}).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, failFoundID, []int{102}).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once()
readerWriter.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
return t.Name == "Create"
})).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = 100
}).Return(nil).Once()
readerWriter.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
return t.Name == "CreateError"
})).Return(errors.New("failed creating parent")).Once()
@ -206,25 +206,25 @@ func TestImporterPostImportParentMissing(t *testing.T) {
err = i.PostImport(testCtx, ignoreFoundID)
assert.Nil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Tag,
Input: jsonschema.Tag{
Name: tagName,
},
}
errFindByName := errors.New("FindByName error")
readerWriter.On("FindByName", testCtx, tagName, false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, existingTagName, false).Return(&models.Tag{
db.Tag.On("FindByName", testCtx, tagName, false).Return(nil, nil).Once()
db.Tag.On("FindByName", testCtx, existingTagName, false).Return(&models.Tag{
ID: existingTagID,
}, nil).Once()
readerWriter.On("FindByName", testCtx, tagNameErr, false).Return(nil, errFindByName).Once()
db.Tag.On("FindByName", testCtx, tagNameErr, false).Return(nil, errFindByName).Once()
id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
@ -240,11 +240,11 @@ func TestImporterFindExistingID(t *testing.T) {
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestCreate(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
tag := models.Tag{
Name: tagName,
@ -255,16 +255,16 @@ func TestCreate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Tag,
tag: tag,
}
errCreate := errors.New("Create error")
readerWriter.On("Create", testCtx, &tag).Run(func(args mock.Arguments) {
db.Tag.On("Create", testCtx, &tag).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = tagID
}).Return(nil).Once()
readerWriter.On("Create", testCtx, &tagErr).Return(errCreate).Once()
db.Tag.On("Create", testCtx, &tagErr).Return(errCreate).Once()
id, err := i.Create(testCtx)
assert.Equal(t, tagID, *id)
@ -275,11 +275,11 @@ func TestCreate(t *testing.T) {
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestUpdate(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
tag := models.Tag{
Name: tagName,
@ -290,7 +290,7 @@ func TestUpdate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Tag,
tag: tag,
}
@ -298,7 +298,7 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
tag.ID = tagID
readerWriter.On("Update", testCtx, &tag).Return(nil).Once()
db.Tag.On("Update", testCtx, &tag).Return(nil).Once()
err := i.Update(testCtx, tagID)
assert.Nil(t, err)
@ -307,10 +307,10 @@ func TestUpdate(t *testing.T) {
// need to set id separately
tagErr.ID = errImageID
readerWriter.On("Update", testCtx, &tagErr).Return(errUpdate).Once()
db.Tag.On("Update", testCtx, &tagErr).Return(errUpdate).Once()
err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}

View file

@ -219,8 +219,7 @@ func TestEnsureHierarchy(t *testing.T) {
}
func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, queryChildren bool) {
mockTagReader := &mocks.TagReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
var parentIDs, childIDs []int
find := make(map[int]*models.Tag)
@ -247,15 +246,15 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
if queryParents {
parentIDs = nil
mockTagReader.On("FindByChildTagID", ctx, tc.id).Return(tc.parents, nil).Once()
db.Tag.On("FindByChildTagID", testCtx, tc.id).Return(tc.parents, nil).Once()
}
if queryChildren {
childIDs = nil
mockTagReader.On("FindByParentTagID", ctx, tc.id).Return(tc.children, nil).Once()
db.Tag.On("FindByParentTagID", testCtx, tc.id).Return(tc.children, nil).Once()
}
mockTagReader.On("FindAllAncestors", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
db.Tag.On("FindAllAncestors", testCtx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
return tc.onFindAllAncestors
}, func(ctx context.Context, tagID int, excludeIDs []int) error {
if tc.onFindAllAncestors != nil {
@ -264,7 +263,7 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
return fmt.Errorf("undefined ancestors for: %d", tagID)
}).Maybe()
mockTagReader.On("FindAllDescendants", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
db.Tag.On("FindAllDescendants", testCtx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
return tc.onFindAllDescendants
}, func(ctx context.Context, tagID int, excludeIDs []int) error {
if tc.onFindAllDescendants != nil {
@ -273,7 +272,7 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
return fmt.Errorf("undefined descendants for: %d", tagID)
}).Maybe()
res := ValidateHierarchy(ctx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, mockTagReader)
res := ValidateHierarchy(testCtx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, db.Tag)
assert := assert.New(t)
@ -285,5 +284,5 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
assert.Nil(res)
}
mockTagReader.AssertExpectations(t)
db.AssertExpectations(t)
}