Studio Tagger (#3510)

* Studio image and parent studio support in scene tagger
* Refactor studio backend and add studio tagger
---------
Co-authored-by: WithoutPants <53250216+WithoutPants@users.noreply.github.com>
This commit is contained in:
Flashy78 2023-07-30 16:50:24 -07:00 committed by GitHub
parent d48dbeb864
commit a665a56ef0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
79 changed files with 5224 additions and 1039 deletions

View file

@ -40,6 +40,7 @@ NOTE: The `make` command in Windows will be `mingw32-make` with MinGW. For examp
* `make pre-ui` - Installs the UI dependencies. This only needs to be run once after cloning the repository, or if the dependencies are updated. * `make pre-ui` - Installs the UI dependencies. This only needs to be run once after cloning the repository, or if the dependencies are updated.
* `make generate` - Generates Go and UI GraphQL files. Requires `make pre-ui` to have been run. * `make generate` - Generates Go and UI GraphQL files. Requires `make pre-ui` to have been run.
* `make generate-stash-box-client` - Generate Go files for the Stash-box client code.
* `make ui` - Builds the UI. Requires `make pre-ui` to have been run. * `make ui` - Builds the UI. Requires `make pre-ui` to have been run.
* `make stash` - Builds the `stash` binary (make sure to build the UI as well... see below) * `make stash` - Builds the `stash` binary (make sure to build the UI as well... see below)
* `make stash-release` - Builds a release version the `stash` binary, with debug information removed * `make stash-release` - Builds a release version the `stash` binary, with debug information removed

1
go.mod
View file

@ -10,6 +10,7 @@ require (
github.com/corona10/goimagehash v1.0.3 github.com/corona10/goimagehash v1.0.3
github.com/disintegration/imaging v1.6.0 github.com/disintegration/imaging v1.6.0
github.com/go-chi/chi v4.0.2+incompatible github.com/go-chi/chi v4.0.2+incompatible
github.com/gofrs/uuid v4.4.0+incompatible
github.com/golang-jwt/jwt/v4 v4.0.0 github.com/golang-jwt/jwt/v4 v4.0.0
github.com/golang-migrate/migrate/v4 v4.15.0-beta.1 github.com/golang-migrate/migrate/v4 v4.15.0-beta.1
github.com/gorilla/securecookie v1.1.1 github.com/gorilla/securecookie v1.1.1

2
go.sum
View file

@ -295,6 +295,8 @@ github.com/gocql/gocql v0.0.0-20190301043612-f6df8288f9b4/go.mod h1:4Fw1eo5iaEhD
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA=
github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o= github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o=

View file

@ -74,8 +74,8 @@ models:
model: github.com/stashapp/stash/internal/manager.AutoTagMetadataInput model: github.com/stashapp/stash/internal/manager.AutoTagMetadataInput
CleanMetadataInput: CleanMetadataInput:
model: github.com/stashapp/stash/internal/manager.CleanMetadataInput model: github.com/stashapp/stash/internal/manager.CleanMetadataInput
StashBoxBatchPerformerTagInput: StashBoxBatchTagInput:
model: github.com/stashapp/stash/internal/manager.StashBoxBatchPerformerTagInput model: github.com/stashapp/stash/internal/manager.StashBoxBatchTagInput
SceneStreamEndpoint: SceneStreamEndpoint:
model: github.com/stashapp/stash/internal/manager.SceneStreamEndpoint model: github.com/stashapp/stash/internal/manager.SceneStreamEndpoint
ExportObjectTypeInput: ExportObjectTypeInput:

View file

@ -1,3 +1,18 @@
fragment ScrapedStudioData on ScrapedStudio {
stored_id
name
url
parent {
stored_id
name
url
image
remote_site_id
}
image
remote_site_id
}
fragment ScrapedPerformerData on ScrapedPerformer { fragment ScrapedPerformerData on ScrapedPerformer {
stored_id stored_id
name name
@ -101,6 +116,14 @@ fragment ScrapedSceneStudioData on ScrapedStudio {
stored_id stored_id
name name
url url
parent {
stored_id
name
url
image
remote_site_id
}
image
remote_site_id remote_site_id
} }

View file

@ -4,10 +4,14 @@ mutation SubmitStashBoxFingerprints(
submitStashBoxFingerprints(input: $input) submitStashBoxFingerprints(input: $input)
} }
mutation StashBoxBatchPerformerTag($input: StashBoxBatchPerformerTagInput!) { mutation StashBoxBatchPerformerTag($input: StashBoxBatchTagInput!) {
stashBoxBatchPerformerTag(input: $input) stashBoxBatchPerformerTag(input: $input)
} }
mutation StashBoxBatchStudioTag($input: StashBoxBatchTagInput!) {
stashBoxBatchStudioTag(input: $input)
}
mutation SubmitStashBoxSceneDraft($input: StashBoxDraftSubmissionInput!) { mutation SubmitStashBoxSceneDraft($input: StashBoxDraftSubmissionInput!) {
submitStashBoxSceneDraft(input: $input) submitStashBoxSceneDraft(input: $input)
} }

View file

@ -42,6 +42,15 @@ query ListMovieScrapers {
} }
} }
query ScrapeSingleStudio(
$source: ScraperSourceInput!
$input: ScrapeSingleStudioInput!
) {
scrapeSingleStudio(source: $source, input: $input) {
...ScrapedStudioData
}
}
query ScrapeSinglePerformer( query ScrapeSinglePerformer(
$source: ScraperSourceInput! $source: ScraperSourceInput!
$input: ScrapeSinglePerformerInput! $input: ScrapeSinglePerformerInput!

View file

@ -128,6 +128,12 @@ type Query {
input: ScrapeMultiScenesInput! input: ScrapeMultiScenesInput!
): [[ScrapedScene!]!]! ): [[ScrapedScene!]!]!
"Scrape for a single studio"
scrapeSingleStudio(
source: ScraperSourceInput!
input: ScrapeSingleStudioInput!
): [ScrapedStudio!]!
"Scrape for a single performer" "Scrape for a single performer"
scrapeSinglePerformer( scrapeSinglePerformer(
source: ScraperSourceInput! source: ScraperSourceInput!
@ -416,7 +422,9 @@ type Mutation {
execSQL(sql: String!, args: [Any]): SQLExecResult! execSQL(sql: String!, args: [Any]): SQLExecResult!
"Run batch performer tag task. Returns the job ID." "Run batch performer tag task. Returns the job ID."
stashBoxBatchPerformerTag(input: StashBoxBatchPerformerTagInput!): String! stashBoxBatchPerformerTag(input: StashBoxBatchTagInput!): String!
"Run batch studio tag task. Returns the job ID."
stashBoxBatchStudioTag(input: StashBoxBatchTagInput!): String!
"Enables DLNA for an optional duration. Has no effect if DLNA is enabled by default" "Enables DLNA for an optional duration. Has no effect if DLNA is enabled by default"
enableDLNA(input: EnableDLNAInput!): Boolean! enableDLNA(input: EnableDLNAInput!): Boolean!

View file

@ -48,6 +48,7 @@ type ScrapedStudio {
stored_id: ID stored_id: ID
name: String! name: String!
url: String url: String
parent: ScrapedStudio
image: String image: String
remote_site_id: String remote_site_id: String
@ -148,6 +149,13 @@ input ScrapeMultiScenesInput {
scene_ids: [ID!] scene_ids: [ID!]
} }
input ScrapeSingleStudioInput {
"""
Query can be either a name or a Stash ID
"""
query: String
}
input ScrapeSinglePerformerInput { input ScrapeSinglePerformerInput {
"Instructs to query by string" "Instructs to query by string"
query: String query: String
@ -209,16 +217,22 @@ type StashBoxFingerprint {
duration: Int! duration: Int!
} }
"If neither performer_ids nor performer_names are set, tag all performers" "If neither ids nor names are set, tag all items"
input StashBoxBatchPerformerTagInput { input StashBoxBatchTagInput {
"Stash endpoint to use for the performer tagging" "Stash endpoint to use for the tagging"
endpoint: Int! endpoint: Int!
"Fields to exclude when executing the performer tagging" "Fields to exclude when executing the tagging"
exclude_fields: [String!] exclude_fields: [String!]
"Refresh performers already tagged by StashBox if true. Only tag performers with no StashBox tagging if false" "Refresh items already tagged by StashBox if true. Only tag items with no StashBox tagging if false"
refresh: Boolean! refresh: Boolean!
"If batch adding studios, should their parent studios also be created?"
createParent: Boolean!
"If set, only tag these ids"
ids: [ID!]
"If set, only tag these names"
names: [String!]
"If set, only tag these performer ids" "If set, only tag these performer ids"
performer_ids: [ID!] performer_ids: [ID!] @deprecated(reason: "use ids")
"If set, only tag these performer names" "If set, only tag these performer names"
performer_names: [String!] performer_names: [String!] @deprecated(reason: "use names")
} }

View file

@ -16,6 +16,10 @@ fragment StudioFragment on Studio {
urls { urls {
...URLFragment ...URLFragment
} }
parent {
name
id
}
images { images {
...ImageFragment ...ImageFragment
} }
@ -163,6 +167,12 @@ query FindSceneByID($id: ID!) {
} }
} }
query FindStudio($id: ID, $name: String) {
findStudio(id: $id, name: $name) {
...StudioFragment
}
}
mutation SubmitFingerprint($input: FingerprintSubmission!) { mutation SubmitFingerprint($input: FingerprintSubmission!) {
submitFingerprint(input: $input) submitFingerprint(input: $input)
} }

View file

@ -34,15 +34,16 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st
return &imagePath, nil return &imagePath, nil
} }
func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) (ret []string, err error) { func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) ([]string, error) {
if err := r.withReadTxn(ctx, func(ctx context.Context) error { if !obj.Aliases.Loaded() {
ret, err = r.repository.Studio.GetAliases(ctx, obj.ID) if err := r.withReadTxn(ctx, func(ctx context.Context) error {
return err return obj.LoadAliases(ctx, r.repository.Studio)
}); err != nil { }); err != nil {
return nil, err return nil, err
}
} }
return ret, err return obj.Aliases.List(), nil
} }
func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio, depth *int) (ret int, err error) { func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio, depth *int) (ret int, err error) {
@ -120,16 +121,15 @@ func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) (
} }
func (r *studioResolver) StashIds(ctx context.Context, obj *models.Studio) ([]*models.StashID, error) { func (r *studioResolver) StashIds(ctx context.Context, obj *models.Studio) ([]*models.StashID, error) {
var ret []models.StashID if !obj.StashIDs.Loaded() {
if err := r.withReadTxn(ctx, func(ctx context.Context) error { if err := r.withReadTxn(ctx, func(ctx context.Context) error {
var err error return obj.LoadStashIDs(ctx, r.repository.Studio)
ret, err = r.repository.Studio.GetStashIDs(ctx, obj.ID) }); err != nil {
return err return nil, err
}); err != nil { }
return nil, err
} }
return stashIDsSliceToPtrSlice(ret), nil return stashIDsSliceToPtrSlice(obj.StashIDs.List()), nil
} }
func (r *studioResolver) Rating(ctx context.Context, obj *models.Studio) (*int, error) { func (r *studioResolver) Rating(ctx context.Context, obj *models.Studio) (*int, error) {

View file

@ -32,11 +32,16 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint) return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint)
} }
func (r *mutationResolver) StashBoxBatchPerformerTag(ctx context.Context, input manager.StashBoxBatchPerformerTagInput) (string, error) { func (r *mutationResolver) StashBoxBatchPerformerTag(ctx context.Context, input manager.StashBoxBatchTagInput) (string, error) {
jobID := manager.GetInstance().StashBoxBatchPerformerTag(ctx, input) jobID := manager.GetInstance().StashBoxBatchPerformerTag(ctx, input)
return strconv.Itoa(jobID), nil return strconv.Itoa(jobID), nil
} }
func (r *mutationResolver) StashBoxBatchStudioTag(ctx context.Context, input manager.StashBoxBatchTagInput) (string, error) {
jobID := manager.GetInstance().StashBoxBatchStudioTag(ctx, input)
return strconv.Itoa(jobID), nil
}
func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input StashBoxDraftSubmissionInput) (*string, error) { func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input StashBoxDraftSubmissionInput) (*string, error) {
boxes := config.GetInstance().GetStashBoxes() boxes := config.GetInstance().GetStashBoxes()

View file

@ -6,7 +6,6 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/sliceutil/stringslice"
@ -14,18 +13,54 @@ import (
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
func (r *mutationResolver) getStudio(ctx context.Context, id int) (ret *models.Studio, err error) { func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateInput) (*models.Studio, error) {
s, err := studioFromStudioCreateInput(ctx, input)
if err != nil {
return nil, err
}
// Process the base 64 encoded image string
var imageData []byte
if input.Image != nil {
var err error
imageData, err = utils.ProcessImageInput(ctx, *input.Image)
if err != nil {
return nil, err
}
}
// Start the transaction and save the studio
if err := r.withTxn(ctx, func(ctx context.Context) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Studio.Find(ctx, id) qb := r.repository.Studio
return err
if s.Aliases.Loaded() && len(s.Aliases.List()) > 0 {
if err := studio.EnsureAliasesUnique(ctx, 0, s.Aliases.List(), qb); err != nil {
return err
}
}
err = qb.Create(ctx, s)
if err != nil {
return err
}
if len(imageData) > 0 {
if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil {
return err
}
}
return nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
return ret, nil r.hookExecutor.ExecutePostHooks(ctx, s.ID, plugin.StudioCreatePost, input, nil)
return s, nil
} }
func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateInput) (*models.Studio, error) { func studioFromStudioCreateInput(ctx context.Context, input StudioCreateInput) (*models.Studio, error) {
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx), inputMap: getUpdateInputMap(ctx),
} }
@ -43,143 +78,110 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateI
} }
var err error var err error
newStudio.ParentID, err = translator.intPtrFromString(input.ParentID, "parent_id") newStudio.ParentID, err = translator.intPtrFromString(input.ParentID, "parent_id")
if err != nil { if err != nil {
return nil, fmt.Errorf("converting parent id: %w", err) return nil, fmt.Errorf("converting parent id: %w", err)
} }
// Process the base 64 encoded image string if input.Aliases != nil {
var imageData []byte newStudio.Aliases = models.NewRelatedStrings(input.Aliases)
if input.Image != nil { }
imageData, err = utils.ProcessImageInput(ctx, *input.Image) if input.StashIds != nil {
if err != nil { newStudio.StashIDs = models.NewRelatedStashIDs(stashIDPtrSliceToSlice(input.StashIds))
return nil, err
}
} }
// Start the transaction and save the studio return &newStudio, nil
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Studio
err = qb.Create(ctx, &newStudio)
if err != nil {
return err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateImage(ctx, newStudio.ID, imageData); err != nil {
return err
}
}
// Save the stash_ids
if input.StashIds != nil {
stashIDJoins := stashIDPtrSliceToSlice(input.StashIds)
if err := qb.UpdateStashIDs(ctx, newStudio.ID, stashIDJoins); err != nil {
return err
}
}
if len(input.Aliases) > 0 {
if err := studio.EnsureAliasesUnique(ctx, newStudio.ID, input.Aliases, qb); err != nil {
return err
}
if err := qb.UpdateAliases(ctx, newStudio.ID, input.Aliases); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err
}
r.hookExecutor.ExecutePostHooks(ctx, newStudio.ID, plugin.StudioCreatePost, input, nil)
return r.getStudio(ctx, newStudio.ID)
} }
func (r *mutationResolver) StudioUpdate(ctx context.Context, input StudioUpdateInput) (*models.Studio, error) { func (r *mutationResolver) StudioUpdate(ctx context.Context, input StudioUpdateInput) (*models.Studio, error) {
studioID, err := strconv.Atoi(input.ID) var updatedStudio *models.Studio
if err != nil { var err error
return nil, err
}
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx), inputMap: getNamedUpdateInputMap(ctx, updateInputField),
}
// Populate studio from the input
updatedStudio := models.NewStudioPartial()
updatedStudio.Name = translator.optionalString(input.Name, "name")
updatedStudio.URL = translator.optionalString(input.URL, "url")
updatedStudio.Details = translator.optionalString(input.Details, "details")
updatedStudio.Rating = translator.ratingConversionOptional(input.Rating, input.Rating100)
updatedStudio.IgnoreAutoTag = translator.optionalBool(input.IgnoreAutoTag, "ignore_auto_tag")
updatedStudio.ParentID, err = translator.optionalIntFromString(input.ParentID, "parent_id")
if err != nil {
return nil, fmt.Errorf("converting parent id: %w", err)
} }
s := studioPartialFromStudioUpdateInput(input, &input.ID, translator)
// Process the base 64 encoded image string
var imageData []byte var imageData []byte
imageIncluded := translator.hasField("image") imageIncluded := translator.hasField("image")
if input.Image != nil { if input.Image != nil {
var err error
imageData, err = utils.ProcessImageInput(ctx, *input.Image) imageData, err = utils.ProcessImageInput(ctx, *input.Image)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
// Start the transaction and save the studio // Start the transaction and update the studio
var s *models.Studio
if err := r.withTxn(ctx, func(ctx context.Context) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Studio qb := r.repository.Studio
if err := manager.ValidateModifyStudio(ctx, studioID, updatedStudio, qb); err != nil { if err := studio.ValidateModify(ctx, *s, qb); err != nil {
return err return err
} }
var err error updatedStudio, err = qb.UpdatePartial(ctx, *s)
s, err = qb.UpdatePartial(ctx, studioID, updatedStudio)
if err != nil { if err != nil {
return err return err
} }
// update image table
if imageIncluded { if imageIncluded {
if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil { if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil {
return err return err
} }
} }
// Save the stash_ids
if translator.hasField("stash_ids") {
stashIDJoins := stashIDPtrSliceToSlice(input.StashIds)
if err := qb.UpdateStashIDs(ctx, studioID, stashIDJoins); err != nil {
return err
}
}
if translator.hasField("aliases") {
if err := studio.EnsureAliasesUnique(ctx, studioID, input.Aliases, qb); err != nil {
return err
}
if err := qb.UpdateAliases(ctx, studioID, input.Aliases); err != nil {
return err
}
}
return nil return nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
r.hookExecutor.ExecutePostHooks(ctx, s.ID, plugin.StudioUpdatePost, input, translator.getFields()) r.hookExecutor.ExecutePostHooks(ctx, updatedStudio.ID, plugin.StudioUpdatePost, input, translator.getFields())
return r.getStudio(ctx, s.ID)
return updatedStudio, nil
}
// This is slightly different to studioPartialFromStudioCreateInput in that Name is handled differently
// and ImageIncluded is not hardcoded to true
func studioPartialFromStudioUpdateInput(input StudioUpdateInput, id *string, translator changesetTranslator) *models.StudioPartial {
// Populate studio from the input
updatedStudio := models.StudioPartial{
Name: translator.optionalString(input.Name, "name"),
URL: translator.optionalString(input.URL, "url"),
Details: translator.optionalString(input.Details, "details"),
Rating: translator.ratingConversionOptional(input.Rating, input.Rating100),
IgnoreAutoTag: translator.optionalBool(input.IgnoreAutoTag, "ignore_auto_tag"),
UpdatedAt: models.NewOptionalTime(time.Now()),
}
updatedStudio.ID, _ = strconv.Atoi(*id)
if input.ParentID != nil {
parentID, _ := strconv.Atoi(*input.ParentID)
if parentID > 0 {
// This is to be set directly as we know it has a value and the translator won't have the field
updatedStudio.ParentID = models.NewOptionalInt(parentID)
}
} else {
updatedStudio.ParentID = translator.optionalInt(nil, "parent_id")
}
if translator.hasField("aliases") {
updatedStudio.Aliases = &models.UpdateStrings{
Values: input.Aliases,
Mode: models.RelationshipUpdateModeSet,
}
}
if translator.hasField("stash_ids") {
updatedStudio.StashIDs = &models.UpdateStashIDs{
StashIDs: stashIDPtrSliceToSlice(input.StashIds),
Mode: models.RelationshipUpdateModeSet,
}
}
return &updatedStudio
} }
func (r *mutationResolver) StudioDestroy(ctx context.Context, input StudioDestroyInput) (bool, error) { func (r *mutationResolver) StudioDestroy(ctx context.Context, input StudioDestroyInput) (bool, error) {

View file

@ -327,6 +327,32 @@ func (r *queryResolver) ScrapeMultiScenes(ctx context.Context, source scraper.So
return nil, errors.New("scraper_id or stash_box_index must be set") return nil, errors.New("scraper_id or stash_box_index must be set")
} }
func (r *queryResolver) ScrapeSingleStudio(ctx context.Context, source scraper.Source, input ScrapeSingleStudioInput) ([]*models.ScrapedStudio, error) {
if source.StashBoxIndex != nil {
client, err := r.getStashBoxClient(*source.StashBoxIndex)
if err != nil {
return nil, err
}
var ret []*models.ScrapedStudio
out, err := client.FindStashBoxStudio(ctx, *input.Query)
if err != nil {
return nil, err
} else if out != nil {
ret = append(ret, out)
}
if len(ret) > 0 {
return ret, nil
}
return nil, nil
}
return nil, errors.New("stash_box_index must be set")
}
func (r *queryResolver) ScrapeSinglePerformer(ctx context.Context, source scraper.Source, input ScrapeSinglePerformerInput) ([]*models.ScrapedPerformer, error) { func (r *queryResolver) ScrapeSinglePerformer(ctx context.Context, source scraper.Source, input ScrapeSinglePerformerInput) ([]*models.ScrapedPerformer, error) {
if source.ScraperID != nil { if source.ScraperID != nil {
if input.PerformerInput != nil { if input.PerformerInput != nil {

View file

@ -1,8 +1,9 @@
package urlbuilders package urlbuilders
import ( import (
"github.com/stashapp/stash/pkg/models"
"strconv" "strconv"
"github.com/stashapp/stash/pkg/models"
) )
type StudioURLBuilder struct { type StudioURLBuilder struct {

View file

@ -44,7 +44,7 @@ type ScraperSource struct {
type SceneIdentifier struct { type SceneIdentifier struct {
SceneReaderUpdater SceneReaderUpdater SceneReaderUpdater SceneReaderUpdater
StudioCreator StudioCreator StudioReaderWriter models.StudioReaderWriter
PerformerCreator PerformerCreator PerformerCreator PerformerCreator
TagCreatorFinder TagCreatorFinder TagCreatorFinder TagCreatorFinder
@ -174,7 +174,7 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
rel := sceneRelationships{ rel := sceneRelationships{
sceneReader: t.SceneReaderUpdater, sceneReader: t.SceneReaderUpdater,
studioCreator: t.StudioCreator, studioReaderWriter: t.StudioReaderWriter,
performerCreator: t.PerformerCreator, performerCreator: t.PerformerCreator,
tagCreatorFinder: t.TagCreatorFinder, tagCreatorFinder: t.TagCreatorFinder,
scene: s, scene: s,

View file

@ -34,7 +34,7 @@ type TagCreatorFinder interface {
type sceneRelationships struct { type sceneRelationships struct {
sceneReader SceneReaderUpdater sceneReader SceneReaderUpdater
studioCreator StudioCreator studioReaderWriter models.StudioReaderWriter
performerCreator PerformerCreator performerCreator PerformerCreator
tagCreatorFinder TagCreatorFinder tagCreatorFinder TagCreatorFinder
scene *models.Scene scene *models.Scene
@ -67,7 +67,7 @@ func (g sceneRelationships) studio(ctx context.Context) (*int, error) {
return &studioID, nil return &studioID, nil
} }
} else if createMissing { } else if createMissing {
return createMissingStudio(ctx, endpoint, g.studioCreator, scraped) return createMissingStudio(ctx, endpoint, g.studioReaderWriter, scraped)
} }
return nil, nil return nil, nil

View file

@ -16,6 +16,7 @@ import (
func Test_sceneRelationships_studio(t *testing.T) { func Test_sceneRelationships_studio(t *testing.T) {
validStoredID := "1" validStoredID := "1"
remoteSiteID := "2"
var validStoredIDInt = 1 var validStoredIDInt = 1
invalidStoredID := "invalidStoredID" invalidStoredID := "invalidStoredID"
createMissing := true createMissing := true
@ -31,8 +32,8 @@ func Test_sceneRelationships_studio(t *testing.T) {
}).Return(nil) }).Return(nil)
tr := sceneRelationships{ tr := sceneRelationships{
studioCreator: mockStudioReaderWriter, studioReaderWriter: mockStudioReaderWriter,
fieldOptions: make(map[string]*FieldOptions), fieldOptions: make(map[string]*FieldOptions),
} }
tests := []struct { tests := []struct {
@ -110,7 +111,7 @@ func Test_sceneRelationships_studio(t *testing.T) {
Strategy: FieldStrategyMerge, Strategy: FieldStrategyMerge,
CreateMissing: &createMissing, CreateMissing: &createMissing,
}, },
&models.ScrapedStudio{}, &models.ScrapedStudio{RemoteSiteID: &remoteSiteID},
&validStoredIDInt, &validStoredIDInt,
false, false,
}, },
@ -120,6 +121,9 @@ func Test_sceneRelationships_studio(t *testing.T) {
tr.scene = tt.scene tr.scene = tt.scene
tr.fieldOptions["studio"] = tt.fieldOptions tr.fieldOptions["studio"] = tt.fieldOptions
tr.result = &scrapeResult{ tr.result = &scrapeResult{
source: ScraperSource{
RemoteSite: "endpoint",
},
result: &scraper.ScrapedScene{ result: &scraper.ScrapedScene{
Studio: tt.result, Studio: tt.result,
}, },

View file

@ -2,64 +2,95 @@ package identify
import ( import (
"context" "context"
"fmt" "strconv"
"time"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/studio"
) )
type StudioCreator interface { func createMissingStudio(ctx context.Context, endpoint string, w models.StudioReaderWriter, s *models.ScrapedStudio) (*int, error) {
Create(ctx context.Context, newStudio *models.Studio) error var err error
UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error
UpdateImage(ctx context.Context, studioID int, image []byte) error
}
func createMissingStudio(ctx context.Context, endpoint string, w StudioCreator, studio *models.ScrapedStudio) (*int, error) { if s.Parent != nil {
studioInput := scrapedToStudioInput(studio) if s.Parent.StoredID == nil {
err := w.Create(ctx, &studioInput) // The parent needs to be created
newParentStudio := s.Parent.ToStudio(endpoint, nil)
parentImage, err := s.Parent.GetImage(ctx, nil)
if err != nil {
logger.Errorf("Failed to make parent studio from scraped studio %s: %s", s.Parent.Name, err.Error())
return nil, err
}
// Create the studio
err = w.Create(ctx, newParentStudio)
if err != nil {
return nil, err
}
// Update image table
if len(parentImage) > 0 {
if err := w.UpdateImage(ctx, newParentStudio.ID, parentImage); err != nil {
return nil, err
}
}
storedId := strconv.Itoa(newParentStudio.ID)
s.Parent.StoredID = &storedId
} else {
// The parent studio matched an existing one and the user has chosen in the UI to link and/or update it
existingStashIDs := getStashIDsForStudio(ctx, *s.Parent.StoredID, w)
studioPartial := s.Parent.ToPartial(s.Parent.StoredID, endpoint, nil, existingStashIDs)
parentImage, err := s.Parent.GetImage(ctx, nil)
if err != nil {
return nil, err
}
if err := studio.ValidateModify(ctx, *studioPartial, w); err != nil {
return nil, err
}
_, err = w.UpdatePartial(ctx, *studioPartial)
if err != nil {
return nil, err
}
if len(parentImage) > 0 {
if err := w.UpdateImage(ctx, studioPartial.ID, parentImage); err != nil {
return nil, err
}
}
}
}
newStudio := s.ToStudio(endpoint, nil)
studioImage, err := s.GetImage(ctx, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating studio: %w", err) return nil, err
} }
// update image table err = w.Create(ctx, newStudio)
if studio.Image != nil && len(*studio.Image) > 0 { if err != nil {
imageData, err := utils.ReadImageFromURL(ctx, *studio.Image) return nil, err
if err != nil { }
return nil, err
}
err = w.UpdateImage(ctx, studioInput.ID, imageData) // Update image table
if err != nil { if len(studioImage) > 0 {
if err := w.UpdateImage(ctx, newStudio.ID, studioImage); err != nil {
return nil, err return nil, err
} }
} }
if endpoint != "" && studio.RemoteSiteID != nil { return &newStudio.ID, nil
if err := w.UpdateStashIDs(ctx, studioInput.ID, []models.StashID{
{
Endpoint: endpoint,
StashID: *studio.RemoteSiteID,
},
}); err != nil {
return nil, fmt.Errorf("error setting studio stash id: %w", err)
}
}
return &studioInput.ID, nil
} }
func scrapedToStudioInput(studio *models.ScrapedStudio) models.Studio { func getStashIDsForStudio(ctx context.Context, studioID string, w models.StudioReaderWriter) []models.StashID {
currentTime := time.Now() id, _ := strconv.Atoi(studioID)
ret := models.Studio{ tempStudio := &models.Studio{ID: id}
Name: studio.Name,
CreatedAt: currentTime,
UpdatedAt: currentTime,
}
if studio.URL != nil { err := tempStudio.LoadStashIDs(ctx, w)
ret.URL = *studio.URL if err != nil {
return nil
} }
return tempStudio.StashIDs.List()
return ret
} }

View file

@ -4,7 +4,6 @@ import (
"errors" "errors"
"reflect" "reflect"
"testing" "testing"
"time"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks" "github.com/stashapp/stash/pkg/models/mocks"
@ -31,18 +30,32 @@ func Test_createMissingStudio(t *testing.T) {
return p.Name == invalidName return p.Name == invalidName
})).Return(errors.New("error creating studio")) })).Return(errors.New("error creating studio"))
mockStudioReaderWriter.On("UpdateStashIDs", testCtx, createdID, []models.StashID{ mockStudioReaderWriter.On("UpdatePartial", testCtx, models.StudioPartial{
{ ID: createdID,
Endpoint: invalidEndpoint, StashIDs: &models.UpdateStashIDs{
StashID: remoteSiteID, StashIDs: []models.StashID{
{
Endpoint: invalidEndpoint,
StashID: remoteSiteID,
},
},
Mode: models.RelationshipUpdateModeSet,
}, },
}).Return(errors.New("error updating stash ids")) }).Return(nil, errors.New("error updating stash ids"))
mockStudioReaderWriter.On("UpdateStashIDs", testCtx, createdID, []models.StashID{ mockStudioReaderWriter.On("UpdatePartial", testCtx, models.StudioPartial{
{ ID: createdID,
Endpoint: validEndpoint, StashIDs: &models.UpdateStashIDs{
StashID: remoteSiteID, StashIDs: []models.StashID{
{
Endpoint: validEndpoint,
StashID: remoteSiteID,
},
},
Mode: models.RelationshipUpdateModeSet,
}, },
}).Return(nil) }).Return(models.Studio{
ID: createdID,
}, nil)
type args struct { type args struct {
endpoint string endpoint string
@ -59,7 +72,8 @@ func Test_createMissingStudio(t *testing.T) {
args{ args{
emptyEndpoint, emptyEndpoint,
&models.ScrapedStudio{ &models.ScrapedStudio{
Name: validName, Name: validName,
RemoteSiteID: &remoteSiteID,
}, },
}, },
&createdID, &createdID,
@ -70,7 +84,8 @@ func Test_createMissingStudio(t *testing.T) {
args{ args{
emptyEndpoint, emptyEndpoint,
&models.ScrapedStudio{ &models.ScrapedStudio{
Name: invalidName, Name: invalidName,
RemoteSiteID: &remoteSiteID,
}, },
}, },
nil, nil,
@ -88,18 +103,6 @@ func Test_createMissingStudio(t *testing.T) {
&createdID, &createdID,
false, false,
}, },
{
"invalid stash id",
args{
invalidEndpoint,
&models.ScrapedStudio{
Name: validName,
RemoteSiteID: &remoteSiteID,
},
},
nil,
true,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -114,48 +117,3 @@ func Test_createMissingStudio(t *testing.T) {
}) })
} }
} }
func Test_scrapedToStudioInput(t *testing.T) {
const name = "name"
url := "url"
tests := []struct {
name string
studio *models.ScrapedStudio
want models.Studio
}{
{
"set all",
&models.ScrapedStudio{
Name: name,
URL: &url,
},
models.Studio{
Name: name,
URL: url,
},
},
{
"set none",
&models.ScrapedStudio{
Name: name,
},
models.Studio{
Name: name,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := scrapedToStudioInput(tt.studio)
// clear created/updated dates
got.CreatedAt = time.Time{}
got.UpdatedAt = got.CreatedAt
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("scrapedToStudioInput() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -321,21 +321,31 @@ func (s *Manager) MigrateHash(ctx context.Context) int {
return s.JobManager.Add(ctx, "Migrating scene hashes...", j) return s.JobManager.Add(ctx, "Migrating scene hashes...", j)
} }
// If neither performer_ids nor performer_names are set, tag all performers // If neither ids nor names are set, tag all items
type StashBoxBatchPerformerTagInput struct { type StashBoxBatchTagInput struct {
// Stash endpoint to use for the performer tagging // Stash endpoint to use for the tagging
Endpoint int `json:"endpoint"` Endpoint int `json:"endpoint"`
// Fields to exclude when executing the performer tagging // Fields to exclude when executing the tagging
ExcludeFields []string `json:"exclude_fields"` ExcludeFields []string `json:"exclude_fields"`
// Refresh performers already tagged by StashBox if true. Only tag performers with no StashBox tagging if false // Refresh items already tagged by StashBox if true. Only tag items with no StashBox tagging if false
Refresh bool `json:"refresh"` Refresh bool `json:"refresh"`
// If batch adding studios, should their parent studios also be created?
CreateParent bool `json:"createParent"`
// If set, only tag these ids
Ids []string `json:"ids"`
// If set, only tag these names
Names []string `json:"names"`
// If set, only tag these performer ids // If set, only tag these performer ids
//
// Deprecated: please use Ids
PerformerIds []string `json:"performer_ids"` PerformerIds []string `json:"performer_ids"`
// If set, only tag these performer names // If set, only tag these performer names
//
// Deprecated: please use Names
PerformerNames []string `json:"performer_names"` PerformerNames []string `json:"performer_names"`
} }
func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxBatchPerformerTagInput) int { func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxBatchTagInput) int {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
logger.Infof("Initiating stash-box batch performer tag") logger.Infof("Initiating stash-box batch performer tag")
@ -346,7 +356,7 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
} }
box := boxes[input.Endpoint] box := boxes[input.Endpoint]
var tasks []StashBoxPerformerTagTask var tasks []StashBoxBatchTagTask
// The gocritic linter wants to turn this ifElseChain into a switch. // The gocritic linter wants to turn this ifElseChain into a switch.
// however, such a switch would contain quite large blocks for each section // however, such a switch would contain quite large blocks for each section
@ -354,24 +364,35 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
// //
// This is why we mark this section nolint. In principle, we should look to // This is why we mark this section nolint. In principle, we should look to
// rewrite the section at some point, to avoid the linter warning. // rewrite the section at some point, to avoid the linter warning.
if len(input.PerformerIds) > 0 { //nolint:gocritic if len(input.Ids) > 0 || len(input.PerformerIds) > 0 { //nolint:gocritic
// The user has chosen only to tag the items on the current page
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error { if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
performerQuery := s.Repository.Performer performerQuery := s.Repository.Performer
for _, performerID := range input.PerformerIds { idsToUse := input.PerformerIds
if len(input.Ids) > 0 {
idsToUse = input.Ids
}
for _, performerID := range idsToUse {
if id, err := strconv.Atoi(performerID); err == nil { if id, err := strconv.Atoi(performerID); err == nil {
performer, err := performerQuery.Find(ctx, id) performer, err := performerQuery.Find(ctx, id)
if err == nil { if err == nil {
err = performer.LoadStashIDs(ctx, performerQuery) if err := performer.LoadStashIDs(ctx, performerQuery); err != nil {
} return fmt.Errorf("loading performer stash ids: %w", err)
}
if err == nil { // Check if the user wants to refresh existing or new items
tasks = append(tasks, StashBoxPerformerTagTask{ if (input.Refresh && len(performer.StashIDs.List()) > 0) ||
performer: performer, (!input.Refresh && len(performer.StashIDs.List()) == 0) {
refresh: input.Refresh, tasks = append(tasks, StashBoxBatchTagTask{
box: box, performer: performer,
excluded_fields: input.ExcludeFields, refresh: input.Refresh,
}) box: box,
excludedFields: input.ExcludeFields,
taskType: Performer,
})
}
} else { } else {
return err return err
} }
@ -381,14 +402,25 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
}); err != nil { }); err != nil {
logger.Error(err.Error()) logger.Error(err.Error())
} }
} else if len(input.PerformerNames) > 0 { } else if len(input.Names) > 0 || len(input.PerformerNames) > 0 {
for i := range input.PerformerNames { // The user is batch adding performers
if len(input.PerformerNames[i]) > 0 { namesToUse := input.PerformerNames
tasks = append(tasks, StashBoxPerformerTagTask{ if len(input.Names) > 0 {
name: &input.PerformerNames[i], namesToUse = input.Names
refresh: input.Refresh, }
box: box,
excluded_fields: input.ExcludeFields, for i := range namesToUse {
if len(namesToUse[i]) > 0 {
performer := models.Performer{
Name: namesToUse[i],
}
tasks = append(tasks, StashBoxBatchTagTask{
performer: &performer,
refresh: false,
box: box,
excludedFields: input.ExcludeFields,
taskType: Performer,
}) })
} }
} }
@ -397,6 +429,8 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
// However, this doesn't really help with readability of the current section. Mark it // However, this doesn't really help with readability of the current section. Mark it
// as nolint for now. In the future we'd like to rewrite this code by factoring some of // as nolint for now. In the future we'd like to rewrite this code by factoring some of
// this into separate functions. // this into separate functions.
// The user has chosen to tag every item in their database
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error { if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
performerQuery := s.Repository.Performer performerQuery := s.Repository.Performer
var performers []*models.Performer var performers []*models.Performer
@ -406,6 +440,7 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
} else { } else {
performers, err = performerQuery.FindByStashIDStatus(ctx, false, box.Endpoint) performers, err = performerQuery.FindByStashIDStatus(ctx, false, box.Endpoint)
} }
if err != nil { if err != nil {
return fmt.Errorf("error querying performers: %v", err) return fmt.Errorf("error querying performers: %v", err)
} }
@ -415,11 +450,12 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
return fmt.Errorf("error loading stash ids for performer %s: %v", performer.Name, err) return fmt.Errorf("error loading stash ids for performer %s: %v", performer.Name, err)
} }
tasks = append(tasks, StashBoxPerformerTagTask{ tasks = append(tasks, StashBoxBatchTagTask{
performer: performer, performer: performer,
refresh: input.Refresh, refresh: input.Refresh,
box: box, box: box,
excluded_fields: input.ExcludeFields, excludedFields: input.ExcludeFields,
taskType: Performer,
}) })
} }
return nil return nil
@ -451,3 +487,132 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
return s.JobManager.Add(ctx, "Batch stash-box performer tag...", j) return s.JobManager.Add(ctx, "Batch stash-box performer tag...", j)
} }
func (s *Manager) StashBoxBatchStudioTag(ctx context.Context, input StashBoxBatchTagInput) int {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
logger.Infof("Initiating stash-box batch studio tag")
boxes := config.GetInstance().GetStashBoxes()
if input.Endpoint < 0 || input.Endpoint >= len(boxes) {
logger.Error(fmt.Errorf("invalid stash_box_index %d", input.Endpoint))
return
}
box := boxes[input.Endpoint]
var tasks []StashBoxBatchTagTask
// The gocritic linter wants to turn this ifElseChain into a switch.
// however, such a switch would contain quite large blocks for each section
// and would arguably be hard to read.
//
// This is why we mark this section nolint. In principle, we should look to
// rewrite the section at some point, to avoid the linter warning.
if len(input.Ids) > 0 { //nolint:gocritic
// The user has chosen only to tag the items on the current page
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
studioQuery := s.Repository.Studio
for _, studioID := range input.Ids {
if id, err := strconv.Atoi(studioID); err == nil {
studio, err := studioQuery.Find(ctx, id)
if err == nil {
if err := studio.LoadStashIDs(ctx, studioQuery); err != nil {
return fmt.Errorf("loading studio stash ids: %w", err)
}
// Check if the user wants to refresh existing or new items
if (input.Refresh && len(studio.StashIDs.List()) > 0) ||
(!input.Refresh && len(studio.StashIDs.List()) == 0) {
tasks = append(tasks, StashBoxBatchTagTask{
studio: studio,
refresh: input.Refresh,
createParent: input.CreateParent,
box: box,
excludedFields: input.ExcludeFields,
taskType: Studio,
})
}
} else {
return err
}
}
}
return nil
}); err != nil {
logger.Error(err.Error())
}
} else if len(input.Names) > 0 {
// The user is batch adding studios
for i := range input.Names {
if len(input.Names[i]) > 0 {
tasks = append(tasks, StashBoxBatchTagTask{
name: &input.Names[i],
refresh: false,
createParent: input.CreateParent,
box: box,
excludedFields: input.ExcludeFields,
taskType: Studio,
})
}
}
} else { //nolint:gocritic
// The gocritic linter wants to fold this if-block into the else on the line above.
// However, this doesn't really help with readability of the current section. Mark it
// as nolint for now. In the future we'd like to rewrite this code by factoring some of
// this into separate functions.
// The user has chosen to tag every item in their database
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
studioQuery := s.Repository.Studio
var studios []*models.Studio
var err error
if input.Refresh {
studios, err = studioQuery.FindByStashIDStatus(ctx, true, box.Endpoint)
} else {
studios, err = studioQuery.FindByStashIDStatus(ctx, false, box.Endpoint)
}
if err != nil {
return fmt.Errorf("error querying studios: %v", err)
}
for _, studio := range studios {
tasks = append(tasks, StashBoxBatchTagTask{
studio: studio,
refresh: input.Refresh,
createParent: input.CreateParent,
box: box,
excludedFields: input.ExcludeFields,
taskType: Studio,
})
}
return nil
}); err != nil {
logger.Error(err.Error())
return
}
}
if len(tasks) == 0 {
return
}
progress.SetTotal(len(tasks))
logger.Infof("Starting stash-box batch operation for %d studios", len(tasks))
var wg sync.WaitGroup
for _, task := range tasks {
wg.Add(1)
progress.ExecuteTask(task.Description(), func() {
task.Start(ctx)
wg.Done()
})
progress.Increment()
}
})
return s.JobManager.Add(ctx, "Batch stash-box studio tag...", j)
}

View file

@ -1,38 +0,0 @@
package manager
import (
"context"
"errors"
"fmt"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/studio"
)
func ValidateModifyStudio(ctx context.Context, studioID int, studio models.StudioPartial, qb studio.Finder) error {
if studio.ParentID.Ptr() == nil {
return nil
}
// ensure there is no cyclic dependency
currentParentID := studio.ParentID.Ptr()
for currentParentID != nil {
if *currentParentID == studioID {
return errors.New("studio cannot be an ancestor of itself")
}
currentStudio, err := qb.Find(ctx, *currentParentID)
if err != nil {
return fmt.Errorf("error finding parent studio: %v", err)
}
if currentStudio == nil {
return fmt.Errorf("studio with id %d not found", *currentParentID)
}
currentParentID = currentStudio.ParentID
}
return nil
}

View file

@ -134,7 +134,7 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source
j.progress.ExecuteTask("Identifying "+s.Path, func() { j.progress.ExecuteTask("Identifying "+s.Path, func() {
task := identify.SceneIdentifier{ task := identify.SceneIdentifier{
SceneReaderUpdater: instance.Repository.Scene, SceneReaderUpdater: instance.Repository.Scene,
StudioCreator: instance.Repository.Studio, StudioReaderWriter: instance.Repository.Studio,
PerformerCreator: instance.Repository.Performer, PerformerCreator: instance.Repository.Performer,
TagCreatorFinder: instance.Repository.Tag, TagCreatorFinder: instance.Repository.Tag,

View file

@ -10,34 +10,62 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scraper/stashbox" "github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type StashBoxPerformerTagTask struct { type StashBoxTagTaskType int
box *models.StashBox
name *string const (
performer *models.Performer Performer StashBoxTagTaskType = iota
refresh bool Studio
excluded_fields []string )
type StashBoxBatchTagTask struct {
box *models.StashBox
name *string
performer *models.Performer
studio *models.Studio
refresh bool
createParent bool
excludedFields []string
taskType StashBoxTagTaskType
} }
func (t *StashBoxPerformerTagTask) Start(ctx context.Context) { func (t *StashBoxBatchTagTask) Start(ctx context.Context) {
t.stashBoxPerformerTag(ctx) switch t.taskType {
} case Performer:
t.stashBoxPerformerTag(ctx)
func (t *StashBoxPerformerTagTask) Description() string { case Studio:
var name string t.stashBoxStudioTag(ctx)
if t.name != nil { default:
name = *t.name logger.Errorf("Error starting batch task, unknown task_type %d", t.taskType)
} else if t.performer != nil {
name = t.performer.Name
} }
return fmt.Sprintf("Tagging performer %s from stash-box", name)
} }
func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) { func (t *StashBoxBatchTagTask) Description() string {
if t.taskType == Performer {
var name string
if t.name != nil {
name = *t.name
} else {
name = t.performer.Name
}
return fmt.Sprintf("Tagging performer %s from stash-box", name)
} else if t.taskType == Studio {
var name string
if t.name != nil {
name = *t.name
} else {
name = t.studio.Name
}
return fmt.Sprintf("Tagging studio %s from stash-box", name)
}
return fmt.Sprintf("Unknown tagging task type %d from stash-box", t.taskType)
}
func (t *StashBoxBatchTagTask) stashBoxPerformerTag(ctx context.Context) {
var performer *models.ScrapedPerformer var performer *models.ScrapedPerformer
var err error var err error
@ -74,7 +102,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
} }
excluded := map[string]bool{} excluded := map[string]bool{}
for _, field := range t.excluded_fields { for _, field := range t.excludedFields {
excluded[field] = true excluded[field] = true
} }
@ -187,7 +215,246 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
} }
} }
func (t *StashBoxPerformerTagTask) getPartial(performer *models.ScrapedPerformer, excluded map[string]bool) models.PerformerPartial { func (t *StashBoxBatchTagTask) stashBoxStudioTag(ctx context.Context) {
studio, err := t.findStashBoxStudio(ctx)
if err != nil {
logger.Errorf("Error fetching studio data from stash-box: %s", err.Error())
return
}
excluded := map[string]bool{}
for _, field := range t.excludedFields {
excluded[field] = true
}
// studio will have a value if pulling from Stash-box by Stash ID or name was successful
if studio != nil {
t.processMatchedStudio(ctx, studio, excluded)
} else {
var name string
if t.name != nil {
name = *t.name
} else if t.studio != nil {
name = t.studio.Name
}
logger.Infof("No match found for %s", name)
}
}
func (t *StashBoxBatchTagTask) findStashBoxStudio(ctx context.Context) (*models.ScrapedStudio, error) {
var studio *models.ScrapedStudio
var err error
client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{
Scene: instance.Repository.Scene,
Performer: instance.Repository.Performer,
Tag: instance.Repository.Tag,
Studio: instance.Repository.Studio,
})
if t.refresh {
var remoteID string
txnErr := txn.WithReadTxn(ctx, instance.Repository, func(ctx context.Context) error {
if !t.studio.StashIDs.Loaded() {
err = t.studio.LoadStashIDs(ctx, instance.Repository.Studio)
if err != nil {
return err
}
}
stashids := t.studio.StashIDs.List()
for _, id := range stashids {
if id.Endpoint == t.box.Endpoint {
remoteID = id.StashID
}
}
return nil
})
if txnErr != nil {
logger.Warnf("error while executing read transaction: %v", err)
return nil, err
}
if remoteID != "" {
studio, err = client.FindStashBoxStudio(ctx, remoteID)
}
} else {
var name string
if t.name != nil {
name = *t.name
} else {
name = t.studio.Name
}
studio, err = client.FindStashBoxStudio(ctx, name)
}
return studio, err
}
func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *models.ScrapedStudio, excluded map[string]bool) {
// Refreshing an existing studio
if t.studio != nil {
if s.Parent != nil && t.createParent {
err := t.processParentStudio(ctx, s.Parent, excluded)
if err != nil {
return
}
}
existingStashIDs := getStashIDsForStudio(ctx, *s.StoredID)
studioPartial := s.ToPartial(s.StoredID, t.box.Endpoint, excluded, existingStashIDs)
studioImage, err := s.GetImage(ctx, excluded)
if err != nil {
logger.Errorf("Failed to make studio partial from scraped studio %s: %s", s.Name, err.Error())
return
}
// Start the transaction and update the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Studio
if err := studio.ValidateModify(ctx, *studioPartial, qb); err != nil {
return err
}
if _, err := qb.UpdatePartial(ctx, *studioPartial); err != nil {
return err
}
if len(studioImage) > 0 {
if err := qb.UpdateImage(ctx, studioPartial.ID, studioImage); err != nil {
return err
}
}
return nil
})
if err != nil {
logger.Errorf("Failed to update studio %s: %s", s.Name, err.Error())
} else {
logger.Infof("Updated studio %s", s.Name)
}
} else if t.name != nil && s.Name != "" {
// Creating a new studio
if s.Parent != nil && t.createParent {
err := t.processParentStudio(ctx, s.Parent, excluded)
if err != nil {
return
}
}
newStudio := s.ToStudio(t.box.Endpoint, excluded)
studioImage, err := s.GetImage(ctx, excluded)
if err != nil {
logger.Errorf("Failed to make studio from scraped studio %s: %s", s.Name, err.Error())
return
}
// Start the transaction and save the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Studio
if err := qb.Create(ctx, newStudio); err != nil {
return err
}
if len(studioImage) > 0 {
if err := qb.UpdateImage(ctx, newStudio.ID, studioImage); err != nil {
return err
}
}
return nil
})
if err != nil {
logger.Errorf("Failed to create studio %s: %s", s.Name, err.Error())
} else {
logger.Infof("Created studio %s", s.Name)
}
}
}
func (t *StashBoxBatchTagTask) processParentStudio(ctx context.Context, parent *models.ScrapedStudio, excluded map[string]bool) error {
if parent.StoredID == nil {
// The parent needs to be created
newParentStudio := parent.ToStudio(t.box.Endpoint, excluded)
studioImage, err := parent.GetImage(ctx, excluded)
if err != nil {
logger.Errorf("Failed to make parent studio from scraped studio %s: %s", parent.Name, err.Error())
return err
}
// Start the transaction and save the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Studio
if err := qb.Create(ctx, newParentStudio); err != nil {
return err
}
if len(studioImage) > 0 {
if err := qb.UpdateImage(ctx, newParentStudio.ID, studioImage); err != nil {
return err
}
}
storedId := strconv.Itoa(newParentStudio.ID)
parent.StoredID = &storedId
return nil
})
if err != nil {
logger.Errorf("Failed to create studio %s: %s", parent.Name, err.Error())
return err
}
logger.Infof("Created studio %s", parent.Name)
} else {
// The parent studio matched an existing one and the user has chosen in the UI to link and/or update it
existingStashIDs := getStashIDsForStudio(ctx, *parent.StoredID)
studioPartial := parent.ToPartial(parent.StoredID, t.box.Endpoint, excluded, existingStashIDs)
studioImage, err := parent.GetImage(ctx, excluded)
if err != nil {
logger.Errorf("Failed to make parent studio partial from scraped studio %s: %s", parent.Name, err.Error())
return err
}
// Start the transaction and update the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Studio
if err := studio.ValidateModify(ctx, *studioPartial, instance.Repository.Studio); err != nil {
return err
}
if _, err := qb.UpdatePartial(ctx, *studioPartial); err != nil {
return err
}
if len(studioImage) > 0 {
if err := qb.UpdateImage(ctx, studioPartial.ID, studioImage); err != nil {
return err
}
}
return nil
})
if err != nil {
logger.Errorf("Failed to update studio %s: %s", parent.Name, err.Error())
return err
}
logger.Infof("Updated studio %s", parent.Name)
}
return nil
}
func getStashIDsForStudio(ctx context.Context, studioID string) []models.StashID {
id, _ := strconv.Atoi(studioID)
tempStudio := &models.Studio{ID: id}
err := tempStudio.LoadStashIDs(ctx, instance.Repository.Studio)
if err != nil {
return nil
}
return tempStudio.StashIDs.List()
}
func (t *StashBoxBatchTagTask) getPartial(performer *models.ScrapedPerformer, excluded map[string]bool) models.PerformerPartial {
partial := models.NewPerformerPartial() partial := models.NewPerformerPartial()
if performer.Aliases != nil && !excluded["aliases"] { if performer.Aliases != nil && !excluded["aliases"] {
@ -243,7 +510,7 @@ func (t *StashBoxPerformerTagTask) getPartial(performer *models.ScrapedPerformer
if performer.Measurements != nil && !excluded["measurements"] { if performer.Measurements != nil && !excluded["measurements"] {
partial.Measurements = models.NewOptionalString(*performer.Measurements) partial.Measurements = models.NewOptionalString(*performer.Measurements)
} }
if excluded["name"] && performer.Name != nil { if performer.Name != nil && !excluded["name"] {
partial.Name = models.NewOptionalString(*performer.Name) partial.Name = models.NewOptionalString(*performer.Name)
} }
if performer.Disambiguation != nil && !excluded["disambiguation"] { if performer.Disambiguation != nil && !excluded["disambiguation"] {

View file

@ -119,7 +119,9 @@ func (i *Importer) populateStudio(ctx context.Context) error {
} }
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name) newStudio := &models.Studio{
Name: name,
}
err := i.StudioWriter.Create(ctx, newStudio) err := i.StudioWriter.Create(ctx, newStudio)
if err != nil { if err != nil {

View file

@ -152,7 +152,9 @@ func (i *Importer) populateStudio(ctx context.Context) error {
} }
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name) newStudio := &models.Studio{
Name: name,
}
err := i.StudioWriter.Create(ctx, newStudio) err := i.StudioWriter.Create(ctx, newStudio)
if err != nil { if err != nil {

View file

@ -58,13 +58,13 @@ func (_m *StudioReaderWriter) Count(ctx context.Context) (int, error) {
return r0, r1 return r0, r1
} }
// Create provides a mock function with given fields: ctx, newStudio // Create provides a mock function with given fields: ctx, input
func (_m *StudioReaderWriter) Create(ctx context.Context, newStudio *models.Studio) error { func (_m *StudioReaderWriter) Create(ctx context.Context, input *models.Studio) error {
ret := _m.Called(ctx, newStudio) ret := _m.Called(ctx, input)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *models.Studio) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *models.Studio) error); ok {
r0 = rf(ctx, newStudio) r0 = rf(ctx, input)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -155,6 +155,29 @@ func (_m *StudioReaderWriter) FindByStashID(ctx context.Context, stashID models.
return r0, r1 return r0, r1
} }
// FindByStashIDStatus provides a mock function with given fields: ctx, hasStashID, stashboxEndpoint
func (_m *StudioReaderWriter) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Studio, error) {
ret := _m.Called(ctx, hasStashID, stashboxEndpoint)
var r0 []*models.Studio
if rf, ok := ret.Get(0).(func(context.Context, bool, string) []*models.Studio); ok {
r0 = rf(ctx, hasStashID, stashboxEndpoint)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, bool, string) error); ok {
r1 = rf(ctx, hasStashID, stashboxEndpoint)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindChildren provides a mock function with given fields: ctx, id // FindChildren provides a mock function with given fields: ctx, id
func (_m *StudioReaderWriter) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) { func (_m *StudioReaderWriter) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) {
ret := _m.Called(ctx, id) ret := _m.Called(ctx, id)
@ -201,13 +224,13 @@ func (_m *StudioReaderWriter) FindMany(ctx context.Context, ids []int) ([]*model
return r0, r1 return r0, r1
} }
// GetAliases provides a mock function with given fields: ctx, studioID // GetAliases provides a mock function with given fields: ctx, relatedID
func (_m *StudioReaderWriter) GetAliases(ctx context.Context, studioID int) ([]string, error) { func (_m *StudioReaderWriter) GetAliases(ctx context.Context, relatedID int) ([]string, error) {
ret := _m.Called(ctx, studioID) ret := _m.Called(ctx, relatedID)
var r0 []string var r0 []string
if rf, ok := ret.Get(0).(func(context.Context, int) []string); ok { if rf, ok := ret.Get(0).(func(context.Context, int) []string); ok {
r0 = rf(ctx, studioID) r0 = rf(ctx, relatedID)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).([]string) r0 = ret.Get(0).([]string)
@ -216,7 +239,7 @@ func (_m *StudioReaderWriter) GetAliases(ctx context.Context, studioID int) ([]s
var r1 error var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, studioID) r1 = rf(ctx, relatedID)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -358,20 +381,6 @@ func (_m *StudioReaderWriter) Update(ctx context.Context, updatedStudio *models.
return r0 return r0
} }
// UpdateAliases provides a mock function with given fields: ctx, studioID, aliases
func (_m *StudioReaderWriter) UpdateAliases(ctx context.Context, studioID int, aliases []string) error {
ret := _m.Called(ctx, studioID, aliases)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int, []string) error); ok {
r0 = rf(ctx, studioID, aliases)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateImage provides a mock function with given fields: ctx, studioID, image // UpdateImage provides a mock function with given fields: ctx, studioID, image
func (_m *StudioReaderWriter) UpdateImage(ctx context.Context, studioID int, image []byte) error { func (_m *StudioReaderWriter) UpdateImage(ctx context.Context, studioID int, image []byte) error {
ret := _m.Called(ctx, studioID, image) ret := _m.Called(ctx, studioID, image)
@ -386,13 +395,13 @@ func (_m *StudioReaderWriter) UpdateImage(ctx context.Context, studioID int, ima
return r0 return r0
} }
// UpdatePartial provides a mock function with given fields: ctx, id, updatedStudio // UpdatePartial provides a mock function with given fields: ctx, input
func (_m *StudioReaderWriter) UpdatePartial(ctx context.Context, id int, updatedStudio models.StudioPartial) (*models.Studio, error) { func (_m *StudioReaderWriter) UpdatePartial(ctx context.Context, input models.StudioPartial) (*models.Studio, error) {
ret := _m.Called(ctx, id, updatedStudio) ret := _m.Called(ctx, input)
var r0 *models.Studio var r0 *models.Studio
if rf, ok := ret.Get(0).(func(context.Context, int, models.StudioPartial) *models.Studio); ok { if rf, ok := ret.Get(0).(func(context.Context, models.StudioPartial) *models.Studio); ok {
r0 = rf(ctx, id, updatedStudio) r0 = rf(ctx, input)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Studio) r0 = ret.Get(0).(*models.Studio)
@ -400,25 +409,11 @@ func (_m *StudioReaderWriter) UpdatePartial(ctx context.Context, id int, updated
} }
var r1 error var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int, models.StudioPartial) error); ok { if rf, ok := ret.Get(1).(func(context.Context, models.StudioPartial) error); ok {
r1 = rf(ctx, id, updatedStudio) r1 = rf(ctx, input)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
return r0, r1 return r0, r1
} }
// UpdateStashIDs provides a mock function with given fields: ctx, studioID, stashIDs
func (_m *StudioReaderWriter) UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error {
ret := _m.Called(ctx, studioID, stashIDs)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int, []models.StashID) error); ok {
r0 = rf(ctx, studioID, stashIDs)
} else {
r0 = ret.Error(0)
}
return r0
}

View file

@ -1,16 +1,108 @@
package models package models
import (
"context"
"strconv"
"time"
"github.com/stashapp/stash/pkg/utils"
)
type ScrapedStudio struct { type ScrapedStudio struct {
// Set if studio matched // Set if studio matched
StoredID *string `json:"stored_id"` StoredID *string `json:"stored_id"`
Name string `json:"name"` Name string `json:"name"`
URL *string `json:"url"` URL *string `json:"url"`
Image *string `json:"image"` Parent *ScrapedStudio `json:"parent"`
RemoteSiteID *string `json:"remote_site_id"` Image *string `json:"image"`
Images []string `json:"images"`
RemoteSiteID *string `json:"remote_site_id"`
} }
func (ScrapedStudio) IsScrapedContent() {} func (ScrapedStudio) IsScrapedContent() {}
func (s *ScrapedStudio) ToStudio(endpoint string, excluded map[string]bool) *Studio {
now := time.Now()
// Populate a new studio from the input
newStudio := Studio{
Name: s.Name,
StashIDs: NewRelatedStashIDs([]StashID{
{
Endpoint: endpoint,
StashID: *s.RemoteSiteID,
},
}),
CreatedAt: now,
UpdatedAt: now,
}
if s.URL != nil && !excluded["url"] {
newStudio.URL = *s.URL
}
if s.Parent != nil && s.Parent.StoredID != nil && !excluded["parent"] {
parentId, _ := strconv.Atoi(*s.Parent.StoredID)
newStudio.ParentID = &parentId
}
return &newStudio
}
func (s *ScrapedStudio) GetImage(ctx context.Context, excluded map[string]bool) ([]byte, error) {
// Process the base 64 encoded image string
if len(s.Images) > 0 && !excluded["image"] {
var err error
img, err := utils.ProcessImageInput(ctx, *s.Image)
if err != nil {
return nil, err
}
return img, nil
}
return nil, nil
}
func (s *ScrapedStudio) ToPartial(id *string, endpoint string, excluded map[string]bool, existingStashIDs []StashID) *StudioPartial {
partial := StudioPartial{
UpdatedAt: NewOptionalTime(time.Now()),
}
partial.ID, _ = strconv.Atoi(*id)
if s.Name != "" && !excluded["name"] {
partial.Name = NewOptionalString(s.Name)
}
if s.URL != nil && !excluded["url"] {
partial.URL = NewOptionalString(*s.URL)
}
if s.Parent != nil && !excluded["parent"] {
if s.Parent.StoredID != nil {
parentID, _ := strconv.Atoi(*s.Parent.StoredID)
if parentID > 0 {
// This is to be set directly as we know it has a value and the translator won't have the field
partial.ParentID = NewOptionalInt(parentID)
}
}
} else {
partial.ParentID = NewOptionalIntPtr(nil)
}
partial.StashIDs = &UpdateStashIDs{
StashIDs: existingStashIDs,
Mode: RelationshipUpdateModeSet,
}
partial.StashIDs.Set(StashID{
Endpoint: endpoint,
StashID: *s.RemoteSiteID,
})
return &partial
}
// A performer from a scraping operation... // A performer from a scraping operation...
type ScrapedPerformer struct { type ScrapedPerformer struct {
// Set if performer matched // Set if performer matched

View file

@ -0,0 +1,65 @@
package models
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func Test_scrapedToStudioInput(t *testing.T) {
const name = "name"
url := "url"
remoteSiteID := "remoteSiteID"
tests := []struct {
name string
studio *ScrapedStudio
want *Studio
}{
{
"set all",
&ScrapedStudio{
Name: name,
URL: &url,
RemoteSiteID: &remoteSiteID,
},
&Studio{
Name: name,
URL: url,
StashIDs: NewRelatedStashIDs([]StashID{
{
StashID: remoteSiteID,
},
}),
},
},
{
"set none",
&ScrapedStudio{
Name: name,
RemoteSiteID: &remoteSiteID,
},
&Studio{
Name: name,
StashIDs: NewRelatedStashIDs([]StashID{
{
StashID: remoteSiteID,
},
}),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.studio.ToStudio("", nil)
assert.NotEqual(t, time.Time{}, got.CreatedAt)
assert.NotEqual(t, time.Time{}, got.UpdatedAt)
got.CreatedAt = time.Time{}
got.UpdatedAt = time.Time{}
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -1,6 +1,7 @@
package models package models
import ( import (
"context"
"time" "time"
) )
@ -15,34 +16,50 @@ type Studio struct {
Rating *int `json:"rating"` Rating *int `json:"rating"`
Details string `json:"details"` Details string `json:"details"`
IgnoreAutoTag bool `json:"ignore_auto_tag"` IgnoreAutoTag bool `json:"ignore_auto_tag"`
Aliases RelatedStrings `json:"aliases"`
StashIDs RelatedStashIDs `json:"stash_ids"`
} }
func (s *Studio) LoadAliases(ctx context.Context, l AliasLoader) error {
return s.Aliases.load(func() ([]string, error) {
return l.GetAliases(ctx, s.ID)
})
}
func (s *Studio) LoadStashIDs(ctx context.Context, l StashIDLoader) error {
return s.StashIDs.load(func() ([]StashID, error) {
return l.GetStashIDs(ctx, s.ID)
})
}
func (s *Studio) LoadRelationships(ctx context.Context, l PerformerReader) error {
if err := s.LoadAliases(ctx, l); err != nil {
return err
}
if err := s.LoadStashIDs(ctx, l); err != nil {
return err
}
return nil
}
// StudioPartial represents part of a Studio object. It is used to update the database entry.
type StudioPartial struct { type StudioPartial struct {
Name OptionalString ID int
URL OptionalString Name OptionalString
ParentID OptionalInt URL OptionalString
CreatedAt OptionalTime ParentID OptionalInt
UpdatedAt OptionalTime
// Rating expressed in 1-100 scale // Rating expressed in 1-100 scale
Rating OptionalInt Rating OptionalInt
Details OptionalString Details OptionalString
CreatedAt OptionalTime
UpdatedAt OptionalTime
IgnoreAutoTag OptionalBool IgnoreAutoTag OptionalBool
}
func NewStudio(name string) *Studio { Aliases *UpdateStrings
currentTime := time.Now() StashIDs *UpdateStashIDs
return &Studio{
Name: name,
CreatedAt: currentTime,
UpdatedAt: currentTime,
}
}
func NewStudioPartial() StudioPartial {
updatedTime := time.Now()
return StudioPartial{
UpdatedAt: NewOptionalTime(updatedTime),
}
} }
type Studios []*Studio type Studios []*Studio

View file

@ -48,6 +48,7 @@ type StudioReader interface {
FindChildren(ctx context.Context, id int) ([]*Studio, error) FindChildren(ctx context.Context, id int) ([]*Studio, error)
FindByName(ctx context.Context, name string, nocase bool) (*Studio, error) FindByName(ctx context.Context, name string, nocase bool) (*Studio, error)
FindByStashID(ctx context.Context, stashID StashID) ([]*Studio, error) FindByStashID(ctx context.Context, stashID StashID) ([]*Studio, error)
FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*Studio, error)
Count(ctx context.Context) (int, error) Count(ctx context.Context) (int, error)
All(ctx context.Context) ([]*Studio, error) All(ctx context.Context) ([]*Studio, error)
// TODO - this interface is temporary until the filter schema can fully // TODO - this interface is temporary until the filter schema can fully
@ -56,18 +57,16 @@ type StudioReader interface {
Query(ctx context.Context, studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error) Query(ctx context.Context, studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error)
GetImage(ctx context.Context, studioID int) ([]byte, error) GetImage(ctx context.Context, studioID int) ([]byte, error)
HasImage(ctx context.Context, studioID int) (bool, error) HasImage(ctx context.Context, studioID int) (bool, error)
AliasLoader
StashIDLoader StashIDLoader
GetAliases(ctx context.Context, studioID int) ([]string, error)
} }
type StudioWriter interface { type StudioWriter interface {
Create(ctx context.Context, newStudio *Studio) error Create(ctx context.Context, newStudio *Studio) error
UpdatePartial(ctx context.Context, id int, updatedStudio StudioPartial) (*Studio, error) UpdatePartial(ctx context.Context, input StudioPartial) (*Studio, error)
Update(ctx context.Context, updatedStudio *Studio) error Update(ctx context.Context, updatedStudio *Studio) error
Destroy(ctx context.Context, id int) error Destroy(ctx context.Context, id int) error
UpdateImage(ctx context.Context, studioID int, image []byte) error UpdateImage(ctx context.Context, studioID int, image []byte) error
UpdateStashIDs(ctx context.Context, studioID int, stashIDs []StashID) error
UpdateAliases(ctx context.Context, studioID int, aliases []string) error
} }
type StudioReaderWriter interface { type StudioReaderWriter interface {

View file

@ -5,6 +5,7 @@ import (
"io" "io"
"strconv" "strconv"
"github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/sliceutil/intslice" "github.com/stashapp/stash/pkg/sliceutil/intslice"
) )
@ -94,16 +95,7 @@ func (u *UpdateIDs) EffectiveIDs(existing []int) []int {
return nil return nil
} }
switch u.Mode { return effectiveValues(u.IDs, u.Mode, existing)
case RelationshipUpdateModeAdd:
return intslice.IntAppendUniques(existing, u.IDs)
case RelationshipUpdateModeRemove:
return intslice.IntExclude(existing, u.IDs)
case RelationshipUpdateModeSet:
return u.IDs
}
return nil
} }
type UpdateStrings struct { type UpdateStrings struct {
@ -118,3 +110,26 @@ func (u *UpdateStrings) Strings() []string {
return u.Values return u.Values
} }
// GetEffectiveIDs returns the new IDs that will be effective after the update.
func (u *UpdateStrings) EffectiveValues(existing []string) []string {
if u == nil {
return nil
}
return effectiveValues(u.Values, u.Mode, existing)
}
// effectiveValues returns the new values that will be effective after the update.
func effectiveValues[T comparable](values []T, mode RelationshipUpdateMode, existing []T) []T {
switch mode {
case RelationshipUpdateModeAdd:
return sliceutil.AppendUniques(existing, values)
case RelationshipUpdateModeRemove:
return sliceutil.Exclude(existing, values)
case RelationshipUpdateModeSet:
return values
}
return nil
}

View file

@ -116,7 +116,9 @@ func (i *Importer) populateStudio(ctx context.Context) error {
} }
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name) newStudio := &models.Studio{
Name: name,
}
err := i.StudioWriter.Create(ctx, newStudio) err := i.StudioWriter.Create(ctx, newStudio)
if err != nil { if err != nil {

View file

@ -176,7 +176,9 @@ func (i *Importer) populateStudio(ctx context.Context) error {
} }
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name) newStudio := &models.Studio{
Name: name,
}
err := i.StudioWriter.Create(ctx, newStudio) err := i.StudioWriter.Create(ctx, newStudio)
if err != nil { if err != nil {

View file

@ -17,6 +17,7 @@ type StashBoxGraphQLClient interface {
SearchPerformer(ctx context.Context, term string, httpRequestOptions ...client.HTTPRequestOption) (*SearchPerformer, error) SearchPerformer(ctx context.Context, term string, httpRequestOptions ...client.HTTPRequestOption) (*SearchPerformer, error)
FindPerformerByID(ctx context.Context, id string, httpRequestOptions ...client.HTTPRequestOption) (*FindPerformerByID, error) FindPerformerByID(ctx context.Context, id string, httpRequestOptions ...client.HTTPRequestOption) (*FindPerformerByID, error)
FindSceneByID(ctx context.Context, id string, httpRequestOptions ...client.HTTPRequestOption) (*FindSceneByID, error) FindSceneByID(ctx context.Context, id string, httpRequestOptions ...client.HTTPRequestOption) (*FindSceneByID, error)
FindStudio(ctx context.Context, id *string, name *string, httpRequestOptions ...client.HTTPRequestOption) (*FindStudio, error)
SubmitFingerprint(ctx context.Context, input FingerprintSubmission, httpRequestOptions ...client.HTTPRequestOption) (*SubmitFingerprint, error) SubmitFingerprint(ctx context.Context, input FingerprintSubmission, httpRequestOptions ...client.HTTPRequestOption) (*SubmitFingerprint, error)
Me(ctx context.Context, httpRequestOptions ...client.HTTPRequestOption) (*Me, error) Me(ctx context.Context, httpRequestOptions ...client.HTTPRequestOption) (*Me, error)
SubmitSceneDraft(ctx context.Context, input SceneDraftInput, httpRequestOptions ...client.HTTPRequestOption) (*SubmitSceneDraft, error) SubmitSceneDraft(ctx context.Context, input SceneDraftInput, httpRequestOptions ...client.HTTPRequestOption) (*SubmitSceneDraft, error)
@ -125,9 +126,13 @@ type ImageFragment struct {
Height int "json:\"height\" graphql:\"height\"" Height int "json:\"height\" graphql:\"height\""
} }
type StudioFragment struct { type StudioFragment struct {
Name string "json:\"name\" graphql:\"name\"" Name string "json:\"name\" graphql:\"name\""
ID string "json:\"id\" graphql:\"id\"" ID string "json:\"id\" graphql:\"id\""
Urls []*URLFragment "json:\"urls\" graphql:\"urls\"" Urls []*URLFragment "json:\"urls\" graphql:\"urls\""
Parent *struct {
Name string "json:\"name\" graphql:\"name\""
ID string "json:\"id\" graphql:\"id\""
} "json:\"parent\" graphql:\"parent\""
Images []*ImageFragment "json:\"images\" graphql:\"images\"" Images []*ImageFragment "json:\"images\" graphql:\"images\""
} }
type TagFragment struct { type TagFragment struct {
@ -215,6 +220,9 @@ type FindPerformerByID struct {
type FindSceneByID struct { type FindSceneByID struct {
FindScene *SceneFragment "json:\"findScene\" graphql:\"findScene\"" FindScene *SceneFragment "json:\"findScene\" graphql:\"findScene\""
} }
type FindStudio struct {
FindStudio *StudioFragment "json:\"findStudio\" graphql:\"findStudio\""
}
type SubmitFingerprint struct { type SubmitFingerprint struct {
SubmitFingerprint bool "json:\"submitFingerprint\" graphql:\"submitFingerprint\"" SubmitFingerprint bool "json:\"submitFingerprint\" graphql:\"submitFingerprint\""
} }
@ -239,12 +247,77 @@ const FindSceneByFingerprintDocument = `query FindSceneByFingerprint ($fingerpri
... SceneFragment ... SceneFragment
} }
} }
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment SceneFragment on Scene {
id
title
code
details
director
duration
date
urls {
... URLFragment
}
images {
... ImageFragment
}
studio {
... StudioFragment
}
tags {
... TagFragment
}
performers {
... PerformerAppearanceFragment
}
fingerprints {
... FingerprintFragment
}
}
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image { fragment ImageFragment on Image {
id id
url url
width width
height height
} }
fragment TagFragment on Tag {
name
id
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
}
}
fragment PerformerFragment on Performer { fragment PerformerFragment on Performer {
id id
name name
@ -279,76 +352,15 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment ... BodyModificationFragment
} }
} }
fragment SceneFragment on Scene {
id
title
code
details
director
duration
date
urls {
... URLFragment
}
images {
... ImageFragment
}
studio {
... StudioFragment
}
tags {
... TagFragment
}
performers {
... PerformerAppearanceFragment
}
fingerprints {
... FingerprintFragment
}
}
fragment URLFragment on URL {
url
type
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
}
}
fragment FuzzyDateFragment on FuzzyDate { fragment FuzzyDateFragment on FuzzyDate {
date date
accuracy accuracy
} }
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint { fragment FingerprintFragment on Fingerprint {
algorithm algorithm
hash hash
duration duration
} }
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
images {
... ImageFragment
}
}
fragment TagFragment on Tag {
name
id
}
` `
func (c *Client) FindSceneByFingerprint(ctx context.Context, fingerprint FingerprintQueryInput, httpRequestOptions ...client.HTTPRequestOption) (*FindSceneByFingerprint, error) { func (c *Client) FindSceneByFingerprint(ctx context.Context, fingerprint FingerprintQueryInput, httpRequestOptions ...client.HTTPRequestOption) (*FindSceneByFingerprint, error) {
@ -369,6 +381,49 @@ const FindScenesByFullFingerprintsDocument = `query FindScenesByFullFingerprints
... SceneFragment ... SceneFragment
} }
} }
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
}
fragment PerformerFragment on Performer { fragment PerformerFragment on Performer {
id id
name name
@ -403,16 +458,6 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment ... BodyModificationFragment
} }
} }
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment SceneFragment on Scene { fragment SceneFragment on Scene {
id id
title title
@ -440,35 +485,6 @@ fragment SceneFragment on Scene {
... FingerprintFragment ... FingerprintFragment
} }
} }
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
images {
... ImageFragment
}
}
fragment TagFragment on Tag { fragment TagFragment on Tag {
name name
id id
@ -499,28 +515,56 @@ const FindScenesBySceneFingerprintsDocument = `query FindScenesBySceneFingerprin
... SceneFragment ... SceneFragment
} }
} }
fragment StudioFragment on Studio { fragment URLFragment on URL {
url
type
}
fragment TagFragment on Tag {
name name
id id
urls { }
... URLFragment fragment PerformerAppearanceFragment on PerformerAppearance {
} as
images { performer {
... ImageFragment ... PerformerFragment
} }
} }
fragment FuzzyDateFragment on FuzzyDate { fragment FuzzyDateFragment on FuzzyDate {
date date
accuracy accuracy
} }
fragment FingerprintFragment on Fingerprint { fragment MeasurementsFragment on Measurements {
algorithm band_size
hash cup_size
duration waist
hip
} }
fragment URLFragment on URL { fragment SceneFragment on Scene {
url id
type title
code
details
director
duration
date
urls {
... URLFragment
}
images {
... ImageFragment
}
studio {
... StudioFragment
}
tags {
... TagFragment
}
performers {
... PerformerAppearanceFragment
}
fingerprints {
... FingerprintFragment
}
} }
fragment ImageFragment on Image { fragment ImageFragment on Image {
id id
@ -528,10 +572,18 @@ fragment ImageFragment on Image {
width width
height height
} }
fragment PerformerAppearanceFragment on PerformerAppearance { fragment StudioFragment on Studio {
as name
performer { id
... PerformerFragment urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
} }
} }
fragment PerformerFragment on Performer { fragment PerformerFragment on Performer {
@ -568,46 +620,14 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment ... BodyModificationFragment
} }
} }
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment BodyModificationFragment on BodyModification { fragment BodyModificationFragment on BodyModification {
location location
description description
} }
fragment SceneFragment on Scene { fragment FingerprintFragment on Fingerprint {
id algorithm
title hash
code
details
director
duration duration
date
urls {
... URLFragment
}
images {
... ImageFragment
}
studio {
... StudioFragment
}
tags {
... TagFragment
}
performers {
... PerformerAppearanceFragment
}
fingerprints {
... FingerprintFragment
}
}
fragment TagFragment on Tag {
name
id
} }
` `
@ -629,6 +649,29 @@ const SearchSceneDocument = `query SearchScene ($term: String!) {
... SceneFragment ... SceneFragment
} }
} }
fragment ImageFragment on Image {
id
url
width
height
}
fragment TagFragment on Tag {
name
id
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment SceneFragment on Scene { fragment SceneFragment on Scene {
id id
title title
@ -660,32 +703,16 @@ fragment URLFragment on URL {
url url
type type
} }
fragment TagFragment on Tag {
name
id
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment StudioFragment on Studio { fragment StudioFragment on Studio {
name name
id id
urls { urls {
... URLFragment ... URLFragment
} }
parent {
name
id
}
images { images {
... ImageFragment ... ImageFragment
} }
@ -730,14 +757,11 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment ... BodyModificationFragment
} }
} }
fragment BodyModificationFragment on BodyModification { fragment MeasurementsFragment on Measurements {
location band_size
description cup_size
} waist
fragment FingerprintFragment on Fingerprint { hip
algorithm
hash
duration
} }
` `
@ -759,16 +783,6 @@ const SearchPerformerDocument = `query SearchPerformer ($term: String!) {
... PerformerFragment ... PerformerFragment
} }
} }
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment BodyModificationFragment on BodyModification { fragment BodyModificationFragment on BodyModification {
location location
description description
@ -817,6 +831,16 @@ fragment ImageFragment on Image {
width width
height height
} }
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
` `
func (c *Client) SearchPerformer(ctx context.Context, term string, httpRequestOptions ...client.HTTPRequestOption) (*SearchPerformer, error) { func (c *Client) SearchPerformer(ctx context.Context, term string, httpRequestOptions ...client.HTTPRequestOption) (*SearchPerformer, error) {
@ -915,26 +939,25 @@ const FindSceneByIDDocument = `query FindSceneByID ($id: ID!) {
... SceneFragment ... SceneFragment
} }
} }
fragment FingerprintFragment on Fingerprint { fragment ImageFragment on Image {
algorithm id
hash
duration
}
fragment URLFragment on URL {
url url
type width
height
} }
fragment PerformerAppearanceFragment on PerformerAppearance { fragment StudioFragment on Studio {
as name
performer { id
... PerformerFragment urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
} }
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
} }
fragment TagFragment on Tag { fragment TagFragment on Tag {
name name
@ -974,13 +997,11 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment ... BodyModificationFragment
} }
} }
fragment FuzzyDateFragment on FuzzyDate { fragment MeasurementsFragment on Measurements {
date band_size
accuracy cup_size
} waist
fragment BodyModificationFragment on BodyModification { hip
location
description
} }
fragment SceneFragment on Scene { fragment SceneFragment on Scene {
id id
@ -1009,22 +1030,29 @@ fragment SceneFragment on Scene {
... FingerprintFragment ... FingerprintFragment
} }
} }
fragment ImageFragment on Image { fragment URLFragment on URL {
id
url url
width type
height
} }
fragment StudioFragment on Studio { fragment BodyModificationFragment on BodyModification {
name location
id description
urls { }
... URLFragment fragment FingerprintFragment on Fingerprint {
} algorithm
images { hash
... ImageFragment duration
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
} }
} }
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
` `
func (c *Client) FindSceneByID(ctx context.Context, id string, httpRequestOptions ...client.HTTPRequestOption) (*FindSceneByID, error) { func (c *Client) FindSceneByID(ctx context.Context, id string, httpRequestOptions ...client.HTTPRequestOption) (*FindSceneByID, error) {
@ -1040,6 +1068,51 @@ func (c *Client) FindSceneByID(ctx context.Context, id string, httpRequestOption
return &res, nil return &res, nil
} }
const FindStudioDocument = `query FindStudio ($id: ID, $name: String) {
findStudio(id: $id, name: $name) {
... StudioFragment
}
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
}
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image {
id
url
width
height
}
`
func (c *Client) FindStudio(ctx context.Context, id *string, name *string, httpRequestOptions ...client.HTTPRequestOption) (*FindStudio, error) {
vars := map[string]interface{}{
"id": id,
"name": name,
}
var res FindStudio
if err := c.Client.Post(ctx, "FindStudio", FindStudioDocument, &res, vars, httpRequestOptions...); err != nil {
return nil, err
}
return &res, nil
}
const SubmitFingerprintDocument = `mutation SubmitFingerprint ($input: FingerprintSubmission!) { const SubmitFingerprintDocument = `mutation SubmitFingerprint ($input: FingerprintSubmission!) {
submitFingerprint(input: $input) submitFingerprint(input: $input)
} }

View file

@ -88,9 +88,9 @@ type DraftEntity struct {
ID *string `json:"id,omitempty"` ID *string `json:"id,omitempty"`
} }
func (DraftEntity) IsSceneDraftPerformer() {}
func (DraftEntity) IsSceneDraftStudio() {}
func (DraftEntity) IsSceneDraftTag() {} func (DraftEntity) IsSceneDraftTag() {}
func (DraftEntity) IsSceneDraftStudio() {}
func (DraftEntity) IsSceneDraftPerformer() {}
type DraftEntityInput struct { type DraftEntityInput struct {
Name string `json:"name"` Name string `json:"name"`
@ -116,6 +116,7 @@ type Edit struct {
// Objects to merge with the target. Only applicable to merges // Objects to merge with the target. Only applicable to merges
MergeSources []EditTarget `json:"merge_sources,omitempty"` MergeSources []EditTarget `json:"merge_sources,omitempty"`
Operation OperationEnum `json:"operation"` Operation OperationEnum `json:"operation"`
Bot bool `json:"bot"`
Details EditDetails `json:"details,omitempty"` Details EditDetails `json:"details,omitempty"`
// Previous state of fields being modified - null if operation is create or delete. // Previous state of fields being modified - null if operation is create or delete.
OldDetails EditDetails `json:"old_details,omitempty"` OldDetails EditDetails `json:"old_details,omitempty"`
@ -154,6 +155,8 @@ type EditInput struct {
// Only required for merge type // Only required for merge type
MergeSourceIds []string `json:"merge_source_ids,omitempty"` MergeSourceIds []string `json:"merge_source_ids,omitempty"`
Comment *string `json:"comment,omitempty"` Comment *string `json:"comment,omitempty"`
// Edit submitted by an automated script. Requires bot permission
Bot *bool `json:"bot,omitempty"`
} }
type EditQueryInput struct { type EditQueryInput struct {
@ -172,11 +175,15 @@ type EditQueryInput struct {
// Filter by target id // Filter by target id
TargetID *string `json:"target_id,omitempty"` TargetID *string `json:"target_id,omitempty"`
// Filter by favorite status // Filter by favorite status
IsFavorite *bool `json:"is_favorite,omitempty"` IsFavorite *bool `json:"is_favorite,omitempty"`
Page int `json:"page"` // Filter by user voted status
PerPage int `json:"per_page"` Voted *UserVotedFilterEnum `json:"voted,omitempty"`
Direction SortDirectionEnum `json:"direction"` // Filter to bot edits only
Sort EditSortEnum `json:"sort"` IsBot *bool `json:"is_bot,omitempty"`
Page int `json:"page"`
PerPage int `json:"per_page"`
Direction SortDirectionEnum `json:"direction"`
Sort EditSortEnum `json:"sort"`
} }
type EditVote struct { type EditVote struct {
@ -542,11 +549,24 @@ type PerformerQueryInput struct {
Tattoos *BodyModificationCriterionInput `json:"tattoos,omitempty"` Tattoos *BodyModificationCriterionInput `json:"tattoos,omitempty"`
Piercings *BodyModificationCriterionInput `json:"piercings,omitempty"` Piercings *BodyModificationCriterionInput `json:"piercings,omitempty"`
// Filter by performerfavorite status for the current user // Filter by performerfavorite status for the current user
IsFavorite *bool `json:"is_favorite,omitempty"` IsFavorite *bool `json:"is_favorite,omitempty"`
Page int `json:"page"` // Filter by a performer they have performed in scenes with
PerPage int `json:"per_page"` PerformedWith *string `json:"performed_with,omitempty"`
Direction SortDirectionEnum `json:"direction"` // Filter by a studio
Sort PerformerSortEnum `json:"sort"` StudioID *string `json:"studio_id,omitempty"`
Page int `json:"page"`
PerPage int `json:"per_page"`
Direction SortDirectionEnum `json:"direction"`
Sort PerformerSortEnum `json:"sort"`
}
type PerformerScenesInput struct {
// Filter by another performer that also performs in the scenes
PerformedWith *string `json:"performed_with,omitempty"`
// Filter by a studio
StudioID *string `json:"studio_id,omitempty"`
// Filter by tags
Tags *MultiIDCriterionInput `json:"tags,omitempty"`
} }
type PerformerStudio struct { type PerformerStudio struct {
@ -689,7 +709,9 @@ type SceneDestroyInput struct {
type SceneDraft struct { type SceneDraft struct {
ID *string `json:"id,omitempty"` ID *string `json:"id,omitempty"`
Title *string `json:"title,omitempty"` Title *string `json:"title,omitempty"`
Code *string `json:"code,omitempty"`
Details *string `json:"details,omitempty"` Details *string `json:"details,omitempty"`
Director *string `json:"director,omitempty"`
URL *URL `json:"url,omitempty"` URL *URL `json:"url,omitempty"`
Date *string `json:"date,omitempty"` Date *string `json:"date,omitempty"`
Studio SceneDraftStudio `json:"studio,omitempty"` Studio SceneDraftStudio `json:"studio,omitempty"`
@ -774,11 +796,13 @@ type SceneQueryInput struct {
// Filter to only include scenes with these fingerprints // Filter to only include scenes with these fingerprints
Fingerprints *MultiStringCriterionInput `json:"fingerprints,omitempty"` Fingerprints *MultiStringCriterionInput `json:"fingerprints,omitempty"`
// Filter by favorited entity // Filter by favorited entity
Favorites *FavoriteFilter `json:"favorites,omitempty"` Favorites *FavoriteFilter `json:"favorites,omitempty"`
Page int `json:"page"` // Filter to scenes with fingerprints submitted by the user
PerPage int `json:"per_page"` HasFingerprintSubmissions *bool `json:"has_fingerprint_submissions,omitempty"`
Direction SortDirectionEnum `json:"direction"` Page int `json:"page"`
Sort SceneSortEnum `json:"sort"` PerPage int `json:"per_page"`
Direction SortDirectionEnum `json:"direction"`
Sort SceneSortEnum `json:"sort"`
} }
type SceneUpdateInput struct { type SceneUpdateInput struct {
@ -847,16 +871,17 @@ type StringCriterionInput struct {
} }
type Studio struct { type Studio struct {
ID string `json:"id"` ID string `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Urls []*URL `json:"urls,omitempty"` Urls []*URL `json:"urls,omitempty"`
Parent *Studio `json:"parent,omitempty"` Parent *Studio `json:"parent,omitempty"`
ChildStudios []*Studio `json:"child_studios,omitempty"` ChildStudios []*Studio `json:"child_studios,omitempty"`
Images []*Image `json:"images,omitempty"` Images []*Image `json:"images,omitempty"`
Deleted bool `json:"deleted"` Deleted bool `json:"deleted"`
IsFavorite bool `json:"is_favorite"` IsFavorite bool `json:"is_favorite"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
Updated time.Time `json:"updated"` Updated time.Time `json:"updated"`
Performers *QueryPerformersResultType `json:"performers,omitempty"`
} }
func (Studio) IsSceneDraftStudio() {} func (Studio) IsSceneDraftStudio() {}
@ -1775,6 +1800,7 @@ const (
PerformerSortEnumOCounter PerformerSortEnum = "O_COUNTER" PerformerSortEnumOCounter PerformerSortEnum = "O_COUNTER"
PerformerSortEnumCareerStartYear PerformerSortEnum = "CAREER_START_YEAR" PerformerSortEnumCareerStartYear PerformerSortEnum = "CAREER_START_YEAR"
PerformerSortEnumDebut PerformerSortEnum = "DEBUT" PerformerSortEnumDebut PerformerSortEnum = "DEBUT"
PerformerSortEnumLastScene PerformerSortEnum = "LAST_SCENE"
PerformerSortEnumCreatedAt PerformerSortEnum = "CREATED_AT" PerformerSortEnumCreatedAt PerformerSortEnum = "CREATED_AT"
PerformerSortEnumUpdatedAt PerformerSortEnum = "UPDATED_AT" PerformerSortEnumUpdatedAt PerformerSortEnum = "UPDATED_AT"
) )
@ -1786,6 +1812,7 @@ var AllPerformerSortEnum = []PerformerSortEnum{
PerformerSortEnumOCounter, PerformerSortEnumOCounter,
PerformerSortEnumCareerStartYear, PerformerSortEnumCareerStartYear,
PerformerSortEnumDebut, PerformerSortEnumDebut,
PerformerSortEnumLastScene,
PerformerSortEnumCreatedAt, PerformerSortEnumCreatedAt,
PerformerSortEnumUpdatedAt, PerformerSortEnumUpdatedAt,
} }
@ -2136,6 +2163,51 @@ func (e TargetTypeEnum) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String())) fmt.Fprint(w, strconv.Quote(e.String()))
} }
type UserVotedFilterEnum string
const (
UserVotedFilterEnumAbstain UserVotedFilterEnum = "ABSTAIN"
UserVotedFilterEnumAccept UserVotedFilterEnum = "ACCEPT"
UserVotedFilterEnumReject UserVotedFilterEnum = "REJECT"
UserVotedFilterEnumNotVoted UserVotedFilterEnum = "NOT_VOTED"
)
var AllUserVotedFilterEnum = []UserVotedFilterEnum{
UserVotedFilterEnumAbstain,
UserVotedFilterEnumAccept,
UserVotedFilterEnumReject,
UserVotedFilterEnumNotVoted,
}
func (e UserVotedFilterEnum) IsValid() bool {
switch e {
case UserVotedFilterEnumAbstain, UserVotedFilterEnumAccept, UserVotedFilterEnumReject, UserVotedFilterEnumNotVoted:
return true
}
return false
}
func (e UserVotedFilterEnum) String() string {
return string(e)
}
func (e *UserVotedFilterEnum) UnmarshalGQL(v interface{}) error {
str, ok := v.(string)
if !ok {
return fmt.Errorf("enums must be strings")
}
*e = UserVotedFilterEnum(str)
if !e.IsValid() {
return fmt.Errorf("%s is not a valid UserVotedFilterEnum", str)
}
return nil
}
func (e UserVotedFilterEnum) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
type ValidSiteTypeEnum string type ValidSiteTypeEnum string
const ( const (

View file

@ -2,6 +2,11 @@ package stashbox
import "github.com/stashapp/stash/pkg/models" import "github.com/stashapp/stash/pkg/models"
type StashBoxStudioQueryResult struct {
Query string `json:"query"`
Results []*models.ScrapedStudio `json:"results"`
}
type StashBoxPerformerQueryResult struct { type StashBoxPerformerQueryResult struct {
Query string `json:"query"` Query string `json:"query"`
Results []*models.ScrapedPerformer `json:"results"` Results []*models.ScrapedPerformer `json:"results"`

View file

@ -18,6 +18,7 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/Yamashou/gqlgenc/graphqljson" "github.com/Yamashou/gqlgenc/graphqljson"
"github.com/gofrs/uuid"
"github.com/stashapp/stash/pkg/file" "github.com/stashapp/stash/pkg/file"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/match"
@ -660,6 +661,26 @@ func performerFragmentToScrapedScenePerformer(p graphql.PerformerFragment) *mode
return sp return sp
} }
func studioFragmentToScrapedStudio(s graphql.StudioFragment) *models.ScrapedStudio {
images := []string{}
for _, image := range s.Images {
images = append(images, image.URL)
}
st := &models.ScrapedStudio{
Name: s.Name,
URL: findURL(s.Urls, "HOME"),
Images: images,
RemoteSiteID: &s.ID,
}
if len(st.Images) > 0 {
st.Image = &st.Images[0]
}
return st
}
func getFirstImage(ctx context.Context, client *http.Client, images []*graphql.ImageFragment) *string { func getFirstImage(ctx context.Context, client *http.Client, images []*graphql.ImageFragment) *string {
ret, err := fetchImage(ctx, client, images[0].URL) ret, err := fetchImage(ctx, client, images[0].URL)
if err != nil && !errors.Is(err, context.Canceled) { if err != nil && !errors.Is(err, context.Canceled) {
@ -725,20 +746,29 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
tqb := c.repository.Tag tqb := c.repository.Tag
if s.Studio != nil { if s.Studio != nil {
studioID := s.Studio.ID ss.Studio = studioFragmentToScrapedStudio(*s.Studio)
ss.Studio = &models.ScrapedStudio{
Name: s.Studio.Name,
URL: findURL(s.Studio.Urls, "HOME"),
RemoteSiteID: &studioID,
}
if s.Studio.Images != nil && len(s.Studio.Images) > 0 {
ss.Studio.Image = &s.Studio.Images[0].URL
}
err := match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio, &c.box.Endpoint) err := match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio, &c.box.Endpoint)
if err != nil { if err != nil {
return err return err
} }
var parentStudio *graphql.FindStudio
if s.Studio.Parent != nil {
parentStudio, err = c.client.FindStudio(ctx, &s.Studio.Parent.ID, nil)
if err != nil {
return err
}
if parentStudio.FindStudio != nil {
ss.Studio.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio.Parent, &c.box.Endpoint)
if err != nil {
return err
}
}
}
} }
for _, p := range s.Performers { for _, p := range s.Performers {
@ -799,6 +829,56 @@ func (c Client) FindStashBoxPerformerByName(ctx context.Context, name string) (*
return ret, nil return ret, nil
} }
func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.ScrapedStudio, error) {
var studio *graphql.FindStudio
_, err := uuid.FromString(query)
if err == nil {
// Confirmed the user passed in a Stash ID
studio, err = c.client.FindStudio(ctx, &query, nil)
} else {
// Otherwise assume they're searching on a name
studio, err = c.client.FindStudio(ctx, nil, &query)
}
if err != nil {
return nil, err
}
var ret *models.ScrapedStudio
if studio.FindStudio != nil {
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
ret = studioFragmentToScrapedStudio(*studio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ret, &c.box.Endpoint)
if err != nil {
return err
}
if studio.FindStudio.Parent != nil {
parentStudio, err := c.client.FindStudio(ctx, &studio.FindStudio.Parent.ID, nil)
if err != nil {
return err
}
if parentStudio.FindStudio != nil {
ret.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ret.Parent, &c.box.Endpoint)
if err != nil {
return err
}
}
}
return nil
}); err != nil {
return nil, err
}
}
return ret, nil
}
func (c Client) GetUser(ctx context.Context) (*graphql.Me, error) { func (c Client) GetUser(ctx context.Context) (*graphql.Me, error) {
return c.client.Me(ctx) return c.client.Me(ctx)
} }

View file

@ -438,21 +438,6 @@ func (r *stashIDRepository) get(ctx context.Context, id int) ([]models.StashID,
return []models.StashID(ret), err return []models.StashID(ret), err
} }
func (r *stashIDRepository) replace(ctx context.Context, id int, newIDs []models.StashID) error {
if err := r.destroy(ctx, []int{id}); err != nil {
return err
}
query := fmt.Sprintf("INSERT INTO %s (%s, endpoint, stash_id) VALUES (?, ?, ?)", r.tableName, r.idColumn)
for _, stashID := range newIDs {
_, err := r.tx.Exec(ctx, query, id, stashID.Endpoint, stashID.StashID)
if err != nil {
return err
}
}
return nil
}
type filesRepository struct { type filesRepository struct {
repository repository
} }

View file

@ -631,7 +631,7 @@ func populateDB() error {
return fmt.Errorf("error creating performers: %s", err.Error()) return fmt.Errorf("error creating performers: %s", err.Error())
} }
if err := createStudios(ctx, db.Studio, studiosNameCase, studiosNameNoCase); err != nil { if err := createStudios(ctx, studiosNameCase, studiosNameNoCase); err != nil {
return fmt.Errorf("error creating studios: %s", err.Error()) return fmt.Errorf("error creating studios: %s", err.Error())
} }
@ -659,7 +659,7 @@ func populateDB() error {
return fmt.Errorf("error linking movie studios: %s", err.Error()) return fmt.Errorf("error linking movie studios: %s", err.Error())
} }
if err := linkStudiosParent(ctx, db.Studio); err != nil { if err := linkStudiosParent(ctx); err != nil {
return fmt.Errorf("error linking studios parent: %s", err.Error()) return fmt.Errorf("error linking studios parent: %s", err.Error())
} }
@ -1310,8 +1310,8 @@ func createMovies(ctx context.Context, mqb models.MovieReaderWriter, n int, o in
name = getMovieStringValue(index, name) name = getMovieStringValue(index, name)
movie := models.Movie{ movie := models.Movie{
Name: name, Name: name,
URL: getMovieNullStringValue(index, urlField), URL: getMovieNullStringValue(index, urlField),
} }
err := mqb.Create(ctx, &movie) err := mqb.Create(ctx, &movie)
@ -1573,9 +1573,9 @@ func getStudioNullStringValue(index int, field string) string {
return ret.String return ret.String
} }
func createStudio(ctx context.Context, sqb models.StudioReaderWriter, name string, parentID *int) (*models.Studio, error) { func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int) (*models.Studio, error) {
studio := models.Studio{ studio := models.Studio{
Name: name, Name: name,
} }
if parentID != nil { if parentID != nil {
@ -1590,7 +1590,7 @@ func createStudio(ctx context.Context, sqb models.StudioReaderWriter, name strin
return &studio, nil return &studio, nil
} }
func createStudioFromModel(ctx context.Context, sqb models.StudioReaderWriter, studio *models.Studio) error { func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio) error {
err := sqb.Create(ctx, studio) err := sqb.Create(ctx, studio)
if err != nil { if err != nil {
@ -1601,7 +1601,8 @@ func createStudioFromModel(ctx context.Context, sqb models.StudioReaderWriter, s
} }
// createStudios creates n studios with plain Name and o studios with camel cased NaMe included // createStudios creates n studios with plain Name and o studios with camel cased NaMe included
func createStudios(ctx context.Context, sqb models.StudioReaderWriter, n int, o int) error { func createStudios(ctx context.Context, n int, o int) error {
sqb := db.Studio
const namePlain = "Name" const namePlain = "Name"
const nameNoCase = "NaMe" const nameNoCase = "NaMe"
@ -1618,22 +1619,18 @@ func createStudios(ctx context.Context, sqb models.StudioReaderWriter, n int, o
name = getStudioStringValue(index, name) name = getStudioStringValue(index, name)
studio := models.Studio{ studio := models.Studio{
Name: name, Name: name,
URL: getStudioNullStringValue(index, urlField), URL: getStudioStringValue(index, urlField),
IgnoreAutoTag: getIgnoreAutoTag(i), IgnoreAutoTag: getIgnoreAutoTag(i),
} }
err := createStudioFromModel(ctx, sqb, &studio)
if err != nil {
return err
}
// add alias
// only add aliases for some scenes // only add aliases for some scenes
if i == studioIdxWithMovie || i%5 == 0 { if i == studioIdxWithMovie || i%5 == 0 {
alias := getStudioStringValue(i, "Alias") alias := getStudioStringValue(i, "Alias")
if err := sqb.UpdateAliases(ctx, studio.ID, []string{alias}); err != nil { studio.Aliases = models.NewRelatedStrings([]string{alias})
return fmt.Errorf("error setting studio alias: %s", err.Error()) }
} err := createStudioFromModel(ctx, sqb, &studio)
if err != nil {
return err
} }
studioIDs = append(studioIDs, studio.ID) studioIDs = append(studioIDs, studio.ID)
@ -1756,12 +1753,14 @@ func linkMovieStudios(ctx context.Context, mqb models.MovieWriter) error {
}) })
} }
func linkStudiosParent(ctx context.Context, qb models.StudioWriter) error { func linkStudiosParent(ctx context.Context) error {
qb := db.Studio
return doLinks(studioParentLinks, func(parentIndex, childIndex int) error { return doLinks(studioParentLinks, func(parentIndex, childIndex int) error {
studio := models.StudioPartial{ input := &models.StudioPartial{
ID: studioIDs[childIndex],
ParentID: models.NewOptionalInt(studioIDs[parentIndex]), ParentID: models.NewOptionalInt(studioIDs[parentIndex]),
} }
_, err := qb.UpdatePartial(ctx, studioIDs[childIndex], studio) _, err := qb.UpdatePartial(ctx, *input)
return err return err
}) })

View file

@ -5,7 +5,6 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"strings"
"github.com/doug-martin/goqu/v9" "github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp" "github.com/doug-martin/goqu/v9/exp"
@ -15,14 +14,16 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil/intslice" "github.com/stashapp/stash/pkg/sliceutil/intslice"
"github.com/stashapp/stash/pkg/studio"
) )
const ( const (
studioTable = "studios" studioTable = "studios"
studioIDColumn = "studio_id" studioIDColumn = "studio_id"
studioAliasesTable = "studio_aliases" studioAliasesTable = "studio_aliases"
studioAliasColumn = "alias" studioAliasColumn = "alias"
studioParentIDColumn = "parent_id"
studioNameColumn = "name"
studioImageBlobColumn = "image_blob" studioImageBlobColumn = "image_blob"
) )
@ -39,7 +40,7 @@ type studioRow struct {
IgnoreAutoTag bool `db:"ignore_auto_tag"` IgnoreAutoTag bool `db:"ignore_auto_tag"`
// not used in resolutions or updates // not used in resolutions or updates
CoverBlob zero.String `db:"image_blob"` ImageBlob zero.String `db:"image_blob"`
} }
func (r *studioRow) fromStudio(o models.Studio) { func (r *studioRow) fromStudio(o models.Studio) {
@ -116,6 +117,8 @@ func (qb *StudioStore) selectDataset() *goqu.SelectDataset {
} }
func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) error { func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) error {
var err error
var r studioRow var r studioRow
r.fromStudio(*newObject) r.fromStudio(*newObject)
@ -124,34 +127,66 @@ func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) err
return err return err
} }
updated, err := qb.find(ctx, id) if newObject.Aliases.Loaded() {
if err := studio.EnsureAliasesUnique(ctx, id, newObject.Aliases.List(), qb); err != nil {
return err
}
if err := studiosAliasesTableMgr.insertJoins(ctx, id, newObject.Aliases.List()); err != nil {
return err
}
}
if newObject.StashIDs.Loaded() {
if err := studiosStashIDsTableMgr.insertJoins(ctx, id, newObject.StashIDs.List()); err != nil {
return err
}
}
updated, _ := qb.find(ctx, id)
if err != nil { if err != nil {
return fmt.Errorf("finding after create: %w", err) return fmt.Errorf("finding after create: %w", err)
} }
*newObject = *updated *newObject = *updated
return nil return nil
} }
func (qb *StudioStore) UpdatePartial(ctx context.Context, id int, partial models.StudioPartial) (*models.Studio, error) { func (qb *StudioStore) UpdatePartial(ctx context.Context, input models.StudioPartial) (*models.Studio, error) {
r := studioRowRecord{ r := studioRowRecord{
updateRecord{ updateRecord{
Record: make(exp.Record), Record: make(exp.Record),
}, },
} }
r.fromPartial(partial) r.fromPartial(input)
if len(r.Record) > 0 { if len(r.Record) > 0 {
if err := qb.tableMgr.updateByID(ctx, id, r.Record); err != nil { if err := qb.tableMgr.updateByID(ctx, input.ID, r.Record); err != nil {
return nil, err return nil, err
} }
} }
return qb.find(ctx, id) if input.Aliases != nil {
if err := studio.EnsureAliasesUnique(ctx, input.ID, input.Aliases.Values, qb); err != nil {
return nil, err
}
if err := studiosAliasesTableMgr.modifyJoins(ctx, input.ID, input.Aliases.Values, input.Aliases.Mode); err != nil {
return nil, err
}
}
if input.StashIDs != nil {
if err := studiosStashIDsTableMgr.modifyJoins(ctx, input.ID, input.StashIDs.StashIDs, input.StashIDs.Mode); err != nil {
return nil, err
}
}
return qb.Find(ctx, input.ID)
} }
// This is only used by the Import/Export functionality
func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio) error { func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio) error {
var r studioRow var r studioRow
r.fromStudio(*updatedObject) r.fromStudio(*updatedObject)
@ -160,6 +195,18 @@ func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio)
return err return err
} }
if updatedObject.Aliases.Loaded() {
if err := studiosAliasesTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.Aliases.List()); err != nil {
return err
}
}
if updatedObject.StashIDs.Loaded() {
if err := studiosStashIDsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.StashIDs.List()); err != nil {
return err
}
}
return nil return nil
} }
@ -257,10 +304,22 @@ func (qb *StudioStore) getMany(ctx context.Context, q *goqu.SelectDataset) ([]*m
return ret, nil return ret, nil
} }
func (qb *StudioStore) findBySubquery(ctx context.Context, sq *goqu.SelectDataset) ([]*models.Studio, error) {
table := qb.table()
q := qb.selectDataset().Where(
table.Col(idColumn).Eq(
sq,
),
)
return qb.getMany(ctx, q)
}
func (qb *StudioStore) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) { func (qb *StudioStore) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) {
// SELECT studios.* FROM studios WHERE studios.parent_id = ? // SELECT studios.* FROM studios WHERE studios.parent_id = ?
table := qb.table() table := qb.table()
sq := qb.selectDataset().Where(table.Col("parent_id").Eq(id)) sq := qb.selectDataset().Where(table.Col(studioParentIDColumn).Eq(id))
ret, err := qb.getMany(ctx, sq) ret, err := qb.getMany(ctx, sq)
if err != nil { if err != nil {
@ -309,13 +368,44 @@ func (qb *StudioStore) FindByName(ctx context.Context, name string, nocase bool)
} }
func (qb *StudioStore) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error) { func (qb *StudioStore) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error) {
query := selectAll("studios") + ` sq := dialect.From(studiosStashIDsJoinTable).Select(studiosStashIDsJoinTable.Col(studioIDColumn)).Where(
LEFT JOIN studio_stash_ids on studio_stash_ids.studio_id = studios.id studiosStashIDsJoinTable.Col("stash_id").Eq(stashID.StashID),
WHERE studio_stash_ids.stash_id = ? studiosStashIDsJoinTable.Col("endpoint").Eq(stashID.Endpoint),
AND studio_stash_ids.endpoint = ? )
` ret, err := qb.findBySubquery(ctx, sq)
args := []interface{}{stashID.StashID, stashID.Endpoint}
return qb.queryStudios(ctx, query, args) if err != nil {
return nil, fmt.Errorf("getting studios for stash ID %s: %w", stashID.StashID, err)
}
return ret, nil
}
func (qb *StudioStore) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Studio, error) {
table := qb.table()
sq := dialect.From(table).LeftJoin(
studiosStashIDsJoinTable,
goqu.On(table.Col(idColumn).Eq(studiosStashIDsJoinTable.Col(studioIDColumn))),
).Select(table.Col(idColumn))
if hasStashID {
sq = sq.Where(
studiosStashIDsJoinTable.Col("stash_id").IsNotNull(),
studiosStashIDsJoinTable.Col("endpoint").Eq(stashboxEndpoint),
)
} else {
sq = sq.Where(
studiosStashIDsJoinTable.Col("stash_id").IsNull(),
)
}
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting studios for stash-box endpoint %s: %w", stashboxEndpoint, err)
}
return ret, nil
} }
func (qb *StudioStore) Count(ctx context.Context) (int, error) { func (qb *StudioStore) Count(ctx context.Context) (int, error) {
@ -325,38 +415,37 @@ func (qb *StudioStore) Count(ctx context.Context) (int, error) {
func (qb *StudioStore) All(ctx context.Context) ([]*models.Studio, error) { func (qb *StudioStore) All(ctx context.Context) ([]*models.Studio, error) {
table := qb.table() table := qb.table()
return qb.getMany(ctx, qb.selectDataset().Order(table.Col(studioNameColumn).Asc()))
return qb.getMany(ctx, qb.selectDataset().Order(
table.Col("name").Asc(),
table.Col(idColumn).Asc(),
))
} }
func (qb *StudioStore) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) { func (qb *StudioStore) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) {
// TODO - Query needs to be changed to support queries of this type, and // TODO - Query needs to be changed to support queries of this type, and
// this method should be removed // this method should be removed
query := selectAll(studioTable) table := qb.table()
query += " LEFT JOIN studio_aliases ON studio_aliases.studio_id = studios.id" sq := dialect.From(table).Select(table.Col(idColumn)).LeftJoin(
studiosAliasesJoinTable,
goqu.On(studiosAliasesJoinTable.Col(studioIDColumn).Eq(table.Col(idColumn))),
)
var whereClauses []string var whereClauses []exp.Expression
var args []interface{}
for _, w := range words { for _, w := range words {
ww := w + "%" whereClauses = append(whereClauses, table.Col(studioNameColumn).Like(w+"%"))
whereClauses = append(whereClauses, "studios.name like ?") whereClauses = append(whereClauses, studiosAliasesJoinTable.Col("alias").Like(w+"%"))
args = append(args, ww)
// include aliases
whereClauses = append(whereClauses, "studio_aliases.alias like ?")
args = append(args, ww)
} }
whereOr := "(" + strings.Join(whereClauses, " OR ") + ")" sq = sq.Where(
where := strings.Join([]string{ goqu.Or(whereClauses...),
"studios.ignore_auto_tag = 0", table.Col("ignore_auto_tag").Eq(0),
whereOr, )
}, " AND ")
return qb.queryStudios(ctx, query+" WHERE "+where, args) ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting performers for autotag: %w", err)
}
return ret, nil
} }
func (qb *StudioStore) validateFilter(filter *models.StudioFilterType) error { func (qb *StudioStore) validateFilter(filter *models.StudioFilterType) error {
@ -430,13 +519,13 @@ func (qb *StudioStore) makeFilter(ctx context.Context, studioFilter *models.Stud
query.handleCriterion(ctx, studioGalleryCountCriterionHandler(qb, studioFilter.GalleryCount)) query.handleCriterion(ctx, studioGalleryCountCriterionHandler(qb, studioFilter.GalleryCount))
query.handleCriterion(ctx, studioParentCriterionHandler(qb, studioFilter.Parents)) query.handleCriterion(ctx, studioParentCriterionHandler(qb, studioFilter.Parents))
query.handleCriterion(ctx, studioAliasCriterionHandler(qb, studioFilter.Aliases)) query.handleCriterion(ctx, studioAliasCriterionHandler(qb, studioFilter.Aliases))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.CreatedAt, "studios.created_at")) query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.CreatedAt, studioTable+".created_at"))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.UpdatedAt, "studios.updated_at")) query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.UpdatedAt, studioTable+".updated_at"))
return query return query
} }
func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) { func (qb *StudioStore) makeQuery(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) {
if studioFilter == nil { if studioFilter == nil {
studioFilter = &models.StudioFilterType{} studioFilter = &models.StudioFilterType{}
} }
@ -450,20 +539,29 @@ func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFil
if q := findFilter.Q; q != nil && *q != "" { if q := findFilter.Q; q != nil && *q != "" {
query.join(studioAliasesTable, "", "studio_aliases.studio_id = studios.id") query.join(studioAliasesTable, "", "studio_aliases.studio_id = studios.id")
searchColumns := []string{"studios.name", "studio_aliases.alias"} searchColumns := []string{"studios.name", "studio_aliases.alias"}
query.parseQueryString(searchColumns, *q) query.parseQueryString(searchColumns, *q)
} }
if err := qb.validateFilter(studioFilter); err != nil { if err := qb.validateFilter(studioFilter); err != nil {
return nil, 0, err return nil, err
} }
filter := qb.makeFilter(ctx, studioFilter) filter := qb.makeFilter(ctx, studioFilter)
if err := query.addFilter(filter); err != nil { if err := query.addFilter(filter); err != nil {
return nil, 0, err return nil, err
} }
query.sortAndPagination = qb.getStudioSort(findFilter) + getPagination(findFilter) query.sortAndPagination = qb.getStudioSort(findFilter) + getPagination(findFilter)
return &query, nil
}
func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
query, err := qb.makeQuery(ctx, studioFilter, findFilter)
if err != nil {
return nil, 0, err
}
idsResult, countResult, err := query.executeFind(ctx) idsResult, countResult, err := query.executeFind(ctx)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@ -546,7 +644,7 @@ func studioAliasCriterionHandler(qb *StudioStore, alias *models.StringCriterionI
joinTable: studioAliasesTable, joinTable: studioAliasesTable,
stringColumn: studioAliasColumn, stringColumn: studioAliasColumn,
addJoinTable: func(f *filterBuilder) { addJoinTable: func(f *filterBuilder) {
qb.aliasRepository().join(f, "", "studios.id") studiosAliasesTableMgr.join(f, "", "studios.id")
}, },
} }
@ -581,26 +679,6 @@ func (qb *StudioStore) getStudioSort(findFilter *models.FindFilterType) string {
return sortQuery return sortQuery
} }
func (qb *StudioStore) queryStudios(ctx context.Context, query string, args []interface{}) ([]*models.Studio, error) {
const single = false
var ret []*models.Studio
if err := qb.queryFunc(ctx, query, args, single, func(r *sqlx.Rows) error {
var f studioRow
if err := r.StructScan(&f); err != nil {
return err
}
s := f.resolve()
ret = append(ret, s)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
func (qb *StudioStore) GetImage(ctx context.Context, studioID int) ([]byte, error) { func (qb *StudioStore) GetImage(ctx context.Context, studioID int) ([]byte, error) {
return qb.blobJoinQueryBuilder.GetImage(ctx, studioID, studioImageBlobColumn) return qb.blobJoinQueryBuilder.GetImage(ctx, studioID, studioImageBlobColumn)
} }
@ -628,28 +706,9 @@ func (qb *StudioStore) stashIDRepository() *stashIDRepository {
} }
func (qb *StudioStore) GetStashIDs(ctx context.Context, studioID int) ([]models.StashID, error) { func (qb *StudioStore) GetStashIDs(ctx context.Context, studioID int) ([]models.StashID, error) {
return qb.stashIDRepository().get(ctx, studioID) return studiosStashIDsTableMgr.get(ctx, studioID)
}
func (qb *StudioStore) UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error {
return qb.stashIDRepository().replace(ctx, studioID, stashIDs)
}
func (qb *StudioStore) aliasRepository() *stringRepository {
return &stringRepository{
repository: repository{
tx: qb.tx,
tableName: studioAliasesTable,
idColumn: studioIDColumn,
},
stringColumn: studioAliasColumn,
}
} }
func (qb *StudioStore) GetAliases(ctx context.Context, studioID int) ([]string, error) { func (qb *StudioStore) GetAliases(ctx context.Context, studioID int) ([]string, error) {
return qb.aliasRepository().get(ctx, studioID) return studiosAliasesTableMgr.get(ctx, studioID)
}
func (qb *StudioStore) UpdateAliases(ctx context.Context, studioID int, aliases []string) error {
return qb.aliasRepository().replace(ctx, studioID, aliases)
} }

View file

@ -219,18 +219,15 @@ func TestStudioQueryForAutoTag(t *testing.T) {
assert.Len(t, studios, 1) assert.Len(t, studios, 1)
assert.Equal(t, strings.ToLower(studioNames[studioIdxWithMovie]), strings.ToLower(studios[0].Name)) assert.Equal(t, strings.ToLower(studioNames[studioIdxWithMovie]), strings.ToLower(studios[0].Name))
// find by alias
name = getStudioStringValue(studioIdxWithMovie, "Alias") name = getStudioStringValue(studioIdxWithMovie, "Alias")
studios, err = tqb.QueryForAutoTag(ctx, []string{name}) studios, err = tqb.QueryForAutoTag(ctx, []string{name})
if err != nil { if err != nil {
t.Errorf("Error finding studios: %s", err.Error()) t.Errorf("Error finding studios: %s", err.Error())
} }
if assert.Len(t, studios, 1) { if assert.Len(t, studios, 1) {
assert.Equal(t, studioIDs[studioIdxWithMovie], studios[0].ID) assert.Equal(t, studioIDs[studioIdxWithMovie], studios[0].ID)
} }
return nil return nil
}) })
} }
@ -363,11 +360,12 @@ func TestStudioUpdateClearParent(t *testing.T) {
sqb := db.Studio sqb := db.Studio
// clear the parent id from the child // clear the parent id from the child
updatePartial := models.StudioPartial{ input := models.StudioPartial{
ID: createdChild.ID,
ParentID: models.NewOptionalIntPtr(nil), ParentID: models.NewOptionalIntPtr(nil),
} }
updatedStudio, err := sqb.UpdatePartial(ctx, createdChild.ID, updatePartial) updatedStudio, err := sqb.UpdatePartial(ctx, input)
if err != nil { if err != nil {
return fmt.Errorf("Error updated studio: %s", err.Error()) return fmt.Errorf("Error updated studio: %s", err.Error())
@ -548,7 +546,7 @@ func verifyStudiosGalleryCount(t *testing.T, galleryCountCriterion models.IntCri
} }
func TestStudioStashIDs(t *testing.T) { func TestStudioStashIDs(t *testing.T) {
if err := withTxn(func(ctx context.Context) error { if err := withRollbackTxn(func(ctx context.Context) error {
qb := db.Studio qb := db.Studio
// create studio to test against // create studio to test against
@ -558,13 +556,83 @@ func TestStudioStashIDs(t *testing.T) {
return fmt.Errorf("Error creating studio: %s", err.Error()) return fmt.Errorf("Error creating studio: %s", err.Error())
} }
testStashIDReaderWriter(ctx, t, qb, created.ID) studio, err := qb.Find(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting studio: %s", err.Error())
}
if err := studio.LoadStashIDs(ctx, qb); err != nil {
return err
}
testStudioStashIDs(ctx, t, studio)
return nil return nil
}); err != nil { }); err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
} }
func testStudioStashIDs(ctx context.Context, t *testing.T, s *models.Studio) {
qb := db.Studio
if err := s.LoadStashIDs(ctx, qb); err != nil {
t.Error(err.Error())
return
}
// ensure no stash IDs to begin with
assert.Len(t, s.StashIDs.List(), 0)
// add stash ids
const stashIDStr = "stashID"
const endpoint = "endpoint"
stashID := models.StashID{
StashID: stashIDStr,
Endpoint: endpoint,
}
// update stash ids and ensure was updated
input := models.StudioPartial{
ID: s.ID,
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{stashID},
Mode: models.RelationshipUpdateModeSet,
},
}
var err error
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadStashIDs(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Equal(t, []models.StashID{stashID}, s.StashIDs.List())
// remove stash ids and ensure was updated
input = models.StudioPartial{
ID: s.ID,
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{stashID},
Mode: models.RelationshipUpdateModeRemove,
},
}
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadStashIDs(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Len(t, s.StashIDs.List(), 0)
}
func TestStudioQueryURL(t *testing.T) { func TestStudioQueryURL(t *testing.T) {
const sceneIdx = 1 const sceneIdx = 1
studioURL := getStudioStringValue(sceneIdx, urlField) studioURL := getStudioStringValue(sceneIdx, urlField)
@ -684,7 +752,7 @@ func TestStudioQueryIsMissingRating(t *testing.T) {
assert.True(t, len(studios) > 0) assert.True(t, len(studios) > 0)
for _, studio := range studios { for _, studio := range studios {
assert.True(t, studio.Rating == nil) assert.Nil(t, studio.Rating)
} }
return nil return nil
@ -778,36 +846,87 @@ func TestStudioQueryAlias(t *testing.T) {
verifyStudioQuery(t, studioFilter, verifyFn) verifyStudioQuery(t, studioFilter, verifyFn)
} }
func TestStudioUpdateAlias(t *testing.T) { func TestStudioAlias(t *testing.T) {
if err := withTxn(func(ctx context.Context) error { if err := withRollbackTxn(func(ctx context.Context) error {
qb := db.Studio qb := db.Studio
// create studio to test against // create studio to test against
const name = "TestStudioUpdateAlias" const name = "TestStudioAlias"
created, err := createStudio(ctx, qb, name, nil) created, err := createStudio(ctx, db.Studio, name, nil)
if err != nil { if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error()) return fmt.Errorf("Error creating studio: %s", err.Error())
} }
aliases := []string{"alias1", "alias2"} studio, err := qb.Find(ctx, created.ID)
err = qb.UpdateAliases(ctx, created.ID, aliases)
if err != nil { if err != nil {
return fmt.Errorf("Error updating studio aliases: %s", err.Error()) return fmt.Errorf("Error getting studio: %s", err.Error())
} }
// ensure aliases set if err := studio.LoadStashIDs(ctx, qb); err != nil {
storedAliases, err := qb.GetAliases(ctx, created.ID) return err
if err != nil {
return fmt.Errorf("Error getting aliases: %s", err.Error())
} }
assert.Equal(t, aliases, storedAliases)
testStudioAlias(ctx, t, studio)
return nil return nil
}); err != nil { }); err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
} }
func testStudioAlias(ctx context.Context, t *testing.T, s *models.Studio) {
qb := db.Studio
if err := s.LoadAliases(ctx, qb); err != nil {
t.Error(err.Error())
return
}
// ensure no alias to begin with
assert.Len(t, s.Aliases.List(), 0)
aliases := []string{"alias1", "alias2"}
// update alias and ensure was updated
input := models.StudioPartial{
ID: s.ID,
Aliases: &models.UpdateStrings{
Values: aliases,
Mode: models.RelationshipUpdateModeSet,
},
}
var err error
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadAliases(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Equal(t, aliases, s.Aliases.List())
// remove alias and ensure was updated
input = models.StudioPartial{
ID: s.ID,
Aliases: &models.UpdateStrings{
Values: aliases,
Mode: models.RelationshipUpdateModeRemove,
},
}
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadAliases(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Len(t, s.Aliases.List(), 0)
}
// TestStudioQueryFast does a quick test for major errors, no result verification // TestStudioQueryFast does a quick test for major errors, no result verification
func TestStudioQueryFast(t *testing.T) { func TestStudioQueryFast(t *testing.T) {

View file

@ -29,6 +29,9 @@ var (
performersAliasesJoinTable = goqu.T(performersAliasesTable) performersAliasesJoinTable = goqu.T(performersAliasesTable)
performersTagsJoinTable = goqu.T(performersTagsTable) performersTagsJoinTable = goqu.T(performersTagsTable)
performersStashIDsJoinTable = goqu.T("performer_stash_ids") performersStashIDsJoinTable = goqu.T("performer_stash_ids")
studiosAliasesJoinTable = goqu.T(studioAliasesTable)
studiosStashIDsJoinTable = goqu.T("studio_stash_ids")
) )
var ( var (
@ -233,6 +236,21 @@ var (
table: goqu.T(studioTable), table: goqu.T(studioTable),
idColumn: goqu.T(studioTable).Col(idColumn), idColumn: goqu.T(studioTable).Col(idColumn),
} }
studiosAliasesTableMgr = &stringTable{
table: table{
table: studiosAliasesJoinTable,
idColumn: studiosAliasesJoinTable.Col(studioIDColumn),
},
stringColumn: studiosAliasesJoinTable.Col(studioAliasColumn),
}
studiosStashIDsTableMgr = &stashIDTable{
table: table{
table: studiosStashIDsJoinTable,
idColumn: studiosStashIDsJoinTable.Col(studioIDColumn),
},
}
) )
var ( var (

View file

@ -11,15 +11,15 @@ import (
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type FinderImageStashIDGetter interface { type FinderImageAliasStashIDGetter interface {
Finder Finder
GetAliases(ctx context.Context, studioID int) ([]string, error)
GetImage(ctx context.Context, studioID int) ([]byte, error) GetImage(ctx context.Context, studioID int) ([]byte, error)
models.AliasLoader
models.StashIDLoader models.StashIDLoader
} }
// ToJSON converts a Studio object into its JSON equivalent. // ToJSON converts a Studio object into its JSON equivalent.
func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models.Studio) (*jsonschema.Studio, error) { func ToJSON(ctx context.Context, reader FinderImageAliasStashIDGetter, studio *models.Studio) (*jsonschema.Studio, error) {
newStudioJSON := jsonschema.Studio{ newStudioJSON := jsonschema.Studio{
Name: studio.Name, Name: studio.Name,
URL: studio.URL, URL: studio.URL,
@ -44,12 +44,15 @@ func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models
newStudioJSON.Rating = *studio.Rating newStudioJSON.Rating = *studio.Rating
} }
aliases, err := reader.GetAliases(ctx, studio.ID) if err := studio.LoadAliases(ctx, reader); err != nil {
if err != nil { return nil, fmt.Errorf("loading studio aliases: %w", err)
return nil, fmt.Errorf("error getting studio aliases: %v", err)
} }
newStudioJSON.Aliases = studio.Aliases.List()
newStudioJSON.Aliases = aliases if err := studio.LoadStashIDs(ctx, reader); err != nil {
return nil, fmt.Errorf("loading studio stash ids: %w", err)
}
newStudioJSON.StashIDs = studio.StashIDs.List()
image, err := reader.GetImage(ctx, studio.ID) image, err := reader.GetImage(ctx, studio.ID)
if err != nil { if err != nil {
@ -60,17 +63,5 @@ func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models
newStudioJSON.Image = utils.GetBase64StringFromData(image) newStudioJSON.Image = utils.GetBase64StringFromData(image)
} }
stashIDs, _ := reader.GetStashIDs(ctx, studio.ID)
var ret []models.StashID
for _, stashID := range stashIDs {
newJoin := models.StashID{
StashID: stashID.StashID,
Endpoint: stashID.Endpoint,
}
ret = append(ret, newJoin)
}
newStudioJSON.StashIDs = ret
return &newStudioJSON, nil return &newStudioJSON, nil
} }

View file

@ -15,12 +15,10 @@ import (
) )
const ( const (
studioID = 1
noImageID = 2 noImageID = 2
errImageID = 3 errImageID = 3
missingParentStudioID = 4 missingParentStudioID = 4
errStudioID = 5 errStudioID = 5
errAliasID = 6
parentStudioID = 10 parentStudioID = 10
missingStudioID = 11 missingStudioID = 11
@ -31,17 +29,19 @@ var (
studioName = "testStudio" studioName = "testStudio"
url = "url" url = "url"
details = "details" details = "details"
rating = 5
parentStudioName = "parentStudio" parentStudioName = "parentStudio"
autoTagIgnored = true autoTagIgnored = true
) )
var studioID = 1
var rating = 5
var parentStudio models.Studio = models.Studio{ var parentStudio models.Studio = models.Studio{
Name: parentStudioName, Name: parentStudioName,
} }
var imageBytes = []byte("imageBytes") var imageBytes = []byte("imageBytes")
var aliases = []string{"alias"}
var stashID = models.StashID{ var stashID = models.StashID{
StashID: "StashID", StashID: "StashID",
Endpoint: "Endpoint", Endpoint: "Endpoint",
@ -67,6 +67,8 @@ func createFullStudio(id int, parentID int) models.Studio {
UpdatedAt: updateTime, UpdatedAt: updateTime,
Rating: &rating, Rating: &rating,
IgnoreAutoTag: autoTagIgnored, IgnoreAutoTag: autoTagIgnored,
Aliases: models.NewRelatedStrings(aliases),
StashIDs: models.NewRelatedStashIDs(stashIDs),
} }
if parentID != 0 { if parentID != 0 {
@ -81,6 +83,8 @@ func createEmptyStudio(id int) models.Studio {
ID: id, ID: id,
CreatedAt: createTime, CreatedAt: createTime,
UpdatedAt: updateTime, UpdatedAt: updateTime,
Aliases: models.NewRelatedStrings([]string{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
} }
} }
@ -95,13 +99,11 @@ func createFullJSONStudio(parentStudio, image string, aliases []string) *jsonsch
UpdatedAt: json.JSONTime{ UpdatedAt: json.JSONTime{
Time: updateTime, Time: updateTime,
}, },
ParentStudio: parentStudio, ParentStudio: parentStudio,
Image: image, Image: image,
Rating: rating, Rating: rating,
Aliases: aliases, Aliases: aliases,
StashIDs: []models.StashID{ StashIDs: stashIDs,
stashID,
},
IgnoreAutoTag: autoTagIgnored, IgnoreAutoTag: autoTagIgnored,
} }
} }
@ -114,6 +116,8 @@ func createEmptyJSONStudio() *jsonschema.Studio {
UpdatedAt: json.JSONTime{ UpdatedAt: json.JSONTime{
Time: updateTime, Time: updateTime,
}, },
Aliases: []string{},
StashIDs: []models.StashID{},
} }
} }
@ -139,13 +143,13 @@ func initTestTable() {
}, },
{ {
createFullStudio(errImageID, parentStudioID), createFullStudio(errImageID, parentStudioID),
createFullJSONStudio(parentStudioName, "", nil), createFullJSONStudio(parentStudioName, "", []string{"alias"}),
// failure to get image is not an error // failure to get image is not an error
false, false,
}, },
{ {
createFullStudio(missingParentStudioID, missingStudioID), createFullStudio(missingParentStudioID, missingStudioID),
createFullJSONStudio("", image, nil), createFullJSONStudio("", image, []string{"alias"}),
false, false,
}, },
{ {
@ -153,11 +157,6 @@ func initTestTable() {
nil, nil,
true, true,
}, },
{
createFullStudio(errAliasID, parentStudioID),
nil,
true,
},
} }
} }
@ -174,7 +173,6 @@ func TestToJSON(t *testing.T) {
mockStudioReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once() mockStudioReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once()
mockStudioReader.On("GetImage", ctx, missingParentStudioID).Return(imageBytes, nil).Maybe() mockStudioReader.On("GetImage", ctx, missingParentStudioID).Return(imageBytes, nil).Maybe()
mockStudioReader.On("GetImage", ctx, errStudioID).Return(imageBytes, nil).Maybe() mockStudioReader.On("GetImage", ctx, errStudioID).Return(imageBytes, nil).Maybe()
mockStudioReader.On("GetImage", ctx, errAliasID).Return(imageBytes, nil).Maybe()
parentStudioErr := errors.New("error getting parent studio") parentStudioErr := errors.New("error getting parent studio")
@ -182,19 +180,6 @@ func TestToJSON(t *testing.T) {
mockStudioReader.On("Find", ctx, missingStudioID).Return(nil, nil) mockStudioReader.On("Find", ctx, missingStudioID).Return(nil, nil)
mockStudioReader.On("Find", ctx, errParentStudioID).Return(nil, parentStudioErr) mockStudioReader.On("Find", ctx, errParentStudioID).Return(nil, parentStudioErr)
aliasErr := errors.New("error getting aliases")
mockStudioReader.On("GetAliases", ctx, studioID).Return([]string{"alias"}, nil).Once()
mockStudioReader.On("GetAliases", ctx, noImageID).Return(nil, nil).Once()
mockStudioReader.On("GetAliases", ctx, errImageID).Return(nil, nil).Once()
mockStudioReader.On("GetAliases", ctx, missingParentStudioID).Return(nil, nil).Once()
mockStudioReader.On("GetAliases", ctx, errAliasID).Return(nil, aliasErr).Once()
mockStudioReader.On("GetStashIDs", ctx, studioID).Return(stashIDs, nil).Once()
mockStudioReader.On("GetStashIDs", ctx, noImageID).Return(nil, nil).Once()
mockStudioReader.On("GetStashIDs", ctx, missingParentStudioID).Return(stashIDs, nil).Once()
mockStudioReader.On("GetStashIDs", ctx, errImageID).Return(stashIDs, nil).Once()
for i, s := range scenarios { for i, s := range scenarios {
studio := s.input studio := s.input
json, err := ToJSON(ctx, mockStudioReader, &studio) json, err := ToJSON(ctx, mockStudioReader, &studio)

View file

@ -14,8 +14,6 @@ type NameFinderCreatorUpdater interface {
NameFinderCreator NameFinderCreator
Update(ctx context.Context, updatedStudio *models.Studio) error Update(ctx context.Context, updatedStudio *models.Studio) error
UpdateImage(ctx context.Context, studioID int, image []byte) error UpdateImage(ctx context.Context, studioID int, image []byte) error
UpdateAliases(ctx context.Context, studioID int, aliases []string) error
UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error
} }
var ErrParentStudioNotExist = errors.New("parent studio does not exist") var ErrParentStudioNotExist = errors.New("parent studio does not exist")
@ -25,20 +23,13 @@ type Importer struct {
Input jsonschema.Studio Input jsonschema.Studio
MissingRefBehaviour models.ImportMissingRefEnum MissingRefBehaviour models.ImportMissingRefEnum
ID int
studio models.Studio studio models.Studio
imageData []byte imageData []byte
} }
func (i *Importer) PreImport(ctx context.Context) error { func (i *Importer) PreImport(ctx context.Context) error {
i.studio = models.Studio{ i.studio = studioJSONtoStudio(i.Input)
Name: i.Input.Name,
URL: i.Input.URL,
Details: i.Input.Details,
IgnoreAutoTag: i.Input.IgnoreAutoTag,
CreatedAt: i.Input.CreatedAt.GetTime(),
UpdatedAt: i.Input.UpdatedAt.GetTime(),
Rating: &i.Input.Rating,
}
if err := i.populateParentStudio(ctx); err != nil { if err := i.populateParentStudio(ctx); err != nil {
return err return err
@ -87,7 +78,9 @@ func (i *Importer) populateParentStudio(ctx context.Context) error {
} }
func (i *Importer) createParentStudio(ctx context.Context, name string) (int, error) { func (i *Importer) createParentStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name) newStudio := &models.Studio{
Name: name,
}
err := i.ReaderWriter.Create(ctx, newStudio) err := i.ReaderWriter.Create(ctx, newStudio)
if err != nil { if err != nil {
@ -104,16 +97,6 @@ func (i *Importer) PostImport(ctx context.Context, id int) error {
} }
} }
if len(i.Input.StashIDs) > 0 {
if err := i.ReaderWriter.UpdateStashIDs(ctx, id, i.Input.StashIDs); err != nil {
return fmt.Errorf("error setting stash id: %v", err)
}
}
if err := i.ReaderWriter.UpdateAliases(ctx, id, i.Input.Aliases); err != nil {
return fmt.Errorf("error setting tag aliases: %v", err)
}
return nil return nil
} }
@ -156,3 +139,23 @@ func (i *Importer) Update(ctx context.Context, id int) error {
return nil return nil
} }
func studioJSONtoStudio(studioJSON jsonschema.Studio) models.Studio {
newStudio := models.Studio{
Name: studioJSON.Name,
URL: studioJSON.URL,
Aliases: models.NewRelatedStrings(studioJSON.Aliases),
Details: studioJSON.Details,
IgnoreAutoTag: studioJSON.IgnoreAutoTag,
CreatedAt: studioJSON.CreatedAt.GetTime(),
UpdatedAt: studioJSON.UpdatedAt.GetTime(),
StashIDs: models.NewRelatedStashIDs(studioJSON.StashIDs),
}
if studioJSON.Rating != 0 {
newStudio.Rating = &studioJSON.Rating
}
return newStudio
}

View file

@ -164,15 +164,9 @@ func TestImporterPostImport(t *testing.T) {
} }
updateStudioImageErr := errors.New("UpdateImage error") updateStudioImageErr := errors.New("UpdateImage error")
updateTagAliasErr := errors.New("UpdateAlias error")
readerWriter.On("UpdateImage", ctx, studioID, imageBytes).Return(nil).Once() readerWriter.On("UpdateImage", ctx, studioID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", ctx, errImageID, imageBytes).Return(updateStudioImageErr).Once() readerWriter.On("UpdateImage", ctx, errImageID, imageBytes).Return(updateStudioImageErr).Once()
readerWriter.On("UpdateImage", ctx, errAliasID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateAliases", ctx, studioID, i.Input.Aliases).Return(nil).Once()
readerWriter.On("UpdateAliases", ctx, errImageID, i.Input.Aliases).Return(nil).Maybe()
readerWriter.On("UpdateAliases", ctx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
err := i.PostImport(ctx, studioID) err := i.PostImport(ctx, studioID)
assert.Nil(t, err) assert.Nil(t, err)
@ -180,9 +174,6 @@ func TestImporterPostImport(t *testing.T) {
err = i.PostImport(ctx, errImageID) err = i.PostImport(ctx, errImageID)
assert.NotNil(t, err) assert.NotNil(t, err)
err = i.PostImport(ctx, errAliasID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t) readerWriter.AssertExpectations(t)
} }

View file

@ -14,6 +14,12 @@ type Queryer interface {
Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error)
} }
type FinderQueryer interface {
Finder
Queryer
models.AliasLoader
}
func ByName(ctx context.Context, qb Queryer, name string) (*models.Studio, error) { func ByName(ctx context.Context, qb Queryer, name string) (*models.Studio, error) {
f := &models.StudioFilterType{ f := &models.StudioFilterType{
Name: &models.StringCriterionInput{ Name: &models.StringCriterionInput{

View file

@ -2,11 +2,16 @@ package studio
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
var (
ErrStudioOwnAncestor = errors.New("studio cannot be an ancestor of itself")
)
type NameFinderCreator interface { type NameFinderCreator interface {
FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error) FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error)
Create(ctx context.Context, newStudio *models.Studio) error Create(ctx context.Context, newStudio *models.Studio) error
@ -69,3 +74,60 @@ func EnsureAliasesUnique(ctx context.Context, id int, aliases []string, qb Query
return nil return nil
} }
// Checks to make sure that:
// 1. The studio exists locally
// 2. The studio is not its own ancestor
// 3. The studio's aliases are unique
func ValidateModify(ctx context.Context, s models.StudioPartial, qb FinderQueryer) error {
existing, err := qb.Find(ctx, s.ID)
if err != nil {
return err
}
if existing == nil {
return fmt.Errorf("studio with id %d not found", s.ID)
}
newParentID := s.ParentID.Ptr()
if newParentID != nil {
if err := validateParent(ctx, s.ID, *newParentID, qb); err != nil {
return err
}
}
if s.Aliases != nil {
if err := existing.LoadAliases(ctx, qb); err != nil {
return err
}
effectiveAliases := s.Aliases.EffectiveValues(existing.Aliases.List())
if err := EnsureAliasesUnique(ctx, s.ID, effectiveAliases, qb); err != nil {
return err
}
}
return nil
}
func validateParent(ctx context.Context, studioID int, newParentID int, qb FinderQueryer) error {
if newParentID == studioID {
return ErrStudioOwnAncestor
}
// ensure there is no cyclic dependency
parentStudio, err := qb.Find(ctx, newParentID)
if err != nil {
return fmt.Errorf("error finding parent studio: %v", err)
}
if parentStudio == nil {
return fmt.Errorf("studio with id %d not found", newParentID)
}
if parentStudio.ParentID != nil {
return validateParent(ctx, studioID, *parentStudio.ParentID, qb)
}
return nil
}

View file

@ -36,7 +36,7 @@ type config struct {
Markers int `yaml:"markers"` Markers int `yaml:"markers"`
Images int `yaml:"images"` Images int `yaml:"images"`
Galleries int `yaml:"galleries"` Galleries int `yaml:"galleries"`
Chapters int `yaml:"chapters"` Chapters int `yaml:"chapters"`
Performers int `yaml:"performers"` Performers int `yaml:"performers"`
Studios int `yaml:"studios"` Studios int `yaml:"studios"`
Tags int `yaml:"tags"` Tags int `yaml:"tags"`
@ -98,7 +98,7 @@ func populateDB() {
makeScenes(c.Scenes) makeScenes(c.Scenes)
makeImages(c.Images) makeImages(c.Images)
makeGalleries(c.Galleries) makeGalleries(c.Galleries)
makeChapters(c.Chapters) makeChapters(c.Chapters)
makeMarkers(c.Markers) makeMarkers(c.Markers)
} }
@ -504,35 +504,35 @@ func generateGallery(i int) models.Gallery {
} }
func makeChapters(n int) { func makeChapters(n int) {
logf("creating %d chapters...", n) logf("creating %d chapters...", n)
for i := 0; i < n; { for i := 0; i < n; {
// do in batches of 1000 // do in batches of 1000
batch := i + batchSize batch := i + batchSize
if err := withTxn(func(ctx context.Context) error { if err := withTxn(func(ctx context.Context) error {
for ; i < batch && i < n; i++ { for ; i < batch && i < n; i++ {
chapter := generateChapter(i) chapter := generateChapter(i)
chapter.GalleryID = models.NullInt64(int64(getRandomGallery())) chapter.GalleryID = models.NullInt64(int64(getRandomGallery()))
created, err := repo.GalleryChapter.Create(ctx, chapter) created, err := repo.GalleryChapter.Create(ctx, chapter)
if err != nil { if err != nil {
return err return err
} }
} }
logf("... created %d chapters", i) logf("... created %d chapters", i)
return nil return nil
}); err != nil { }); err != nil {
panic(err) panic(err)
} }
} }
} }
func generateChapter(i int) models.GalleryChapter { func generateChapter(i int) models.GalleryChapter {
return models.GalleryChapter{ return models.GalleryChapter{
Title: names[c.Naming.Galleries].generateName(rand.Intn(7) + 1), Title: names[c.Naming.Galleries].generateName(rand.Intn(7) + 1),
ImageIndex: rand.Intn(200), ImageIndex: rand.Intn(200),
} }
} }
func makeMarkers(n int) { func makeMarkers(n int) {
@ -657,7 +657,7 @@ func getRandomScene() int {
} }
func getRandomGallery() int { func getRandomGallery() int {
return rand.Intn(c.Galleries) + 1 return rand.Intn(c.Galleries) + 1
} }
func getRandomTags(ctx context.Context, min, max int) []int { func getRandomTags(ctx context.Context, min, max int) []int {

View file

@ -31,6 +31,8 @@ export const multiValueSceneFields: SceneField[] = [
export function sceneFieldMessageID(field: SceneField) { export function sceneFieldMessageID(field: SceneField) {
if (field === "code") { if (field === "code") {
return "scene_code"; return "scene_code";
} else if (field === "studio") {
return "studio_and_parent";
} }
return field; return field;

View file

@ -58,7 +58,7 @@ export const StudioDetailsPanel: React.FC<IStudioDetailsPanel> = ({
return ( return (
<> <>
<dt> <dt>
<FormattedMessage id="StashIDs" /> <FormattedMessage id="stash_ids" />
</dt> </dt>
<dd> <dd>
<ul className="pl-0"> <ul className="pl-0">

View file

@ -19,6 +19,7 @@ import { DisplayMode } from "src/models/list-filter/types";
import { ExportDialog } from "../Shared/ExportDialog"; import { ExportDialog } from "../Shared/ExportDialog";
import { DeleteEntityDialog } from "../Shared/DeleteEntityDialog"; import { DeleteEntityDialog } from "../Shared/DeleteEntityDialog";
import { StudioCard } from "./StudioCard"; import { StudioCard } from "./StudioCard";
import { StudioTagger } from "../Tagger/studios/StudioTagger";
const StudioItemList = makeItemList({ const StudioItemList = makeItemList({
filterMode: GQL.FilterMode.Studios, filterMode: GQL.FilterMode.Studios,
@ -156,6 +157,9 @@ export const StudioList: React.FC<IStudioList> = ({
if (filter.displayMode === DisplayMode.Wall) { if (filter.displayMode === DisplayMode.Wall) {
return <h1>TODO</h1>; return <h1>TODO</h1>;
} }
if (filter.displayMode === DisplayMode.Tagger) {
return <StudioTagger studios={result.data.findStudios.studios} />;
}
} }
return ( return (

View file

@ -232,7 +232,7 @@ const PerformerModal: React.FC<IPerformerModalProps> = ({
{link && ( {link && (
<h6 className="mt-2"> <h6 className="mt-2">
<a href={link} target="_blank" rel="noopener noreferrer"> <a href={link} target="_blank" rel="noopener noreferrer">
Stash-Box Source <FormattedMessage id="stashbox.source" />
<Icon icon={faExternalLinkAlt} className="ml-2" /> <Icon icon={faExternalLinkAlt} className="ml-2" />
</a> </a>
</h6> </h6>

View file

@ -25,6 +25,7 @@ export const DEFAULT_BLACKLIST = [
"\\]", "\\]",
]; ];
export const DEFAULT_EXCLUDED_PERFORMER_FIELDS = ["name"]; export const DEFAULT_EXCLUDED_PERFORMER_FIELDS = ["name"];
export const DEFAULT_EXCLUDED_STUDIO_FIELDS = ["name"];
export const initialConfig: ITaggerConfig = { export const initialConfig: ITaggerConfig = {
blacklist: DEFAULT_BLACKLIST, blacklist: DEFAULT_BLACKLIST,
@ -35,6 +36,8 @@ export const initialConfig: ITaggerConfig = {
tagOperation: "merge", tagOperation: "merge",
fingerprintQueue: {}, fingerprintQueue: {},
excludedPerformerFields: DEFAULT_EXCLUDED_PERFORMER_FIELDS, excludedPerformerFields: DEFAULT_EXCLUDED_PERFORMER_FIELDS,
excludedStudioFields: DEFAULT_EXCLUDED_STUDIO_FIELDS,
createParentStudios: true,
}; };
export type ParseMode = "auto" | "filename" | "dir" | "path" | "metadata"; export type ParseMode = "auto" | "filename" | "dir" | "path" | "metadata";
@ -49,6 +52,8 @@ export interface ITaggerConfig {
selectedEndpoint?: string; selectedEndpoint?: string;
fingerprintQueue: Record<string, string[]>; fingerprintQueue: Record<string, string[]>;
excludedPerformerFields?: string[]; excludedPerformerFields?: string[];
excludedStudioFields?: string[];
createParentStudios: boolean;
} }
export const PERFORMER_FIELDS = [ export const PERFORMER_FIELDS = [
@ -74,3 +79,5 @@ export const PERFORMER_FIELDS = [
"death_date", "death_date",
"weight", "weight",
]; ];
export const STUDIO_FIELDS = ["name", "image", "url", "parent"];

View file

@ -55,6 +55,7 @@ export interface ITaggerContextState {
studio: GQL.ScrapedStudio, studio: GQL.ScrapedStudio,
toCreate: GQL.StudioCreateInput toCreate: GQL.StudioCreateInput
) => Promise<string | undefined>; ) => Promise<string | undefined>;
updateStudio: (studio: GQL.StudioUpdateInput) => Promise<void>;
linkStudio: (studio: GQL.ScrapedStudio, studioID: string) => Promise<void>; linkStudio: (studio: GQL.ScrapedStudio, studioID: string) => Promise<void>;
resolveScene: ( resolveScene: (
sceneID: string, sceneID: string,
@ -91,6 +92,7 @@ export const TaggerStateContext = React.createContext<ITaggerContextState>({
createNewPerformer: dummyValFn, createNewPerformer: dummyValFn,
linkPerformer: dummyFn, linkPerformer: dummyFn,
createNewStudio: dummyValFn, createNewStudio: dummyValFn,
updateStudio: dummyFn,
linkStudio: dummyFn, linkStudio: dummyFn,
resolveScene: dummyFn, resolveScene: dummyFn,
submitFingerprints: dummyFn, submitFingerprints: dummyFn,
@ -701,6 +703,53 @@ export const TaggerContext: React.FC = ({ children }) => {
} }
} }
async function updateExistingStudio(input: GQL.StudioUpdateInput) {
try {
const result = await updateStudio({
variables: {
input: input,
},
});
const studioID = result.data?.studioUpdate?.id;
const stashID = input.stash_ids?.find((e) => {
return e.endpoint === currentSource?.stashboxEndpoint;
})?.stash_id;
if (stashID) {
const newSearchResults = mapResults((r) => {
if (!r.studio) {
return r;
}
return {
...r,
studio:
r.remote_site_id === stashID
? {
...r.studio,
stored_id: studioID,
}
: r.studio,
};
});
setSearchResults(newSearchResults);
}
Toast.success({
content: (
<span>
Created studio: <b>{input.name}</b>
</span>
),
});
} catch (e) {
Toast.error(e);
}
}
async function linkStudio(studio: GQL.ScrapedStudio, studioID: string) { async function linkStudio(studio: GQL.ScrapedStudio, studioID: string) {
if (!studio.remote_site_id || !currentSource?.stashboxEndpoint) return; if (!studio.remote_site_id || !currentSource?.stashboxEndpoint) return;
@ -780,6 +829,7 @@ export const TaggerContext: React.FC = ({ children }) => {
createNewPerformer, createNewPerformer,
linkPerformer, linkPerformer,
createNewStudio, createNewStudio,
updateStudio: updateExistingStudio,
linkStudio, linkStudio,
resolveScene, resolveScene,
saveScene, saveScene,

View file

@ -112,7 +112,7 @@ const PerformerBatchUpdateModal: React.FC<IPerformerBatchUpdateModal> = ({
type="radio" type="radio"
name="performer-query" name="performer-query"
label={<FormattedMessage id="performer_tagger.current_page" />} label={<FormattedMessage id="performer_tagger.current_page" />}
defaultChecked defaultChecked={!queryAll}
onChange={() => setQueryAll(false)} onChange={() => setQueryAll(false)}
/> />
<Form.Check <Form.Check
@ -123,7 +123,7 @@ const PerformerBatchUpdateModal: React.FC<IPerformerBatchUpdateModal> = ({
id: "performer_tagger.query_all_performers_in_the_database", id: "performer_tagger.query_all_performers_in_the_database",
})} })}
defaultChecked={false} defaultChecked={false}
onChange={() => setQueryAll(true)} onChange={() => setQueryAll(queryAll)}
/> />
</Form.Group> </Form.Group>
<Form.Group> <Form.Group>
@ -139,7 +139,7 @@ const PerformerBatchUpdateModal: React.FC<IPerformerBatchUpdateModal> = ({
label={intl.formatMessage({ label={intl.formatMessage({
id: "performer_tagger.untagged_performers", id: "performer_tagger.untagged_performers",
})} })}
defaultChecked defaultChecked={!refresh}
onChange={() => setRefresh(false)} onChange={() => setRefresh(false)}
/> />
<Form.Text> <Form.Text>
@ -153,7 +153,7 @@ const PerformerBatchUpdateModal: React.FC<IPerformerBatchUpdateModal> = ({
id: "performer_tagger.refresh_tagged_performers", id: "performer_tagger.refresh_tagged_performers",
})} })}
defaultChecked={false} defaultChecked={false}
onChange={() => setRefresh(true)} onChange={() => setRefresh(refresh)}
/> />
<Form.Text> <Form.Text>
<FormattedMessage id="performer_tagger.refreshing_will_update_the_data" /> <FormattedMessage id="performer_tagger.refreshing_will_update_the_data" />
@ -656,9 +656,10 @@ export const PerformerTagger: React.FC<ITaggerProps> = ({ performers }) => {
if (names.length > 0) { if (names.length > 0) {
const ret = await mutateStashBoxBatchPerformerTag({ const ret = await mutateStashBoxBatchPerformerTag({
performer_names: names, names: names,
endpoint: selectedEndpointIndex, endpoint: selectedEndpointIndex,
refresh: false, refresh: false,
createParent: false,
}); });
setBatchJobID(ret.data?.stashBoxBatchPerformerTag); setBatchJobID(ret.data?.stashBoxBatchPerformerTag);
@ -669,10 +670,11 @@ export const PerformerTagger: React.FC<ITaggerProps> = ({ performers }) => {
async function batchUpdate(ids: string[] | undefined, refresh: boolean) { async function batchUpdate(ids: string[] | undefined, refresh: boolean) {
if (config && selectedEndpoint) { if (config && selectedEndpoint) {
const ret = await mutateStashBoxBatchPerformerTag({ const ret = await mutateStashBoxBatchPerformerTag({
performer_ids: ids, ids: ids,
endpoint: selectedEndpointIndex, endpoint: selectedEndpointIndex,
refresh, refresh,
exclude_fields: config.excludedPerformerFields ?? [], exclude_fields: config.excludedPerformerFields ?? [],
createParent: false,
}); });
setBatchJobID(ret.data?.stashBoxBatchPerformerTag); setBatchJobID(ret.data?.stashBoxBatchPerformerTag);

View file

@ -1,5 +1,10 @@
import * as GQL from "src/core/generated-graphql"; import * as GQL from "src/core/generated-graphql";
import sortBy from "lodash-es/sortBy"; import sortBy from "lodash-es/sortBy";
import {
evictQueries,
getClient,
studioMutationImpactedQueries,
} from "src/core/StashService";
export const useUpdatePerformerStashID = () => { export const useUpdatePerformerStashID = () => {
const [updatePerformer] = GQL.usePerformerUpdateMutation({ const [updatePerformer] = GQL.usePerformerUpdateMutation({
@ -204,6 +209,54 @@ export const useUpdateStudioStashID = () => {
return handleUpdate; return handleUpdate;
}; };
export const useUpdateStudio = () => {
const [updateStudio] = GQL.useStudioUpdateMutation({
onError: (errors) => errors,
errorPolicy: "all",
});
const updateStudioHandler = (input: GQL.StudioUpdateInput) =>
updateStudio({
variables: {
input,
},
update: (store, updatedStudio) => {
if (!updatedStudio.data?.studioUpdate) return;
if (updatedStudio.data?.studioUpdate?.parent_studio) {
const ac = getClient();
evictQueries(ac.cache, studioMutationImpactedQueries);
} else {
updatedStudio.data.studioUpdate.stash_ids.forEach((id) => {
store.writeQuery<
GQL.FindStudiosQuery,
GQL.FindStudiosQueryVariables
>({
query: GQL.FindStudiosDocument,
variables: {
studio_filter: {
stash_id: {
value: id.stash_id,
modifier: GQL.CriterionModifier.Equals,
},
},
},
data: {
findStudios: {
count: 1,
studios: [updatedStudio.data!.studioUpdate!],
__typename: "FindStudiosResultType",
},
},
});
});
}
},
});
return updateStudioHandler;
};
export const useCreateStudio = () => { export const useCreateStudio = () => {
const [createStudio] = GQL.useStudioCreateMutation({ const [createStudio] = GQL.useStudioCreateMutation({
onError: (errors) => errors, onError: (errors) => errors,

View file

@ -204,6 +204,7 @@ const StashSearchResult: React.FC<IStashSearchResultProps> = ({
createNewPerformer, createNewPerformer,
linkPerformer, linkPerformer,
createNewStudio, createNewStudio,
updateStudio,
linkStudio, linkStudio,
resolveScene, resolveScene,
currentSource, currentSource,
@ -404,11 +405,32 @@ const StashSearchResult: React.FC<IStashSearchResultProps> = ({
}); });
} }
function showStudioModal(t: GQL.ScrapedStudio) { async function studioModalCallback(
createStudioModal(t, (toCreate) => { studio: GQL.ScrapedStudio,
if (toCreate) { toCreate?: GQL.StudioCreateInput,
createNewStudio(t, toCreate); parentInput?: GQL.StudioCreateInput
) {
if (toCreate) {
if (parentInput && studio.parent) {
if (toCreate.parent_id) {
const parentUpdateData: GQL.StudioUpdateInput = {
...parentInput,
id: toCreate.parent_id,
};
await updateStudio(parentUpdateData);
} else {
const parentID = await createNewStudio(studio.parent, parentInput);
toCreate.parent_id = parentID;
}
} }
createNewStudio(studio, toCreate);
}
}
function showStudioModal(t: GQL.ScrapedStudio) {
createStudioModal(t, (toCreate, parentInput) => {
studioModalCallback(t, toCreate, parentInput);
}); });
} }

View file

@ -1,66 +1,54 @@
import React, { useContext } from "react"; import React, { useState } from "react";
import { FormattedMessage, useIntl } from "react-intl"; import { FormattedMessage, useIntl } from "react-intl";
import cx from "classnames";
import { IconDefinition } from "@fortawesome/fontawesome-svg-core"; import { IconDefinition } from "@fortawesome/fontawesome-svg-core";
import * as GQL from "src/core/generated-graphql"; import * as GQL from "src/core/generated-graphql";
import { useFindStudio } from "src/core/StashService";
import { Icon } from "src/components/Shared/Icon"; import { Icon } from "src/components/Shared/Icon";
import { ModalComponent } from "src/components/Shared/Modal"; import { ModalComponent } from "src/components/Shared/Modal";
import {
faCheck,
faExternalLinkAlt,
faTimes,
} from "@fortawesome/free-solid-svg-icons";
import { Button, Form } from "react-bootstrap";
import { TruncatedText } from "src/components/Shared/TruncatedText"; import { TruncatedText } from "src/components/Shared/TruncatedText";
import { TaggerStateContext } from "../context"; import { excludeFields } from "src/utils/data";
import { faExternalLinkAlt } from "@fortawesome/free-solid-svg-icons";
interface IStudioModalProps { interface IStudioDetailsProps {
studio: GQL.ScrapedSceneStudioDataFragment; studio: GQL.ScrapedSceneStudioDataFragment;
modalVisible: boolean; link?: string;
closeModal: () => void; excluded: Record<string, boolean>;
handleStudioCreate: (input: GQL.StudioCreateInput) => void; toggleField: (field: string) => void;
header: string; isNew?: boolean;
icon: IconDefinition;
} }
const StudioModal: React.FC<IStudioModalProps> = ({ const StudioDetails: React.FC<IStudioDetailsProps> = ({
modalVisible,
studio, studio,
handleStudioCreate, link,
closeModal, excluded,
header, toggleField,
icon, isNew = false,
}) => { }) => {
const { currentSource } = useContext(TaggerStateContext);
const intl = useIntl();
function onSave() {
if (!studio.name) {
throw new Error("studio name must set");
}
const studioData: GQL.StudioCreateInput = {
name: studio.name ?? "",
url: studio.url,
};
// stashid handling code
const remoteSiteID = studio.remote_site_id;
if (remoteSiteID && currentSource?.stashboxEndpoint) {
studioData.stash_ids = [
{
endpoint: currentSource.stashboxEndpoint,
stash_id: remoteSiteID,
},
];
}
handleStudioCreate(studioData);
}
const renderField = ( const renderField = (
id: string, id: string,
text: string | null | undefined, text: string | null | undefined,
isSelectable: boolean = true,
truncate: boolean = true truncate: boolean = true
) => ) =>
text && ( text && (
<div className="row no-gutters"> <div className="row no-gutters">
<div className="col-5 studio-create-modal-field" key={id}> <div className="col-5 studio-create-modal-field" key={id}>
{isSelectable && (
<Button
onClick={() => toggleField(id)}
variant="secondary"
className={excluded[id] ? "text-muted" : "text-success"}
>
<Icon icon={excluded[id] ? faTimes : faCheck} />
</Button>
)}
<strong> <strong>
<FormattedMessage id={id} />: <FormattedMessage id={id} />:
</strong> </strong>
@ -73,8 +61,226 @@ const StudioModal: React.FC<IStudioModalProps> = ({
</div> </div>
); );
const base = currentSource?.stashboxEndpoint?.match(/https?:\/\/.*?\//)?.[0]; return (
<div>
<div className="row">
<div className="col-12 image-selection">
<div className="studio-image">
<Button
onClick={() => toggleField("image")}
variant="secondary"
className={cx(
"studio-image-exclude",
excluded.image ? "text-muted" : "text-success"
)}
>
<Icon icon={excluded.image ? faTimes : faCheck} />
</Button>
<img src={studio.image ?? ""} alt="" />
</div>
</div>
</div>
<div className="row">
<div className="col-12">
{renderField("name", studio.name, !isNew)}
{renderField("url", studio.url)}
{renderField("parent_studio", studio.parent?.name, false)}
{link && (
<h6 className="mt-2">
<a href={link} target="_blank" rel="noopener noreferrer">
<FormattedMessage id="stashbox.source" />
<Icon icon={faExternalLinkAlt} className="ml-2" />
</a>
</h6>
)}
</div>
</div>
</div>
);
};
interface IStudioModalProps {
studio: GQL.ScrapedSceneStudioDataFragment;
modalVisible: boolean;
closeModal: () => void;
handleStudioCreate: (
input: GQL.StudioCreateInput,
parent?: GQL.StudioCreateInput
) => void;
excludedStudioFields?: string[];
header: string;
icon: IconDefinition;
endpoint?: string;
}
const StudioModal: React.FC<IStudioModalProps> = ({
modalVisible,
studio,
handleStudioCreate,
closeModal,
excludedStudioFields = [],
header,
icon,
endpoint,
}) => {
const intl = useIntl();
const [excluded, setExcluded] = useState<Record<string, boolean>>(
excludedStudioFields.reduce(
(dict, field) => ({ ...dict, [field]: true }),
{}
)
);
const toggleField = (name: string) =>
setExcluded({
...excluded,
[name]: !excluded[name],
});
const [parentExcluded, setParentExcluded] = useState<Record<string, boolean>>(
excludedStudioFields.reduce(
(dict, field) => ({ ...dict, [field]: true }),
{}
)
);
const toggleParentField = (name: string) =>
setParentExcluded({
...parentExcluded,
[name]: !parentExcluded[name],
});
const [createParentStudio, setCreateParentStudio] = useState<boolean>(
!!studio.parent
);
let sendParentStudio = true;
// The parent studio exists, need to check if it has a Stash ID.
const queryResult = useFindStudio(studio.parent?.stored_id ?? "");
if (
queryResult.data?.findStudio?.stash_ids?.length &&
queryResult.data?.findStudio?.stash_ids?.length > 0
) {
// It already has a Stash ID, so we can skip worrying about it
sendParentStudio = false;
}
const parentStudioCreateText = () => {
if (studio.parent && studio.parent.stored_id) {
return "actions.assign_stashid_to_parent_studio";
}
return "actions.create_parent_studio";
};
function onSave() {
if (!studio.name) {
throw new Error("studio name must set");
}
const studioData: GQL.StudioCreateInput & {
[index: string]: unknown;
} = {
name: studio.name,
url: studio.url,
image: studio.image,
parent_id: studio.parent?.stored_id,
};
// stashid handling code
const remoteSiteID = studio.remote_site_id;
if (remoteSiteID && endpoint) {
studioData.stash_ids = [
{
endpoint,
stash_id: remoteSiteID,
},
];
}
// handle exclusions
excludeFields(studioData, excluded);
let parentData:
| (GQL.StudioCreateInput & {
[index: string]: unknown;
})
| undefined = undefined;
if (createParentStudio && sendParentStudio) {
if (!studio.parent?.name) {
throw new Error("parent studio name must set");
}
parentData = {
name: studio.parent?.name,
url: studio.parent?.url,
image: studio.parent?.image,
};
// stashid handling code
const parentRemoteSiteID = studio.parent?.remote_site_id;
if (parentRemoteSiteID && endpoint) {
parentData.stash_ids = [
{
endpoint,
stash_id: parentRemoteSiteID,
},
];
}
// handle exclusions
// Can't exclude parent studio name when creating a new one
parentExcluded.name = false;
excludeFields(parentData, parentExcluded);
}
handleStudioCreate(studioData, parentData);
}
const base = endpoint?.match(/https?:\/\/.*?\//)?.[0];
const link = base ? `${base}studios/${studio.remote_site_id}` : undefined; const link = base ? `${base}studios/${studio.remote_site_id}` : undefined;
const parentLink = base
? `${base}studios/${studio.parent?.remote_site_id}`
: undefined;
function maybeRenderParentStudio() {
// There is no parent studio or it already has a Stash ID
if (!studio.parent || !sendParentStudio) {
return;
}
return (
<div>
<div className="mb-4 mt-4">
<Form.Check
id="create-parent"
checked={createParentStudio}
label={intl.formatMessage({
id: parentStudioCreateText(),
})}
onChange={() => setCreateParentStudio(!createParentStudio)}
/>
</div>
{maybeRenderParentStudioDetails()}
</div>
);
}
function maybeRenderParentStudioDetails() {
if (!createParentStudio || !studio.parent) {
return;
}
return (
<StudioDetails
studio={studio.parent}
excluded={parentExcluded}
toggleField={(field) => toggleParentField(field)}
link={parentLink}
isNew
/>
);
}
return ( return (
<ModalComponent <ModalComponent
@ -83,33 +289,20 @@ const StudioModal: React.FC<IStudioModalProps> = ({
text: intl.formatMessage({ id: "actions.save" }), text: intl.formatMessage({ id: "actions.save" }),
onClick: onSave, onClick: onSave,
}} }}
onHide={() => closeModal()}
cancel={{ onClick: () => closeModal(), variant: "secondary" }} cancel={{ onClick: () => closeModal(), variant: "secondary" }}
onHide={() => closeModal()}
dialogClassName="studio-create-modal"
icon={icon} icon={icon}
header={header} header={header}
> >
<div className="row"> <StudioDetails
<div className="col-12"> studio={studio}
{renderField("name", studio.name)} excluded={excluded}
{renderField("url", studio.url)} toggleField={(field) => toggleField(field)}
{link && ( link={link}
<h6 className="mt-2"> />
<a href={link} target="_blank" rel="noopener noreferrer">
Stash-Box Source
<Icon icon={faExternalLinkAlt} className="ml-2" />
</a>
</h6>
)}
</div>
</div>
{/* TODO - add image */} {maybeRenderParentStudio()}
{/* <div className="row">
<strong className="col-2">Logo:</strong>
<span className="col-10">
<img src={studio?.image ?? ""} alt="" />
</span>
</div> */}
</ModalComponent> </ModalComponent>
); );
}; };

View file

@ -8,16 +8,19 @@ import { useIntl } from "react-intl";
import { faTags } from "@fortawesome/free-solid-svg-icons"; import { faTags } from "@fortawesome/free-solid-svg-icons";
type PerformerModalCallback = (toCreate?: GQL.PerformerCreateInput) => void; type PerformerModalCallback = (toCreate?: GQL.PerformerCreateInput) => void;
type StudioModalCallback = (toCreate?: GQL.StudioCreateInput) => void; type StudioModalCallback = (
toCreate?: GQL.StudioCreateInput,
parentInput?: GQL.StudioCreateInput
) => void;
export interface ISceneTaggerModalsContextState { export interface ISceneTaggerModalsContextState {
createPerformerModal: ( createPerformerModal: (
performer: GQL.ScrapedPerformerDataFragment, performer: GQL.ScrapedPerformerDataFragment,
callback: (toCreate?: GQL.PerformerCreateInput) => void callback: PerformerModalCallback
) => void; ) => void;
createStudioModal: ( createStudioModal: (
studio: GQL.ScrapedSceneStudioDataFragment, studio: GQL.ScrapedSceneStudioDataFragment,
callback: (toCreate?: GQL.StudioCreateInput) => void callback: StudioModalCallback
) => void; ) => void;
} }
@ -73,9 +76,12 @@ export const SceneTaggerModals: React.FC = ({ children }) => {
setPerformerCallback(() => callback); setPerformerCallback(() => callback);
} }
function handleStudioSave(toCreate: GQL.StudioCreateInput) { function handleStudioSave(
toCreate: GQL.StudioCreateInput,
parentInput?: GQL.StudioCreateInput
) {
if (studioCallback) { if (studioCallback) {
studioCallback(toCreate); studioCallback(toCreate, parentInput);
} }
setStudioToCreate(undefined); setStudioToCreate(undefined);
@ -132,6 +138,7 @@ export const SceneTaggerModals: React.FC = ({ children }) => {
{ id: "actions.create_entity" }, { id: "actions.create_entity" },
{ entityType: intl.formatMessage({ id: "studio" }) } { entityType: intl.formatMessage({ id: "studio" }) }
)} )}
endpoint={endpoint}
/> />
)} )}
{children} {children}

View file

@ -0,0 +1,132 @@
import React, { Dispatch, useState } from "react";
import { Badge, Button, Card, Collapse, Form } from "react-bootstrap";
import { FormattedMessage } from "react-intl";
import { ConfigurationContext } from "src/hooks/Config";
import TextUtils from "src/utils/text";
import { ITaggerConfig, STUDIO_FIELDS } from "../constants";
import StudioFieldSelector from "./StudioFieldSelector";
interface IConfigProps {
show: boolean;
config: ITaggerConfig;
setConfig: Dispatch<ITaggerConfig>;
}
const Config: React.FC<IConfigProps> = ({ show, config, setConfig }) => {
const { configuration: stashConfig } = React.useContext(ConfigurationContext);
const [showExclusionModal, setShowExclusionModal] = useState(false);
const excludedFields = config.excludedStudioFields ?? [];
const handleInstanceSelect = (e: React.ChangeEvent<HTMLSelectElement>) => {
const selectedEndpoint = e.currentTarget.value;
setConfig({
...config,
selectedEndpoint,
});
};
const stashBoxes = stashConfig?.general.stashBoxes ?? [];
const handleFieldSelect = (fields: string[]) => {
setConfig({ ...config, excludedStudioFields: fields });
setShowExclusionModal(false);
};
return (
<>
<Collapse in={show}>
<Card>
<div className="row">
<h4 className="col-12">
<FormattedMessage id="configuration" />
</h4>
<hr className="w-100" />
<div className="col-md-6">
<Form.Group
controlId="create-parent"
className="align-items-center"
>
<Form.Check
label={
<FormattedMessage id="studio_tagger.config.create_parent_label" />
}
checked={config.createParentStudios}
onChange={(e: React.ChangeEvent<HTMLInputElement>) =>
setConfig({
...config,
createParentStudios: e.currentTarget.checked,
})
}
/>
<Form.Text>
<FormattedMessage id="studio_tagger.config.create_parent_desc" />
</Form.Text>
</Form.Group>
<Form.Group controlId="excluded-studio-fields">
<h6>
<FormattedMessage id="studio_tagger.config.excluded_fields" />
</h6>
<span>
{excludedFields.length > 0 ? (
excludedFields.map((f) => (
<Badge variant="secondary" className="tag-item" key={f}>
{TextUtils.capitalize(f)}
</Badge>
))
) : (
<FormattedMessage id="studio_tagger.config.no_fields_are_excluded" />
)}
</span>
<Form.Text>
<FormattedMessage id="studio_tagger.config.these_fields_will_not_be_changed_when_updating_studios" />
</Form.Text>
<Button
onClick={() => setShowExclusionModal(true)}
className="mt-2"
>
<FormattedMessage id="studio_tagger.config.edit_excluded_fields" />
</Button>
</Form.Group>
<Form.Group
controlId="stash-box-endpoint"
className="align-items-center row no-gutters mt-4"
>
<Form.Label className="mr-4">
<FormattedMessage id="studio_tagger.config.active_stash-box_instance" />
</Form.Label>
<Form.Control
as="select"
value={config.selectedEndpoint}
className="col-md-4 col-6 input-control"
disabled={!stashBoxes.length}
onChange={handleInstanceSelect}
>
{!stashBoxes.length && (
<option>
<FormattedMessage id="studio_tagger.config.no_instances_found" />
</option>
)}
{stashConfig?.general.stashBoxes.map((i) => (
<option value={i.endpoint} key={i.endpoint}>
{i.endpoint}
</option>
))}
</Form.Control>
</Form.Group>
</div>
</div>
</Card>
</Collapse>
<StudioFieldSelector
fields={STUDIO_FIELDS}
show={showExclusionModal}
onSelect={handleFieldSelect}
excludedFields={excludedFields}
/>
</>
);
};
export default Config;

View file

@ -0,0 +1,144 @@
import React, { useState } from "react";
import { Button } from "react-bootstrap";
import * as GQL from "src/core/generated-graphql";
import { useUpdateStudio } from "../queries";
import StudioModal from "../scenes/StudioModal";
import { faTags } from "@fortawesome/free-solid-svg-icons";
import { useStudioCreate } from "src/core/StashService";
import { useIntl } from "react-intl";
interface IStashSearchResultProps {
studio: GQL.SlimStudioDataFragment;
stashboxStudios: GQL.ScrapedStudioDataFragment[];
endpoint: string;
onStudioTagged: (
studio: Pick<GQL.SlimStudioDataFragment, "id"> &
Partial<Omit<GQL.SlimStudioDataFragment, "id">>
) => void;
excludedStudioFields: string[];
}
const StashSearchResult: React.FC<IStashSearchResultProps> = ({
studio,
stashboxStudios,
onStudioTagged,
excludedStudioFields,
endpoint,
}) => {
const intl = useIntl();
const [modalStudio, setModalStudio] = useState<
GQL.ScrapedStudioDataFragment | undefined
>();
const [saveState, setSaveState] = useState<string>("");
const [error, setError] = useState<{ message?: string; details?: string }>(
{}
);
const [createStudio] = useStudioCreate();
const updateStudio = useUpdateStudio();
function handleSaveError(name: string, message: string) {
setError({
message: intl.formatMessage(
{ id: "studio_tagger.failed_to_save_studio" },
{ studio: name }
),
details:
message === "UNIQUE constraint failed: studios.checksum"
? "Name already exists"
: message,
});
}
const handleSave = async (
input: GQL.StudioCreateInput,
parentInput?: GQL.StudioCreateInput
) => {
setError({});
setModalStudio(undefined);
if (parentInput) {
setSaveState("Saving parent studio");
try {
// if parent id is set, then update the existing studio
if (input.parent_id) {
const parentUpdateData: GQL.StudioUpdateInput = {
...parentInput,
id: input.parent_id,
};
await updateStudio(parentUpdateData);
} else {
const parentRes = await createStudio({
variables: { input: parentInput },
});
input.parent_id = parentRes.data?.studioCreate?.id;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
handleSaveError(parentInput.name, e.message ?? "");
}
}
setSaveState("Saving studio");
const updateData: GQL.StudioUpdateInput = {
...input,
id: studio.id,
};
const res = await updateStudio(updateData);
if (!res?.data?.studioUpdate)
handleSaveError(studio.name, res?.errors?.[0]?.message ?? "");
else onStudioTagged(studio);
setSaveState("");
};
const studios = stashboxStudios.map((p) => (
<Button
className="StudioTagger-studio-search-item minimal col-6"
variant="link"
key={p.remote_site_id}
onClick={() => setModalStudio(p)}
>
<img src={(p.image ?? [])[0]} alt="" className="StudioTagger-thumb" />
<span>{p.name}</span>
</Button>
));
return (
<>
{modalStudio && (
<StudioModal
closeModal={() => setModalStudio(undefined)}
modalVisible={modalStudio !== undefined}
studio={modalStudio}
handleStudioCreate={handleSave}
icon={faTags}
header="Update Studio"
excludedStudioFields={excludedStudioFields}
endpoint={endpoint}
/>
)}
<div className="StudioTagger-studio-search">{studios}</div>
<div className="row no-gutters mt-2 align-items-center justify-content-end">
{error.message && (
<div className="text-right text-danger mt-1">
<strong>
<span className="mr-2">Error:</span>
{error.message}
</strong>
<div>{error.details}</div>
</div>
)}
{saveState && (
<strong className="col-4 mt-1 mr-2 text-right">{saveState}</strong>
)}
</div>
</>
);
};
export default StashSearchResult;

View file

@ -0,0 +1,67 @@
import { faCheck, faList, faTimes } from "@fortawesome/free-solid-svg-icons";
import React, { useState } from "react";
import { Button, Row, Col } from "react-bootstrap";
import { useIntl } from "react-intl";
import { ModalComponent } from "../../Shared/Modal";
import { Icon } from "../../Shared/Icon";
import TextUtils from "src/utils/text";
interface IProps {
fields: string[];
show: boolean;
excludedFields: string[];
onSelect: (fields: string[]) => void;
}
const StudioFieldSelect: React.FC<IProps> = ({
fields,
show,
excludedFields,
onSelect,
}) => {
const intl = useIntl();
const [excluded, setExcluded] = useState<Record<string, boolean>>(
excludedFields.reduce((dict, field) => ({ ...dict, [field]: true }), {})
);
const toggleField = (name: string) =>
setExcluded({
...excluded,
[name]: !excluded[name],
});
const renderField = (name: string) => (
<Col xs={6} className="mb-1" key={name}>
<Button
onClick={() => toggleField(name)}
variant="secondary"
className={excluded[name] ? "text-muted" : "text-success"}
>
<Icon icon={excluded[name] ? faTimes : faCheck} />
</Button>
<span className="ml-3">{TextUtils.capitalize(name)}</span>
</Col>
);
return (
<ModalComponent
show={show}
icon={faList}
dialogClassName="FieldSelect"
accept={{
text: intl.formatMessage({ id: "actions.save" }),
onClick: () =>
onSelect(Object.keys(excluded).filter((f) => excluded[f])),
}}
>
<h4>Select tagged fields</h4>
<div className="mb-2">
These fields will be tagged by default. Click the button to toggle.
</div>
<Row>{fields.map((f) => renderField(f))}</Row>
</ModalComponent>
);
};
export default StudioFieldSelect;

View file

@ -0,0 +1,870 @@
import React, { useEffect, useMemo, useRef, useState } from "react";
import { Button, Card, Form, InputGroup, ProgressBar } from "react-bootstrap";
import { FormattedMessage, useIntl } from "react-intl";
import { Link } from "react-router-dom";
import { HashLink } from "react-router-hash-link";
import { useLocalForage } from "src/hooks/LocalForage";
import * as GQL from "src/core/generated-graphql";
import { LoadingIndicator } from "src/components/Shared/LoadingIndicator";
import { ModalComponent } from "src/components/Shared/Modal";
import {
stashBoxStudioQuery,
useJobsSubscribe,
mutateStashBoxBatchStudioTag,
getClient,
studioMutationImpactedQueries,
useStudioCreate,
evictQueries,
} from "src/core/StashService";
import { Manual } from "src/components/Help/Manual";
import { ConfigurationContext } from "src/hooks/Config";
import StashSearchResult from "./StashSearchResult";
import StudioConfig from "./Config";
import { LOCAL_FORAGE_KEY, ITaggerConfig, initialConfig } from "../constants";
import StudioModal from "../scenes/StudioModal";
import { useUpdateStudio } from "../queries";
import { faStar, faTags } from "@fortawesome/free-solid-svg-icons";
type JobFragment = Pick<
GQL.Job,
"id" | "status" | "subTasks" | "description" | "progress"
>;
const CLASSNAME = "StudioTagger";
interface IStudioBatchUpdateModal {
studios: GQL.StudioDataFragment[];
isIdle: boolean;
selectedEndpoint: { endpoint: string; index: number };
onBatchUpdate: (queryAll: boolean, refresh: boolean) => void;
batchAddParents: boolean;
setBatchAddParents: (addParents: boolean) => void;
close: () => void;
}
const StudioBatchUpdateModal: React.FC<IStudioBatchUpdateModal> = ({
studios,
isIdle,
selectedEndpoint,
onBatchUpdate,
batchAddParents,
setBatchAddParents,
close,
}) => {
const intl = useIntl();
const [queryAll, setQueryAll] = useState(false);
const [refresh, setRefresh] = useState(false);
const { data: allStudios } = GQL.useFindStudiosQuery({
variables: {
studio_filter: {
stash_id_endpoint: {
endpoint: selectedEndpoint.endpoint,
modifier: refresh
? GQL.CriterionModifier.NotNull
: GQL.CriterionModifier.IsNull,
},
},
filter: {
per_page: 0,
},
},
});
const studioCount = useMemo(() => {
// get all stash ids for the selected endpoint
const filteredStashIDs = studios.map((p) =>
p.stash_ids.filter((s) => s.endpoint === selectedEndpoint.endpoint)
);
return queryAll
? allStudios?.findStudios.count
: filteredStashIDs.filter((s) =>
// if refresh, then we filter out the studios without a stash id
// otherwise, we want untagged studios, filtering out those with a stash id
refresh ? s.length > 0 : s.length === 0
).length;
}, [queryAll, refresh, studios, allStudios, selectedEndpoint.endpoint]);
return (
<ModalComponent
show
icon={faTags}
header={intl.formatMessage({
id: "studio_tagger.update_studios",
})}
accept={{
text: intl.formatMessage({
id: "studio_tagger.update_studios",
}),
onClick: () => onBatchUpdate(queryAll, refresh),
}}
cancel={{
text: intl.formatMessage({ id: "actions.cancel" }),
variant: "danger",
onClick: () => close(),
}}
disabled={!isIdle}
>
<Form.Group>
<Form.Label>
<h6>
<FormattedMessage id="studio_tagger.studio_selection" />
</h6>
</Form.Label>
<Form.Check
id="query-page"
type="radio"
name="studio-query"
label={<FormattedMessage id="studio_tagger.current_page" />}
defaultChecked={!queryAll}
onChange={() => setQueryAll(false)}
/>
<Form.Check
id="query-all"
type="radio"
name="studio-query"
label={intl.formatMessage({
id: "studio_tagger.query_all_studios_in_the_database",
})}
defaultChecked={queryAll}
onChange={() => setQueryAll(true)}
/>
</Form.Group>
<Form.Group>
<Form.Label>
<h6>
<FormattedMessage id="studio_tagger.tag_status" />
</h6>
</Form.Label>
<Form.Check
id="untagged-studios"
type="radio"
name="studio-refresh"
label={intl.formatMessage({
id: "studio_tagger.untagged_studios",
})}
defaultChecked={!refresh}
onChange={() => setRefresh(false)}
/>
<Form.Text>
<FormattedMessage id="studio_tagger.updating_untagged_studios_description" />
</Form.Text>
<Form.Check
id="tagged-studios"
type="radio"
name="studio-refresh"
label={intl.formatMessage({
id: "studio_tagger.refresh_tagged_studios",
})}
defaultChecked={refresh}
onChange={() => setRefresh(true)}
/>
<Form.Text>
<FormattedMessage id="studio_tagger.refreshing_will_update_the_data" />
</Form.Text>
<div className="mt-4">
<Form.Check
id="add-parent"
checked={batchAddParents}
label={intl.formatMessage({
id: "studio_tagger.create_or_tag_parent_studios",
})}
onChange={() => setBatchAddParents(!batchAddParents)}
/>
</div>
</Form.Group>
<b>
<FormattedMessage
id="studio_tagger.number_of_studios_will_be_processed"
values={{
studio_count: studioCount,
}}
/>
</b>
</ModalComponent>
);
};
interface IStudioBatchAddModal {
isIdle: boolean;
onBatchAdd: (input: string) => void;
batchAddParents: boolean;
setBatchAddParents: (addParents: boolean) => void;
close: () => void;
}
const StudioBatchAddModal: React.FC<IStudioBatchAddModal> = ({
isIdle,
onBatchAdd,
batchAddParents,
setBatchAddParents,
close,
}) => {
const intl = useIntl();
const studioInput = useRef<HTMLTextAreaElement | null>(null);
return (
<ModalComponent
show
icon={faStar}
header={intl.formatMessage({
id: "studio_tagger.add_new_studios",
})}
accept={{
text: intl.formatMessage({
id: "studio_tagger.add_new_studios",
}),
onClick: () => {
if (studioInput.current) {
onBatchAdd(studioInput.current.value);
} else {
close();
}
},
}}
cancel={{
text: intl.formatMessage({ id: "actions.cancel" }),
variant: "danger",
onClick: () => close(),
}}
disabled={!isIdle}
>
<Form.Control
className="text-input"
as="textarea"
ref={studioInput}
placeholder={intl.formatMessage({
id: "studio_tagger.studio_names_separated_by_comma",
})}
rows={6}
/>
<Form.Text>
<FormattedMessage id="studio_tagger.any_names_entered_will_be_queried" />
</Form.Text>
<div className="mt-2">
<Form.Check
id="add-parent"
checked={batchAddParents}
label={intl.formatMessage({
id: "studio_tagger.create_or_tag_parent_studios",
})}
onChange={() => setBatchAddParents(!batchAddParents)}
/>
</div>
</ModalComponent>
);
};
interface IStudioTaggerListProps {
studios: GQL.StudioDataFragment[];
selectedEndpoint: { endpoint: string; index: number };
isIdle: boolean;
config: ITaggerConfig;
stashBoxes?: GQL.StashBox[];
onBatchAdd: (studioInput: string, createParent: boolean) => void;
onBatchUpdate: (
ids: string[] | undefined,
refresh: boolean,
createParent: boolean
) => void;
}
const StudioTaggerList: React.FC<IStudioTaggerListProps> = ({
studios,
selectedEndpoint,
isIdle,
config,
stashBoxes,
onBatchAdd,
onBatchUpdate,
}) => {
const intl = useIntl();
const [loading, setLoading] = useState(false);
const [searchResults, setSearchResults] = useState<
Record<string, GQL.ScrapedStudioDataFragment[]>
>({});
const [searchErrors, setSearchErrors] = useState<
Record<string, string | undefined>
>({});
const [taggedStudios, setTaggedStudios] = useState<
Record<string, Partial<GQL.SlimStudioDataFragment>>
>({});
const [queries, setQueries] = useState<Record<string, string>>({});
const [showBatchAdd, setShowBatchAdd] = useState(false);
const [showBatchUpdate, setShowBatchUpdate] = useState(false);
const [batchAddParents, setBatchAddParents] = useState(
config.createParentStudios || false
);
const [error, setError] = useState<
Record<string, { message?: string; details?: string } | undefined>
>({});
const [loadingUpdate, setLoadingUpdate] = useState<string | undefined>();
const [modalStudio, setModalStudio] = useState<
GQL.ScrapedStudioDataFragment | undefined
>();
const doBoxSearch = (studioID: string, searchVal: string) => {
stashBoxStudioQuery(searchVal, selectedEndpoint.index)
.then((queryData) => {
const s = queryData.data?.scrapeSingleStudio ?? [];
setSearchResults({
...searchResults,
[studioID]: s,
});
setSearchErrors({
...searchErrors,
[studioID]: undefined,
});
setLoading(false);
})
.catch(() => {
setLoading(false);
// Destructure to remove existing result
const { [studioID]: unassign, ...results } = searchResults;
setSearchResults(results);
setSearchErrors({
...searchErrors,
[studioID]: intl.formatMessage({
id: "studio_tagger.network_error",
}),
});
});
setLoading(true);
};
const doBoxUpdate = (
studioID: string,
stashID: string,
endpointIndex: number
) => {
setLoadingUpdate(stashID);
setError({
...error,
[studioID]: undefined,
});
stashBoxStudioQuery(stashID, endpointIndex)
.then((queryData) => {
const data = queryData.data?.scrapeSingleStudio ?? [];
if (data.length > 0) {
setModalStudio({
...data[0],
stored_id: studioID,
});
}
})
.finally(() => setLoadingUpdate(undefined));
};
async function handleBatchAdd(input: string) {
onBatchAdd(input, batchAddParents);
setShowBatchAdd(false);
}
const handleBatchUpdate = (queryAll: boolean, refresh: boolean) => {
onBatchUpdate(
!queryAll ? studios.map((p) => p.id) : undefined,
refresh,
batchAddParents
);
setShowBatchUpdate(false);
};
const handleTaggedStudio = (
studio: Pick<GQL.SlimStudioDataFragment, "id"> &
Partial<Omit<GQL.SlimStudioDataFragment, "id">>
) => {
setTaggedStudios({
...taggedStudios,
[studio.id]: studio,
});
};
const [createStudio] = useStudioCreate();
const updateStudio = useUpdateStudio();
function handleSaveError(studioID: string, name: string, message: string) {
setError({
...error,
[studioID]: {
message: intl.formatMessage(
{ id: "studio_tagger.failed_to_save_studio" },
{ studio: modalStudio?.name }
),
details:
message === "UNIQUE constraint failed: studios.checksum"
? intl.formatMessage({
id: "studio_tagger.name_already_exists",
})
: message,
},
});
}
const handleStudioUpdate = async (
input: GQL.StudioCreateInput,
parentInput?: GQL.StudioCreateInput
) => {
setModalStudio(undefined);
const studioID = modalStudio?.stored_id;
if (studioID) {
if (parentInput) {
try {
// if parent id is set, then update the existing studio
if (input.parent_id) {
const parentUpdateData: GQL.StudioUpdateInput = {
...parentInput,
id: input.parent_id,
};
await updateStudio(parentUpdateData);
} else {
const parentRes = await createStudio({
variables: { input: parentInput },
});
input.parent_id = parentRes.data?.studioCreate?.id;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
handleSaveError(studioID, parentInput.name, e.message ?? "");
}
}
const updateData: GQL.StudioUpdateInput = {
...input,
id: studioID,
};
const res = await updateStudio(updateData);
if (!res.data?.studioUpdate)
handleSaveError(
studioID,
modalStudio?.name ?? "",
res?.errors?.[0]?.message ?? ""
);
}
};
const renderStudios = () =>
studios.map((studio) => {
const isTagged = taggedStudios[studio.id];
const stashID = studio.stash_ids.find((s) => {
return s.endpoint === selectedEndpoint.endpoint;
});
let mainContent;
if (!isTagged && stashID !== undefined) {
mainContent = (
<div className="text-left">
<h5 className="text-bold">
<FormattedMessage id="studio_tagger.studio_already_tagged" />
</h5>
</div>
);
} else if (!isTagged && !stashID) {
mainContent = (
<InputGroup>
<Form.Control
className="text-input"
defaultValue={studio.name ?? ""}
onChange={(e) =>
setQueries({
...queries,
[studio.id]: e.currentTarget.value,
})
}
onKeyPress={(e: React.KeyboardEvent<HTMLInputElement>) =>
e.key === "Enter" &&
doBoxSearch(studio.id, queries[studio.id] ?? studio.name ?? "")
}
/>
<InputGroup.Append>
<Button
disabled={loading}
onClick={() =>
doBoxSearch(
studio.id,
queries[studio.id] ?? studio.name ?? ""
)
}
>
<FormattedMessage id="actions.search" />
</Button>
</InputGroup.Append>
</InputGroup>
);
} else if (isTagged) {
mainContent = (
<div className="d-flex flex-column text-left">
<h5>
<FormattedMessage id="studio_tagger.studio_successfully_tagged" />
</h5>
</div>
);
}
let subContent;
if (stashID !== undefined) {
const base = stashID.endpoint.match(/https?:\/\/.*?\//)?.[0];
const link = base ? (
<a
className="small d-block"
href={`${base}studios/${stashID.stash_id}`}
target="_blank"
rel="noopener noreferrer"
>
{stashID.stash_id}
</a>
) : (
<div className="small">{stashID.stash_id}</div>
);
const endpointIndex =
stashBoxes?.findIndex((box) => box.endpoint === stashID.endpoint) ??
-1;
subContent = (
<div key={studio.id}>
<InputGroup className="StudioTagger-box-link">
<InputGroup.Text>{link}</InputGroup.Text>
<InputGroup.Append>
{endpointIndex !== -1 && (
<Button
onClick={() =>
doBoxUpdate(studio.id, stashID.stash_id, endpointIndex)
}
disabled={!!loadingUpdate}
>
{loadingUpdate === stashID.stash_id ? (
<LoadingIndicator inline small message="" />
) : (
<FormattedMessage id="actions.refresh" />
)}
</Button>
)}
</InputGroup.Append>
</InputGroup>
{error[studio.id] && (
<div className="text-danger mt-1">
<strong>
<span className="mr-2">Error:</span>
{error[studio.id]?.message}
</strong>
<div>{error[studio.id]?.details}</div>
</div>
)}
</div>
);
} else if (searchErrors[studio.id]) {
subContent = (
<div className="text-danger font-weight-bold">
{searchErrors[studio.id]}
</div>
);
} else if (searchResults[studio.id]?.length === 0) {
subContent = (
<div className="text-danger font-weight-bold">
<FormattedMessage id="studio_tagger.no_results_found" />
</div>
);
}
let searchResult;
if (searchResults[studio.id]?.length > 0 && !isTagged) {
searchResult = (
<StashSearchResult
key={studio.id}
stashboxStudios={searchResults[studio.id]}
studio={studio}
endpoint={selectedEndpoint.endpoint}
onStudioTagged={handleTaggedStudio}
excludedStudioFields={config.excludedStudioFields ?? []}
/>
);
}
return (
<div key={studio.id} className={`${CLASSNAME}-studio`}>
{modalStudio && (
<StudioModal
closeModal={() => setModalStudio(undefined)}
modalVisible={modalStudio.stored_id === studio.id}
studio={modalStudio}
handleStudioCreate={handleStudioUpdate}
excludedStudioFields={config.excludedStudioFields}
icon={faTags}
header={intl.formatMessage({
id: "studio_tagger.update_studio",
})}
endpoint={selectedEndpoint.endpoint}
/>
)}
<div className={`${CLASSNAME}-details`}>
<div></div>
<div>
<Card className="studio-card">
<img src={studio.image_path ?? ""} alt="" />
</Card>
</div>
<div className={`${CLASSNAME}-details-text`}>
<Link
to={`/studios/${studio.id}`}
className={`${CLASSNAME}-header`}
>
<h2>{studio.name}</h2>
</Link>
{mainContent}
<div className="sub-content text-left">{subContent}</div>
{searchResult}
</div>
</div>
</div>
);
});
return (
<Card>
{showBatchUpdate && (
<StudioBatchUpdateModal
close={() => setShowBatchUpdate(false)}
isIdle={isIdle}
selectedEndpoint={selectedEndpoint}
studios={studios}
onBatchUpdate={handleBatchUpdate}
batchAddParents={batchAddParents}
setBatchAddParents={setBatchAddParents}
/>
)}
{showBatchAdd && (
<StudioBatchAddModal
close={() => setShowBatchAdd(false)}
isIdle={isIdle}
onBatchAdd={handleBatchAdd}
batchAddParents={batchAddParents}
setBatchAddParents={setBatchAddParents}
/>
)}
<div className="ml-auto mb-3">
<Button onClick={() => setShowBatchAdd(true)}>
<FormattedMessage id="studio_tagger.batch_add_studios" />
</Button>
<Button className="ml-3" onClick={() => setShowBatchUpdate(true)}>
<FormattedMessage id="studio_tagger.batch_update_studios" />
</Button>
</div>
<div className={CLASSNAME}>{renderStudios()}</div>
</Card>
);
};
interface ITaggerProps {
studios: GQL.StudioDataFragment[];
}
export const StudioTagger: React.FC<ITaggerProps> = ({ studios }) => {
const jobsSubscribe = useJobsSubscribe();
const intl = useIntl();
const { configuration: stashConfig } = React.useContext(ConfigurationContext);
const [{ data: config }, setConfig] = useLocalForage<ITaggerConfig>(
LOCAL_FORAGE_KEY,
initialConfig
);
const [showConfig, setShowConfig] = useState(false);
const [showManual, setShowManual] = useState(false);
const [batchJobID, setBatchJobID] = useState<string | undefined | null>();
const [batchJob, setBatchJob] = useState<JobFragment | undefined>();
// monitor batch operation
useEffect(() => {
if (!jobsSubscribe.data) {
return;
}
const event = jobsSubscribe.data.jobsSubscribe;
if (event.job.id !== batchJobID) {
return;
}
if (event.type !== GQL.JobStatusUpdateType.Remove) {
setBatchJob(event.job);
} else {
setBatchJob(undefined);
setBatchJobID(undefined);
// Once the studio batch is complete, refresh all local studio data
const ac = getClient();
evictQueries(ac.cache, studioMutationImpactedQueries);
}
}, [jobsSubscribe, batchJobID]);
if (!config) return <LoadingIndicator />;
const savedEndpointIndex =
stashConfig?.general.stashBoxes.findIndex(
(s) => s.endpoint === config.selectedEndpoint
) ?? -1;
const selectedEndpointIndex =
savedEndpointIndex === -1 && stashConfig?.general.stashBoxes.length
? 0
: savedEndpointIndex;
const selectedEndpoint =
stashConfig?.general.stashBoxes[selectedEndpointIndex];
async function batchAdd(studioInput: string, createParent: boolean) {
if (studioInput && selectedEndpoint) {
const names = studioInput
.split(",")
.map((n) => n.trim())
.filter((n) => n.length > 0);
if (names.length > 0) {
const ret = await mutateStashBoxBatchStudioTag({
names: names,
endpoint: selectedEndpointIndex,
refresh: false,
exclude_fields: config?.excludedStudioFields ?? [],
createParent: createParent,
});
setBatchJobID(ret.data?.stashBoxBatchStudioTag);
}
}
}
async function batchUpdate(
ids: string[] | undefined,
refresh: boolean,
createParent: boolean
) {
if (selectedEndpoint) {
const ret = await mutateStashBoxBatchStudioTag({
ids: ids,
endpoint: selectedEndpointIndex,
refresh,
exclude_fields: config?.excludedStudioFields ?? [],
createParent: createParent,
});
setBatchJobID(ret.data?.stashBoxBatchStudioTag);
}
}
// const progress =
// jobStatus.data?.metadataUpdate.status ===
// "Stash-Box Studio Batch Operation" &&
// jobStatus.data.metadataUpdate.progress >= 0
// ? jobStatus.data.metadataUpdate.progress * 100
// : null;
function renderStatus() {
if (batchJob) {
const progress =
batchJob.progress !== undefined && batchJob.progress !== null
? batchJob.progress * 100
: undefined;
return (
<Form.Group className="px-4">
<h5>
<FormattedMessage id="studio_tagger.status_tagging_studios" />
</h5>
{progress !== undefined && (
<ProgressBar
animated
now={progress}
label={`${progress.toFixed(0)}%`}
/>
)}
</Form.Group>
);
}
if (batchJobID !== undefined) {
return (
<Form.Group className="px-4">
<h5>
<FormattedMessage id="studio_tagger.status_tagging_job_queued" />
</h5>
</Form.Group>
);
}
}
const showHideConfigId = showConfig
? "actions.hide_configuration"
: "actions.show_configuration";
return (
<>
<Manual
show={showManual}
onClose={() => setShowManual(false)}
defaultActiveTab="Tagger.md"
/>
{renderStatus()}
<div className="tagger-container mx-md-auto">
{selectedEndpointIndex !== -1 && selectedEndpoint ? (
<>
<div className="row mb-2 no-gutters">
<Button onClick={() => setShowConfig(!showConfig)} variant="link">
{intl.formatMessage({ id: showHideConfigId })}
</Button>
<Button
className="ml-auto"
onClick={() => setShowManual(true)}
title={intl.formatMessage({ id: "help" })}
variant="link"
>
<FormattedMessage id="help" />
</Button>
</div>
<StudioConfig
config={config}
setConfig={setConfig}
show={showConfig}
/>
<StudioTaggerList
studios={studios}
selectedEndpoint={{
endpoint: selectedEndpoint.endpoint,
index: selectedEndpointIndex,
}}
isIdle={batchJobID === undefined}
config={config}
stashBoxes={stashConfig?.general.stashBoxes}
onBatchAdd={batchAdd}
onBatchUpdate={batchUpdate}
/>
</>
) : (
<div className="my-4">
<h3 className="text-center mt-4">
<FormattedMessage id="studio_tagger.to_use_the_studio_tagger" />
</h3>
<h5 className="text-center">
Please see{" "}
<HashLink
to="/settings?tab=metadata-providers#stash-boxes"
scroll={(el) =>
el.scrollIntoView({ behavior: "smooth", block: "center" })
}
>
Settings.
</HashLink>
</h5>
</div>
)}
</div>
</>
);
};

View file

@ -227,6 +227,128 @@
} }
} }
.studio-create-modal {
font-size: 1.2rem;
max-width: 800px;
.image-selection {
text-align: center;
.studio-image {
height: 85%;
position: relative;
&-exclude {
position: absolute;
right: 20px;
top: 10px;
}
}
img {
max-height: 100%;
max-width: 100%;
}
}
.LoadingIndicator {
height: 100%;
}
&-field {
margin-bottom: 5px;
.btn {
margin-right: 5px;
}
.fa-icon {
width: 12px;
}
}
}
.StudioTagger {
display: flex;
flex-wrap: wrap;
justify-content: center;
max-width: 1600px;
&-header {
color: white;
&:hover {
color: white;
}
}
&-studio {
background-color: #495b68;
border-radius: 3px;
display: flex;
margin: 1rem;
max-width: 100%;
padding: 1rem;
.studio-card {
box-shadow: none;
flex-shrink: 0;
margin: 0;
padding: 0;
img {
background-color: #495b68;
max-height: 150px;
object-fit: contain;
vertical-align: middle;
width: 100%;
}
}
}
&-details {
//flex-grow: 1;
display: flex;
flex-direction: column;
justify-content: space-between;
margin: 0.5rem;
width: 24rem;
}
&-details-image {
vertical-align: bottom;
}
&-details-text {
vertical-align: bottom;
}
&-studio-search {
display: flex;
flex-wrap: wrap;
&-item {
align-items: center;
display: flex;
overflow: hidden;
text-align: left;
}
}
&-thumb {
height: 40px;
margin-right: 10px;
}
&-box-link {
margin-bottom: 5px;
.input-group-text {
font-family: monospace;
}
}
}
.FieldSelect { .FieldSelect {
.fa-icon { .fa-icon {
width: 12px; width: 12px;

View file

@ -19,7 +19,10 @@ export const getClient = () => client;
// Evicts cached results for the given queries. // Evicts cached results for the given queries.
// Will also call a cache GC afterwards. // Will also call a cache GC afterwards.
function evictQueries(cache: ApolloCache<unknown>, queries: DocumentNode[]) { export function evictQueries(
cache: ApolloCache<unknown>,
queries: DocumentNode[]
) {
const fields: Modifiers = {}; const fields: Modifiers = {};
for (const query of queries) { for (const query of queries) {
const { selections } = getQueryDefinition(query).selectionSet; const { selections } = getQueryDefinition(query).selectionSet;
@ -111,7 +114,7 @@ function deleteObject(
/// Object queries /// Object queries
export const useFindScene = (id: string) => { export const useFindScene = (id: string) => {
const skip = id === "new"; const skip = id === "new" || id === "";
return GQL.useFindSceneQuery({ variables: { id }, skip }); return GQL.useFindSceneQuery({ variables: { id }, skip });
}; };
@ -172,7 +175,7 @@ export const queryFindImages = (filter: ListFilterModel) =>
}); });
export const useFindMovie = (id: string) => { export const useFindMovie = (id: string) => {
const skip = id === "new"; const skip = id === "new" || id === "";
return GQL.useFindMovieQuery({ variables: { id }, skip }); return GQL.useFindMovieQuery({ variables: { id }, skip });
}; };
@ -217,7 +220,7 @@ export const queryFindSceneMarkers = (filter: ListFilterModel) =>
export const useMarkerStrings = () => GQL.useMarkerStringsQuery(); export const useMarkerStrings = () => GQL.useMarkerStringsQuery();
export const useFindGallery = (id: string) => { export const useFindGallery = (id: string) => {
const skip = id === "new"; const skip = id === "new" || id === "";
return GQL.useFindGalleryQuery({ variables: { id }, skip }); return GQL.useFindGalleryQuery({ variables: { id }, skip });
}; };
@ -240,7 +243,7 @@ export const queryFindGalleries = (filter: ListFilterModel) =>
}); });
export const useFindPerformer = (id: string) => { export const useFindPerformer = (id: string) => {
const skip = id === "new"; const skip = id === "new" || id === "";
return GQL.useFindPerformerQuery({ variables: { id }, skip }); return GQL.useFindPerformerQuery({ variables: { id }, skip });
}; };
@ -272,7 +275,7 @@ export const useAllPerformersForFilter = () =>
GQL.useAllPerformersForFilterQuery(); GQL.useAllPerformersForFilterQuery();
export const useFindStudio = (id: string) => { export const useFindStudio = (id: string) => {
const skip = id === "new"; const skip = id === "new" || id === "";
return GQL.useFindStudioQuery({ variables: { id }, skip }); return GQL.useFindStudioQuery({ variables: { id }, skip });
}; };
@ -303,7 +306,7 @@ export const queryFindStudios = (filter: ListFilterModel) =>
export const useAllStudiosForFilter = () => GQL.useAllStudiosForFilterQuery(); export const useAllStudiosForFilter = () => GQL.useAllStudiosForFilterQuery();
export const useFindTag = (id: string) => { export const useFindTag = (id: string) => {
const skip = id === "new"; const skip = id === "new" || id === "";
return GQL.useFindTagQuery({ variables: { id }, skip }); return GQL.useFindTagQuery({ variables: { id }, skip });
}; };
@ -1475,7 +1478,7 @@ const studioMutationImpactedTypeFields = {
Studio: ["child_studios"], Studio: ["child_studios"],
}; };
const studioMutationImpactedQueries = [ export const studioMutationImpactedQueries = [
GQL.FindScenesDocument, // filter by studio GQL.FindScenesDocument, // filter by studio
GQL.FindImagesDocument, // filter by studio GQL.FindImagesDocument, // filter by studio
GQL.FindMoviesDocument, // filter by studio GQL.FindMoviesDocument, // filter by studio
@ -1868,16 +1871,42 @@ export const stashBoxPerformerQuery = (
query: searchVal, query: searchVal,
}, },
}, },
fetchPolicy: "network-only",
});
export const stashBoxStudioQuery = (
query: string | null,
stashBoxIndex: number
) =>
client.query<GQL.ScrapeSingleStudioQuery>({
query: GQL.ScrapeSingleStudioDocument,
variables: {
source: {
stash_box_index: stashBoxIndex,
},
input: {
query: query,
},
},
fetchPolicy: "network-only",
}); });
export const mutateStashBoxBatchPerformerTag = ( export const mutateStashBoxBatchPerformerTag = (
input: GQL.StashBoxBatchPerformerTagInput input: GQL.StashBoxBatchTagInput
) => ) =>
client.mutate<GQL.StashBoxBatchPerformerTagMutation>({ client.mutate<GQL.StashBoxBatchPerformerTagMutation>({
mutation: GQL.StashBoxBatchPerformerTagDocument, mutation: GQL.StashBoxBatchPerformerTagDocument,
variables: { input }, variables: { input },
}); });
export const mutateStashBoxBatchStudioTag = (
input: GQL.StashBoxBatchTagInput
) =>
client.mutate<GQL.StashBoxBatchStudioTagMutation>({
mutation: GQL.StashBoxBatchStudioTagDocument,
variables: { input },
});
export const useListMovieScrapers = () => GQL.useListMovieScrapersQuery(); export const useListMovieScrapers = () => GQL.useListMovieScrapersQuery();
export const queryScrapeMovieURL = (url: string) => export const queryScrapeMovieURL = (url: string) =>

View file

@ -8,6 +8,7 @@
"allow_temporarily": "Allow temporarily", "allow_temporarily": "Allow temporarily",
"anonymise": "Anonymise", "anonymise": "Anonymise",
"apply": "Apply", "apply": "Apply",
"assign_stashid_to_parent_studio": "Assign Stash ID to existing parent studio and update metadata",
"auto_tag": "Auto Tag", "auto_tag": "Auto Tag",
"backup": "Backup", "backup": "Backup",
"browse_for_image": "Browse for image…", "browse_for_image": "Browse for image…",
@ -24,6 +25,7 @@
"create_chapters": "Create Chapter", "create_chapters": "Create Chapter",
"create_entity": "Create {entityType}", "create_entity": "Create {entityType}",
"create_marker": "Create Marker", "create_marker": "Create Marker",
"create_parent_studio": "Create parent studio",
"created_entity": "Created {entity_type}: {entity_name}", "created_entity": "Created {entity_type}: {entity_name}",
"customise": "Customise", "customise": "Customise",
"delete": "Delete", "delete": "Delete",
@ -1040,6 +1042,7 @@
"previous": "Previous" "previous": "Previous"
}, },
"parent_of": "Parent of {children}", "parent_of": "Parent of {children}",
"parent_studio": "Parent Studio",
"parent_studios": "Parent Studios", "parent_studios": "Parent Studios",
"parent_tag_count": "Parent Tag Count", "parent_tag_count": "Parent Tag Count",
"parent_tags": "Parent Tags", "parent_tags": "Parent Tags",
@ -1221,6 +1224,7 @@
"stashbox": { "stashbox": {
"go_review_draft": "Go to {endpoint_name} to review draft.", "go_review_draft": "Go to {endpoint_name} to review draft.",
"selected_stash_box": "Selected Stash-Box endpoint", "selected_stash_box": "Selected Stash-Box endpoint",
"source": "Stash-Box Source",
"submission_failed": "Submission failed", "submission_failed": "Submission failed",
"submission_successful": "Submission successful", "submission_successful": "Submission successful",
"submit_update": "Already exists in {endpoint_name}" "submit_update": "Already exists in {endpoint_name}"
@ -1237,7 +1241,46 @@
}, },
"status": "Status: {statusText}", "status": "Status: {statusText}",
"studio": "Studio", "studio": "Studio",
"studio_and_parent": "Studio & Parent",
"studio_depth": "Levels (empty for all)", "studio_depth": "Levels (empty for all)",
"studio_tagger": {
"add_new_studios": "Add New Studios",
"any_names_entered_will_be_queried": "Any names entered will be queried from the remote Stash-Box instance and added if found. Only exact matches will be considered a match.",
"batch_add_studios": "Batch Add Studios",
"batch_update_studios": "Batch Update Studios",
"config": {
"active_stash-box_instance": "Active stash-box instance:",
"create_parent_desc": "Create missing parent studios, or tag and update data/image for existing parent studios with exact name matches",
"create_parent_label": "Create parent studios",
"edit_excluded_fields": "Edit Excluded Fields",
"excluded_fields": "Excluded fields:",
"no_fields_are_excluded": "No fields are excluded",
"no_instances_found": "No instances found",
"these_fields_will_not_be_changed_when_updating_studios": "These fields will not be changed when updating studios."
},
"create_or_tag_parent_studios": "Create missing or tag existing parent studios",
"current_page": "Current page",
"failed_to_save_studio": "Failed to save studio \"{studio}\"",
"name_already_exists": "Name already exists",
"network_error": "Network Error",
"no_results_found": "No results found.",
"number_of_studios_will_be_processed": "{studio_count} studios will be processed",
"studio_already_tagged": "Studio already tagged",
"studio_names_separated_by_comma": "Studio names separated by comma",
"studio_selection": "Studio selection",
"studio_successfully_tagged": "Studio successfully tagged",
"query_all_studios_in_the_database": "All studios in the database",
"refresh_tagged_studios": "Refresh tagged studios",
"refreshing_will_update_the_data": "Refreshing will update the data of any tagged studios from the stash-box instance.",
"status_tagging_job_queued": "Status: Tagging job queued",
"status_tagging_studios": "Status: Tagging studios",
"tag_status": "Tag Status",
"to_use_the_studio_tagger": "To use the studio tagger a stash-box instance needs to be configured.",
"untagged_studios": "Untagged studios",
"update_studio": "Update Studio",
"update_studios": "Update Studios",
"updating_untagged_studios_description": "Updating untagged studios will try to match any studios that lack a stashid and update the metadata."
},
"studios": "Studios", "studios": "Studios",
"sub_tag_count": "Sub-Tag Count", "sub_tag_count": "Sub-Tag Count",
"sub_tag_of": "Sub-tag of {parent}", "sub_tag_of": "Sub-tag of {parent}",
@ -1288,4 +1331,4 @@
"weight_kg": "Weight (kg)", "weight_kg": "Weight (kg)",
"years_old": "years old", "years_old": "years old",
"zip_file_count": "Zip File Count" "zip_file_count": "Zip File Count"
} }

View file

@ -30,7 +30,7 @@ const sortByOptions = ["name", "random", "rating"]
}, },
]); ]);
const displayModeOptions = [DisplayMode.Grid]; const displayModeOptions = [DisplayMode.Grid, DisplayMode.Tagger];
const criterionOptions = [ const criterionOptions = [
createMandatoryStringCriterionOption("name"), createMandatoryStringCriterionOption("name"),
createStringCriterionOption("details"), createStringCriterionOption("details"),

View file

@ -30,3 +30,15 @@ export function withoutTypename<T extends ITypename>(
{} as Omit<T, "__typename"> {} as Omit<T, "__typename">
); );
} }
// excludeFields removes fields from data that are in the excluded object
export function excludeFields(
data: { [index: string]: unknown },
excluded: Record<string, boolean>
) {
Object.keys(data).forEach((k) => {
if (excluded[k] || !data[k]) {
data[k] = undefined;
}
});
}

15
vendor/github.com/gofrs/uuid/.gitignore generated vendored Normal file
View file

@ -0,0 +1,15 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, build with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# binary bundle generated by go-fuzz
uuid-fuzz.zip

20
vendor/github.com/gofrs/uuid/LICENSE generated vendored Normal file
View file

@ -0,0 +1,20 @@
Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

117
vendor/github.com/gofrs/uuid/README.md generated vendored Normal file
View file

@ -0,0 +1,117 @@
# UUID
[![License](https://img.shields.io/github/license/gofrs/uuid.svg)](https://github.com/gofrs/uuid/blob/master/LICENSE)
[![Build Status](https://travis-ci.org/gofrs/uuid.svg?branch=master)](https://travis-ci.org/gofrs/uuid)
[![GoDoc](http://godoc.org/github.com/gofrs/uuid?status.svg)](http://godoc.org/github.com/gofrs/uuid)
[![Coverage Status](https://codecov.io/gh/gofrs/uuid/branch/master/graphs/badge.svg?branch=master)](https://codecov.io/gh/gofrs/uuid/)
[![Go Report Card](https://goreportcard.com/badge/github.com/gofrs/uuid)](https://goreportcard.com/report/github.com/gofrs/uuid)
Package uuid provides a pure Go implementation of Universally Unique Identifiers
(UUID) variant as defined in RFC-4122. This package supports both the creation
and parsing of UUIDs in different formats.
This package supports the following UUID versions:
* Version 1, based on timestamp and MAC address (RFC-4122)
* Version 3, based on MD5 hashing of a named value (RFC-4122)
* Version 4, based on random numbers (RFC-4122)
* Version 5, based on SHA-1 hashing of a named value (RFC-4122)
This package also supports experimental Universally Unique Identifier implementations based on a
[draft RFC](https://www.ietf.org/archive/id/draft-peabody-dispatch-new-uuid-format-04.html) that updates RFC-4122
* Version 6, a k-sortable id based on timestamp, and field-compatible with v1 (draft-peabody-dispatch-new-uuid-format, RFC-4122)
* Version 7, a k-sortable id based on timestamp (draft-peabody-dispatch-new-uuid-format, RFC-4122)
The v6 and v7 IDs are **not** considered a part of the stable API, and may be subject to behavior or API changes as part of minor releases
to this package. They will be updated as the draft RFC changes, and will become stable if and when the draft RFC is accepted.
## Project History
This project was originally forked from the
[github.com/satori/go.uuid](https://github.com/satori/go.uuid) repository after
it appeared to be no longer maintained, while exhibiting [critical
flaws](https://github.com/satori/go.uuid/issues/73). We have decided to take
over this project to ensure it receives regular maintenance for the benefit of
the larger Go community.
We'd like to thank Maxim Bublis for his hard work on the original iteration of
the package.
## License
This source code of this package is released under the MIT License. Please see
the [LICENSE](https://github.com/gofrs/uuid/blob/master/LICENSE) for the full
content of the license.
## Recommended Package Version
We recommend using v2.0.0+ of this package, as versions prior to 2.0.0 were
created before our fork of the original package and have some known
deficiencies.
## Installation
It is recommended to use a package manager like `dep` that understands tagged
releases of a package, as well as semantic versioning.
If you are unable to make use of a dependency manager with your project, you can
use the `go get` command to download it directly:
```Shell
$ go get github.com/gofrs/uuid
```
## Requirements
Due to subtests not being supported in older versions of Go, this package is
only regularly tested against Go 1.7+. This package may work perfectly fine with
Go 1.2+, but support for these older versions is not actively maintained.
## Go 1.11 Modules
As of v3.2.0, this repository no longer adopts Go modules, and v3.2.0 no longer has a `go.mod` file. As a result, v3.2.0 also drops support for the `github.com/gofrs/uuid/v3` import path. Only module-based consumers are impacted. With the v3.2.0 release, _all_ gofrs/uuid consumers should use the `github.com/gofrs/uuid` import path.
An existing module-based consumer will continue to be able to build using the `github.com/gofrs/uuid/v3` import path using any valid consumer `go.mod` that worked prior to the publishing of v3.2.0, but any module-based consumer should start using the `github.com/gofrs/uuid` import path when possible and _must_ use the `github.com/gofrs/uuid` import path prior to upgrading to v3.2.0.
Please refer to [Issue #61](https://github.com/gofrs/uuid/issues/61) and [Issue #66](https://github.com/gofrs/uuid/issues/66) for more details.
## Usage
Here is a quick overview of how to use this package. For more detailed
documentation, please see the [GoDoc Page](http://godoc.org/github.com/gofrs/uuid).
```go
package main
import (
"log"
"github.com/gofrs/uuid"
)
// Create a Version 4 UUID, panicking on error.
// Use this form to initialize package-level variables.
var u1 = uuid.Must(uuid.NewV4())
func main() {
// Create a Version 4 UUID.
u2, err := uuid.NewV4()
if err != nil {
log.Fatalf("failed to generate UUID: %v", err)
}
log.Printf("generated Version 4 UUID %v", u2)
// Parse a UUID from a string.
s := "6ba7b810-9dad-11d1-80b4-00c04fd430c8"
u3, err := uuid.FromString(s)
if err != nil {
log.Fatalf("failed to parse UUID %q: %v", s, err)
}
log.Printf("successfully parsed UUID %v", u3)
}
```
## References
* [RFC-4122](https://tools.ietf.org/html/rfc4122)
* [DCE 1.1: Authentication and Security Services](http://pubs.opengroup.org/onlinepubs/9696989899/chap5.htm#tagcjh_08_02_01_01)
* [New UUID Formats RFC Draft (Peabody) Rev 04](https://www.ietf.org/archive/id/draft-peabody-dispatch-new-uuid-format-04.html#)

234
vendor/github.com/gofrs/uuid/codec.go generated vendored Normal file
View file

@ -0,0 +1,234 @@
// Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package uuid
import (
"errors"
"fmt"
)
// FromBytes returns a UUID generated from the raw byte slice input.
// It will return an error if the slice isn't 16 bytes long.
func FromBytes(input []byte) (UUID, error) {
u := UUID{}
err := u.UnmarshalBinary(input)
return u, err
}
// FromBytesOrNil returns a UUID generated from the raw byte slice input.
// Same behavior as FromBytes(), but returns uuid.Nil instead of an error.
func FromBytesOrNil(input []byte) UUID {
uuid, err := FromBytes(input)
if err != nil {
return Nil
}
return uuid
}
var errInvalidFormat = errors.New("uuid: invalid UUID format")
func fromHexChar(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 255
}
// Parse parses the UUID stored in the string text. Parsing and supported
// formats are the same as UnmarshalText.
func (u *UUID) Parse(s string) error {
switch len(s) {
case 32: // hash
case 36: // canonical
case 34, 38:
if s[0] != '{' || s[len(s)-1] != '}' {
return fmt.Errorf("uuid: incorrect UUID format in string %q", s)
}
s = s[1 : len(s)-1]
case 41, 45:
if s[:9] != "urn:uuid:" {
return fmt.Errorf("uuid: incorrect UUID format in string %q", s[:9])
}
s = s[9:]
default:
return fmt.Errorf("uuid: incorrect UUID length %d in string %q", len(s), s)
}
// canonical
if len(s) == 36 {
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return fmt.Errorf("uuid: incorrect UUID format in string %q", s)
}
for i, x := range [16]byte{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34,
} {
v1 := fromHexChar(s[x])
v2 := fromHexChar(s[x+1])
if v1|v2 == 255 {
return errInvalidFormat
}
u[i] = (v1 << 4) | v2
}
return nil
}
// hash like
for i := 0; i < 32; i += 2 {
v1 := fromHexChar(s[i])
v2 := fromHexChar(s[i+1])
if v1|v2 == 255 {
return errInvalidFormat
}
u[i/2] = (v1 << 4) | v2
}
return nil
}
// FromString returns a UUID parsed from the input string.
// Input is expected in a form accepted by UnmarshalText.
func FromString(text string) (UUID, error) {
var u UUID
err := u.Parse(text)
return u, err
}
// FromStringOrNil returns a UUID parsed from the input string.
// Same behavior as FromString(), but returns uuid.Nil instead of an error.
func FromStringOrNil(input string) UUID {
uuid, err := FromString(input)
if err != nil {
return Nil
}
return uuid
}
// MarshalText implements the encoding.TextMarshaler interface.
// The encoding is the same as returned by the String() method.
func (u UUID) MarshalText() ([]byte, error) {
var buf [36]byte
encodeCanonical(buf[:], u)
return buf[:], nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// Following formats are supported:
//
// "6ba7b810-9dad-11d1-80b4-00c04fd430c8",
// "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}",
// "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8"
// "6ba7b8109dad11d180b400c04fd430c8"
// "{6ba7b8109dad11d180b400c04fd430c8}",
// "urn:uuid:6ba7b8109dad11d180b400c04fd430c8"
//
// ABNF for supported UUID text representation follows:
//
// URN := 'urn'
// UUID-NID := 'uuid'
//
// hexdig := '0' | '1' | '2' | '3' | '4' | '5' | '6' | '7' | '8' | '9' |
// 'a' | 'b' | 'c' | 'd' | 'e' | 'f' |
// 'A' | 'B' | 'C' | 'D' | 'E' | 'F'
//
// hexoct := hexdig hexdig
// 2hexoct := hexoct hexoct
// 4hexoct := 2hexoct 2hexoct
// 6hexoct := 4hexoct 2hexoct
// 12hexoct := 6hexoct 6hexoct
//
// hashlike := 12hexoct
// canonical := 4hexoct '-' 2hexoct '-' 2hexoct '-' 6hexoct
//
// plain := canonical | hashlike
// uuid := canonical | hashlike | braced | urn
//
// braced := '{' plain '}' | '{' hashlike '}'
// urn := URN ':' UUID-NID ':' plain
func (u *UUID) UnmarshalText(b []byte) error {
switch len(b) {
case 32: // hash
case 36: // canonical
case 34, 38:
if b[0] != '{' || b[len(b)-1] != '}' {
return fmt.Errorf("uuid: incorrect UUID format in string %q", b)
}
b = b[1 : len(b)-1]
case 41, 45:
if string(b[:9]) != "urn:uuid:" {
return fmt.Errorf("uuid: incorrect UUID format in string %q", b[:9])
}
b = b[9:]
default:
return fmt.Errorf("uuid: incorrect UUID length %d in string %q", len(b), b)
}
if len(b) == 36 {
if b[8] != '-' || b[13] != '-' || b[18] != '-' || b[23] != '-' {
return fmt.Errorf("uuid: incorrect UUID format in string %q", b)
}
for i, x := range [16]byte{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34,
} {
v1 := fromHexChar(b[x])
v2 := fromHexChar(b[x+1])
if v1|v2 == 255 {
return errInvalidFormat
}
u[i] = (v1 << 4) | v2
}
return nil
}
for i := 0; i < 32; i += 2 {
v1 := fromHexChar(b[i])
v2 := fromHexChar(b[i+1])
if v1|v2 == 255 {
return errInvalidFormat
}
u[i/2] = (v1 << 4) | v2
}
return nil
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (u UUID) MarshalBinary() ([]byte, error) {
return u.Bytes(), nil
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
// It will return an error if the slice isn't 16 bytes long.
func (u *UUID) UnmarshalBinary(data []byte) error {
if len(data) != Size {
return fmt.Errorf("uuid: UUID must be exactly 16 bytes long, got %d bytes", len(data))
}
copy(u[:], data)
return nil
}

48
vendor/github.com/gofrs/uuid/fuzz.go generated vendored Normal file
View file

@ -0,0 +1,48 @@
// Copyright (c) 2018 Andrei Tudor Călin <mail@acln.ro>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
//go:build gofuzz
// +build gofuzz
package uuid
// Fuzz implements a simple fuzz test for FromString / UnmarshalText.
//
// To run:
//
// $ go get github.com/dvyukov/go-fuzz/...
// $ cd $GOPATH/src/github.com/gofrs/uuid
// $ go-fuzz-build github.com/gofrs/uuid
// $ go-fuzz -bin=uuid-fuzz.zip -workdir=./testdata
//
// If you make significant changes to FromString / UnmarshalText and add
// new cases to fromStringTests (in codec_test.go), please run
//
// $ go test -seed_fuzz_corpus
//
// to seed the corpus with the new interesting inputs, then run the fuzzer.
func Fuzz(data []byte) int {
_, err := FromString(string(data))
if err != nil {
return 0
}
return 1
}

456
vendor/github.com/gofrs/uuid/generator.go generated vendored Normal file
View file

@ -0,0 +1,456 @@
// Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package uuid
import (
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"encoding/binary"
"fmt"
"hash"
"io"
"net"
"sync"
"time"
)
// Difference in 100-nanosecond intervals between
// UUID epoch (October 15, 1582) and Unix epoch (January 1, 1970).
const epochStart = 122192928000000000
// EpochFunc is the function type used to provide the current time.
type EpochFunc func() time.Time
// HWAddrFunc is the function type used to provide hardware (MAC) addresses.
type HWAddrFunc func() (net.HardwareAddr, error)
// DefaultGenerator is the default UUID Generator used by this package.
var DefaultGenerator Generator = NewGen()
// NewV1 returns a UUID based on the current timestamp and MAC address.
func NewV1() (UUID, error) {
return DefaultGenerator.NewV1()
}
// NewV3 returns a UUID based on the MD5 hash of the namespace UUID and name.
func NewV3(ns UUID, name string) UUID {
return DefaultGenerator.NewV3(ns, name)
}
// NewV4 returns a randomly generated UUID.
func NewV4() (UUID, error) {
return DefaultGenerator.NewV4()
}
// NewV5 returns a UUID based on SHA-1 hash of the namespace UUID and name.
func NewV5(ns UUID, name string) UUID {
return DefaultGenerator.NewV5(ns, name)
}
// NewV6 returns a k-sortable UUID based on a timestamp and 48 bits of
// pseudorandom data. The timestamp in a V6 UUID is the same as V1, with the bit
// order being adjusted to allow the UUID to be k-sortable.
//
// This is implemented based on revision 03 of the Peabody UUID draft, and may
// be subject to change pending further revisions. Until the final specification
// revision is finished, changes required to implement updates to the spec will
// not be considered a breaking change. They will happen as a minor version
// releases until the spec is final.
func NewV6() (UUID, error) {
return DefaultGenerator.NewV6()
}
// NewV7 returns a k-sortable UUID based on the current millisecond precision
// UNIX epoch and 74 bits of pseudorandom data. It supports single-node batch generation (multiple UUIDs in the same timestamp) with a Monotonic Random counter.
//
// This is implemented based on revision 04 of the Peabody UUID draft, and may
// be subject to change pending further revisions. Until the final specification
// revision is finished, changes required to implement updates to the spec will
// not be considered a breaking change. They will happen as a minor version
// releases until the spec is final.
func NewV7() (UUID, error) {
return DefaultGenerator.NewV7()
}
// Generator provides an interface for generating UUIDs.
type Generator interface {
NewV1() (UUID, error)
NewV3(ns UUID, name string) UUID
NewV4() (UUID, error)
NewV5(ns UUID, name string) UUID
NewV6() (UUID, error)
NewV7() (UUID, error)
}
// Gen is a reference UUID generator based on the specifications laid out in
// RFC-4122 and DCE 1.1: Authentication and Security Services. This type
// satisfies the Generator interface as defined in this package.
//
// For consumers who are generating V1 UUIDs, but don't want to expose the MAC
// address of the node generating the UUIDs, the NewGenWithHWAF() function has been
// provided as a convenience. See the function's documentation for more info.
//
// The authors of this package do not feel that the majority of users will need
// to obfuscate their MAC address, and so we recommend using NewGen() to create
// a new generator.
type Gen struct {
clockSequenceOnce sync.Once
hardwareAddrOnce sync.Once
storageMutex sync.Mutex
rand io.Reader
epochFunc EpochFunc
hwAddrFunc HWAddrFunc
lastTime uint64
clockSequence uint16
hardwareAddr [6]byte
}
// GenOption is a function type that can be used to configure a Gen generator.
type GenOption func(*Gen)
// interface check -- build will fail if *Gen doesn't satisfy Generator
var _ Generator = (*Gen)(nil)
// NewGen returns a new instance of Gen with some default values set. Most
// people should use this.
func NewGen() *Gen {
return NewGenWithHWAF(defaultHWAddrFunc)
}
// NewGenWithHWAF builds a new UUID generator with the HWAddrFunc provided. Most
// consumers should use NewGen() instead.
//
// This is used so that consumers can generate their own MAC addresses, for use
// in the generated UUIDs, if there is some concern about exposing the physical
// address of the machine generating the UUID.
//
// The Gen generator will only invoke the HWAddrFunc once, and cache that MAC
// address for all the future UUIDs generated by it. If you'd like to switch the
// MAC address being used, you'll need to create a new generator using this
// function.
func NewGenWithHWAF(hwaf HWAddrFunc) *Gen {
return NewGenWithOptions(WithHWAddrFunc(hwaf))
}
// NewGenWithOptions returns a new instance of Gen with the options provided.
// Most people should use NewGen() or NewGenWithHWAF() instead.
//
// To customize the generator, you can pass in one or more GenOption functions.
// For example:
//
// gen := NewGenWithOptions(
// WithHWAddrFunc(myHWAddrFunc),
// WithEpochFunc(myEpochFunc),
// WithRandomReader(myRandomReader),
// )
//
// NewGenWithOptions(WithHWAddrFunc(myHWAddrFunc)) is equivalent to calling
// NewGenWithHWAF(myHWAddrFunc)
// NewGenWithOptions() is equivalent to calling NewGen()
func NewGenWithOptions(opts ...GenOption) *Gen {
gen := &Gen{
epochFunc: time.Now,
hwAddrFunc: defaultHWAddrFunc,
rand: rand.Reader,
}
for _, opt := range opts {
opt(gen)
}
return gen
}
// WithHWAddrFunc is a GenOption that allows you to provide your own HWAddrFunc
// function.
// When this option is nil, the defaultHWAddrFunc is used.
func WithHWAddrFunc(hwaf HWAddrFunc) GenOption {
return func(gen *Gen) {
if hwaf == nil {
hwaf = defaultHWAddrFunc
}
gen.hwAddrFunc = hwaf
}
}
// WithEpochFunc is a GenOption that allows you to provide your own EpochFunc
// function.
// When this option is nil, time.Now is used.
func WithEpochFunc(epochf EpochFunc) GenOption {
return func(gen *Gen) {
if epochf == nil {
epochf = time.Now
}
gen.epochFunc = epochf
}
}
// WithRandomReader is a GenOption that allows you to provide your own random
// reader.
// When this option is nil, the default rand.Reader is used.
func WithRandomReader(reader io.Reader) GenOption {
return func(gen *Gen) {
if reader == nil {
reader = rand.Reader
}
gen.rand = reader
}
}
// NewV1 returns a UUID based on the current timestamp and MAC address.
func (g *Gen) NewV1() (UUID, error) {
u := UUID{}
timeNow, clockSeq, err := g.getClockSequence(false)
if err != nil {
return Nil, err
}
binary.BigEndian.PutUint32(u[0:], uint32(timeNow))
binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32))
binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48))
binary.BigEndian.PutUint16(u[8:], clockSeq)
hardwareAddr, err := g.getHardwareAddr()
if err != nil {
return Nil, err
}
copy(u[10:], hardwareAddr)
u.SetVersion(V1)
u.SetVariant(VariantRFC4122)
return u, nil
}
// NewV3 returns a UUID based on the MD5 hash of the namespace UUID and name.
func (g *Gen) NewV3(ns UUID, name string) UUID {
u := newFromHash(md5.New(), ns, name)
u.SetVersion(V3)
u.SetVariant(VariantRFC4122)
return u
}
// NewV4 returns a randomly generated UUID.
func (g *Gen) NewV4() (UUID, error) {
u := UUID{}
if _, err := io.ReadFull(g.rand, u[:]); err != nil {
return Nil, err
}
u.SetVersion(V4)
u.SetVariant(VariantRFC4122)
return u, nil
}
// NewV5 returns a UUID based on SHA-1 hash of the namespace UUID and name.
func (g *Gen) NewV5(ns UUID, name string) UUID {
u := newFromHash(sha1.New(), ns, name)
u.SetVersion(V5)
u.SetVariant(VariantRFC4122)
return u
}
// NewV6 returns a k-sortable UUID based on a timestamp and 48 bits of
// pseudorandom data. The timestamp in a V6 UUID is the same as V1, with the bit
// order being adjusted to allow the UUID to be k-sortable.
//
// This is implemented based on revision 03 of the Peabody UUID draft, and may
// be subject to change pending further revisions. Until the final specification
// revision is finished, changes required to implement updates to the spec will
// not be considered a breaking change. They will happen as a minor version
// releases until the spec is final.
func (g *Gen) NewV6() (UUID, error) {
var u UUID
if _, err := io.ReadFull(g.rand, u[10:]); err != nil {
return Nil, err
}
timeNow, clockSeq, err := g.getClockSequence(false)
if err != nil {
return Nil, err
}
binary.BigEndian.PutUint32(u[0:], uint32(timeNow>>28)) // set time_high
binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>12)) // set time_mid
binary.BigEndian.PutUint16(u[6:], uint16(timeNow&0xfff)) // set time_low (minus four version bits)
binary.BigEndian.PutUint16(u[8:], clockSeq&0x3fff) // set clk_seq_hi_res (minus two variant bits)
u.SetVersion(V6)
u.SetVariant(VariantRFC4122)
return u, nil
}
// getClockSequence returns the epoch and clock sequence for V1,V6 and V7 UUIDs.
//
// When useUnixTSMs is false, it uses the Coordinated Universal Time (UTC) as a count of 100-
//
// nanosecond intervals since 00:00:00.00, 15 October 1582 (the date of Gregorian reform to the Christian calendar).
func (g *Gen) getClockSequence(useUnixTSMs bool) (uint64, uint16, error) {
var err error
g.clockSequenceOnce.Do(func() {
buf := make([]byte, 2)
if _, err = io.ReadFull(g.rand, buf); err != nil {
return
}
g.clockSequence = binary.BigEndian.Uint16(buf)
})
if err != nil {
return 0, 0, err
}
g.storageMutex.Lock()
defer g.storageMutex.Unlock()
var timeNow uint64
if useUnixTSMs {
timeNow = uint64(g.epochFunc().UnixMilli())
} else {
timeNow = g.getEpoch()
}
// Clock didn't change since last UUID generation.
// Should increase clock sequence.
if timeNow <= g.lastTime {
g.clockSequence++
}
g.lastTime = timeNow
return timeNow, g.clockSequence, nil
}
// NewV7 returns a k-sortable UUID based on the current millisecond precision
// UNIX epoch and 74 bits of pseudorandom data.
//
// This is implemented based on revision 04 of the Peabody UUID draft, and may
// be subject to change pending further revisions. Until the final specification
// revision is finished, changes required to implement updates to the spec will
// not be considered a breaking change. They will happen as a minor version
// releases until the spec is final.
func (g *Gen) NewV7() (UUID, error) {
var u UUID
/* https://www.ietf.org/archive/id/draft-peabody-dispatch-new-uuid-format-04.html#name-uuid-version-7
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| unix_ts_ms |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| unix_ts_ms | ver | rand_a |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|var| rand_b |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| rand_b |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */
ms, clockSeq, err := g.getClockSequence(true)
if err != nil {
return Nil, err
}
//UUIDv7 features a 48 bit timestamp. First 32bit (4bytes) represents seconds since 1970, followed by 2 bytes for the ms granularity.
u[0] = byte(ms >> 40) //1-6 bytes: big-endian unsigned number of Unix epoch timestamp
u[1] = byte(ms >> 32)
u[2] = byte(ms >> 24)
u[3] = byte(ms >> 16)
u[4] = byte(ms >> 8)
u[5] = byte(ms)
//support batching by using a monotonic pseudo-random sequence
//The 6th byte contains the version and partially rand_a data.
//We will lose the most significant bites from the clockSeq (with SetVersion), but it is ok, we need the least significant that contains the counter to ensure the monotonic property
binary.BigEndian.PutUint16(u[6:8], clockSeq) // set rand_a with clock seq which is random and monotonic
//override first 4bits of u[6].
u.SetVersion(V7)
//set rand_b 64bits of pseudo-random bits (first 2 will be overridden)
if _, err = io.ReadFull(g.rand, u[8:16]); err != nil {
return Nil, err
}
//override first 2 bits of byte[8] for the variant
u.SetVariant(VariantRFC4122)
return u, nil
}
// Returns the hardware address.
func (g *Gen) getHardwareAddr() ([]byte, error) {
var err error
g.hardwareAddrOnce.Do(func() {
var hwAddr net.HardwareAddr
if hwAddr, err = g.hwAddrFunc(); err == nil {
copy(g.hardwareAddr[:], hwAddr)
return
}
// Initialize hardwareAddr randomly in case
// of real network interfaces absence.
if _, err = io.ReadFull(g.rand, g.hardwareAddr[:]); err != nil {
return
}
// Set multicast bit as recommended by RFC-4122
g.hardwareAddr[0] |= 0x01
})
if err != nil {
return []byte{}, err
}
return g.hardwareAddr[:], nil
}
// Returns the difference between UUID epoch (October 15, 1582)
// and current time in 100-nanosecond intervals.
func (g *Gen) getEpoch() uint64 {
return epochStart + uint64(g.epochFunc().UnixNano()/100)
}
// Returns the UUID based on the hashing of the namespace UUID and name.
func newFromHash(h hash.Hash, ns UUID, name string) UUID {
u := UUID{}
h.Write(ns[:])
h.Write([]byte(name))
copy(u[:], h.Sum(nil))
return u
}
var netInterfaces = net.Interfaces
// Returns the hardware address.
func defaultHWAddrFunc() (net.HardwareAddr, error) {
ifaces, err := netInterfaces()
if err != nil {
return []byte{}, err
}
for _, iface := range ifaces {
if len(iface.HardwareAddr) >= 6 {
return iface.HardwareAddr, nil
}
}
return []byte{}, fmt.Errorf("uuid: no HW address found")
}

116
vendor/github.com/gofrs/uuid/sql.go generated vendored Normal file
View file

@ -0,0 +1,116 @@
// Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package uuid
import (
"database/sql"
"database/sql/driver"
"fmt"
)
var _ driver.Valuer = UUID{}
var _ sql.Scanner = (*UUID)(nil)
// Value implements the driver.Valuer interface.
func (u UUID) Value() (driver.Value, error) {
return u.String(), nil
}
// Scan implements the sql.Scanner interface.
// A 16-byte slice will be handled by UnmarshalBinary, while
// a longer byte slice or a string will be handled by UnmarshalText.
func (u *UUID) Scan(src interface{}) error {
switch src := src.(type) {
case UUID: // support gorm convert from UUID to NullUUID
*u = src
return nil
case []byte:
if len(src) == Size {
return u.UnmarshalBinary(src)
}
return u.UnmarshalText(src)
case string:
uu, err := FromString(src)
*u = uu
return err
}
return fmt.Errorf("uuid: cannot convert %T to UUID", src)
}
// NullUUID can be used with the standard sql package to represent a
// UUID value that can be NULL in the database.
type NullUUID struct {
UUID UUID
Valid bool
}
// Value implements the driver.Valuer interface.
func (u NullUUID) Value() (driver.Value, error) {
if !u.Valid {
return nil, nil
}
// Delegate to UUID Value function
return u.UUID.Value()
}
// Scan implements the sql.Scanner interface.
func (u *NullUUID) Scan(src interface{}) error {
if src == nil {
u.UUID, u.Valid = Nil, false
return nil
}
// Delegate to UUID Scan function
u.Valid = true
return u.UUID.Scan(src)
}
var nullJSON = []byte("null")
// MarshalJSON marshals the NullUUID as null or the nested UUID
func (u NullUUID) MarshalJSON() ([]byte, error) {
if !u.Valid {
return nullJSON, nil
}
var buf [38]byte
buf[0] = '"'
encodeCanonical(buf[1:37], u.UUID)
buf[37] = '"'
return buf[:], nil
}
// UnmarshalJSON unmarshals a NullUUID
func (u *NullUUID) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
u.UUID, u.Valid = Nil, false
return nil
}
if n := len(b); n >= 2 && b[0] == '"' {
b = b[1 : n-1]
}
err := u.UUID.UnmarshalText(b)
u.Valid = (err == nil)
return err
}

285
vendor/github.com/gofrs/uuid/uuid.go generated vendored Normal file
View file

@ -0,0 +1,285 @@
// Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// Package uuid provides implementations of the Universally Unique Identifier
// (UUID), as specified in RFC-4122 and the Peabody RFC Draft (revision 03).
//
// RFC-4122[1] provides the specification for versions 1, 3, 4, and 5. The
// Peabody UUID RFC Draft[2] provides the specification for the new k-sortable
// UUIDs, versions 6 and 7.
//
// DCE 1.1[3] provides the specification for version 2, but version 2 support
// was removed from this package in v4 due to some concerns with the
// specification itself. Reading the spec, it seems that it would result in
// generating UUIDs that aren't very unique. In having read the spec it seemed
// that our implementation did not meet the spec. It also seems to be at-odds
// with RFC 4122, meaning we would need quite a bit of special code to support
// it. Lastly, there were no Version 2 implementations that we could find to
// ensure we were understanding the specification correctly.
//
// [1] https://tools.ietf.org/html/rfc4122
// [2] https://datatracker.ietf.org/doc/html/draft-peabody-dispatch-new-uuid-format-03
// [3] http://pubs.opengroup.org/onlinepubs/9696989899/chap5.htm#tagcjh_08_02_01_01
package uuid
import (
"encoding/binary"
"encoding/hex"
"fmt"
"time"
)
// Size of a UUID in bytes.
const Size = 16
// UUID is an array type to represent the value of a UUID, as defined in RFC-4122.
type UUID [Size]byte
// UUID versions.
const (
_ byte = iota
V1 // Version 1 (date-time and MAC address)
_ // Version 2 (date-time and MAC address, DCE security version) [removed]
V3 // Version 3 (namespace name-based)
V4 // Version 4 (random)
V5 // Version 5 (namespace name-based)
V6 // Version 6 (k-sortable timestamp and random data, field-compatible with v1) [peabody draft]
V7 // Version 7 (k-sortable timestamp and random data) [peabody draft]
_ // Version 8 (k-sortable timestamp, meant for custom implementations) [peabody draft] [not implemented]
)
// UUID layout variants.
const (
VariantNCS byte = iota
VariantRFC4122
VariantMicrosoft
VariantFuture
)
// UUID DCE domains.
const (
DomainPerson = iota
DomainGroup
DomainOrg
)
// Timestamp is the count of 100-nanosecond intervals since 00:00:00.00,
// 15 October 1582 within a V1 UUID. This type has no meaning for other
// UUID versions since they don't have an embedded timestamp.
type Timestamp uint64
const _100nsPerSecond = 10000000
// Time returns the UTC time.Time representation of a Timestamp
func (t Timestamp) Time() (time.Time, error) {
secs := uint64(t) / _100nsPerSecond
nsecs := 100 * (uint64(t) % _100nsPerSecond)
return time.Unix(int64(secs)-(epochStart/_100nsPerSecond), int64(nsecs)), nil
}
// TimestampFromV1 returns the Timestamp embedded within a V1 UUID.
// Returns an error if the UUID is any version other than 1.
func TimestampFromV1(u UUID) (Timestamp, error) {
if u.Version() != 1 {
err := fmt.Errorf("uuid: %s is version %d, not version 1", u, u.Version())
return 0, err
}
low := binary.BigEndian.Uint32(u[0:4])
mid := binary.BigEndian.Uint16(u[4:6])
hi := binary.BigEndian.Uint16(u[6:8]) & 0xfff
return Timestamp(uint64(low) + (uint64(mid) << 32) + (uint64(hi) << 48)), nil
}
// TimestampFromV6 returns the Timestamp embedded within a V6 UUID. This
// function returns an error if the UUID is any version other than 6.
//
// This is implemented based on revision 03 of the Peabody UUID draft, and may
// be subject to change pending further revisions. Until the final specification
// revision is finished, changes required to implement updates to the spec will
// not be considered a breaking change. They will happen as a minor version
// releases until the spec is final.
func TimestampFromV6(u UUID) (Timestamp, error) {
if u.Version() != 6 {
return 0, fmt.Errorf("uuid: %s is version %d, not version 6", u, u.Version())
}
hi := binary.BigEndian.Uint32(u[0:4])
mid := binary.BigEndian.Uint16(u[4:6])
low := binary.BigEndian.Uint16(u[6:8]) & 0xfff
return Timestamp(uint64(low) + (uint64(mid) << 12) + (uint64(hi) << 28)), nil
}
// Nil is the nil UUID, as specified in RFC-4122, that has all 128 bits set to
// zero.
var Nil = UUID{}
// Predefined namespace UUIDs.
var (
NamespaceDNS = Must(FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8"))
NamespaceURL = Must(FromString("6ba7b811-9dad-11d1-80b4-00c04fd430c8"))
NamespaceOID = Must(FromString("6ba7b812-9dad-11d1-80b4-00c04fd430c8"))
NamespaceX500 = Must(FromString("6ba7b814-9dad-11d1-80b4-00c04fd430c8"))
)
// IsNil returns if the UUID is equal to the nil UUID
func (u UUID) IsNil() bool {
return u == Nil
}
// Version returns the algorithm version used to generate the UUID.
func (u UUID) Version() byte {
return u[6] >> 4
}
// Variant returns the UUID layout variant.
func (u UUID) Variant() byte {
switch {
case (u[8] >> 7) == 0x00:
return VariantNCS
case (u[8] >> 6) == 0x02:
return VariantRFC4122
case (u[8] >> 5) == 0x06:
return VariantMicrosoft
case (u[8] >> 5) == 0x07:
fallthrough
default:
return VariantFuture
}
}
// Bytes returns a byte slice representation of the UUID.
func (u UUID) Bytes() []byte {
return u[:]
}
// encodeCanonical encodes the canonical RFC-4122 form of UUID u into the
// first 36 bytes dst.
func encodeCanonical(dst []byte, u UUID) {
const hextable = "0123456789abcdef"
dst[8] = '-'
dst[13] = '-'
dst[18] = '-'
dst[23] = '-'
for i, x := range [16]byte{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34,
} {
c := u[i]
dst[x] = hextable[c>>4]
dst[x+1] = hextable[c&0x0f]
}
}
// String returns a canonical RFC-4122 string representation of the UUID:
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx.
func (u UUID) String() string {
var buf [36]byte
encodeCanonical(buf[:], u)
return string(buf[:])
}
// Format implements fmt.Formatter for UUID values.
//
// The behavior is as follows:
// The 'x' and 'X' verbs output only the hex digits of the UUID, using a-f for 'x' and A-F for 'X'.
// The 'v', '+v', 's' and 'q' verbs return the canonical RFC-4122 string representation.
// The 'S' verb returns the RFC-4122 format, but with capital hex digits.
// The '#v' verb returns the "Go syntax" representation, which is a 16 byte array initializer.
// All other verbs not handled directly by the fmt package (like '%p') are unsupported and will return
// "%!verb(uuid.UUID=value)" as recommended by the fmt package.
func (u UUID) Format(f fmt.State, c rune) {
if c == 'v' && f.Flag('#') {
fmt.Fprintf(f, "%#v", [Size]byte(u))
return
}
switch c {
case 'x', 'X':
b := make([]byte, 32)
hex.Encode(b, u[:])
if c == 'X' {
toUpperHex(b)
}
_, _ = f.Write(b)
case 'v', 's', 'S':
b, _ := u.MarshalText()
if c == 'S' {
toUpperHex(b)
}
_, _ = f.Write(b)
case 'q':
b := make([]byte, 38)
b[0] = '"'
encodeCanonical(b[1:], u)
b[37] = '"'
_, _ = f.Write(b)
default:
// invalid/unsupported format verb
fmt.Fprintf(f, "%%!%c(uuid.UUID=%s)", c, u.String())
}
}
func toUpperHex(b []byte) {
for i, c := range b {
if 'a' <= c && c <= 'f' {
b[i] = c - ('a' - 'A')
}
}
}
// SetVersion sets the version bits.
func (u *UUID) SetVersion(v byte) {
u[6] = (u[6] & 0x0f) | (v << 4)
}
// SetVariant sets the variant bits.
func (u *UUID) SetVariant(v byte) {
switch v {
case VariantNCS:
u[8] = (u[8]&(0xff>>1) | (0x00 << 7))
case VariantRFC4122:
u[8] = (u[8]&(0xff>>2) | (0x02 << 6))
case VariantMicrosoft:
u[8] = (u[8]&(0xff>>3) | (0x06 << 5))
case VariantFuture:
fallthrough
default:
u[8] = (u[8]&(0xff>>3) | (0x07 << 5))
}
}
// Must is a helper that wraps a call to a function returning (UUID, error)
// and panics if the error is non-nil. It is intended for use in variable
// initializations such as
//
// var packageUUID = uuid.Must(uuid.FromString("123e4567-e89b-12d3-a456-426655440000"))
func Must(u UUID, err error) UUID {
if err != nil {
panic(err)
}
return u
}

3
vendor/modules.txt vendored
View file

@ -176,6 +176,9 @@ github.com/gobwas/pool/pbytes
## explicit; go 1.15 ## explicit; go 1.15
github.com/gobwas/ws github.com/gobwas/ws
github.com/gobwas/ws/wsutil github.com/gobwas/ws/wsutil
# github.com/gofrs/uuid v4.4.0+incompatible
## explicit
github.com/gofrs/uuid
# github.com/golang-jwt/jwt/v4 v4.0.0 # github.com/golang-jwt/jwt/v4 v4.0.0
## explicit; go 1.15 ## explicit; go 1.15
github.com/golang-jwt/jwt/v4 github.com/golang-jwt/jwt/v4