From c571687c99ce88ece9d3c28e83636dffc86bde8b Mon Sep 17 00:00:00 2001 From: InfiniteTF Date: Sun, 14 Nov 2021 21:51:52 +0100 Subject: [PATCH] Resolve performer/studio stashIDs for scraped scenes (#2006) * Resolve performer/studio stashIDs for scraped scenes * Check endpoint when matching stashids --- pkg/match/scraped.go | 36 +++++++++++++++++++++-- pkg/models/mocks/PerformerReaderWriter.go | 23 +++++++++++++++ pkg/models/mocks/StudioReaderWriter.go | 23 +++++++++++++++ pkg/models/performer.go | 1 + pkg/models/studio.go | 1 + pkg/scraper/scrapers.go | 10 +++---- pkg/scraper/stashbox/stash_box.go | 16 +++++----- pkg/sqlite/performer.go | 10 +++++++ pkg/sqlite/studio.go | 10 +++++++ 9 files changed, 116 insertions(+), 14 deletions(-) diff --git a/pkg/match/scraped.go b/pkg/match/scraped.go index 839fe3786..1e9de81e1 100644 --- a/pkg/match/scraped.go +++ b/pkg/match/scraped.go @@ -10,11 +10,27 @@ import ( // ScrapedPerformer matches the provided performer with the // performers in the database and sets the ID field if one is found. -func ScrapedPerformer(qb models.PerformerReader, p *models.ScrapedPerformer) error { +func ScrapedPerformer(qb models.PerformerReader, p *models.ScrapedPerformer, stashBoxEndpoint *string) error { if p.StoredID != nil || p.Name == nil { return nil } + // Check if a performer with the StashID already exists + if stashBoxEndpoint != nil && p.RemoteSiteID != nil { + performers, err := qb.FindByStashID(models.StashID{ + StashID: *p.RemoteSiteID, + Endpoint: *stashBoxEndpoint, + }) + if err != nil { + return err + } + if len(performers) > 0 { + id := strconv.Itoa(performers[0].ID) + p.StoredID = &id + return nil + } + } + performers, err := qb.FindByNames([]string{*p.Name}, true) if err != nil { @@ -33,11 +49,27 @@ func ScrapedPerformer(qb models.PerformerReader, p *models.ScrapedPerformer) err // ScrapedStudio matches the provided studio with the studios // in the database and sets the ID field if one is found. -func ScrapedStudio(qb models.StudioReader, s *models.ScrapedStudio) error { +func ScrapedStudio(qb models.StudioReader, s *models.ScrapedStudio, stashBoxEndpoint *string) error { if s.StoredID != nil { return nil } + // Check if a studio with the StashID already exists + if stashBoxEndpoint != nil && s.RemoteSiteID != nil { + studios, err := qb.FindByStashID(models.StashID{ + StashID: *s.RemoteSiteID, + Endpoint: *stashBoxEndpoint, + }) + if err != nil { + return err + } + if len(studios) > 0 { + id := strconv.Itoa(studios[0].ID) + s.StoredID = &id + return nil + } + } + st, err := studio.ByName(qb, s.Name) if err != nil { diff --git a/pkg/models/mocks/PerformerReaderWriter.go b/pkg/models/mocks/PerformerReaderWriter.go index 986074405..0ccaddb33 100644 --- a/pkg/models/mocks/PerformerReaderWriter.go +++ b/pkg/models/mocks/PerformerReaderWriter.go @@ -243,6 +243,29 @@ func (_m *PerformerReaderWriter) FindBySceneID(sceneID int) ([]*models.Performer return r0, r1 } +// FindByStashID provides a mock function with given fields: stashID +func (_m *PerformerReaderWriter) FindByStashID(stashID models.StashID) ([]*models.Performer, error) { + ret := _m.Called(stashID) + + var r0 []*models.Performer + if rf, ok := ret.Get(0).(func(models.StashID) []*models.Performer); ok { + r0 = rf(stashID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Performer) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(models.StashID) error); ok { + r1 = rf(stashID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindByStashIDStatus provides a mock function with given fields: hasStashID, stashboxEndpoint func (_m *PerformerReaderWriter) FindByStashIDStatus(hasStashID bool, stashboxEndpoint string) ([]*models.Performer, error) { ret := _m.Called(hasStashID, stashboxEndpoint) diff --git a/pkg/models/mocks/StudioReaderWriter.go b/pkg/models/mocks/StudioReaderWriter.go index 3c7b61ab0..c433fe305 100644 --- a/pkg/models/mocks/StudioReaderWriter.go +++ b/pkg/models/mocks/StudioReaderWriter.go @@ -153,6 +153,29 @@ func (_m *StudioReaderWriter) FindByName(name string, nocase bool) (*models.Stud return r0, r1 } +// FindByStashID provides a mock function with given fields: stashID +func (_m *StudioReaderWriter) FindByStashID(stashID models.StashID) ([]*models.Studio, error) { + ret := _m.Called(stashID) + + var r0 []*models.Studio + if rf, ok := ret.Get(0).(func(models.StashID) []*models.Studio); ok { + r0 = rf(stashID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Studio) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(models.StashID) error); ok { + r1 = rf(stashID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindChildren provides a mock function with given fields: id func (_m *StudioReaderWriter) FindChildren(id int) ([]*models.Studio, error) { ret := _m.Called(id) diff --git a/pkg/models/performer.go b/pkg/models/performer.go index ea316be2d..04173b47e 100644 --- a/pkg/models/performer.go +++ b/pkg/models/performer.go @@ -8,6 +8,7 @@ type PerformerReader interface { FindByImageID(imageID int) ([]*Performer, error) FindByGalleryID(galleryID int) ([]*Performer, error) FindByNames(names []string, nocase bool) ([]*Performer, error) + FindByStashID(stashID StashID) ([]*Performer, error) FindByStashIDStatus(hasStashID bool, stashboxEndpoint string) ([]*Performer, error) CountByTagID(tagID int) (int, error) Count() (int, error) diff --git a/pkg/models/studio.go b/pkg/models/studio.go index 6eec0cdf2..e5d6bfb19 100644 --- a/pkg/models/studio.go +++ b/pkg/models/studio.go @@ -5,6 +5,7 @@ type StudioReader interface { FindMany(ids []int) ([]*Studio, error) FindChildren(id int) ([]*Studio, error) FindByName(name string, nocase bool) (*Studio, error) + FindByStashID(stashID StashID) ([]*Studio, error) Count() (int, error) All() ([]*Studio, error) // TODO - this interface is temporary until the filter schema can fully diff --git a/pkg/scraper/scrapers.go b/pkg/scraper/scrapers.go index 2a1fc8efc..590991000 100644 --- a/pkg/scraper/scrapers.go +++ b/pkg/scraper/scrapers.go @@ -347,7 +347,7 @@ func (c Cache) postScrapeScene(ctx context.Context, ret *models.ScrapedScene) er return err } - if err := match.ScrapedPerformer(pqb, p); err != nil { + if err := match.ScrapedPerformer(pqb, p, nil); err != nil { return err } } @@ -366,7 +366,7 @@ func (c Cache) postScrapeScene(ctx context.Context, ret *models.ScrapedScene) er ret.Tags = tags if ret.Studio != nil { - err := match.ScrapedStudio(sqb, ret.Studio) + err := match.ScrapedStudio(sqb, ret.Studio, nil) if err != nil { return err } @@ -392,7 +392,7 @@ func (c Cache) postScrapeGallery(ret *models.ScrapedGallery) error { sqb := r.Studio() for _, p := range ret.Performers { - err := match.ScrapedPerformer(pqb, p) + err := match.ScrapedPerformer(pqb, p, nil) if err != nil { return err } @@ -405,7 +405,7 @@ func (c Cache) postScrapeGallery(ret *models.ScrapedGallery) error { ret.Tags = tags if ret.Studio != nil { - err := match.ScrapedStudio(sqb, ret.Studio) + err := match.ScrapedStudio(sqb, ret.Studio, nil) if err != nil { return err } @@ -599,7 +599,7 @@ func (c Cache) ScrapeMovieURL(url string) (*models.ScrapedMovie, error) { if ret.Studio != nil { if err := c.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { - return match.ScrapedStudio(r.Studio(), ret.Studio) + return match.ScrapedStudio(r.Studio(), ret.Studio, nil) }); err != nil { return nil, err } diff --git a/pkg/scraper/stashbox/stash_box.go b/pkg/scraper/stashbox/stash_box.go index 6973a78c8..cd4638809 100644 --- a/pkg/scraper/stashbox/stash_box.go +++ b/pkg/scraper/stashbox/stash_box.go @@ -21,6 +21,7 @@ import ( type Client struct { client *graphql.Client txnManager models.TransactionManager + box models.StashBox } // NewClient returns a new instance of a stash-box client. @@ -36,6 +37,7 @@ func NewClient(box models.StashBox, txnManager models.TransactionManager) *Clien return &Client{ client: client, txnManager: txnManager, + box: box, } } @@ -54,7 +56,7 @@ func (c Client) QueryStashBoxScene(ctx context.Context, queryStr string) ([]*mod var ret []*models.ScrapedScene for _, s := range sceneFragments { - ss, err := sceneFragmentToScrapedScene(context.TODO(), c.getHTTPClient(), c.txnManager, s) + ss, err := c.sceneFragmentToScrapedScene(context.TODO(), s) if err != nil { return nil, err } @@ -217,7 +219,7 @@ func (c Client) findStashBoxScenesByFingerprints(ctx context.Context, fingerprin sceneFragments := scenes.FindScenesByFullFingerprints for _, s := range sceneFragments { - ss, err := sceneFragmentToScrapedScene(ctx, c.getHTTPClient(), c.txnManager, s) + ss, err := c.sceneFragmentToScrapedScene(ctx, s) if err != nil { return nil, err } @@ -633,7 +635,7 @@ func getFingerprints(scene *graphql.SceneFragment) []*models.StashBoxFingerprint return fingerprints } -func sceneFragmentToScrapedScene(ctx context.Context, client *http.Client, txnManager models.TransactionManager, s *graphql.SceneFragment) (*models.ScrapedScene, error) { +func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.SceneFragment) (*models.ScrapedScene, error) { stashID := s.ID ss := &models.ScrapedScene{ Title: s.Title, @@ -650,10 +652,10 @@ func sceneFragmentToScrapedScene(ctx context.Context, client *http.Client, txnMa if len(s.Images) > 0 { // TODO - #454 code sorts images by aspect ratio according to a wanted // orientation. I'm just grabbing the first for now - ss.Image = getFirstImage(ctx, client, s.Images) + ss.Image = getFirstImage(ctx, c.getHTTPClient(), s.Images) } - if err := txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { pqb := r.Performer() tqb := r.Tag() @@ -665,7 +667,7 @@ func sceneFragmentToScrapedScene(ctx context.Context, client *http.Client, txnMa RemoteSiteID: &studioID, } - err := match.ScrapedStudio(r.Studio(), ss.Studio) + err := match.ScrapedStudio(r.Studio(), ss.Studio, &c.box.Endpoint) if err != nil { return err } @@ -674,7 +676,7 @@ func sceneFragmentToScrapedScene(ctx context.Context, client *http.Client, txnMa for _, p := range s.Performers { sp := performerFragmentToScrapedScenePerformer(p.Performer) - err := match.ScrapedPerformer(pqb, sp) + err := match.ScrapedPerformer(pqb, sp, &c.box.Endpoint) if err != nil { return err } diff --git a/pkg/sqlite/performer.go b/pkg/sqlite/performer.go index 33ad50e43..d33d63539 100644 --- a/pkg/sqlite/performer.go +++ b/pkg/sqlite/performer.go @@ -610,6 +610,16 @@ func (qb *performerQueryBuilder) UpdateStashIDs(performerID int, stashIDs []mode return qb.stashIDRepository().replace(performerID, stashIDs) } +func (qb *performerQueryBuilder) FindByStashID(stashID models.StashID) ([]*models.Performer, error) { + query := selectAll("performers") + ` + LEFT JOIN performer_stash_ids on performer_stash_ids.performer_id = performers.id + WHERE performer_stash_ids.stash_id = ? + AND performer_stash_ids.endpoint = ? + ` + args := []interface{}{stashID.StashID, stashID.Endpoint} + return qb.queryPerformers(query, args) +} + func (qb *performerQueryBuilder) FindByStashIDStatus(hasStashID bool, stashboxEndpoint string) ([]*models.Performer, error) { query := selectAll("performers") + ` LEFT JOIN performer_stash_ids on performer_stash_ids.performer_id = performers.id diff --git a/pkg/sqlite/studio.go b/pkg/sqlite/studio.go index 8198217a6..b0c7745a1 100644 --- a/pkg/sqlite/studio.go +++ b/pkg/sqlite/studio.go @@ -117,6 +117,16 @@ func (qb *studioQueryBuilder) FindByName(name string, nocase bool) (*models.Stud return qb.queryStudio(query, args) } +func (qb *studioQueryBuilder) FindByStashID(stashID models.StashID) ([]*models.Studio, error) { + query := selectAll("studios") + ` + LEFT JOIN studio_stash_ids on studio_stash_ids.studio_id = studios.id + WHERE studio_stash_ids.stash_id = ? + AND studio_stash_ids.endpoint = ? + ` + args := []interface{}{stashID.StashID, stashID.Endpoint} + return qb.queryStudios(query, args) +} + func (qb *studioQueryBuilder) Count() (int, error) { return qb.runCountQuery(qb.buildCountQuery("SELECT studios.id FROM studios"), nil) }