Studio Tagger (#3510)

* Studio image and parent studio support in scene tagger
* Refactor studio backend and add studio tagger
---------
Co-authored-by: WithoutPants <53250216+WithoutPants@users.noreply.github.com>
This commit is contained in:
Flashy78 2023-07-30 16:50:24 -07:00 committed by GitHub
parent d48dbeb864
commit a665a56ef0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
79 changed files with 5224 additions and 1039 deletions

View file

@ -40,6 +40,7 @@ NOTE: The `make` command in Windows will be `mingw32-make` with MinGW. For examp
* `make pre-ui` - Installs the UI dependencies. This only needs to be run once after cloning the repository, or if the dependencies are updated.
* `make generate` - Generates Go and UI GraphQL files. Requires `make pre-ui` to have been run.
* `make generate-stash-box-client` - Generate Go files for the Stash-box client code.
* `make ui` - Builds the UI. Requires `make pre-ui` to have been run.
* `make stash` - Builds the `stash` binary (make sure to build the UI as well... see below)
* `make stash-release` - Builds a release version the `stash` binary, with debug information removed

1
go.mod
View file

@ -10,6 +10,7 @@ require (
github.com/corona10/goimagehash v1.0.3
github.com/disintegration/imaging v1.6.0
github.com/go-chi/chi v4.0.2+incompatible
github.com/gofrs/uuid v4.4.0+incompatible
github.com/golang-jwt/jwt/v4 v4.0.0
github.com/golang-migrate/migrate/v4 v4.15.0-beta.1
github.com/gorilla/securecookie v1.1.1

2
go.sum
View file

@ -295,6 +295,8 @@ github.com/gocql/gocql v0.0.0-20190301043612-f6df8288f9b4/go.mod h1:4Fw1eo5iaEhD
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA=
github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o=

View file

@ -74,8 +74,8 @@ models:
model: github.com/stashapp/stash/internal/manager.AutoTagMetadataInput
CleanMetadataInput:
model: github.com/stashapp/stash/internal/manager.CleanMetadataInput
StashBoxBatchPerformerTagInput:
model: github.com/stashapp/stash/internal/manager.StashBoxBatchPerformerTagInput
StashBoxBatchTagInput:
model: github.com/stashapp/stash/internal/manager.StashBoxBatchTagInput
SceneStreamEndpoint:
model: github.com/stashapp/stash/internal/manager.SceneStreamEndpoint
ExportObjectTypeInput:

View file

@ -1,3 +1,18 @@
fragment ScrapedStudioData on ScrapedStudio {
stored_id
name
url
parent {
stored_id
name
url
image
remote_site_id
}
image
remote_site_id
}
fragment ScrapedPerformerData on ScrapedPerformer {
stored_id
name
@ -101,6 +116,14 @@ fragment ScrapedSceneStudioData on ScrapedStudio {
stored_id
name
url
parent {
stored_id
name
url
image
remote_site_id
}
image
remote_site_id
}

View file

@ -4,10 +4,14 @@ mutation SubmitStashBoxFingerprints(
submitStashBoxFingerprints(input: $input)
}
mutation StashBoxBatchPerformerTag($input: StashBoxBatchPerformerTagInput!) {
mutation StashBoxBatchPerformerTag($input: StashBoxBatchTagInput!) {
stashBoxBatchPerformerTag(input: $input)
}
mutation StashBoxBatchStudioTag($input: StashBoxBatchTagInput!) {
stashBoxBatchStudioTag(input: $input)
}
mutation SubmitStashBoxSceneDraft($input: StashBoxDraftSubmissionInput!) {
submitStashBoxSceneDraft(input: $input)
}

View file

@ -42,6 +42,15 @@ query ListMovieScrapers {
}
}
query ScrapeSingleStudio(
$source: ScraperSourceInput!
$input: ScrapeSingleStudioInput!
) {
scrapeSingleStudio(source: $source, input: $input) {
...ScrapedStudioData
}
}
query ScrapeSinglePerformer(
$source: ScraperSourceInput!
$input: ScrapeSinglePerformerInput!

View file

@ -128,6 +128,12 @@ type Query {
input: ScrapeMultiScenesInput!
): [[ScrapedScene!]!]!
"Scrape for a single studio"
scrapeSingleStudio(
source: ScraperSourceInput!
input: ScrapeSingleStudioInput!
): [ScrapedStudio!]!
"Scrape for a single performer"
scrapeSinglePerformer(
source: ScraperSourceInput!
@ -416,7 +422,9 @@ type Mutation {
execSQL(sql: String!, args: [Any]): SQLExecResult!
"Run batch performer tag task. Returns the job ID."
stashBoxBatchPerformerTag(input: StashBoxBatchPerformerTagInput!): String!
stashBoxBatchPerformerTag(input: StashBoxBatchTagInput!): String!
"Run batch studio tag task. Returns the job ID."
stashBoxBatchStudioTag(input: StashBoxBatchTagInput!): String!
"Enables DLNA for an optional duration. Has no effect if DLNA is enabled by default"
enableDLNA(input: EnableDLNAInput!): Boolean!

View file

@ -48,6 +48,7 @@ type ScrapedStudio {
stored_id: ID
name: String!
url: String
parent: ScrapedStudio
image: String
remote_site_id: String
@ -148,6 +149,13 @@ input ScrapeMultiScenesInput {
scene_ids: [ID!]
}
input ScrapeSingleStudioInput {
"""
Query can be either a name or a Stash ID
"""
query: String
}
input ScrapeSinglePerformerInput {
"Instructs to query by string"
query: String
@ -209,16 +217,22 @@ type StashBoxFingerprint {
duration: Int!
}
"If neither performer_ids nor performer_names are set, tag all performers"
input StashBoxBatchPerformerTagInput {
"Stash endpoint to use for the performer tagging"
"If neither ids nor names are set, tag all items"
input StashBoxBatchTagInput {
"Stash endpoint to use for the tagging"
endpoint: Int!
"Fields to exclude when executing the performer tagging"
"Fields to exclude when executing the tagging"
exclude_fields: [String!]
"Refresh performers already tagged by StashBox if true. Only tag performers with no StashBox tagging if false"
"Refresh items already tagged by StashBox if true. Only tag items with no StashBox tagging if false"
refresh: Boolean!
"If batch adding studios, should their parent studios also be created?"
createParent: Boolean!
"If set, only tag these ids"
ids: [ID!]
"If set, only tag these names"
names: [String!]
"If set, only tag these performer ids"
performer_ids: [ID!]
performer_ids: [ID!] @deprecated(reason: "use ids")
"If set, only tag these performer names"
performer_names: [String!]
performer_names: [String!] @deprecated(reason: "use names")
}

View file

@ -16,6 +16,10 @@ fragment StudioFragment on Studio {
urls {
...URLFragment
}
parent {
name
id
}
images {
...ImageFragment
}
@ -163,6 +167,12 @@ query FindSceneByID($id: ID!) {
}
}
query FindStudio($id: ID, $name: String) {
findStudio(id: $id, name: $name) {
...StudioFragment
}
}
mutation SubmitFingerprint($input: FingerprintSubmission!) {
submitFingerprint(input: $input)
}

View file

@ -34,15 +34,16 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st
return &imagePath, nil
}
func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) (ret []string, err error) {
func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) ([]string, error) {
if !obj.Aliases.Loaded() {
if err := r.withReadTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Studio.GetAliases(ctx, obj.ID)
return err
return obj.LoadAliases(ctx, r.repository.Studio)
}); err != nil {
return nil, err
}
}
return ret, err
return obj.Aliases.List(), nil
}
func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio, depth *int) (ret int, err error) {
@ -120,16 +121,15 @@ func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) (
}
func (r *studioResolver) StashIds(ctx context.Context, obj *models.Studio) ([]*models.StashID, error) {
var ret []models.StashID
if !obj.StashIDs.Loaded() {
if err := r.withReadTxn(ctx, func(ctx context.Context) error {
var err error
ret, err = r.repository.Studio.GetStashIDs(ctx, obj.ID)
return err
return obj.LoadStashIDs(ctx, r.repository.Studio)
}); err != nil {
return nil, err
}
}
return stashIDsSliceToPtrSlice(ret), nil
return stashIDsSliceToPtrSlice(obj.StashIDs.List()), nil
}
func (r *studioResolver) Rating(ctx context.Context, obj *models.Studio) (*int, error) {

View file

@ -32,11 +32,16 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint)
}
func (r *mutationResolver) StashBoxBatchPerformerTag(ctx context.Context, input manager.StashBoxBatchPerformerTagInput) (string, error) {
func (r *mutationResolver) StashBoxBatchPerformerTag(ctx context.Context, input manager.StashBoxBatchTagInput) (string, error) {
jobID := manager.GetInstance().StashBoxBatchPerformerTag(ctx, input)
return strconv.Itoa(jobID), nil
}
func (r *mutationResolver) StashBoxBatchStudioTag(ctx context.Context, input manager.StashBoxBatchTagInput) (string, error) {
jobID := manager.GetInstance().StashBoxBatchStudioTag(ctx, input)
return strconv.Itoa(jobID), nil
}
func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input StashBoxDraftSubmissionInput) (*string, error) {
boxes := config.GetInstance().GetStashBoxes()

View file

@ -6,7 +6,6 @@ import (
"strconv"
"time"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
@ -14,18 +13,54 @@ import (
"github.com/stashapp/stash/pkg/utils"
)
func (r *mutationResolver) getStudio(ctx context.Context, id int) (ret *models.Studio, err error) {
func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateInput) (*models.Studio, error) {
s, err := studioFromStudioCreateInput(ctx, input)
if err != nil {
return nil, err
}
// Process the base 64 encoded image string
var imageData []byte
if input.Image != nil {
var err error
imageData, err = utils.ProcessImageInput(ctx, *input.Image)
if err != nil {
return nil, err
}
}
// Start the transaction and save the studio
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Studio.Find(ctx, id)
qb := r.repository.Studio
if s.Aliases.Loaded() && len(s.Aliases.List()) > 0 {
if err := studio.EnsureAliasesUnique(ctx, 0, s.Aliases.List(), qb); err != nil {
return err
}
}
err = qb.Create(ctx, s)
if err != nil {
return err
}
if len(imageData) > 0 {
if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
r.hookExecutor.ExecutePostHooks(ctx, s.ID, plugin.StudioCreatePost, input, nil)
return s, nil
}
func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateInput) (*models.Studio, error) {
func studioFromStudioCreateInput(ctx context.Context, input StudioCreateInput) (*models.Studio, error) {
translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx),
}
@ -43,143 +78,110 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateI
}
var err error
newStudio.ParentID, err = translator.intPtrFromString(input.ParentID, "parent_id")
if err != nil {
return nil, fmt.Errorf("converting parent id: %w", err)
}
// Process the base 64 encoded image string
var imageData []byte
if input.Image != nil {
imageData, err = utils.ProcessImageInput(ctx, *input.Image)
if err != nil {
return nil, err
if input.Aliases != nil {
newStudio.Aliases = models.NewRelatedStrings(input.Aliases)
}
}
// Start the transaction and save the studio
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Studio
err = qb.Create(ctx, &newStudio)
if err != nil {
return err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateImage(ctx, newStudio.ID, imageData); err != nil {
return err
}
}
// Save the stash_ids
if input.StashIds != nil {
stashIDJoins := stashIDPtrSliceToSlice(input.StashIds)
if err := qb.UpdateStashIDs(ctx, newStudio.ID, stashIDJoins); err != nil {
return err
}
newStudio.StashIDs = models.NewRelatedStashIDs(stashIDPtrSliceToSlice(input.StashIds))
}
if len(input.Aliases) > 0 {
if err := studio.EnsureAliasesUnique(ctx, newStudio.ID, input.Aliases, qb); err != nil {
return err
}
if err := qb.UpdateAliases(ctx, newStudio.ID, input.Aliases); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err
}
r.hookExecutor.ExecutePostHooks(ctx, newStudio.ID, plugin.StudioCreatePost, input, nil)
return r.getStudio(ctx, newStudio.ID)
return &newStudio, nil
}
func (r *mutationResolver) StudioUpdate(ctx context.Context, input StudioUpdateInput) (*models.Studio, error) {
studioID, err := strconv.Atoi(input.ID)
if err != nil {
return nil, err
}
var updatedStudio *models.Studio
var err error
translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx),
}
// Populate studio from the input
updatedStudio := models.NewStudioPartial()
updatedStudio.Name = translator.optionalString(input.Name, "name")
updatedStudio.URL = translator.optionalString(input.URL, "url")
updatedStudio.Details = translator.optionalString(input.Details, "details")
updatedStudio.Rating = translator.ratingConversionOptional(input.Rating, input.Rating100)
updatedStudio.IgnoreAutoTag = translator.optionalBool(input.IgnoreAutoTag, "ignore_auto_tag")
updatedStudio.ParentID, err = translator.optionalIntFromString(input.ParentID, "parent_id")
if err != nil {
return nil, fmt.Errorf("converting parent id: %w", err)
inputMap: getNamedUpdateInputMap(ctx, updateInputField),
}
s := studioPartialFromStudioUpdateInput(input, &input.ID, translator)
// Process the base 64 encoded image string
var imageData []byte
imageIncluded := translator.hasField("image")
if input.Image != nil {
var err error
imageData, err = utils.ProcessImageInput(ctx, *input.Image)
if err != nil {
return nil, err
}
}
// Start the transaction and save the studio
var s *models.Studio
// Start the transaction and update the studio
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Studio
if err := manager.ValidateModifyStudio(ctx, studioID, updatedStudio, qb); err != nil {
if err := studio.ValidateModify(ctx, *s, qb); err != nil {
return err
}
var err error
s, err = qb.UpdatePartial(ctx, studioID, updatedStudio)
updatedStudio, err = qb.UpdatePartial(ctx, *s)
if err != nil {
return err
}
// update image table
if imageIncluded {
if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil {
return err
}
}
// Save the stash_ids
if translator.hasField("stash_ids") {
stashIDJoins := stashIDPtrSliceToSlice(input.StashIds)
if err := qb.UpdateStashIDs(ctx, studioID, stashIDJoins); err != nil {
return err
}
}
if translator.hasField("aliases") {
if err := studio.EnsureAliasesUnique(ctx, studioID, input.Aliases, qb); err != nil {
return err
}
if err := qb.UpdateAliases(ctx, studioID, input.Aliases); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err
}
r.hookExecutor.ExecutePostHooks(ctx, s.ID, plugin.StudioUpdatePost, input, translator.getFields())
return r.getStudio(ctx, s.ID)
r.hookExecutor.ExecutePostHooks(ctx, updatedStudio.ID, plugin.StudioUpdatePost, input, translator.getFields())
return updatedStudio, nil
}
// This is slightly different to studioPartialFromStudioCreateInput in that Name is handled differently
// and ImageIncluded is not hardcoded to true
func studioPartialFromStudioUpdateInput(input StudioUpdateInput, id *string, translator changesetTranslator) *models.StudioPartial {
// Populate studio from the input
updatedStudio := models.StudioPartial{
Name: translator.optionalString(input.Name, "name"),
URL: translator.optionalString(input.URL, "url"),
Details: translator.optionalString(input.Details, "details"),
Rating: translator.ratingConversionOptional(input.Rating, input.Rating100),
IgnoreAutoTag: translator.optionalBool(input.IgnoreAutoTag, "ignore_auto_tag"),
UpdatedAt: models.NewOptionalTime(time.Now()),
}
updatedStudio.ID, _ = strconv.Atoi(*id)
if input.ParentID != nil {
parentID, _ := strconv.Atoi(*input.ParentID)
if parentID > 0 {
// This is to be set directly as we know it has a value and the translator won't have the field
updatedStudio.ParentID = models.NewOptionalInt(parentID)
}
} else {
updatedStudio.ParentID = translator.optionalInt(nil, "parent_id")
}
if translator.hasField("aliases") {
updatedStudio.Aliases = &models.UpdateStrings{
Values: input.Aliases,
Mode: models.RelationshipUpdateModeSet,
}
}
if translator.hasField("stash_ids") {
updatedStudio.StashIDs = &models.UpdateStashIDs{
StashIDs: stashIDPtrSliceToSlice(input.StashIds),
Mode: models.RelationshipUpdateModeSet,
}
}
return &updatedStudio
}
func (r *mutationResolver) StudioDestroy(ctx context.Context, input StudioDestroyInput) (bool, error) {

View file

@ -327,6 +327,32 @@ func (r *queryResolver) ScrapeMultiScenes(ctx context.Context, source scraper.So
return nil, errors.New("scraper_id or stash_box_index must be set")
}
func (r *queryResolver) ScrapeSingleStudio(ctx context.Context, source scraper.Source, input ScrapeSingleStudioInput) ([]*models.ScrapedStudio, error) {
if source.StashBoxIndex != nil {
client, err := r.getStashBoxClient(*source.StashBoxIndex)
if err != nil {
return nil, err
}
var ret []*models.ScrapedStudio
out, err := client.FindStashBoxStudio(ctx, *input.Query)
if err != nil {
return nil, err
} else if out != nil {
ret = append(ret, out)
}
if len(ret) > 0 {
return ret, nil
}
return nil, nil
}
return nil, errors.New("stash_box_index must be set")
}
func (r *queryResolver) ScrapeSinglePerformer(ctx context.Context, source scraper.Source, input ScrapeSinglePerformerInput) ([]*models.ScrapedPerformer, error) {
if source.ScraperID != nil {
if input.PerformerInput != nil {

View file

@ -1,8 +1,9 @@
package urlbuilders
import (
"github.com/stashapp/stash/pkg/models"
"strconv"
"github.com/stashapp/stash/pkg/models"
)
type StudioURLBuilder struct {

View file

@ -44,7 +44,7 @@ type ScraperSource struct {
type SceneIdentifier struct {
SceneReaderUpdater SceneReaderUpdater
StudioCreator StudioCreator
StudioReaderWriter models.StudioReaderWriter
PerformerCreator PerformerCreator
TagCreatorFinder TagCreatorFinder
@ -174,7 +174,7 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
rel := sceneRelationships{
sceneReader: t.SceneReaderUpdater,
studioCreator: t.StudioCreator,
studioReaderWriter: t.StudioReaderWriter,
performerCreator: t.PerformerCreator,
tagCreatorFinder: t.TagCreatorFinder,
scene: s,

View file

@ -34,7 +34,7 @@ type TagCreatorFinder interface {
type sceneRelationships struct {
sceneReader SceneReaderUpdater
studioCreator StudioCreator
studioReaderWriter models.StudioReaderWriter
performerCreator PerformerCreator
tagCreatorFinder TagCreatorFinder
scene *models.Scene
@ -67,7 +67,7 @@ func (g sceneRelationships) studio(ctx context.Context) (*int, error) {
return &studioID, nil
}
} else if createMissing {
return createMissingStudio(ctx, endpoint, g.studioCreator, scraped)
return createMissingStudio(ctx, endpoint, g.studioReaderWriter, scraped)
}
return nil, nil

View file

@ -16,6 +16,7 @@ import (
func Test_sceneRelationships_studio(t *testing.T) {
validStoredID := "1"
remoteSiteID := "2"
var validStoredIDInt = 1
invalidStoredID := "invalidStoredID"
createMissing := true
@ -31,7 +32,7 @@ func Test_sceneRelationships_studio(t *testing.T) {
}).Return(nil)
tr := sceneRelationships{
studioCreator: mockStudioReaderWriter,
studioReaderWriter: mockStudioReaderWriter,
fieldOptions: make(map[string]*FieldOptions),
}
@ -110,7 +111,7 @@ func Test_sceneRelationships_studio(t *testing.T) {
Strategy: FieldStrategyMerge,
CreateMissing: &createMissing,
},
&models.ScrapedStudio{},
&models.ScrapedStudio{RemoteSiteID: &remoteSiteID},
&validStoredIDInt,
false,
},
@ -120,6 +121,9 @@ func Test_sceneRelationships_studio(t *testing.T) {
tr.scene = tt.scene
tr.fieldOptions["studio"] = tt.fieldOptions
tr.result = &scrapeResult{
source: ScraperSource{
RemoteSite: "endpoint",
},
result: &scraper.ScrapedScene{
Studio: tt.result,
},

View file

@ -2,64 +2,95 @@ package identify
import (
"context"
"fmt"
"time"
"strconv"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils"
"github.com/stashapp/stash/pkg/studio"
)
type StudioCreator interface {
Create(ctx context.Context, newStudio *models.Studio) error
UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error
UpdateImage(ctx context.Context, studioID int, image []byte) error
}
func createMissingStudio(ctx context.Context, endpoint string, w models.StudioReaderWriter, s *models.ScrapedStudio) (*int, error) {
var err error
func createMissingStudio(ctx context.Context, endpoint string, w StudioCreator, studio *models.ScrapedStudio) (*int, error) {
studioInput := scrapedToStudioInput(studio)
err := w.Create(ctx, &studioInput)
if s.Parent != nil {
if s.Parent.StoredID == nil {
// The parent needs to be created
newParentStudio := s.Parent.ToStudio(endpoint, nil)
parentImage, err := s.Parent.GetImage(ctx, nil)
if err != nil {
return nil, fmt.Errorf("error creating studio: %w", err)
logger.Errorf("Failed to make parent studio from scraped studio %s: %s", s.Parent.Name, err.Error())
return nil, err
}
// update image table
if studio.Image != nil && len(*studio.Image) > 0 {
imageData, err := utils.ReadImageFromURL(ctx, *studio.Image)
// Create the studio
err = w.Create(ctx, newParentStudio)
if err != nil {
return nil, err
}
err = w.UpdateImage(ctx, studioInput.ID, imageData)
if err != nil {
// Update image table
if len(parentImage) > 0 {
if err := w.UpdateImage(ctx, newParentStudio.ID, parentImage); err != nil {
return nil, err
}
}
if endpoint != "" && studio.RemoteSiteID != nil {
if err := w.UpdateStashIDs(ctx, studioInput.ID, []models.StashID{
{
Endpoint: endpoint,
StashID: *studio.RemoteSiteID,
},
}); err != nil {
return nil, fmt.Errorf("error setting studio stash id: %w", err)
storedId := strconv.Itoa(newParentStudio.ID)
s.Parent.StoredID = &storedId
} else {
// The parent studio matched an existing one and the user has chosen in the UI to link and/or update it
existingStashIDs := getStashIDsForStudio(ctx, *s.Parent.StoredID, w)
studioPartial := s.Parent.ToPartial(s.Parent.StoredID, endpoint, nil, existingStashIDs)
parentImage, err := s.Parent.GetImage(ctx, nil)
if err != nil {
return nil, err
}
if err := studio.ValidateModify(ctx, *studioPartial, w); err != nil {
return nil, err
}
_, err = w.UpdatePartial(ctx, *studioPartial)
if err != nil {
return nil, err
}
if len(parentImage) > 0 {
if err := w.UpdateImage(ctx, studioPartial.ID, parentImage); err != nil {
return nil, err
}
}
}
}
return &studioInput.ID, nil
newStudio := s.ToStudio(endpoint, nil)
studioImage, err := s.GetImage(ctx, nil)
if err != nil {
return nil, err
}
func scrapedToStudioInput(studio *models.ScrapedStudio) models.Studio {
currentTime := time.Now()
ret := models.Studio{
Name: studio.Name,
CreatedAt: currentTime,
UpdatedAt: currentTime,
err = w.Create(ctx, newStudio)
if err != nil {
return nil, err
}
if studio.URL != nil {
ret.URL = *studio.URL
// Update image table
if len(studioImage) > 0 {
if err := w.UpdateImage(ctx, newStudio.ID, studioImage); err != nil {
return nil, err
}
}
return ret
return &newStudio.ID, nil
}
func getStashIDsForStudio(ctx context.Context, studioID string, w models.StudioReaderWriter) []models.StashID {
id, _ := strconv.Atoi(studioID)
tempStudio := &models.Studio{ID: id}
err := tempStudio.LoadStashIDs(ctx, w)
if err != nil {
return nil
}
return tempStudio.StashIDs.List()
}

View file

@ -4,7 +4,6 @@ import (
"errors"
"reflect"
"testing"
"time"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
@ -31,18 +30,32 @@ func Test_createMissingStudio(t *testing.T) {
return p.Name == invalidName
})).Return(errors.New("error creating studio"))
mockStudioReaderWriter.On("UpdateStashIDs", testCtx, createdID, []models.StashID{
mockStudioReaderWriter.On("UpdatePartial", testCtx, models.StudioPartial{
ID: createdID,
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{
{
Endpoint: invalidEndpoint,
StashID: remoteSiteID,
},
}).Return(errors.New("error updating stash ids"))
mockStudioReaderWriter.On("UpdateStashIDs", testCtx, createdID, []models.StashID{
},
Mode: models.RelationshipUpdateModeSet,
},
}).Return(nil, errors.New("error updating stash ids"))
mockStudioReaderWriter.On("UpdatePartial", testCtx, models.StudioPartial{
ID: createdID,
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{
{
Endpoint: validEndpoint,
StashID: remoteSiteID,
},
}).Return(nil)
},
Mode: models.RelationshipUpdateModeSet,
},
}).Return(models.Studio{
ID: createdID,
}, nil)
type args struct {
endpoint string
@ -60,6 +73,7 @@ func Test_createMissingStudio(t *testing.T) {
emptyEndpoint,
&models.ScrapedStudio{
Name: validName,
RemoteSiteID: &remoteSiteID,
},
},
&createdID,
@ -71,6 +85,7 @@ func Test_createMissingStudio(t *testing.T) {
emptyEndpoint,
&models.ScrapedStudio{
Name: invalidName,
RemoteSiteID: &remoteSiteID,
},
},
nil,
@ -88,18 +103,6 @@ func Test_createMissingStudio(t *testing.T) {
&createdID,
false,
},
{
"invalid stash id",
args{
invalidEndpoint,
&models.ScrapedStudio{
Name: validName,
RemoteSiteID: &remoteSiteID,
},
},
nil,
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -114,48 +117,3 @@ func Test_createMissingStudio(t *testing.T) {
})
}
}
func Test_scrapedToStudioInput(t *testing.T) {
const name = "name"
url := "url"
tests := []struct {
name string
studio *models.ScrapedStudio
want models.Studio
}{
{
"set all",
&models.ScrapedStudio{
Name: name,
URL: &url,
},
models.Studio{
Name: name,
URL: url,
},
},
{
"set none",
&models.ScrapedStudio{
Name: name,
},
models.Studio{
Name: name,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := scrapedToStudioInput(tt.studio)
// clear created/updated dates
got.CreatedAt = time.Time{}
got.UpdatedAt = got.CreatedAt
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("scrapedToStudioInput() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -321,21 +321,31 @@ func (s *Manager) MigrateHash(ctx context.Context) int {
return s.JobManager.Add(ctx, "Migrating scene hashes...", j)
}
// If neither performer_ids nor performer_names are set, tag all performers
type StashBoxBatchPerformerTagInput struct {
// Stash endpoint to use for the performer tagging
// If neither ids nor names are set, tag all items
type StashBoxBatchTagInput struct {
// Stash endpoint to use for the tagging
Endpoint int `json:"endpoint"`
// Fields to exclude when executing the performer tagging
// Fields to exclude when executing the tagging
ExcludeFields []string `json:"exclude_fields"`
// Refresh performers already tagged by StashBox if true. Only tag performers with no StashBox tagging if false
// Refresh items already tagged by StashBox if true. Only tag items with no StashBox tagging if false
Refresh bool `json:"refresh"`
// If batch adding studios, should their parent studios also be created?
CreateParent bool `json:"createParent"`
// If set, only tag these ids
Ids []string `json:"ids"`
// If set, only tag these names
Names []string `json:"names"`
// If set, only tag these performer ids
//
// Deprecated: please use Ids
PerformerIds []string `json:"performer_ids"`
// If set, only tag these performer names
//
// Deprecated: please use Names
PerformerNames []string `json:"performer_names"`
}
func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxBatchPerformerTagInput) int {
func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxBatchTagInput) int {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
logger.Infof("Initiating stash-box batch performer tag")
@ -346,7 +356,7 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
}
box := boxes[input.Endpoint]
var tasks []StashBoxPerformerTagTask
var tasks []StashBoxBatchTagTask
// The gocritic linter wants to turn this ifElseChain into a switch.
// however, such a switch would contain quite large blocks for each section
@ -354,24 +364,35 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
//
// This is why we mark this section nolint. In principle, we should look to
// rewrite the section at some point, to avoid the linter warning.
if len(input.PerformerIds) > 0 { //nolint:gocritic
if len(input.Ids) > 0 || len(input.PerformerIds) > 0 { //nolint:gocritic
// The user has chosen only to tag the items on the current page
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
performerQuery := s.Repository.Performer
for _, performerID := range input.PerformerIds {
idsToUse := input.PerformerIds
if len(input.Ids) > 0 {
idsToUse = input.Ids
}
for _, performerID := range idsToUse {
if id, err := strconv.Atoi(performerID); err == nil {
performer, err := performerQuery.Find(ctx, id)
if err == nil {
err = performer.LoadStashIDs(ctx, performerQuery)
if err := performer.LoadStashIDs(ctx, performerQuery); err != nil {
return fmt.Errorf("loading performer stash ids: %w", err)
}
if err == nil {
tasks = append(tasks, StashBoxPerformerTagTask{
// Check if the user wants to refresh existing or new items
if (input.Refresh && len(performer.StashIDs.List()) > 0) ||
(!input.Refresh && len(performer.StashIDs.List()) == 0) {
tasks = append(tasks, StashBoxBatchTagTask{
performer: performer,
refresh: input.Refresh,
box: box,
excluded_fields: input.ExcludeFields,
excludedFields: input.ExcludeFields,
taskType: Performer,
})
}
} else {
return err
}
@ -381,14 +402,25 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
}); err != nil {
logger.Error(err.Error())
}
} else if len(input.PerformerNames) > 0 {
for i := range input.PerformerNames {
if len(input.PerformerNames[i]) > 0 {
tasks = append(tasks, StashBoxPerformerTagTask{
name: &input.PerformerNames[i],
refresh: input.Refresh,
} else if len(input.Names) > 0 || len(input.PerformerNames) > 0 {
// The user is batch adding performers
namesToUse := input.PerformerNames
if len(input.Names) > 0 {
namesToUse = input.Names
}
for i := range namesToUse {
if len(namesToUse[i]) > 0 {
performer := models.Performer{
Name: namesToUse[i],
}
tasks = append(tasks, StashBoxBatchTagTask{
performer: &performer,
refresh: false,
box: box,
excluded_fields: input.ExcludeFields,
excludedFields: input.ExcludeFields,
taskType: Performer,
})
}
}
@ -397,6 +429,8 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
// However, this doesn't really help with readability of the current section. Mark it
// as nolint for now. In the future we'd like to rewrite this code by factoring some of
// this into separate functions.
// The user has chosen to tag every item in their database
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
performerQuery := s.Repository.Performer
var performers []*models.Performer
@ -406,6 +440,7 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
} else {
performers, err = performerQuery.FindByStashIDStatus(ctx, false, box.Endpoint)
}
if err != nil {
return fmt.Errorf("error querying performers: %v", err)
}
@ -415,11 +450,12 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
return fmt.Errorf("error loading stash ids for performer %s: %v", performer.Name, err)
}
tasks = append(tasks, StashBoxPerformerTagTask{
tasks = append(tasks, StashBoxBatchTagTask{
performer: performer,
refresh: input.Refresh,
box: box,
excluded_fields: input.ExcludeFields,
excludedFields: input.ExcludeFields,
taskType: Performer,
})
}
return nil
@ -451,3 +487,132 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
return s.JobManager.Add(ctx, "Batch stash-box performer tag...", j)
}
func (s *Manager) StashBoxBatchStudioTag(ctx context.Context, input StashBoxBatchTagInput) int {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
logger.Infof("Initiating stash-box batch studio tag")
boxes := config.GetInstance().GetStashBoxes()
if input.Endpoint < 0 || input.Endpoint >= len(boxes) {
logger.Error(fmt.Errorf("invalid stash_box_index %d", input.Endpoint))
return
}
box := boxes[input.Endpoint]
var tasks []StashBoxBatchTagTask
// The gocritic linter wants to turn this ifElseChain into a switch.
// however, such a switch would contain quite large blocks for each section
// and would arguably be hard to read.
//
// This is why we mark this section nolint. In principle, we should look to
// rewrite the section at some point, to avoid the linter warning.
if len(input.Ids) > 0 { //nolint:gocritic
// The user has chosen only to tag the items on the current page
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
studioQuery := s.Repository.Studio
for _, studioID := range input.Ids {
if id, err := strconv.Atoi(studioID); err == nil {
studio, err := studioQuery.Find(ctx, id)
if err == nil {
if err := studio.LoadStashIDs(ctx, studioQuery); err != nil {
return fmt.Errorf("loading studio stash ids: %w", err)
}
// Check if the user wants to refresh existing or new items
if (input.Refresh && len(studio.StashIDs.List()) > 0) ||
(!input.Refresh && len(studio.StashIDs.List()) == 0) {
tasks = append(tasks, StashBoxBatchTagTask{
studio: studio,
refresh: input.Refresh,
createParent: input.CreateParent,
box: box,
excludedFields: input.ExcludeFields,
taskType: Studio,
})
}
} else {
return err
}
}
}
return nil
}); err != nil {
logger.Error(err.Error())
}
} else if len(input.Names) > 0 {
// The user is batch adding studios
for i := range input.Names {
if len(input.Names[i]) > 0 {
tasks = append(tasks, StashBoxBatchTagTask{
name: &input.Names[i],
refresh: false,
createParent: input.CreateParent,
box: box,
excludedFields: input.ExcludeFields,
taskType: Studio,
})
}
}
} else { //nolint:gocritic
// The gocritic linter wants to fold this if-block into the else on the line above.
// However, this doesn't really help with readability of the current section. Mark it
// as nolint for now. In the future we'd like to rewrite this code by factoring some of
// this into separate functions.
// The user has chosen to tag every item in their database
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
studioQuery := s.Repository.Studio
var studios []*models.Studio
var err error
if input.Refresh {
studios, err = studioQuery.FindByStashIDStatus(ctx, true, box.Endpoint)
} else {
studios, err = studioQuery.FindByStashIDStatus(ctx, false, box.Endpoint)
}
if err != nil {
return fmt.Errorf("error querying studios: %v", err)
}
for _, studio := range studios {
tasks = append(tasks, StashBoxBatchTagTask{
studio: studio,
refresh: input.Refresh,
createParent: input.CreateParent,
box: box,
excludedFields: input.ExcludeFields,
taskType: Studio,
})
}
return nil
}); err != nil {
logger.Error(err.Error())
return
}
}
if len(tasks) == 0 {
return
}
progress.SetTotal(len(tasks))
logger.Infof("Starting stash-box batch operation for %d studios", len(tasks))
var wg sync.WaitGroup
for _, task := range tasks {
wg.Add(1)
progress.ExecuteTask(task.Description(), func() {
task.Start(ctx)
wg.Done()
})
progress.Increment()
}
})
return s.JobManager.Add(ctx, "Batch stash-box studio tag...", j)
}

View file

@ -1,38 +0,0 @@
package manager
import (
"context"
"errors"
"fmt"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/studio"
)
func ValidateModifyStudio(ctx context.Context, studioID int, studio models.StudioPartial, qb studio.Finder) error {
if studio.ParentID.Ptr() == nil {
return nil
}
// ensure there is no cyclic dependency
currentParentID := studio.ParentID.Ptr()
for currentParentID != nil {
if *currentParentID == studioID {
return errors.New("studio cannot be an ancestor of itself")
}
currentStudio, err := qb.Find(ctx, *currentParentID)
if err != nil {
return fmt.Errorf("error finding parent studio: %v", err)
}
if currentStudio == nil {
return fmt.Errorf("studio with id %d not found", *currentParentID)
}
currentParentID = currentStudio.ParentID
}
return nil
}

View file

@ -134,7 +134,7 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source
j.progress.ExecuteTask("Identifying "+s.Path, func() {
task := identify.SceneIdentifier{
SceneReaderUpdater: instance.Repository.Scene,
StudioCreator: instance.Repository.Studio,
StudioReaderWriter: instance.Repository.Studio,
PerformerCreator: instance.Repository.Performer,
TagCreatorFinder: instance.Repository.Tag,

View file

@ -10,34 +10,62 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
type StashBoxPerformerTagTask struct {
type StashBoxTagTaskType int
const (
Performer StashBoxTagTaskType = iota
Studio
)
type StashBoxBatchTagTask struct {
box *models.StashBox
name *string
performer *models.Performer
studio *models.Studio
refresh bool
excluded_fields []string
createParent bool
excludedFields []string
taskType StashBoxTagTaskType
}
func (t *StashBoxPerformerTagTask) Start(ctx context.Context) {
func (t *StashBoxBatchTagTask) Start(ctx context.Context) {
switch t.taskType {
case Performer:
t.stashBoxPerformerTag(ctx)
case Studio:
t.stashBoxStudioTag(ctx)
default:
logger.Errorf("Error starting batch task, unknown task_type %d", t.taskType)
}
}
func (t *StashBoxPerformerTagTask) Description() string {
func (t *StashBoxBatchTagTask) Description() string {
if t.taskType == Performer {
var name string
if t.name != nil {
name = *t.name
} else if t.performer != nil {
} else {
name = t.performer.Name
}
return fmt.Sprintf("Tagging performer %s from stash-box", name)
} else if t.taskType == Studio {
var name string
if t.name != nil {
name = *t.name
} else {
name = t.studio.Name
}
return fmt.Sprintf("Tagging studio %s from stash-box", name)
}
return fmt.Sprintf("Unknown tagging task type %d from stash-box", t.taskType)
}
func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
func (t *StashBoxBatchTagTask) stashBoxPerformerTag(ctx context.Context) {
var performer *models.ScrapedPerformer
var err error
@ -74,7 +102,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
}
excluded := map[string]bool{}
for _, field := range t.excluded_fields {
for _, field := range t.excludedFields {
excluded[field] = true
}
@ -187,7 +215,246 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
}
}
func (t *StashBoxPerformerTagTask) getPartial(performer *models.ScrapedPerformer, excluded map[string]bool) models.PerformerPartial {
func (t *StashBoxBatchTagTask) stashBoxStudioTag(ctx context.Context) {
studio, err := t.findStashBoxStudio(ctx)
if err != nil {
logger.Errorf("Error fetching studio data from stash-box: %s", err.Error())
return
}
excluded := map[string]bool{}
for _, field := range t.excludedFields {
excluded[field] = true
}
// studio will have a value if pulling from Stash-box by Stash ID or name was successful
if studio != nil {
t.processMatchedStudio(ctx, studio, excluded)
} else {
var name string
if t.name != nil {
name = *t.name
} else if t.studio != nil {
name = t.studio.Name
}
logger.Infof("No match found for %s", name)
}
}
func (t *StashBoxBatchTagTask) findStashBoxStudio(ctx context.Context) (*models.ScrapedStudio, error) {
var studio *models.ScrapedStudio
var err error
client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{
Scene: instance.Repository.Scene,
Performer: instance.Repository.Performer,
Tag: instance.Repository.Tag,
Studio: instance.Repository.Studio,
})
if t.refresh {
var remoteID string
txnErr := txn.WithReadTxn(ctx, instance.Repository, func(ctx context.Context) error {
if !t.studio.StashIDs.Loaded() {
err = t.studio.LoadStashIDs(ctx, instance.Repository.Studio)
if err != nil {
return err
}
}
stashids := t.studio.StashIDs.List()
for _, id := range stashids {
if id.Endpoint == t.box.Endpoint {
remoteID = id.StashID
}
}
return nil
})
if txnErr != nil {
logger.Warnf("error while executing read transaction: %v", err)
return nil, err
}
if remoteID != "" {
studio, err = client.FindStashBoxStudio(ctx, remoteID)
}
} else {
var name string
if t.name != nil {
name = *t.name
} else {
name = t.studio.Name
}
studio, err = client.FindStashBoxStudio(ctx, name)
}
return studio, err
}
func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *models.ScrapedStudio, excluded map[string]bool) {
// Refreshing an existing studio
if t.studio != nil {
if s.Parent != nil && t.createParent {
err := t.processParentStudio(ctx, s.Parent, excluded)
if err != nil {
return
}
}
existingStashIDs := getStashIDsForStudio(ctx, *s.StoredID)
studioPartial := s.ToPartial(s.StoredID, t.box.Endpoint, excluded, existingStashIDs)
studioImage, err := s.GetImage(ctx, excluded)
if err != nil {
logger.Errorf("Failed to make studio partial from scraped studio %s: %s", s.Name, err.Error())
return
}
// Start the transaction and update the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Studio
if err := studio.ValidateModify(ctx, *studioPartial, qb); err != nil {
return err
}
if _, err := qb.UpdatePartial(ctx, *studioPartial); err != nil {
return err
}
if len(studioImage) > 0 {
if err := qb.UpdateImage(ctx, studioPartial.ID, studioImage); err != nil {
return err
}
}
return nil
})
if err != nil {
logger.Errorf("Failed to update studio %s: %s", s.Name, err.Error())
} else {
logger.Infof("Updated studio %s", s.Name)
}
} else if t.name != nil && s.Name != "" {
// Creating a new studio
if s.Parent != nil && t.createParent {
err := t.processParentStudio(ctx, s.Parent, excluded)
if err != nil {
return
}
}
newStudio := s.ToStudio(t.box.Endpoint, excluded)
studioImage, err := s.GetImage(ctx, excluded)
if err != nil {
logger.Errorf("Failed to make studio from scraped studio %s: %s", s.Name, err.Error())
return
}
// Start the transaction and save the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Studio
if err := qb.Create(ctx, newStudio); err != nil {
return err
}
if len(studioImage) > 0 {
if err := qb.UpdateImage(ctx, newStudio.ID, studioImage); err != nil {
return err
}
}
return nil
})
if err != nil {
logger.Errorf("Failed to create studio %s: %s", s.Name, err.Error())
} else {
logger.Infof("Created studio %s", s.Name)
}
}
}
func (t *StashBoxBatchTagTask) processParentStudio(ctx context.Context, parent *models.ScrapedStudio, excluded map[string]bool) error {
if parent.StoredID == nil {
// The parent needs to be created
newParentStudio := parent.ToStudio(t.box.Endpoint, excluded)
studioImage, err := parent.GetImage(ctx, excluded)
if err != nil {
logger.Errorf("Failed to make parent studio from scraped studio %s: %s", parent.Name, err.Error())
return err
}
// Start the transaction and save the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Studio
if err := qb.Create(ctx, newParentStudio); err != nil {
return err
}
if len(studioImage) > 0 {
if err := qb.UpdateImage(ctx, newParentStudio.ID, studioImage); err != nil {
return err
}
}
storedId := strconv.Itoa(newParentStudio.ID)
parent.StoredID = &storedId
return nil
})
if err != nil {
logger.Errorf("Failed to create studio %s: %s", parent.Name, err.Error())
return err
}
logger.Infof("Created studio %s", parent.Name)
} else {
// The parent studio matched an existing one and the user has chosen in the UI to link and/or update it
existingStashIDs := getStashIDsForStudio(ctx, *parent.StoredID)
studioPartial := parent.ToPartial(parent.StoredID, t.box.Endpoint, excluded, existingStashIDs)
studioImage, err := parent.GetImage(ctx, excluded)
if err != nil {
logger.Errorf("Failed to make parent studio partial from scraped studio %s: %s", parent.Name, err.Error())
return err
}
// Start the transaction and update the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
qb := instance.Repository.Studio
if err := studio.ValidateModify(ctx, *studioPartial, instance.Repository.Studio); err != nil {
return err
}
if _, err := qb.UpdatePartial(ctx, *studioPartial); err != nil {
return err
}
if len(studioImage) > 0 {
if err := qb.UpdateImage(ctx, studioPartial.ID, studioImage); err != nil {
return err
}
}
return nil
})
if err != nil {
logger.Errorf("Failed to update studio %s: %s", parent.Name, err.Error())
return err
}
logger.Infof("Updated studio %s", parent.Name)
}
return nil
}
func getStashIDsForStudio(ctx context.Context, studioID string) []models.StashID {
id, _ := strconv.Atoi(studioID)
tempStudio := &models.Studio{ID: id}
err := tempStudio.LoadStashIDs(ctx, instance.Repository.Studio)
if err != nil {
return nil
}
return tempStudio.StashIDs.List()
}
func (t *StashBoxBatchTagTask) getPartial(performer *models.ScrapedPerformer, excluded map[string]bool) models.PerformerPartial {
partial := models.NewPerformerPartial()
if performer.Aliases != nil && !excluded["aliases"] {
@ -243,7 +510,7 @@ func (t *StashBoxPerformerTagTask) getPartial(performer *models.ScrapedPerformer
if performer.Measurements != nil && !excluded["measurements"] {
partial.Measurements = models.NewOptionalString(*performer.Measurements)
}
if excluded["name"] && performer.Name != nil {
if performer.Name != nil && !excluded["name"] {
partial.Name = models.NewOptionalString(*performer.Name)
}
if performer.Disambiguation != nil && !excluded["disambiguation"] {

View file

@ -119,7 +119,9 @@ func (i *Importer) populateStudio(ctx context.Context) error {
}
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name)
newStudio := &models.Studio{
Name: name,
}
err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {

View file

@ -152,7 +152,9 @@ func (i *Importer) populateStudio(ctx context.Context) error {
}
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name)
newStudio := &models.Studio{
Name: name,
}
err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {

View file

@ -58,13 +58,13 @@ func (_m *StudioReaderWriter) Count(ctx context.Context) (int, error) {
return r0, r1
}
// Create provides a mock function with given fields: ctx, newStudio
func (_m *StudioReaderWriter) Create(ctx context.Context, newStudio *models.Studio) error {
ret := _m.Called(ctx, newStudio)
// Create provides a mock function with given fields: ctx, input
func (_m *StudioReaderWriter) Create(ctx context.Context, input *models.Studio) error {
ret := _m.Called(ctx, input)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *models.Studio) error); ok {
r0 = rf(ctx, newStudio)
r0 = rf(ctx, input)
} else {
r0 = ret.Error(0)
}
@ -155,6 +155,29 @@ func (_m *StudioReaderWriter) FindByStashID(ctx context.Context, stashID models.
return r0, r1
}
// FindByStashIDStatus provides a mock function with given fields: ctx, hasStashID, stashboxEndpoint
func (_m *StudioReaderWriter) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Studio, error) {
ret := _m.Called(ctx, hasStashID, stashboxEndpoint)
var r0 []*models.Studio
if rf, ok := ret.Get(0).(func(context.Context, bool, string) []*models.Studio); ok {
r0 = rf(ctx, hasStashID, stashboxEndpoint)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, bool, string) error); ok {
r1 = rf(ctx, hasStashID, stashboxEndpoint)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindChildren provides a mock function with given fields: ctx, id
func (_m *StudioReaderWriter) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) {
ret := _m.Called(ctx, id)
@ -201,13 +224,13 @@ func (_m *StudioReaderWriter) FindMany(ctx context.Context, ids []int) ([]*model
return r0, r1
}
// GetAliases provides a mock function with given fields: ctx, studioID
func (_m *StudioReaderWriter) GetAliases(ctx context.Context, studioID int) ([]string, error) {
ret := _m.Called(ctx, studioID)
// GetAliases provides a mock function with given fields: ctx, relatedID
func (_m *StudioReaderWriter) GetAliases(ctx context.Context, relatedID int) ([]string, error) {
ret := _m.Called(ctx, relatedID)
var r0 []string
if rf, ok := ret.Get(0).(func(context.Context, int) []string); ok {
r0 = rf(ctx, studioID)
r0 = rf(ctx, relatedID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
@ -216,7 +239,7 @@ func (_m *StudioReaderWriter) GetAliases(ctx context.Context, studioID int) ([]s
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, studioID)
r1 = rf(ctx, relatedID)
} else {
r1 = ret.Error(1)
}
@ -358,20 +381,6 @@ func (_m *StudioReaderWriter) Update(ctx context.Context, updatedStudio *models.
return r0
}
// UpdateAliases provides a mock function with given fields: ctx, studioID, aliases
func (_m *StudioReaderWriter) UpdateAliases(ctx context.Context, studioID int, aliases []string) error {
ret := _m.Called(ctx, studioID, aliases)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int, []string) error); ok {
r0 = rf(ctx, studioID, aliases)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateImage provides a mock function with given fields: ctx, studioID, image
func (_m *StudioReaderWriter) UpdateImage(ctx context.Context, studioID int, image []byte) error {
ret := _m.Called(ctx, studioID, image)
@ -386,13 +395,13 @@ func (_m *StudioReaderWriter) UpdateImage(ctx context.Context, studioID int, ima
return r0
}
// UpdatePartial provides a mock function with given fields: ctx, id, updatedStudio
func (_m *StudioReaderWriter) UpdatePartial(ctx context.Context, id int, updatedStudio models.StudioPartial) (*models.Studio, error) {
ret := _m.Called(ctx, id, updatedStudio)
// UpdatePartial provides a mock function with given fields: ctx, input
func (_m *StudioReaderWriter) UpdatePartial(ctx context.Context, input models.StudioPartial) (*models.Studio, error) {
ret := _m.Called(ctx, input)
var r0 *models.Studio
if rf, ok := ret.Get(0).(func(context.Context, int, models.StudioPartial) *models.Studio); ok {
r0 = rf(ctx, id, updatedStudio)
if rf, ok := ret.Get(0).(func(context.Context, models.StudioPartial) *models.Studio); ok {
r0 = rf(ctx, input)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Studio)
@ -400,25 +409,11 @@ func (_m *StudioReaderWriter) UpdatePartial(ctx context.Context, id int, updated
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int, models.StudioPartial) error); ok {
r1 = rf(ctx, id, updatedStudio)
if rf, ok := ret.Get(1).(func(context.Context, models.StudioPartial) error); ok {
r1 = rf(ctx, input)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdateStashIDs provides a mock function with given fields: ctx, studioID, stashIDs
func (_m *StudioReaderWriter) UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error {
ret := _m.Called(ctx, studioID, stashIDs)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int, []models.StashID) error); ok {
r0 = rf(ctx, studioID, stashIDs)
} else {
r0 = ret.Error(0)
}
return r0
}

View file

@ -1,16 +1,108 @@
package models
import (
"context"
"strconv"
"time"
"github.com/stashapp/stash/pkg/utils"
)
type ScrapedStudio struct {
// Set if studio matched
StoredID *string `json:"stored_id"`
Name string `json:"name"`
URL *string `json:"url"`
Parent *ScrapedStudio `json:"parent"`
Image *string `json:"image"`
Images []string `json:"images"`
RemoteSiteID *string `json:"remote_site_id"`
}
func (ScrapedStudio) IsScrapedContent() {}
func (s *ScrapedStudio) ToStudio(endpoint string, excluded map[string]bool) *Studio {
now := time.Now()
// Populate a new studio from the input
newStudio := Studio{
Name: s.Name,
StashIDs: NewRelatedStashIDs([]StashID{
{
Endpoint: endpoint,
StashID: *s.RemoteSiteID,
},
}),
CreatedAt: now,
UpdatedAt: now,
}
if s.URL != nil && !excluded["url"] {
newStudio.URL = *s.URL
}
if s.Parent != nil && s.Parent.StoredID != nil && !excluded["parent"] {
parentId, _ := strconv.Atoi(*s.Parent.StoredID)
newStudio.ParentID = &parentId
}
return &newStudio
}
func (s *ScrapedStudio) GetImage(ctx context.Context, excluded map[string]bool) ([]byte, error) {
// Process the base 64 encoded image string
if len(s.Images) > 0 && !excluded["image"] {
var err error
img, err := utils.ProcessImageInput(ctx, *s.Image)
if err != nil {
return nil, err
}
return img, nil
}
return nil, nil
}
func (s *ScrapedStudio) ToPartial(id *string, endpoint string, excluded map[string]bool, existingStashIDs []StashID) *StudioPartial {
partial := StudioPartial{
UpdatedAt: NewOptionalTime(time.Now()),
}
partial.ID, _ = strconv.Atoi(*id)
if s.Name != "" && !excluded["name"] {
partial.Name = NewOptionalString(s.Name)
}
if s.URL != nil && !excluded["url"] {
partial.URL = NewOptionalString(*s.URL)
}
if s.Parent != nil && !excluded["parent"] {
if s.Parent.StoredID != nil {
parentID, _ := strconv.Atoi(*s.Parent.StoredID)
if parentID > 0 {
// This is to be set directly as we know it has a value and the translator won't have the field
partial.ParentID = NewOptionalInt(parentID)
}
}
} else {
partial.ParentID = NewOptionalIntPtr(nil)
}
partial.StashIDs = &UpdateStashIDs{
StashIDs: existingStashIDs,
Mode: RelationshipUpdateModeSet,
}
partial.StashIDs.Set(StashID{
Endpoint: endpoint,
StashID: *s.RemoteSiteID,
})
return &partial
}
// A performer from a scraping operation...
type ScrapedPerformer struct {
// Set if performer matched

View file

@ -0,0 +1,65 @@
package models
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func Test_scrapedToStudioInput(t *testing.T) {
const name = "name"
url := "url"
remoteSiteID := "remoteSiteID"
tests := []struct {
name string
studio *ScrapedStudio
want *Studio
}{
{
"set all",
&ScrapedStudio{
Name: name,
URL: &url,
RemoteSiteID: &remoteSiteID,
},
&Studio{
Name: name,
URL: url,
StashIDs: NewRelatedStashIDs([]StashID{
{
StashID: remoteSiteID,
},
}),
},
},
{
"set none",
&ScrapedStudio{
Name: name,
RemoteSiteID: &remoteSiteID,
},
&Studio{
Name: name,
StashIDs: NewRelatedStashIDs([]StashID{
{
StashID: remoteSiteID,
},
}),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.studio.ToStudio("", nil)
assert.NotEqual(t, time.Time{}, got.CreatedAt)
assert.NotEqual(t, time.Time{}, got.UpdatedAt)
got.CreatedAt = time.Time{}
got.UpdatedAt = time.Time{}
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -1,6 +1,7 @@
package models
import (
"context"
"time"
)
@ -15,34 +16,50 @@ type Studio struct {
Rating *int `json:"rating"`
Details string `json:"details"`
IgnoreAutoTag bool `json:"ignore_auto_tag"`
Aliases RelatedStrings `json:"aliases"`
StashIDs RelatedStashIDs `json:"stash_ids"`
}
func (s *Studio) LoadAliases(ctx context.Context, l AliasLoader) error {
return s.Aliases.load(func() ([]string, error) {
return l.GetAliases(ctx, s.ID)
})
}
func (s *Studio) LoadStashIDs(ctx context.Context, l StashIDLoader) error {
return s.StashIDs.load(func() ([]StashID, error) {
return l.GetStashIDs(ctx, s.ID)
})
}
func (s *Studio) LoadRelationships(ctx context.Context, l PerformerReader) error {
if err := s.LoadAliases(ctx, l); err != nil {
return err
}
if err := s.LoadStashIDs(ctx, l); err != nil {
return err
}
return nil
}
// StudioPartial represents part of a Studio object. It is used to update the database entry.
type StudioPartial struct {
ID int
Name OptionalString
URL OptionalString
ParentID OptionalInt
CreatedAt OptionalTime
UpdatedAt OptionalTime
// Rating expressed in 1-100 scale
Rating OptionalInt
Details OptionalString
CreatedAt OptionalTime
UpdatedAt OptionalTime
IgnoreAutoTag OptionalBool
}
func NewStudio(name string) *Studio {
currentTime := time.Now()
return &Studio{
Name: name,
CreatedAt: currentTime,
UpdatedAt: currentTime,
}
}
func NewStudioPartial() StudioPartial {
updatedTime := time.Now()
return StudioPartial{
UpdatedAt: NewOptionalTime(updatedTime),
}
Aliases *UpdateStrings
StashIDs *UpdateStashIDs
}
type Studios []*Studio

View file

@ -48,6 +48,7 @@ type StudioReader interface {
FindChildren(ctx context.Context, id int) ([]*Studio, error)
FindByName(ctx context.Context, name string, nocase bool) (*Studio, error)
FindByStashID(ctx context.Context, stashID StashID) ([]*Studio, error)
FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*Studio, error)
Count(ctx context.Context) (int, error)
All(ctx context.Context) ([]*Studio, error)
// TODO - this interface is temporary until the filter schema can fully
@ -56,18 +57,16 @@ type StudioReader interface {
Query(ctx context.Context, studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error)
GetImage(ctx context.Context, studioID int) ([]byte, error)
HasImage(ctx context.Context, studioID int) (bool, error)
AliasLoader
StashIDLoader
GetAliases(ctx context.Context, studioID int) ([]string, error)
}
type StudioWriter interface {
Create(ctx context.Context, newStudio *Studio) error
UpdatePartial(ctx context.Context, id int, updatedStudio StudioPartial) (*Studio, error)
UpdatePartial(ctx context.Context, input StudioPartial) (*Studio, error)
Update(ctx context.Context, updatedStudio *Studio) error
Destroy(ctx context.Context, id int) error
UpdateImage(ctx context.Context, studioID int, image []byte) error
UpdateStashIDs(ctx context.Context, studioID int, stashIDs []StashID) error
UpdateAliases(ctx context.Context, studioID int, aliases []string) error
}
type StudioReaderWriter interface {

View file

@ -5,6 +5,7 @@ import (
"io"
"strconv"
"github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
)
@ -94,16 +95,7 @@ func (u *UpdateIDs) EffectiveIDs(existing []int) []int {
return nil
}
switch u.Mode {
case RelationshipUpdateModeAdd:
return intslice.IntAppendUniques(existing, u.IDs)
case RelationshipUpdateModeRemove:
return intslice.IntExclude(existing, u.IDs)
case RelationshipUpdateModeSet:
return u.IDs
}
return nil
return effectiveValues(u.IDs, u.Mode, existing)
}
type UpdateStrings struct {
@ -118,3 +110,26 @@ func (u *UpdateStrings) Strings() []string {
return u.Values
}
// GetEffectiveIDs returns the new IDs that will be effective after the update.
func (u *UpdateStrings) EffectiveValues(existing []string) []string {
if u == nil {
return nil
}
return effectiveValues(u.Values, u.Mode, existing)
}
// effectiveValues returns the new values that will be effective after the update.
func effectiveValues[T comparable](values []T, mode RelationshipUpdateMode, existing []T) []T {
switch mode {
case RelationshipUpdateModeAdd:
return sliceutil.AppendUniques(existing, values)
case RelationshipUpdateModeRemove:
return sliceutil.Exclude(existing, values)
case RelationshipUpdateModeSet:
return values
}
return nil
}

View file

@ -116,7 +116,9 @@ func (i *Importer) populateStudio(ctx context.Context) error {
}
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name)
newStudio := &models.Studio{
Name: name,
}
err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {

View file

@ -176,7 +176,9 @@ func (i *Importer) populateStudio(ctx context.Context) error {
}
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name)
newStudio := &models.Studio{
Name: name,
}
err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {

View file

@ -17,6 +17,7 @@ type StashBoxGraphQLClient interface {
SearchPerformer(ctx context.Context, term string, httpRequestOptions ...client.HTTPRequestOption) (*SearchPerformer, error)
FindPerformerByID(ctx context.Context, id string, httpRequestOptions ...client.HTTPRequestOption) (*FindPerformerByID, error)
FindSceneByID(ctx context.Context, id string, httpRequestOptions ...client.HTTPRequestOption) (*FindSceneByID, error)
FindStudio(ctx context.Context, id *string, name *string, httpRequestOptions ...client.HTTPRequestOption) (*FindStudio, error)
SubmitFingerprint(ctx context.Context, input FingerprintSubmission, httpRequestOptions ...client.HTTPRequestOption) (*SubmitFingerprint, error)
Me(ctx context.Context, httpRequestOptions ...client.HTTPRequestOption) (*Me, error)
SubmitSceneDraft(ctx context.Context, input SceneDraftInput, httpRequestOptions ...client.HTTPRequestOption) (*SubmitSceneDraft, error)
@ -128,6 +129,10 @@ type StudioFragment struct {
Name string "json:\"name\" graphql:\"name\""
ID string "json:\"id\" graphql:\"id\""
Urls []*URLFragment "json:\"urls\" graphql:\"urls\""
Parent *struct {
Name string "json:\"name\" graphql:\"name\""
ID string "json:\"id\" graphql:\"id\""
} "json:\"parent\" graphql:\"parent\""
Images []*ImageFragment "json:\"images\" graphql:\"images\""
}
type TagFragment struct {
@ -215,6 +220,9 @@ type FindPerformerByID struct {
type FindSceneByID struct {
FindScene *SceneFragment "json:\"findScene\" graphql:\"findScene\""
}
type FindStudio struct {
FindStudio *StudioFragment "json:\"findStudio\" graphql:\"findStudio\""
}
type SubmitFingerprint struct {
SubmitFingerprint bool "json:\"submitFingerprint\" graphql:\"submitFingerprint\""
}
@ -239,12 +247,77 @@ const FindSceneByFingerprintDocument = `query FindSceneByFingerprint ($fingerpri
... SceneFragment
}
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment SceneFragment on Scene {
id
title
code
details
director
duration
date
urls {
... URLFragment
}
images {
... ImageFragment
}
studio {
... StudioFragment
}
tags {
... TagFragment
}
performers {
... PerformerAppearanceFragment
}
fingerprints {
... FingerprintFragment
}
}
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment TagFragment on Tag {
name
id
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
}
}
fragment PerformerFragment on Performer {
id
name
@ -279,76 +352,15 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment
}
}
fragment SceneFragment on Scene {
id
title
code
details
director
duration
date
urls {
... URLFragment
}
images {
... ImageFragment
}
studio {
... StudioFragment
}
tags {
... TagFragment
}
performers {
... PerformerAppearanceFragment
}
fingerprints {
... FingerprintFragment
}
}
fragment URLFragment on URL {
url
type
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
images {
... ImageFragment
}
}
fragment TagFragment on Tag {
name
id
}
`
func (c *Client) FindSceneByFingerprint(ctx context.Context, fingerprint FingerprintQueryInput, httpRequestOptions ...client.HTTPRequestOption) (*FindSceneByFingerprint, error) {
@ -369,6 +381,49 @@ const FindScenesByFullFingerprintsDocument = `query FindScenesByFullFingerprints
... SceneFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
}
fragment PerformerFragment on Performer {
id
name
@ -403,16 +458,6 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment SceneFragment on Scene {
id
title
@ -440,35 +485,6 @@ fragment SceneFragment on Scene {
... FingerprintFragment
}
}
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
images {
... ImageFragment
}
}
fragment TagFragment on Tag {
name
id
@ -499,28 +515,56 @@ const FindScenesBySceneFingerprintsDocument = `query FindScenesBySceneFingerprin
... SceneFragment
}
}
fragment StudioFragment on Studio {
fragment URLFragment on URL {
url
type
}
fragment TagFragment on Tag {
name
id
urls {
... URLFragment
}
images {
... ImageFragment
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment SceneFragment on Scene {
id
title
code
details
director
duration
date
urls {
... URLFragment
}
images {
... ImageFragment
}
studio {
... StudioFragment
}
tags {
... TagFragment
}
performers {
... PerformerAppearanceFragment
}
fingerprints {
... FingerprintFragment
}
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image {
id
@ -528,10 +572,18 @@ fragment ImageFragment on Image {
width
height
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
}
fragment PerformerFragment on Performer {
@ -568,46 +620,14 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment
}
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment SceneFragment on Scene {
id
title
code
details
director
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
date
urls {
... URLFragment
}
images {
... ImageFragment
}
studio {
... StudioFragment
}
tags {
... TagFragment
}
performers {
... PerformerAppearanceFragment
}
fingerprints {
... FingerprintFragment
}
}
fragment TagFragment on Tag {
name
id
}
`
@ -629,6 +649,29 @@ const SearchSceneDocument = `query SearchScene ($term: String!) {
... SceneFragment
}
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment TagFragment on Tag {
name
id
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment SceneFragment on Scene {
id
title
@ -660,32 +703,16 @@ fragment URLFragment on URL {
url
type
}
fragment TagFragment on Tag {
name
id
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
@ -730,14 +757,11 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment
}
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
`
@ -759,16 +783,6 @@ const SearchPerformerDocument = `query SearchPerformer ($term: String!) {
... PerformerFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment BodyModificationFragment on BodyModification {
location
description
@ -817,6 +831,16 @@ fragment ImageFragment on Image {
width
height
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
`
func (c *Client) SearchPerformer(ctx context.Context, term string, httpRequestOptions ...client.HTTPRequestOption) (*SearchPerformer, error) {
@ -915,26 +939,25 @@ const FindSceneByIDDocument = `query FindSceneByID ($id: ID!) {
... SceneFragment
}
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment URLFragment on URL {
fragment ImageFragment on Image {
id
url
type
width
height
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment TagFragment on Tag {
name
@ -974,13 +997,11 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment BodyModificationFragment on BodyModification {
location
description
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment SceneFragment on Scene {
id
@ -1009,21 +1030,28 @@ fragment SceneFragment on Scene {
... FingerprintFragment
}
}
fragment ImageFragment on Image {
id
fragment URLFragment on URL {
url
width
height
type
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
fragment BodyModificationFragment on BodyModification {
location
description
}
images {
... ImageFragment
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
`
@ -1040,6 +1068,51 @@ func (c *Client) FindSceneByID(ctx context.Context, id string, httpRequestOption
return &res, nil
}
const FindStudioDocument = `query FindStudio ($id: ID, $name: String) {
findStudio(id: $id, name: $name) {
... StudioFragment
}
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
}
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image {
id
url
width
height
}
`
func (c *Client) FindStudio(ctx context.Context, id *string, name *string, httpRequestOptions ...client.HTTPRequestOption) (*FindStudio, error) {
vars := map[string]interface{}{
"id": id,
"name": name,
}
var res FindStudio
if err := c.Client.Post(ctx, "FindStudio", FindStudioDocument, &res, vars, httpRequestOptions...); err != nil {
return nil, err
}
return &res, nil
}
const SubmitFingerprintDocument = `mutation SubmitFingerprint ($input: FingerprintSubmission!) {
submitFingerprint(input: $input)
}

View file

@ -88,9 +88,9 @@ type DraftEntity struct {
ID *string `json:"id,omitempty"`
}
func (DraftEntity) IsSceneDraftPerformer() {}
func (DraftEntity) IsSceneDraftStudio() {}
func (DraftEntity) IsSceneDraftTag() {}
func (DraftEntity) IsSceneDraftStudio() {}
func (DraftEntity) IsSceneDraftPerformer() {}
type DraftEntityInput struct {
Name string `json:"name"`
@ -116,6 +116,7 @@ type Edit struct {
// Objects to merge with the target. Only applicable to merges
MergeSources []EditTarget `json:"merge_sources,omitempty"`
Operation OperationEnum `json:"operation"`
Bot bool `json:"bot"`
Details EditDetails `json:"details,omitempty"`
// Previous state of fields being modified - null if operation is create or delete.
OldDetails EditDetails `json:"old_details,omitempty"`
@ -154,6 +155,8 @@ type EditInput struct {
// Only required for merge type
MergeSourceIds []string `json:"merge_source_ids,omitempty"`
Comment *string `json:"comment,omitempty"`
// Edit submitted by an automated script. Requires bot permission
Bot *bool `json:"bot,omitempty"`
}
type EditQueryInput struct {
@ -173,6 +176,10 @@ type EditQueryInput struct {
TargetID *string `json:"target_id,omitempty"`
// Filter by favorite status
IsFavorite *bool `json:"is_favorite,omitempty"`
// Filter by user voted status
Voted *UserVotedFilterEnum `json:"voted,omitempty"`
// Filter to bot edits only
IsBot *bool `json:"is_bot,omitempty"`
Page int `json:"page"`
PerPage int `json:"per_page"`
Direction SortDirectionEnum `json:"direction"`
@ -543,12 +550,25 @@ type PerformerQueryInput struct {
Piercings *BodyModificationCriterionInput `json:"piercings,omitempty"`
// Filter by performerfavorite status for the current user
IsFavorite *bool `json:"is_favorite,omitempty"`
// Filter by a performer they have performed in scenes with
PerformedWith *string `json:"performed_with,omitempty"`
// Filter by a studio
StudioID *string `json:"studio_id,omitempty"`
Page int `json:"page"`
PerPage int `json:"per_page"`
Direction SortDirectionEnum `json:"direction"`
Sort PerformerSortEnum `json:"sort"`
}
type PerformerScenesInput struct {
// Filter by another performer that also performs in the scenes
PerformedWith *string `json:"performed_with,omitempty"`
// Filter by a studio
StudioID *string `json:"studio_id,omitempty"`
// Filter by tags
Tags *MultiIDCriterionInput `json:"tags,omitempty"`
}
type PerformerStudio struct {
Studio *Studio `json:"studio,omitempty"`
SceneCount int `json:"scene_count"`
@ -689,7 +709,9 @@ type SceneDestroyInput struct {
type SceneDraft struct {
ID *string `json:"id,omitempty"`
Title *string `json:"title,omitempty"`
Code *string `json:"code,omitempty"`
Details *string `json:"details,omitempty"`
Director *string `json:"director,omitempty"`
URL *URL `json:"url,omitempty"`
Date *string `json:"date,omitempty"`
Studio SceneDraftStudio `json:"studio,omitempty"`
@ -775,6 +797,8 @@ type SceneQueryInput struct {
Fingerprints *MultiStringCriterionInput `json:"fingerprints,omitempty"`
// Filter by favorited entity
Favorites *FavoriteFilter `json:"favorites,omitempty"`
// Filter to scenes with fingerprints submitted by the user
HasFingerprintSubmissions *bool `json:"has_fingerprint_submissions,omitempty"`
Page int `json:"page"`
PerPage int `json:"per_page"`
Direction SortDirectionEnum `json:"direction"`
@ -857,6 +881,7 @@ type Studio struct {
IsFavorite bool `json:"is_favorite"`
Created time.Time `json:"created"`
Updated time.Time `json:"updated"`
Performers *QueryPerformersResultType `json:"performers,omitempty"`
}
func (Studio) IsSceneDraftStudio() {}
@ -1775,6 +1800,7 @@ const (
PerformerSortEnumOCounter PerformerSortEnum = "O_COUNTER"
PerformerSortEnumCareerStartYear PerformerSortEnum = "CAREER_START_YEAR"
PerformerSortEnumDebut PerformerSortEnum = "DEBUT"
PerformerSortEnumLastScene PerformerSortEnum = "LAST_SCENE"
PerformerSortEnumCreatedAt PerformerSortEnum = "CREATED_AT"
PerformerSortEnumUpdatedAt PerformerSortEnum = "UPDATED_AT"
)
@ -1786,6 +1812,7 @@ var AllPerformerSortEnum = []PerformerSortEnum{
PerformerSortEnumOCounter,
PerformerSortEnumCareerStartYear,
PerformerSortEnumDebut,
PerformerSortEnumLastScene,
PerformerSortEnumCreatedAt,
PerformerSortEnumUpdatedAt,
}
@ -2136,6 +2163,51 @@ func (e TargetTypeEnum) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
type UserVotedFilterEnum string
const (
UserVotedFilterEnumAbstain UserVotedFilterEnum = "ABSTAIN"
UserVotedFilterEnumAccept UserVotedFilterEnum = "ACCEPT"
UserVotedFilterEnumReject UserVotedFilterEnum = "REJECT"
UserVotedFilterEnumNotVoted UserVotedFilterEnum = "NOT_VOTED"
)
var AllUserVotedFilterEnum = []UserVotedFilterEnum{
UserVotedFilterEnumAbstain,
UserVotedFilterEnumAccept,
UserVotedFilterEnumReject,
UserVotedFilterEnumNotVoted,
}
func (e UserVotedFilterEnum) IsValid() bool {
switch e {
case UserVotedFilterEnumAbstain, UserVotedFilterEnumAccept, UserVotedFilterEnumReject, UserVotedFilterEnumNotVoted:
return true
}
return false
}
func (e UserVotedFilterEnum) String() string {
return string(e)
}
func (e *UserVotedFilterEnum) UnmarshalGQL(v interface{}) error {
str, ok := v.(string)
if !ok {
return fmt.Errorf("enums must be strings")
}
*e = UserVotedFilterEnum(str)
if !e.IsValid() {
return fmt.Errorf("%s is not a valid UserVotedFilterEnum", str)
}
return nil
}
func (e UserVotedFilterEnum) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
type ValidSiteTypeEnum string
const (

View file

@ -2,6 +2,11 @@ package stashbox
import "github.com/stashapp/stash/pkg/models"
type StashBoxStudioQueryResult struct {
Query string `json:"query"`
Results []*models.ScrapedStudio `json:"results"`
}
type StashBoxPerformerQueryResult struct {
Query string `json:"query"`
Results []*models.ScrapedPerformer `json:"results"`

View file

@ -18,6 +18,7 @@ import (
"golang.org/x/text/language"
"github.com/Yamashou/gqlgenc/graphqljson"
"github.com/gofrs/uuid"
"github.com/stashapp/stash/pkg/file"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/match"
@ -660,6 +661,26 @@ func performerFragmentToScrapedScenePerformer(p graphql.PerformerFragment) *mode
return sp
}
func studioFragmentToScrapedStudio(s graphql.StudioFragment) *models.ScrapedStudio {
images := []string{}
for _, image := range s.Images {
images = append(images, image.URL)
}
st := &models.ScrapedStudio{
Name: s.Name,
URL: findURL(s.Urls, "HOME"),
Images: images,
RemoteSiteID: &s.ID,
}
if len(st.Images) > 0 {
st.Image = &st.Images[0]
}
return st
}
func getFirstImage(ctx context.Context, client *http.Client, images []*graphql.ImageFragment) *string {
ret, err := fetchImage(ctx, client, images[0].URL)
if err != nil && !errors.Is(err, context.Canceled) {
@ -725,20 +746,29 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
tqb := c.repository.Tag
if s.Studio != nil {
studioID := s.Studio.ID
ss.Studio = &models.ScrapedStudio{
Name: s.Studio.Name,
URL: findURL(s.Studio.Urls, "HOME"),
RemoteSiteID: &studioID,
}
if s.Studio.Images != nil && len(s.Studio.Images) > 0 {
ss.Studio.Image = &s.Studio.Images[0].URL
}
ss.Studio = studioFragmentToScrapedStudio(*s.Studio)
err := match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio, &c.box.Endpoint)
if err != nil {
return err
}
var parentStudio *graphql.FindStudio
if s.Studio.Parent != nil {
parentStudio, err = c.client.FindStudio(ctx, &s.Studio.Parent.ID, nil)
if err != nil {
return err
}
if parentStudio.FindStudio != nil {
ss.Studio.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio.Parent, &c.box.Endpoint)
if err != nil {
return err
}
}
}
}
for _, p := range s.Performers {
@ -799,6 +829,56 @@ func (c Client) FindStashBoxPerformerByName(ctx context.Context, name string) (*
return ret, nil
}
func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.ScrapedStudio, error) {
var studio *graphql.FindStudio
_, err := uuid.FromString(query)
if err == nil {
// Confirmed the user passed in a Stash ID
studio, err = c.client.FindStudio(ctx, &query, nil)
} else {
// Otherwise assume they're searching on a name
studio, err = c.client.FindStudio(ctx, nil, &query)
}
if err != nil {
return nil, err
}
var ret *models.ScrapedStudio
if studio.FindStudio != nil {
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
ret = studioFragmentToScrapedStudio(*studio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ret, &c.box.Endpoint)
if err != nil {
return err
}
if studio.FindStudio.Parent != nil {
parentStudio, err := c.client.FindStudio(ctx, &studio.FindStudio.Parent.ID, nil)
if err != nil {
return err
}
if parentStudio.FindStudio != nil {
ret.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ret.Parent, &c.box.Endpoint)
if err != nil {
return err
}
}
}
return nil
}); err != nil {
return nil, err
}
}
return ret, nil
}
func (c Client) GetUser(ctx context.Context) (*graphql.Me, error) {
return c.client.Me(ctx)
}

View file

@ -438,21 +438,6 @@ func (r *stashIDRepository) get(ctx context.Context, id int) ([]models.StashID,
return []models.StashID(ret), err
}
func (r *stashIDRepository) replace(ctx context.Context, id int, newIDs []models.StashID) error {
if err := r.destroy(ctx, []int{id}); err != nil {
return err
}
query := fmt.Sprintf("INSERT INTO %s (%s, endpoint, stash_id) VALUES (?, ?, ?)", r.tableName, r.idColumn)
for _, stashID := range newIDs {
_, err := r.tx.Exec(ctx, query, id, stashID.Endpoint, stashID.StashID)
if err != nil {
return err
}
}
return nil
}
type filesRepository struct {
repository
}

View file

@ -631,7 +631,7 @@ func populateDB() error {
return fmt.Errorf("error creating performers: %s", err.Error())
}
if err := createStudios(ctx, db.Studio, studiosNameCase, studiosNameNoCase); err != nil {
if err := createStudios(ctx, studiosNameCase, studiosNameNoCase); err != nil {
return fmt.Errorf("error creating studios: %s", err.Error())
}
@ -659,7 +659,7 @@ func populateDB() error {
return fmt.Errorf("error linking movie studios: %s", err.Error())
}
if err := linkStudiosParent(ctx, db.Studio); err != nil {
if err := linkStudiosParent(ctx); err != nil {
return fmt.Errorf("error linking studios parent: %s", err.Error())
}
@ -1573,7 +1573,7 @@ func getStudioNullStringValue(index int, field string) string {
return ret.String
}
func createStudio(ctx context.Context, sqb models.StudioReaderWriter, name string, parentID *int) (*models.Studio, error) {
func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int) (*models.Studio, error) {
studio := models.Studio{
Name: name,
}
@ -1590,7 +1590,7 @@ func createStudio(ctx context.Context, sqb models.StudioReaderWriter, name strin
return &studio, nil
}
func createStudioFromModel(ctx context.Context, sqb models.StudioReaderWriter, studio *models.Studio) error {
func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio) error {
err := sqb.Create(ctx, studio)
if err != nil {
@ -1601,7 +1601,8 @@ func createStudioFromModel(ctx context.Context, sqb models.StudioReaderWriter, s
}
// createStudios creates n studios with plain Name and o studios with camel cased NaMe included
func createStudios(ctx context.Context, sqb models.StudioReaderWriter, n int, o int) error {
func createStudios(ctx context.Context, n int, o int) error {
sqb := db.Studio
const namePlain = "Name"
const nameNoCase = "NaMe"
@ -1618,22 +1619,18 @@ func createStudios(ctx context.Context, sqb models.StudioReaderWriter, n int, o
name = getStudioStringValue(index, name)
studio := models.Studio{
Name: name,
URL: getStudioNullStringValue(index, urlField),
URL: getStudioStringValue(index, urlField),
IgnoreAutoTag: getIgnoreAutoTag(i),
}
err := createStudioFromModel(ctx, sqb, &studio)
if err != nil {
return err
}
// add alias
// only add aliases for some scenes
if i == studioIdxWithMovie || i%5 == 0 {
alias := getStudioStringValue(i, "Alias")
if err := sqb.UpdateAliases(ctx, studio.ID, []string{alias}); err != nil {
return fmt.Errorf("error setting studio alias: %s", err.Error())
studio.Aliases = models.NewRelatedStrings([]string{alias})
}
err := createStudioFromModel(ctx, sqb, &studio)
if err != nil {
return err
}
studioIDs = append(studioIDs, studio.ID)
@ -1756,12 +1753,14 @@ func linkMovieStudios(ctx context.Context, mqb models.MovieWriter) error {
})
}
func linkStudiosParent(ctx context.Context, qb models.StudioWriter) error {
func linkStudiosParent(ctx context.Context) error {
qb := db.Studio
return doLinks(studioParentLinks, func(parentIndex, childIndex int) error {
studio := models.StudioPartial{
input := &models.StudioPartial{
ID: studioIDs[childIndex],
ParentID: models.NewOptionalInt(studioIDs[parentIndex]),
}
_, err := qb.UpdatePartial(ctx, studioIDs[childIndex], studio)
_, err := qb.UpdatePartial(ctx, *input)
return err
})

View file

@ -5,7 +5,6 @@ import (
"database/sql"
"errors"
"fmt"
"strings"
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
@ -15,6 +14,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
"github.com/stashapp/stash/pkg/studio"
)
const (
@ -22,7 +22,8 @@ const (
studioIDColumn = "studio_id"
studioAliasesTable = "studio_aliases"
studioAliasColumn = "alias"
studioParentIDColumn = "parent_id"
studioNameColumn = "name"
studioImageBlobColumn = "image_blob"
)
@ -39,7 +40,7 @@ type studioRow struct {
IgnoreAutoTag bool `db:"ignore_auto_tag"`
// not used in resolutions or updates
CoverBlob zero.String `db:"image_blob"`
ImageBlob zero.String `db:"image_blob"`
}
func (r *studioRow) fromStudio(o models.Studio) {
@ -116,6 +117,8 @@ func (qb *StudioStore) selectDataset() *goqu.SelectDataset {
}
func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) error {
var err error
var r studioRow
r.fromStudio(*newObject)
@ -124,34 +127,66 @@ func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) err
return err
}
updated, err := qb.find(ctx, id)
if newObject.Aliases.Loaded() {
if err := studio.EnsureAliasesUnique(ctx, id, newObject.Aliases.List(), qb); err != nil {
return err
}
if err := studiosAliasesTableMgr.insertJoins(ctx, id, newObject.Aliases.List()); err != nil {
return err
}
}
if newObject.StashIDs.Loaded() {
if err := studiosStashIDsTableMgr.insertJoins(ctx, id, newObject.StashIDs.List()); err != nil {
return err
}
}
updated, _ := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
}
*newObject = *updated
return nil
}
func (qb *StudioStore) UpdatePartial(ctx context.Context, id int, partial models.StudioPartial) (*models.Studio, error) {
func (qb *StudioStore) UpdatePartial(ctx context.Context, input models.StudioPartial) (*models.Studio, error) {
r := studioRowRecord{
updateRecord{
Record: make(exp.Record),
},
}
r.fromPartial(partial)
r.fromPartial(input)
if len(r.Record) > 0 {
if err := qb.tableMgr.updateByID(ctx, id, r.Record); err != nil {
if err := qb.tableMgr.updateByID(ctx, input.ID, r.Record); err != nil {
return nil, err
}
}
return qb.find(ctx, id)
if input.Aliases != nil {
if err := studio.EnsureAliasesUnique(ctx, input.ID, input.Aliases.Values, qb); err != nil {
return nil, err
}
if err := studiosAliasesTableMgr.modifyJoins(ctx, input.ID, input.Aliases.Values, input.Aliases.Mode); err != nil {
return nil, err
}
}
if input.StashIDs != nil {
if err := studiosStashIDsTableMgr.modifyJoins(ctx, input.ID, input.StashIDs.StashIDs, input.StashIDs.Mode); err != nil {
return nil, err
}
}
return qb.Find(ctx, input.ID)
}
// This is only used by the Import/Export functionality
func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio) error {
var r studioRow
r.fromStudio(*updatedObject)
@ -160,6 +195,18 @@ func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio)
return err
}
if updatedObject.Aliases.Loaded() {
if err := studiosAliasesTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.Aliases.List()); err != nil {
return err
}
}
if updatedObject.StashIDs.Loaded() {
if err := studiosStashIDsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.StashIDs.List()); err != nil {
return err
}
}
return nil
}
@ -257,10 +304,22 @@ func (qb *StudioStore) getMany(ctx context.Context, q *goqu.SelectDataset) ([]*m
return ret, nil
}
func (qb *StudioStore) findBySubquery(ctx context.Context, sq *goqu.SelectDataset) ([]*models.Studio, error) {
table := qb.table()
q := qb.selectDataset().Where(
table.Col(idColumn).Eq(
sq,
),
)
return qb.getMany(ctx, q)
}
func (qb *StudioStore) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) {
// SELECT studios.* FROM studios WHERE studios.parent_id = ?
table := qb.table()
sq := qb.selectDataset().Where(table.Col("parent_id").Eq(id))
sq := qb.selectDataset().Where(table.Col(studioParentIDColumn).Eq(id))
ret, err := qb.getMany(ctx, sq)
if err != nil {
@ -309,13 +368,44 @@ func (qb *StudioStore) FindByName(ctx context.Context, name string, nocase bool)
}
func (qb *StudioStore) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error) {
query := selectAll("studios") + `
LEFT JOIN studio_stash_ids on studio_stash_ids.studio_id = studios.id
WHERE studio_stash_ids.stash_id = ?
AND studio_stash_ids.endpoint = ?
`
args := []interface{}{stashID.StashID, stashID.Endpoint}
return qb.queryStudios(ctx, query, args)
sq := dialect.From(studiosStashIDsJoinTable).Select(studiosStashIDsJoinTable.Col(studioIDColumn)).Where(
studiosStashIDsJoinTable.Col("stash_id").Eq(stashID.StashID),
studiosStashIDsJoinTable.Col("endpoint").Eq(stashID.Endpoint),
)
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting studios for stash ID %s: %w", stashID.StashID, err)
}
return ret, nil
}
func (qb *StudioStore) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Studio, error) {
table := qb.table()
sq := dialect.From(table).LeftJoin(
studiosStashIDsJoinTable,
goqu.On(table.Col(idColumn).Eq(studiosStashIDsJoinTable.Col(studioIDColumn))),
).Select(table.Col(idColumn))
if hasStashID {
sq = sq.Where(
studiosStashIDsJoinTable.Col("stash_id").IsNotNull(),
studiosStashIDsJoinTable.Col("endpoint").Eq(stashboxEndpoint),
)
} else {
sq = sq.Where(
studiosStashIDsJoinTable.Col("stash_id").IsNull(),
)
}
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting studios for stash-box endpoint %s: %w", stashboxEndpoint, err)
}
return ret, nil
}
func (qb *StudioStore) Count(ctx context.Context) (int, error) {
@ -325,38 +415,37 @@ func (qb *StudioStore) Count(ctx context.Context) (int, error) {
func (qb *StudioStore) All(ctx context.Context) ([]*models.Studio, error) {
table := qb.table()
return qb.getMany(ctx, qb.selectDataset().Order(
table.Col("name").Asc(),
table.Col(idColumn).Asc(),
))
return qb.getMany(ctx, qb.selectDataset().Order(table.Col(studioNameColumn).Asc()))
}
func (qb *StudioStore) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) {
// TODO - Query needs to be changed to support queries of this type, and
// this method should be removed
query := selectAll(studioTable)
query += " LEFT JOIN studio_aliases ON studio_aliases.studio_id = studios.id"
table := qb.table()
sq := dialect.From(table).Select(table.Col(idColumn)).LeftJoin(
studiosAliasesJoinTable,
goqu.On(studiosAliasesJoinTable.Col(studioIDColumn).Eq(table.Col(idColumn))),
)
var whereClauses []string
var args []interface{}
var whereClauses []exp.Expression
for _, w := range words {
ww := w + "%"
whereClauses = append(whereClauses, "studios.name like ?")
args = append(args, ww)
// include aliases
whereClauses = append(whereClauses, "studio_aliases.alias like ?")
args = append(args, ww)
whereClauses = append(whereClauses, table.Col(studioNameColumn).Like(w+"%"))
whereClauses = append(whereClauses, studiosAliasesJoinTable.Col("alias").Like(w+"%"))
}
whereOr := "(" + strings.Join(whereClauses, " OR ") + ")"
where := strings.Join([]string{
"studios.ignore_auto_tag = 0",
whereOr,
}, " AND ")
return qb.queryStudios(ctx, query+" WHERE "+where, args)
sq = sq.Where(
goqu.Or(whereClauses...),
table.Col("ignore_auto_tag").Eq(0),
)
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting performers for autotag: %w", err)
}
return ret, nil
}
func (qb *StudioStore) validateFilter(filter *models.StudioFilterType) error {
@ -430,13 +519,13 @@ func (qb *StudioStore) makeFilter(ctx context.Context, studioFilter *models.Stud
query.handleCriterion(ctx, studioGalleryCountCriterionHandler(qb, studioFilter.GalleryCount))
query.handleCriterion(ctx, studioParentCriterionHandler(qb, studioFilter.Parents))
query.handleCriterion(ctx, studioAliasCriterionHandler(qb, studioFilter.Aliases))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.CreatedAt, "studios.created_at"))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.UpdatedAt, "studios.updated_at"))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.CreatedAt, studioTable+".created_at"))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.UpdatedAt, studioTable+".updated_at"))
return query
}
func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
func (qb *StudioStore) makeQuery(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) {
if studioFilter == nil {
studioFilter = &models.StudioFilterType{}
}
@ -450,20 +539,29 @@ func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFil
if q := findFilter.Q; q != nil && *q != "" {
query.join(studioAliasesTable, "", "studio_aliases.studio_id = studios.id")
searchColumns := []string{"studios.name", "studio_aliases.alias"}
query.parseQueryString(searchColumns, *q)
}
if err := qb.validateFilter(studioFilter); err != nil {
return nil, 0, err
return nil, err
}
filter := qb.makeFilter(ctx, studioFilter)
if err := query.addFilter(filter); err != nil {
return nil, 0, err
return nil, err
}
query.sortAndPagination = qb.getStudioSort(findFilter) + getPagination(findFilter)
return &query, nil
}
func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
query, err := qb.makeQuery(ctx, studioFilter, findFilter)
if err != nil {
return nil, 0, err
}
idsResult, countResult, err := query.executeFind(ctx)
if err != nil {
return nil, 0, err
@ -546,7 +644,7 @@ func studioAliasCriterionHandler(qb *StudioStore, alias *models.StringCriterionI
joinTable: studioAliasesTable,
stringColumn: studioAliasColumn,
addJoinTable: func(f *filterBuilder) {
qb.aliasRepository().join(f, "", "studios.id")
studiosAliasesTableMgr.join(f, "", "studios.id")
},
}
@ -581,26 +679,6 @@ func (qb *StudioStore) getStudioSort(findFilter *models.FindFilterType) string {
return sortQuery
}
func (qb *StudioStore) queryStudios(ctx context.Context, query string, args []interface{}) ([]*models.Studio, error) {
const single = false
var ret []*models.Studio
if err := qb.queryFunc(ctx, query, args, single, func(r *sqlx.Rows) error {
var f studioRow
if err := r.StructScan(&f); err != nil {
return err
}
s := f.resolve()
ret = append(ret, s)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
func (qb *StudioStore) GetImage(ctx context.Context, studioID int) ([]byte, error) {
return qb.blobJoinQueryBuilder.GetImage(ctx, studioID, studioImageBlobColumn)
}
@ -628,28 +706,9 @@ func (qb *StudioStore) stashIDRepository() *stashIDRepository {
}
func (qb *StudioStore) GetStashIDs(ctx context.Context, studioID int) ([]models.StashID, error) {
return qb.stashIDRepository().get(ctx, studioID)
}
func (qb *StudioStore) UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error {
return qb.stashIDRepository().replace(ctx, studioID, stashIDs)
}
func (qb *StudioStore) aliasRepository() *stringRepository {
return &stringRepository{
repository: repository{
tx: qb.tx,
tableName: studioAliasesTable,
idColumn: studioIDColumn,
},
stringColumn: studioAliasColumn,
}
return studiosStashIDsTableMgr.get(ctx, studioID)
}
func (qb *StudioStore) GetAliases(ctx context.Context, studioID int) ([]string, error) {
return qb.aliasRepository().get(ctx, studioID)
}
func (qb *StudioStore) UpdateAliases(ctx context.Context, studioID int, aliases []string) error {
return qb.aliasRepository().replace(ctx, studioID, aliases)
return studiosAliasesTableMgr.get(ctx, studioID)
}

View file

@ -219,18 +219,15 @@ func TestStudioQueryForAutoTag(t *testing.T) {
assert.Len(t, studios, 1)
assert.Equal(t, strings.ToLower(studioNames[studioIdxWithMovie]), strings.ToLower(studios[0].Name))
// find by alias
name = getStudioStringValue(studioIdxWithMovie, "Alias")
studios, err = tqb.QueryForAutoTag(ctx, []string{name})
if err != nil {
t.Errorf("Error finding studios: %s", err.Error())
}
if assert.Len(t, studios, 1) {
assert.Equal(t, studioIDs[studioIdxWithMovie], studios[0].ID)
}
return nil
})
}
@ -363,11 +360,12 @@ func TestStudioUpdateClearParent(t *testing.T) {
sqb := db.Studio
// clear the parent id from the child
updatePartial := models.StudioPartial{
input := models.StudioPartial{
ID: createdChild.ID,
ParentID: models.NewOptionalIntPtr(nil),
}
updatedStudio, err := sqb.UpdatePartial(ctx, createdChild.ID, updatePartial)
updatedStudio, err := sqb.UpdatePartial(ctx, input)
if err != nil {
return fmt.Errorf("Error updated studio: %s", err.Error())
@ -548,7 +546,7 @@ func verifyStudiosGalleryCount(t *testing.T, galleryCountCriterion models.IntCri
}
func TestStudioStashIDs(t *testing.T) {
if err := withTxn(func(ctx context.Context) error {
if err := withRollbackTxn(func(ctx context.Context) error {
qb := db.Studio
// create studio to test against
@ -558,13 +556,83 @@ func TestStudioStashIDs(t *testing.T) {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
testStashIDReaderWriter(ctx, t, qb, created.ID)
studio, err := qb.Find(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting studio: %s", err.Error())
}
if err := studio.LoadStashIDs(ctx, qb); err != nil {
return err
}
testStudioStashIDs(ctx, t, studio)
return nil
}); err != nil {
t.Error(err.Error())
}
}
func testStudioStashIDs(ctx context.Context, t *testing.T, s *models.Studio) {
qb := db.Studio
if err := s.LoadStashIDs(ctx, qb); err != nil {
t.Error(err.Error())
return
}
// ensure no stash IDs to begin with
assert.Len(t, s.StashIDs.List(), 0)
// add stash ids
const stashIDStr = "stashID"
const endpoint = "endpoint"
stashID := models.StashID{
StashID: stashIDStr,
Endpoint: endpoint,
}
// update stash ids and ensure was updated
input := models.StudioPartial{
ID: s.ID,
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{stashID},
Mode: models.RelationshipUpdateModeSet,
},
}
var err error
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadStashIDs(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Equal(t, []models.StashID{stashID}, s.StashIDs.List())
// remove stash ids and ensure was updated
input = models.StudioPartial{
ID: s.ID,
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{stashID},
Mode: models.RelationshipUpdateModeRemove,
},
}
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadStashIDs(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Len(t, s.StashIDs.List(), 0)
}
func TestStudioQueryURL(t *testing.T) {
const sceneIdx = 1
studioURL := getStudioStringValue(sceneIdx, urlField)
@ -684,7 +752,7 @@ func TestStudioQueryIsMissingRating(t *testing.T) {
assert.True(t, len(studios) > 0)
for _, studio := range studios {
assert.True(t, studio.Rating == nil)
assert.Nil(t, studio.Rating)
}
return nil
@ -778,36 +846,87 @@ func TestStudioQueryAlias(t *testing.T) {
verifyStudioQuery(t, studioFilter, verifyFn)
}
func TestStudioUpdateAlias(t *testing.T) {
if err := withTxn(func(ctx context.Context) error {
func TestStudioAlias(t *testing.T) {
if err := withRollbackTxn(func(ctx context.Context) error {
qb := db.Studio
// create studio to test against
const name = "TestStudioUpdateAlias"
created, err := createStudio(ctx, qb, name, nil)
const name = "TestStudioAlias"
created, err := createStudio(ctx, db.Studio, name, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
aliases := []string{"alias1", "alias2"}
err = qb.UpdateAliases(ctx, created.ID, aliases)
studio, err := qb.Find(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error updating studio aliases: %s", err.Error())
return fmt.Errorf("Error getting studio: %s", err.Error())
}
// ensure aliases set
storedAliases, err := qb.GetAliases(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting aliases: %s", err.Error())
if err := studio.LoadStashIDs(ctx, qb); err != nil {
return err
}
assert.Equal(t, aliases, storedAliases)
testStudioAlias(ctx, t, studio)
return nil
}); err != nil {
t.Error(err.Error())
}
}
func testStudioAlias(ctx context.Context, t *testing.T, s *models.Studio) {
qb := db.Studio
if err := s.LoadAliases(ctx, qb); err != nil {
t.Error(err.Error())
return
}
// ensure no alias to begin with
assert.Len(t, s.Aliases.List(), 0)
aliases := []string{"alias1", "alias2"}
// update alias and ensure was updated
input := models.StudioPartial{
ID: s.ID,
Aliases: &models.UpdateStrings{
Values: aliases,
Mode: models.RelationshipUpdateModeSet,
},
}
var err error
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadAliases(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Equal(t, aliases, s.Aliases.List())
// remove alias and ensure was updated
input = models.StudioPartial{
ID: s.ID,
Aliases: &models.UpdateStrings{
Values: aliases,
Mode: models.RelationshipUpdateModeRemove,
},
}
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadAliases(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Len(t, s.Aliases.List(), 0)
}
// TestStudioQueryFast does a quick test for major errors, no result verification
func TestStudioQueryFast(t *testing.T) {

View file

@ -29,6 +29,9 @@ var (
performersAliasesJoinTable = goqu.T(performersAliasesTable)
performersTagsJoinTable = goqu.T(performersTagsTable)
performersStashIDsJoinTable = goqu.T("performer_stash_ids")
studiosAliasesJoinTable = goqu.T(studioAliasesTable)
studiosStashIDsJoinTable = goqu.T("studio_stash_ids")
)
var (
@ -233,6 +236,21 @@ var (
table: goqu.T(studioTable),
idColumn: goqu.T(studioTable).Col(idColumn),
}
studiosAliasesTableMgr = &stringTable{
table: table{
table: studiosAliasesJoinTable,
idColumn: studiosAliasesJoinTable.Col(studioIDColumn),
},
stringColumn: studiosAliasesJoinTable.Col(studioAliasColumn),
}
studiosStashIDsTableMgr = &stashIDTable{
table: table{
table: studiosStashIDsJoinTable,
idColumn: studiosStashIDsJoinTable.Col(studioIDColumn),
},
}
)
var (

View file

@ -11,15 +11,15 @@ import (
"github.com/stashapp/stash/pkg/utils"
)
type FinderImageStashIDGetter interface {
type FinderImageAliasStashIDGetter interface {
Finder
GetAliases(ctx context.Context, studioID int) ([]string, error)
GetImage(ctx context.Context, studioID int) ([]byte, error)
models.AliasLoader
models.StashIDLoader
}
// ToJSON converts a Studio object into its JSON equivalent.
func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models.Studio) (*jsonschema.Studio, error) {
func ToJSON(ctx context.Context, reader FinderImageAliasStashIDGetter, studio *models.Studio) (*jsonschema.Studio, error) {
newStudioJSON := jsonschema.Studio{
Name: studio.Name,
URL: studio.URL,
@ -44,12 +44,15 @@ func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models
newStudioJSON.Rating = *studio.Rating
}
aliases, err := reader.GetAliases(ctx, studio.ID)
if err != nil {
return nil, fmt.Errorf("error getting studio aliases: %v", err)
if err := studio.LoadAliases(ctx, reader); err != nil {
return nil, fmt.Errorf("loading studio aliases: %w", err)
}
newStudioJSON.Aliases = studio.Aliases.List()
newStudioJSON.Aliases = aliases
if err := studio.LoadStashIDs(ctx, reader); err != nil {
return nil, fmt.Errorf("loading studio stash ids: %w", err)
}
newStudioJSON.StashIDs = studio.StashIDs.List()
image, err := reader.GetImage(ctx, studio.ID)
if err != nil {
@ -60,17 +63,5 @@ func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models
newStudioJSON.Image = utils.GetBase64StringFromData(image)
}
stashIDs, _ := reader.GetStashIDs(ctx, studio.ID)
var ret []models.StashID
for _, stashID := range stashIDs {
newJoin := models.StashID{
StashID: stashID.StashID,
Endpoint: stashID.Endpoint,
}
ret = append(ret, newJoin)
}
newStudioJSON.StashIDs = ret
return &newStudioJSON, nil
}

View file

@ -15,12 +15,10 @@ import (
)
const (
studioID = 1
noImageID = 2
errImageID = 3
missingParentStudioID = 4
errStudioID = 5
errAliasID = 6
parentStudioID = 10
missingStudioID = 11
@ -31,17 +29,19 @@ var (
studioName = "testStudio"
url = "url"
details = "details"
rating = 5
parentStudioName = "parentStudio"
autoTagIgnored = true
)
var studioID = 1
var rating = 5
var parentStudio models.Studio = models.Studio{
Name: parentStudioName,
}
var imageBytes = []byte("imageBytes")
var aliases = []string{"alias"}
var stashID = models.StashID{
StashID: "StashID",
Endpoint: "Endpoint",
@ -67,6 +67,8 @@ func createFullStudio(id int, parentID int) models.Studio {
UpdatedAt: updateTime,
Rating: &rating,
IgnoreAutoTag: autoTagIgnored,
Aliases: models.NewRelatedStrings(aliases),
StashIDs: models.NewRelatedStashIDs(stashIDs),
}
if parentID != 0 {
@ -81,6 +83,8 @@ func createEmptyStudio(id int) models.Studio {
ID: id,
CreatedAt: createTime,
UpdatedAt: updateTime,
Aliases: models.NewRelatedStrings([]string{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
}
}
@ -99,9 +103,7 @@ func createFullJSONStudio(parentStudio, image string, aliases []string) *jsonsch
Image: image,
Rating: rating,
Aliases: aliases,
StashIDs: []models.StashID{
stashID,
},
StashIDs: stashIDs,
IgnoreAutoTag: autoTagIgnored,
}
}
@ -114,6 +116,8 @@ func createEmptyJSONStudio() *jsonschema.Studio {
UpdatedAt: json.JSONTime{
Time: updateTime,
},
Aliases: []string{},
StashIDs: []models.StashID{},
}
}
@ -139,13 +143,13 @@ func initTestTable() {
},
{
createFullStudio(errImageID, parentStudioID),
createFullJSONStudio(parentStudioName, "", nil),
createFullJSONStudio(parentStudioName, "", []string{"alias"}),
// failure to get image is not an error
false,
},
{
createFullStudio(missingParentStudioID, missingStudioID),
createFullJSONStudio("", image, nil),
createFullJSONStudio("", image, []string{"alias"}),
false,
},
{
@ -153,11 +157,6 @@ func initTestTable() {
nil,
true,
},
{
createFullStudio(errAliasID, parentStudioID),
nil,
true,
},
}
}
@ -174,7 +173,6 @@ func TestToJSON(t *testing.T) {
mockStudioReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once()
mockStudioReader.On("GetImage", ctx, missingParentStudioID).Return(imageBytes, nil).Maybe()
mockStudioReader.On("GetImage", ctx, errStudioID).Return(imageBytes, nil).Maybe()
mockStudioReader.On("GetImage", ctx, errAliasID).Return(imageBytes, nil).Maybe()
parentStudioErr := errors.New("error getting parent studio")
@ -182,19 +180,6 @@ func TestToJSON(t *testing.T) {
mockStudioReader.On("Find", ctx, missingStudioID).Return(nil, nil)
mockStudioReader.On("Find", ctx, errParentStudioID).Return(nil, parentStudioErr)
aliasErr := errors.New("error getting aliases")
mockStudioReader.On("GetAliases", ctx, studioID).Return([]string{"alias"}, nil).Once()
mockStudioReader.On("GetAliases", ctx, noImageID).Return(nil, nil).Once()
mockStudioReader.On("GetAliases", ctx, errImageID).Return(nil, nil).Once()
mockStudioReader.On("GetAliases", ctx, missingParentStudioID).Return(nil, nil).Once()
mockStudioReader.On("GetAliases", ctx, errAliasID).Return(nil, aliasErr).Once()
mockStudioReader.On("GetStashIDs", ctx, studioID).Return(stashIDs, nil).Once()
mockStudioReader.On("GetStashIDs", ctx, noImageID).Return(nil, nil).Once()
mockStudioReader.On("GetStashIDs", ctx, missingParentStudioID).Return(stashIDs, nil).Once()
mockStudioReader.On("GetStashIDs", ctx, errImageID).Return(stashIDs, nil).Once()
for i, s := range scenarios {
studio := s.input
json, err := ToJSON(ctx, mockStudioReader, &studio)

View file

@ -14,8 +14,6 @@ type NameFinderCreatorUpdater interface {
NameFinderCreator
Update(ctx context.Context, updatedStudio *models.Studio) error
UpdateImage(ctx context.Context, studioID int, image []byte) error
UpdateAliases(ctx context.Context, studioID int, aliases []string) error
UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error
}
var ErrParentStudioNotExist = errors.New("parent studio does not exist")
@ -25,20 +23,13 @@ type Importer struct {
Input jsonschema.Studio
MissingRefBehaviour models.ImportMissingRefEnum
ID int
studio models.Studio
imageData []byte
}
func (i *Importer) PreImport(ctx context.Context) error {
i.studio = models.Studio{
Name: i.Input.Name,
URL: i.Input.URL,
Details: i.Input.Details,
IgnoreAutoTag: i.Input.IgnoreAutoTag,
CreatedAt: i.Input.CreatedAt.GetTime(),
UpdatedAt: i.Input.UpdatedAt.GetTime(),
Rating: &i.Input.Rating,
}
i.studio = studioJSONtoStudio(i.Input)
if err := i.populateParentStudio(ctx); err != nil {
return err
@ -87,7 +78,9 @@ func (i *Importer) populateParentStudio(ctx context.Context) error {
}
func (i *Importer) createParentStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name)
newStudio := &models.Studio{
Name: name,
}
err := i.ReaderWriter.Create(ctx, newStudio)
if err != nil {
@ -104,16 +97,6 @@ func (i *Importer) PostImport(ctx context.Context, id int) error {
}
}
if len(i.Input.StashIDs) > 0 {
if err := i.ReaderWriter.UpdateStashIDs(ctx, id, i.Input.StashIDs); err != nil {
return fmt.Errorf("error setting stash id: %v", err)
}
}
if err := i.ReaderWriter.UpdateAliases(ctx, id, i.Input.Aliases); err != nil {
return fmt.Errorf("error setting tag aliases: %v", err)
}
return nil
}
@ -156,3 +139,23 @@ func (i *Importer) Update(ctx context.Context, id int) error {
return nil
}
func studioJSONtoStudio(studioJSON jsonschema.Studio) models.Studio {
newStudio := models.Studio{
Name: studioJSON.Name,
URL: studioJSON.URL,
Aliases: models.NewRelatedStrings(studioJSON.Aliases),
Details: studioJSON.Details,
IgnoreAutoTag: studioJSON.IgnoreAutoTag,
CreatedAt: studioJSON.CreatedAt.GetTime(),
UpdatedAt: studioJSON.UpdatedAt.GetTime(),
StashIDs: models.NewRelatedStashIDs(studioJSON.StashIDs),
}
if studioJSON.Rating != 0 {
newStudio.Rating = &studioJSON.Rating
}
return newStudio
}

View file

@ -164,15 +164,9 @@ func TestImporterPostImport(t *testing.T) {
}
updateStudioImageErr := errors.New("UpdateImage error")
updateTagAliasErr := errors.New("UpdateAlias error")
readerWriter.On("UpdateImage", ctx, studioID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", ctx, errImageID, imageBytes).Return(updateStudioImageErr).Once()
readerWriter.On("UpdateImage", ctx, errAliasID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateAliases", ctx, studioID, i.Input.Aliases).Return(nil).Once()
readerWriter.On("UpdateAliases", ctx, errImageID, i.Input.Aliases).Return(nil).Maybe()
readerWriter.On("UpdateAliases", ctx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
err := i.PostImport(ctx, studioID)
assert.Nil(t, err)
@ -180,9 +174,6 @@ func TestImporterPostImport(t *testing.T) {
err = i.PostImport(ctx, errImageID)
assert.NotNil(t, err)
err = i.PostImport(ctx, errAliasID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
}

View file

@ -14,6 +14,12 @@ type Queryer interface {
Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error)
}
type FinderQueryer interface {
Finder
Queryer
models.AliasLoader
}
func ByName(ctx context.Context, qb Queryer, name string) (*models.Studio, error) {
f := &models.StudioFilterType{
Name: &models.StringCriterionInput{

View file

@ -2,11 +2,16 @@ package studio
import (
"context"
"errors"
"fmt"
"github.com/stashapp/stash/pkg/models"
)
var (
ErrStudioOwnAncestor = errors.New("studio cannot be an ancestor of itself")
)
type NameFinderCreator interface {
FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error)
Create(ctx context.Context, newStudio *models.Studio) error
@ -69,3 +74,60 @@ func EnsureAliasesUnique(ctx context.Context, id int, aliases []string, qb Query
return nil
}
// Checks to make sure that:
// 1. The studio exists locally
// 2. The studio is not its own ancestor
// 3. The studio's aliases are unique
func ValidateModify(ctx context.Context, s models.StudioPartial, qb FinderQueryer) error {
existing, err := qb.Find(ctx, s.ID)
if err != nil {
return err
}
if existing == nil {
return fmt.Errorf("studio with id %d not found", s.ID)
}
newParentID := s.ParentID.Ptr()
if newParentID != nil {
if err := validateParent(ctx, s.ID, *newParentID, qb); err != nil {
return err
}
}
if s.Aliases != nil {
if err := existing.LoadAliases(ctx, qb); err != nil {
return err
}
effectiveAliases := s.Aliases.EffectiveValues(existing.Aliases.List())
if err := EnsureAliasesUnique(ctx, s.ID, effectiveAliases, qb); err != nil {
return err
}
}
return nil
}
func validateParent(ctx context.Context, studioID int, newParentID int, qb FinderQueryer) error {
if newParentID == studioID {
return ErrStudioOwnAncestor
}
// ensure there is no cyclic dependency
parentStudio, err := qb.Find(ctx, newParentID)
if err != nil {
return fmt.Errorf("error finding parent studio: %v", err)
}
if parentStudio == nil {
return fmt.Errorf("studio with id %d not found", newParentID)
}
if parentStudio.ParentID != nil {
return validateParent(ctx, studioID, *parentStudio.ParentID, qb)
}
return nil
}

View file

@ -31,6 +31,8 @@ export const multiValueSceneFields: SceneField[] = [
export function sceneFieldMessageID(field: SceneField) {
if (field === "code") {
return "scene_code";
} else if (field === "studio") {
return "studio_and_parent";
}
return field;

View file

@ -58,7 +58,7 @@ export const StudioDetailsPanel: React.FC<IStudioDetailsPanel> = ({
return (
<>
<dt>
<FormattedMessage id="StashIDs" />
<FormattedMessage id="stash_ids" />
</dt>
<dd>
<ul className="pl-0">

View file

@ -19,6 +19,7 @@ import { DisplayMode } from "src/models/list-filter/types";
import { ExportDialog } from "../Shared/ExportDialog";
import { DeleteEntityDialog } from "../Shared/DeleteEntityDialog";
import { StudioCard } from "./StudioCard";
import { StudioTagger } from "../Tagger/studios/StudioTagger";
const StudioItemList = makeItemList({
filterMode: GQL.FilterMode.Studios,
@ -156,6 +157,9 @@ export const StudioList: React.FC<IStudioList> = ({
if (filter.displayMode === DisplayMode.Wall) {
return <h1>TODO</h1>;
}
if (filter.displayMode === DisplayMode.Tagger) {
return <StudioTagger studios={result.data.findStudios.studios} />;
}
}
return (

View file

@ -232,7 +232,7 @@ const PerformerModal: React.FC<IPerformerModalProps> = ({
{link && (
<h6 className="mt-2">
<a href={link} target="_blank" rel="noopener noreferrer">
Stash-Box Source
<FormattedMessage id="stashbox.source" />
<Icon icon={faExternalLinkAlt} className="ml-2" />
</a>
</h6>

View file

@ -25,6 +25,7 @@ export const DEFAULT_BLACKLIST = [
"\\]",
];
export const DEFAULT_EXCLUDED_PERFORMER_FIELDS = ["name"];
export const DEFAULT_EXCLUDED_STUDIO_FIELDS = ["name"];
export const initialConfig: ITaggerConfig = {
blacklist: DEFAULT_BLACKLIST,
@ -35,6 +36,8 @@ export const initialConfig: ITaggerConfig = {
tagOperation: "merge",
fingerprintQueue: {},
excludedPerformerFields: DEFAULT_EXCLUDED_PERFORMER_FIELDS,
excludedStudioFields: DEFAULT_EXCLUDED_STUDIO_FIELDS,
createParentStudios: true,
};
export type ParseMode = "auto" | "filename" | "dir" | "path" | "metadata";
@ -49,6 +52,8 @@ export interface ITaggerConfig {
selectedEndpoint?: string;
fingerprintQueue: Record<string, string[]>;
excludedPerformerFields?: string[];
excludedStudioFields?: string[];
createParentStudios: boolean;
}
export const PERFORMER_FIELDS = [
@ -74,3 +79,5 @@ export const PERFORMER_FIELDS = [
"death_date",
"weight",
];
export const STUDIO_FIELDS = ["name", "image", "url", "parent"];

View file

@ -55,6 +55,7 @@ export interface ITaggerContextState {
studio: GQL.ScrapedStudio,
toCreate: GQL.StudioCreateInput
) => Promise<string | undefined>;
updateStudio: (studio: GQL.StudioUpdateInput) => Promise<void>;
linkStudio: (studio: GQL.ScrapedStudio, studioID: string) => Promise<void>;
resolveScene: (
sceneID: string,
@ -91,6 +92,7 @@ export const TaggerStateContext = React.createContext<ITaggerContextState>({
createNewPerformer: dummyValFn,
linkPerformer: dummyFn,
createNewStudio: dummyValFn,
updateStudio: dummyFn,
linkStudio: dummyFn,
resolveScene: dummyFn,
submitFingerprints: dummyFn,
@ -701,6 +703,53 @@ export const TaggerContext: React.FC = ({ children }) => {
}
}
async function updateExistingStudio(input: GQL.StudioUpdateInput) {
try {
const result = await updateStudio({
variables: {
input: input,
},
});
const studioID = result.data?.studioUpdate?.id;
const stashID = input.stash_ids?.find((e) => {
return e.endpoint === currentSource?.stashboxEndpoint;
})?.stash_id;
if (stashID) {
const newSearchResults = mapResults((r) => {
if (!r.studio) {
return r;
}
return {
...r,
studio:
r.remote_site_id === stashID
? {
...r.studio,
stored_id: studioID,
}
: r.studio,
};
});
setSearchResults(newSearchResults);
}
Toast.success({
content: (
<span>
Created studio: <b>{input.name}</b>
</span>
),
});
} catch (e) {
Toast.error(e);
}
}
async function linkStudio(studio: GQL.ScrapedStudio, studioID: string) {
if (!studio.remote_site_id || !currentSource?.stashboxEndpoint) return;
@ -780,6 +829,7 @@ export const TaggerContext: React.FC = ({ children }) => {
createNewPerformer,
linkPerformer,
createNewStudio,
updateStudio: updateExistingStudio,
linkStudio,
resolveScene,
saveScene,

View file

@ -112,7 +112,7 @@ const PerformerBatchUpdateModal: React.FC<IPerformerBatchUpdateModal> = ({
type="radio"
name="performer-query"
label={<FormattedMessage id="performer_tagger.current_page" />}
defaultChecked
defaultChecked={!queryAll}
onChange={() => setQueryAll(false)}
/>
<Form.Check
@ -123,7 +123,7 @@ const PerformerBatchUpdateModal: React.FC<IPerformerBatchUpdateModal> = ({
id: "performer_tagger.query_all_performers_in_the_database",
})}
defaultChecked={false}
onChange={() => setQueryAll(true)}
onChange={() => setQueryAll(queryAll)}
/>
</Form.Group>
<Form.Group>
@ -139,7 +139,7 @@ const PerformerBatchUpdateModal: React.FC<IPerformerBatchUpdateModal> = ({
label={intl.formatMessage({
id: "performer_tagger.untagged_performers",
})}
defaultChecked
defaultChecked={!refresh}
onChange={() => setRefresh(false)}
/>
<Form.Text>
@ -153,7 +153,7 @@ const PerformerBatchUpdateModal: React.FC<IPerformerBatchUpdateModal> = ({
id: "performer_tagger.refresh_tagged_performers",
})}
defaultChecked={false}
onChange={() => setRefresh(true)}
onChange={() => setRefresh(refresh)}
/>
<Form.Text>
<FormattedMessage id="performer_tagger.refreshing_will_update_the_data" />
@ -656,9 +656,10 @@ export const PerformerTagger: React.FC<ITaggerProps> = ({ performers }) => {
if (names.length > 0) {
const ret = await mutateStashBoxBatchPerformerTag({
performer_names: names,
names: names,
endpoint: selectedEndpointIndex,
refresh: false,
createParent: false,
});
setBatchJobID(ret.data?.stashBoxBatchPerformerTag);
@ -669,10 +670,11 @@ export const PerformerTagger: React.FC<ITaggerProps> = ({ performers }) => {
async function batchUpdate(ids: string[] | undefined, refresh: boolean) {
if (config && selectedEndpoint) {
const ret = await mutateStashBoxBatchPerformerTag({
performer_ids: ids,
ids: ids,
endpoint: selectedEndpointIndex,
refresh,
exclude_fields: config.excludedPerformerFields ?? [],
createParent: false,
});
setBatchJobID(ret.data?.stashBoxBatchPerformerTag);

View file

@ -1,5 +1,10 @@
import * as GQL from "src/core/generated-graphql";
import sortBy from "lodash-es/sortBy";
import {
evictQueries,
getClient,
studioMutationImpactedQueries,
} from "src/core/StashService";
export const useUpdatePerformerStashID = () => {
const [updatePerformer] = GQL.usePerformerUpdateMutation({
@ -204,6 +209,54 @@ export const useUpdateStudioStashID = () => {
return handleUpdate;
};
export const useUpdateStudio = () => {
const [updateStudio] = GQL.useStudioUpdateMutation({
onError: (errors) => errors,
errorPolicy: "all",
});
const updateStudioHandler = (input: GQL.StudioUpdateInput) =>
updateStudio({
variables: {
input,
},
update: (store, updatedStudio) => {
if (!updatedStudio.data?.studioUpdate) return;
if (updatedStudio.data?.studioUpdate?.parent_studio) {
const ac = getClient();
evictQueries(ac.cache, studioMutationImpactedQueries);
} else {
updatedStudio.data.studioUpdate.stash_ids.forEach((id) => {
store.writeQuery<
GQL.FindStudiosQuery,
GQL.FindStudiosQueryVariables
>({
query: GQL.FindStudiosDocument,
variables: {
studio_filter: {
stash_id: {
value: id.stash_id,
modifier: GQL.CriterionModifier.Equals,
},
},
},
data: {
findStudios: {
count: 1,
studios: [updatedStudio.data!.studioUpdate!],
__typename: "FindStudiosResultType",
},
},
});
});
}
},
});
return updateStudioHandler;
};
export const useCreateStudio = () => {
const [createStudio] = GQL.useStudioCreateMutation({
onError: (errors) => errors,

View file

@ -204,6 +204,7 @@ const StashSearchResult: React.FC<IStashSearchResultProps> = ({
createNewPerformer,
linkPerformer,
createNewStudio,
updateStudio,
linkStudio,
resolveScene,
currentSource,
@ -404,11 +405,32 @@ const StashSearchResult: React.FC<IStashSearchResultProps> = ({
});
}
function showStudioModal(t: GQL.ScrapedStudio) {
createStudioModal(t, (toCreate) => {
async function studioModalCallback(
studio: GQL.ScrapedStudio,
toCreate?: GQL.StudioCreateInput,
parentInput?: GQL.StudioCreateInput
) {
if (toCreate) {
createNewStudio(t, toCreate);
if (parentInput && studio.parent) {
if (toCreate.parent_id) {
const parentUpdateData: GQL.StudioUpdateInput = {
...parentInput,
id: toCreate.parent_id,
};
await updateStudio(parentUpdateData);
} else {
const parentID = await createNewStudio(studio.parent, parentInput);
toCreate.parent_id = parentID;
}
}
createNewStudio(studio, toCreate);
}
}
function showStudioModal(t: GQL.ScrapedStudio) {
createStudioModal(t, (toCreate, parentInput) => {
studioModalCallback(t, toCreate, parentInput);
});
}

View file

@ -1,66 +1,54 @@
import React, { useContext } from "react";
import React, { useState } from "react";
import { FormattedMessage, useIntl } from "react-intl";
import cx from "classnames";
import { IconDefinition } from "@fortawesome/fontawesome-svg-core";
import * as GQL from "src/core/generated-graphql";
import { useFindStudio } from "src/core/StashService";
import { Icon } from "src/components/Shared/Icon";
import { ModalComponent } from "src/components/Shared/Modal";
import {
faCheck,
faExternalLinkAlt,
faTimes,
} from "@fortawesome/free-solid-svg-icons";
import { Button, Form } from "react-bootstrap";
import { TruncatedText } from "src/components/Shared/TruncatedText";
import { TaggerStateContext } from "../context";
import { faExternalLinkAlt } from "@fortawesome/free-solid-svg-icons";
import { excludeFields } from "src/utils/data";
interface IStudioModalProps {
interface IStudioDetailsProps {
studio: GQL.ScrapedSceneStudioDataFragment;
modalVisible: boolean;
closeModal: () => void;
handleStudioCreate: (input: GQL.StudioCreateInput) => void;
header: string;
icon: IconDefinition;
link?: string;
excluded: Record<string, boolean>;
toggleField: (field: string) => void;
isNew?: boolean;
}
const StudioModal: React.FC<IStudioModalProps> = ({
modalVisible,
const StudioDetails: React.FC<IStudioDetailsProps> = ({
studio,
handleStudioCreate,
closeModal,
header,
icon,
link,
excluded,
toggleField,
isNew = false,
}) => {
const { currentSource } = useContext(TaggerStateContext);
const intl = useIntl();
function onSave() {
if (!studio.name) {
throw new Error("studio name must set");
}
const studioData: GQL.StudioCreateInput = {
name: studio.name ?? "",
url: studio.url,
};
// stashid handling code
const remoteSiteID = studio.remote_site_id;
if (remoteSiteID && currentSource?.stashboxEndpoint) {
studioData.stash_ids = [
{
endpoint: currentSource.stashboxEndpoint,
stash_id: remoteSiteID,
},
];
}
handleStudioCreate(studioData);
}
const renderField = (
id: string,
text: string | null | undefined,
isSelectable: boolean = true,
truncate: boolean = true
) =>
text && (
<div className="row no-gutters">
<div className="col-5 studio-create-modal-field" key={id}>
{isSelectable && (
<Button
onClick={() => toggleField(id)}
variant="secondary"
className={excluded[id] ? "text-muted" : "text-success"}
>
<Icon icon={excluded[id] ? faTimes : faCheck} />
</Button>
)}
<strong>
<FormattedMessage id={id} />:
</strong>
@ -73,8 +61,226 @@ const StudioModal: React.FC<IStudioModalProps> = ({
</div>
);
const base = currentSource?.stashboxEndpoint?.match(/https?:\/\/.*?\//)?.[0];
return (
<div>
<div className="row">
<div className="col-12 image-selection">
<div className="studio-image">
<Button
onClick={() => toggleField("image")}
variant="secondary"
className={cx(
"studio-image-exclude",
excluded.image ? "text-muted" : "text-success"
)}
>
<Icon icon={excluded.image ? faTimes : faCheck} />
</Button>
<img src={studio.image ?? ""} alt="" />
</div>
</div>
</div>
<div className="row">
<div className="col-12">
{renderField("name", studio.name, !isNew)}
{renderField("url", studio.url)}
{renderField("parent_studio", studio.parent?.name, false)}
{link && (
<h6 className="mt-2">
<a href={link} target="_blank" rel="noopener noreferrer">
<FormattedMessage id="stashbox.source" />
<Icon icon={faExternalLinkAlt} className="ml-2" />
</a>
</h6>
)}
</div>
</div>
</div>
);
};
interface IStudioModalProps {
studio: GQL.ScrapedSceneStudioDataFragment;
modalVisible: boolean;
closeModal: () => void;
handleStudioCreate: (
input: GQL.StudioCreateInput,
parent?: GQL.StudioCreateInput
) => void;
excludedStudioFields?: string[];
header: string;
icon: IconDefinition;
endpoint?: string;
}
const StudioModal: React.FC<IStudioModalProps> = ({
modalVisible,
studio,
handleStudioCreate,
closeModal,
excludedStudioFields = [],
header,
icon,
endpoint,
}) => {
const intl = useIntl();
const [excluded, setExcluded] = useState<Record<string, boolean>>(
excludedStudioFields.reduce(
(dict, field) => ({ ...dict, [field]: true }),
{}
)
);
const toggleField = (name: string) =>
setExcluded({
...excluded,
[name]: !excluded[name],
});
const [parentExcluded, setParentExcluded] = useState<Record<string, boolean>>(
excludedStudioFields.reduce(
(dict, field) => ({ ...dict, [field]: true }),
{}
)
);
const toggleParentField = (name: string) =>
setParentExcluded({
...parentExcluded,
[name]: !parentExcluded[name],
});
const [createParentStudio, setCreateParentStudio] = useState<boolean>(
!!studio.parent
);
let sendParentStudio = true;
// The parent studio exists, need to check if it has a Stash ID.
const queryResult = useFindStudio(studio.parent?.stored_id ?? "");
if (
queryResult.data?.findStudio?.stash_ids?.length &&
queryResult.data?.findStudio?.stash_ids?.length > 0
) {
// It already has a Stash ID, so we can skip worrying about it
sendParentStudio = false;
}
const parentStudioCreateText = () => {
if (studio.parent && studio.parent.stored_id) {
return "actions.assign_stashid_to_parent_studio";
}
return "actions.create_parent_studio";
};
function onSave() {
if (!studio.name) {
throw new Error("studio name must set");
}
const studioData: GQL.StudioCreateInput & {
[index: string]: unknown;
} = {
name: studio.name,
url: studio.url,
image: studio.image,
parent_id: studio.parent?.stored_id,
};
// stashid handling code
const remoteSiteID = studio.remote_site_id;
if (remoteSiteID && endpoint) {
studioData.stash_ids = [
{
endpoint,
stash_id: remoteSiteID,
},
];
}
// handle exclusions
excludeFields(studioData, excluded);
let parentData:
| (GQL.StudioCreateInput & {
[index: string]: unknown;
})
| undefined = undefined;
if (createParentStudio && sendParentStudio) {
if (!studio.parent?.name) {
throw new Error("parent studio name must set");
}
parentData = {
name: studio.parent?.name,
url: studio.parent?.url,
image: studio.parent?.image,
};
// stashid handling code
const parentRemoteSiteID = studio.parent?.remote_site_id;
if (parentRemoteSiteID && endpoint) {
parentData.stash_ids = [
{
endpoint,
stash_id: parentRemoteSiteID,
},
];
}
// handle exclusions
// Can't exclude parent studio name when creating a new one
parentExcluded.name = false;
excludeFields(parentData, parentExcluded);
}
handleStudioCreate(studioData, parentData);
}
const base = endpoint?.match(/https?:\/\/.*?\//)?.[0];
const link = base ? `${base}studios/${studio.remote_site_id}` : undefined;
const parentLink = base
? `${base}studios/${studio.parent?.remote_site_id}`
: undefined;
function maybeRenderParentStudio() {
// There is no parent studio or it already has a Stash ID
if (!studio.parent || !sendParentStudio) {
return;
}
return (
<div>
<div className="mb-4 mt-4">
<Form.Check
id="create-parent"
checked={createParentStudio}
label={intl.formatMessage({
id: parentStudioCreateText(),
})}
onChange={() => setCreateParentStudio(!createParentStudio)}
/>
</div>
{maybeRenderParentStudioDetails()}
</div>
);
}
function maybeRenderParentStudioDetails() {
if (!createParentStudio || !studio.parent) {
return;
}
return (
<StudioDetails
studio={studio.parent}
excluded={parentExcluded}
toggleField={(field) => toggleParentField(field)}
link={parentLink}
isNew
/>
);
}
return (
<ModalComponent
@ -83,33 +289,20 @@ const StudioModal: React.FC<IStudioModalProps> = ({
text: intl.formatMessage({ id: "actions.save" }),
onClick: onSave,
}}
onHide={() => closeModal()}
cancel={{ onClick: () => closeModal(), variant: "secondary" }}
onHide={() => closeModal()}
dialogClassName="studio-create-modal"
icon={icon}
header={header}
>
<div className="row">
<div className="col-12">
{renderField("name", studio.name)}
{renderField("url", studio.url)}
{link && (
<h6 className="mt-2">
<a href={link} target="_blank" rel="noopener noreferrer">
Stash-Box Source
<Icon icon={faExternalLinkAlt} className="ml-2" />
</a>
</h6>
)}
</div>
</div>
<StudioDetails
studio={studio}
excluded={excluded}
toggleField={(field) => toggleField(field)}
link={link}
/>
{/* TODO - add image */}
{/* <div className="row">
<strong className="col-2">Logo:</strong>
<span className="col-10">
<img src={studio?.image ?? ""} alt="" />
</span>
</div> */}
{maybeRenderParentStudio()}
</ModalComponent>
);
};

View file

@ -8,16 +8,19 @@ import { useIntl } from "react-intl";
import { faTags } from "@fortawesome/free-solid-svg-icons";
type PerformerModalCallback = (toCreate?: GQL.PerformerCreateInput) => void;
type StudioModalCallback = (toCreate?: GQL.StudioCreateInput) => void;
type StudioModalCallback = (
toCreate?: GQL.StudioCreateInput,
parentInput?: GQL.StudioCreateInput
) => void;
export interface ISceneTaggerModalsContextState {
createPerformerModal: (
performer: GQL.ScrapedPerformerDataFragment,
callback: (toCreate?: GQL.PerformerCreateInput) => void
callback: PerformerModalCallback
) => void;
createStudioModal: (
studio: GQL.ScrapedSceneStudioDataFragment,
callback: (toCreate?: GQL.StudioCreateInput) => void
callback: StudioModalCallback
) => void;
}
@ -73,9 +76,12 @@ export const SceneTaggerModals: React.FC = ({ children }) => {
setPerformerCallback(() => callback);
}
function handleStudioSave(toCreate: GQL.StudioCreateInput) {
function handleStudioSave(
toCreate: GQL.StudioCreateInput,
parentInput?: GQL.StudioCreateInput
) {
if (studioCallback) {
studioCallback(toCreate);
studioCallback(toCreate, parentInput);
}
setStudioToCreate(undefined);
@ -132,6 +138,7 @@ export const SceneTaggerModals: React.FC = ({ children }) => {
{ id: "actions.create_entity" },
{ entityType: intl.formatMessage({ id: "studio" }) }
)}
endpoint={endpoint}
/>
)}
{children}

View file

@ -0,0 +1,132 @@
import React, { Dispatch, useState } from "react";
import { Badge, Button, Card, Collapse, Form } from "react-bootstrap";
import { FormattedMessage } from "react-intl";
import { ConfigurationContext } from "src/hooks/Config";
import TextUtils from "src/utils/text";
import { ITaggerConfig, STUDIO_FIELDS } from "../constants";
import StudioFieldSelector from "./StudioFieldSelector";
interface IConfigProps {
show: boolean;
config: ITaggerConfig;
setConfig: Dispatch<ITaggerConfig>;
}
const Config: React.FC<IConfigProps> = ({ show, config, setConfig }) => {
const { configuration: stashConfig } = React.useContext(ConfigurationContext);
const [showExclusionModal, setShowExclusionModal] = useState(false);
const excludedFields = config.excludedStudioFields ?? [];
const handleInstanceSelect = (e: React.ChangeEvent<HTMLSelectElement>) => {
const selectedEndpoint = e.currentTarget.value;
setConfig({
...config,
selectedEndpoint,
});
};
const stashBoxes = stashConfig?.general.stashBoxes ?? [];
const handleFieldSelect = (fields: string[]) => {
setConfig({ ...config, excludedStudioFields: fields });
setShowExclusionModal(false);
};
return (
<>
<Collapse in={show}>
<Card>
<div className="row">
<h4 className="col-12">
<FormattedMessage id="configuration" />
</h4>
<hr className="w-100" />
<div className="col-md-6">
<Form.Group
controlId="create-parent"
className="align-items-center"
>
<Form.Check
label={
<FormattedMessage id="studio_tagger.config.create_parent_label" />
}
checked={config.createParentStudios}
onChange={(e: React.ChangeEvent<HTMLInputElement>) =>
setConfig({
...config,
createParentStudios: e.currentTarget.checked,
})
}
/>
<Form.Text>
<FormattedMessage id="studio_tagger.config.create_parent_desc" />
</Form.Text>
</Form.Group>
<Form.Group controlId="excluded-studio-fields">
<h6>
<FormattedMessage id="studio_tagger.config.excluded_fields" />
</h6>
<span>
{excludedFields.length > 0 ? (
excludedFields.map((f) => (
<Badge variant="secondary" className="tag-item" key={f}>
{TextUtils.capitalize(f)}
</Badge>
))
) : (
<FormattedMessage id="studio_tagger.config.no_fields_are_excluded" />
)}
</span>
<Form.Text>
<FormattedMessage id="studio_tagger.config.these_fields_will_not_be_changed_when_updating_studios" />
</Form.Text>
<Button
onClick={() => setShowExclusionModal(true)}
className="mt-2"
>
<FormattedMessage id="studio_tagger.config.edit_excluded_fields" />
</Button>
</Form.Group>
<Form.Group
controlId="stash-box-endpoint"
className="align-items-center row no-gutters mt-4"
>
<Form.Label className="mr-4">
<FormattedMessage id="studio_tagger.config.active_stash-box_instance" />
</Form.Label>
<Form.Control
as="select"
value={config.selectedEndpoint}
className="col-md-4 col-6 input-control"
disabled={!stashBoxes.length}
onChange={handleInstanceSelect}
>
{!stashBoxes.length && (
<option>
<FormattedMessage id="studio_tagger.config.no_instances_found" />
</option>
)}
{stashConfig?.general.stashBoxes.map((i) => (
<option value={i.endpoint} key={i.endpoint}>
{i.endpoint}
</option>
))}
</Form.Control>
</Form.Group>
</div>
</div>
</Card>
</Collapse>
<StudioFieldSelector
fields={STUDIO_FIELDS}
show={showExclusionModal}
onSelect={handleFieldSelect}
excludedFields={excludedFields}
/>
</>
);
};
export default Config;

View file

@ -0,0 +1,144 @@
import React, { useState } from "react";
import { Button } from "react-bootstrap";
import * as GQL from "src/core/generated-graphql";
import { useUpdateStudio } from "../queries";
import StudioModal from "../scenes/StudioModal";
import { faTags } from "@fortawesome/free-solid-svg-icons";
import { useStudioCreate } from "src/core/StashService";
import { useIntl } from "react-intl";
interface IStashSearchResultProps {
studio: GQL.SlimStudioDataFragment;
stashboxStudios: GQL.ScrapedStudioDataFragment[];
endpoint: string;
onStudioTagged: (
studio: Pick<GQL.SlimStudioDataFragment, "id"> &
Partial<Omit<GQL.SlimStudioDataFragment, "id">>
) => void;
excludedStudioFields: string[];
}
const StashSearchResult: React.FC<IStashSearchResultProps> = ({
studio,
stashboxStudios,
onStudioTagged,
excludedStudioFields,
endpoint,
}) => {
const intl = useIntl();
const [modalStudio, setModalStudio] = useState<
GQL.ScrapedStudioDataFragment | undefined
>();
const [saveState, setSaveState] = useState<string>("");
const [error, setError] = useState<{ message?: string; details?: string }>(
{}
);
const [createStudio] = useStudioCreate();
const updateStudio = useUpdateStudio();
function handleSaveError(name: string, message: string) {
setError({
message: intl.formatMessage(
{ id: "studio_tagger.failed_to_save_studio" },
{ studio: name }
),
details:
message === "UNIQUE constraint failed: studios.checksum"
? "Name already exists"
: message,
});
}
const handleSave = async (
input: GQL.StudioCreateInput,
parentInput?: GQL.StudioCreateInput
) => {
setError({});
setModalStudio(undefined);
if (parentInput) {
setSaveState("Saving parent studio");
try {
// if parent id is set, then update the existing studio
if (input.parent_id) {
const parentUpdateData: GQL.StudioUpdateInput = {
...parentInput,
id: input.parent_id,
};
await updateStudio(parentUpdateData);
} else {
const parentRes = await createStudio({
variables: { input: parentInput },
});
input.parent_id = parentRes.data?.studioCreate?.id;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
handleSaveError(parentInput.name, e.message ?? "");
}
}
setSaveState("Saving studio");
const updateData: GQL.StudioUpdateInput = {
...input,
id: studio.id,
};
const res = await updateStudio(updateData);
if (!res?.data?.studioUpdate)
handleSaveError(studio.name, res?.errors?.[0]?.message ?? "");
else onStudioTagged(studio);
setSaveState("");
};
const studios = stashboxStudios.map((p) => (
<Button
className="StudioTagger-studio-search-item minimal col-6"
variant="link"
key={p.remote_site_id}
onClick={() => setModalStudio(p)}
>
<img src={(p.image ?? [])[0]} alt="" className="StudioTagger-thumb" />
<span>{p.name}</span>
</Button>
));
return (
<>
{modalStudio && (
<StudioModal
closeModal={() => setModalStudio(undefined)}
modalVisible={modalStudio !== undefined}
studio={modalStudio}
handleStudioCreate={handleSave}
icon={faTags}
header="Update Studio"
excludedStudioFields={excludedStudioFields}
endpoint={endpoint}
/>
)}
<div className="StudioTagger-studio-search">{studios}</div>
<div className="row no-gutters mt-2 align-items-center justify-content-end">
{error.message && (
<div className="text-right text-danger mt-1">
<strong>
<span className="mr-2">Error:</span>
{error.message}
</strong>
<div>{error.details}</div>
</div>
)}
{saveState && (
<strong className="col-4 mt-1 mr-2 text-right">{saveState}</strong>
)}
</div>
</>
);
};
export default StashSearchResult;

View file

@ -0,0 +1,67 @@
import { faCheck, faList, faTimes } from "@fortawesome/free-solid-svg-icons";
import React, { useState } from "react";
import { Button, Row, Col } from "react-bootstrap";
import { useIntl } from "react-intl";
import { ModalComponent } from "../../Shared/Modal";
import { Icon } from "../../Shared/Icon";
import TextUtils from "src/utils/text";
interface IProps {
fields: string[];
show: boolean;
excludedFields: string[];
onSelect: (fields: string[]) => void;
}
const StudioFieldSelect: React.FC<IProps> = ({
fields,
show,
excludedFields,
onSelect,
}) => {
const intl = useIntl();
const [excluded, setExcluded] = useState<Record<string, boolean>>(
excludedFields.reduce((dict, field) => ({ ...dict, [field]: true }), {})
);
const toggleField = (name: string) =>
setExcluded({
...excluded,
[name]: !excluded[name],
});
const renderField = (name: string) => (
<Col xs={6} className="mb-1" key={name}>
<Button
onClick={() => toggleField(name)}
variant="secondary"
className={excluded[name] ? "text-muted" : "text-success"}
>
<Icon icon={excluded[name] ? faTimes : faCheck} />
</Button>
<span className="ml-3">{TextUtils.capitalize(name)}</span>
</Col>
);
return (
<ModalComponent
show={show}
icon={faList}
dialogClassName="FieldSelect"
accept={{
text: intl.formatMessage({ id: "actions.save" }),
onClick: () =>
onSelect(Object.keys(excluded).filter((f) => excluded[f])),
}}
>
<h4>Select tagged fields</h4>
<div className="mb-2">
These fields will be tagged by default. Click the button to toggle.
</div>
<Row>{fields.map((f) => renderField(f))}</Row>
</ModalComponent>
);
};
export default StudioFieldSelect;

View file

@ -0,0 +1,870 @@
import React, { useEffect, useMemo, useRef, useState } from "react";
import { Button, Card, Form, InputGroup, ProgressBar } from "react-bootstrap";
import { FormattedMessage, useIntl } from "react-intl";
import { Link } from "react-router-dom";
import { HashLink } from "react-router-hash-link";
import { useLocalForage } from "src/hooks/LocalForage";
import * as GQL from "src/core/generated-graphql";
import { LoadingIndicator } from "src/components/Shared/LoadingIndicator";
import { ModalComponent } from "src/components/Shared/Modal";
import {
stashBoxStudioQuery,
useJobsSubscribe,
mutateStashBoxBatchStudioTag,
getClient,
studioMutationImpactedQueries,
useStudioCreate,
evictQueries,
} from "src/core/StashService";
import { Manual } from "src/components/Help/Manual";
import { ConfigurationContext } from "src/hooks/Config";
import StashSearchResult from "./StashSearchResult";
import StudioConfig from "./Config";
import { LOCAL_FORAGE_KEY, ITaggerConfig, initialConfig } from "../constants";
import StudioModal from "../scenes/StudioModal";
import { useUpdateStudio } from "../queries";
import { faStar, faTags } from "@fortawesome/free-solid-svg-icons";
type JobFragment = Pick<
GQL.Job,
"id" | "status" | "subTasks" | "description" | "progress"
>;
const CLASSNAME = "StudioTagger";
interface IStudioBatchUpdateModal {
studios: GQL.StudioDataFragment[];
isIdle: boolean;
selectedEndpoint: { endpoint: string; index: number };
onBatchUpdate: (queryAll: boolean, refresh: boolean) => void;
batchAddParents: boolean;
setBatchAddParents: (addParents: boolean) => void;
close: () => void;
}
const StudioBatchUpdateModal: React.FC<IStudioBatchUpdateModal> = ({
studios,
isIdle,
selectedEndpoint,
onBatchUpdate,
batchAddParents,
setBatchAddParents,
close,
}) => {
const intl = useIntl();
const [queryAll, setQueryAll] = useState(false);
const [refresh, setRefresh] = useState(false);
const { data: allStudios } = GQL.useFindStudiosQuery({
variables: {
studio_filter: {
stash_id_endpoint: {
endpoint: selectedEndpoint.endpoint,
modifier: refresh
? GQL.CriterionModifier.NotNull
: GQL.CriterionModifier.IsNull,
},
},
filter: {
per_page: 0,
},
},
});
const studioCount = useMemo(() => {
// get all stash ids for the selected endpoint
const filteredStashIDs = studios.map((p) =>
p.stash_ids.filter((s) => s.endpoint === selectedEndpoint.endpoint)
);
return queryAll
? allStudios?.findStudios.count
: filteredStashIDs.filter((s) =>
// if refresh, then we filter out the studios without a stash id
// otherwise, we want untagged studios, filtering out those with a stash id
refresh ? s.length > 0 : s.length === 0
).length;
}, [queryAll, refresh, studios, allStudios, selectedEndpoint.endpoint]);
return (
<ModalComponent
show
icon={faTags}
header={intl.formatMessage({
id: "studio_tagger.update_studios",
})}
accept={{
text: intl.formatMessage({
id: "studio_tagger.update_studios",
}),
onClick: () => onBatchUpdate(queryAll, refresh),
}}
cancel={{
text: intl.formatMessage({ id: "actions.cancel" }),
variant: "danger",
onClick: () => close(),
}}
disabled={!isIdle}
>
<Form.Group>
<Form.Label>
<h6>
<FormattedMessage id="studio_tagger.studio_selection" />
</h6>
</Form.Label>
<Form.Check
id="query-page"
type="radio"
name="studio-query"
label={<FormattedMessage id="studio_tagger.current_page" />}
defaultChecked={!queryAll}
onChange={() => setQueryAll(false)}
/>
<Form.Check
id="query-all"
type="radio"
name="studio-query"
label={intl.formatMessage({
id: "studio_tagger.query_all_studios_in_the_database",
})}
defaultChecked={queryAll}
onChange={() => setQueryAll(true)}
/>
</Form.Group>
<Form.Group>
<Form.Label>
<h6>
<FormattedMessage id="studio_tagger.tag_status" />
</h6>
</Form.Label>
<Form.Check
id="untagged-studios"
type="radio"
name="studio-refresh"
label={intl.formatMessage({
id: "studio_tagger.untagged_studios",
})}
defaultChecked={!refresh}
onChange={() => setRefresh(false)}
/>
<Form.Text>
<FormattedMessage id="studio_tagger.updating_untagged_studios_description" />
</Form.Text>
<Form.Check
id="tagged-studios"
type="radio"
name="studio-refresh"
label={intl.formatMessage({
id: "studio_tagger.refresh_tagged_studios",
})}
defaultChecked={refresh}
onChange={() => setRefresh(true)}
/>
<Form.Text>
<FormattedMessage id="studio_tagger.refreshing_will_update_the_data" />
</Form.Text>
<div className="mt-4">
<Form.Check
id="add-parent"
checked={batchAddParents}
label={intl.formatMessage({
id: "studio_tagger.create_or_tag_parent_studios",
})}
onChange={() => setBatchAddParents(!batchAddParents)}
/>
</div>
</Form.Group>
<b>
<FormattedMessage
id="studio_tagger.number_of_studios_will_be_processed"
values={{
studio_count: studioCount,
}}
/>
</b>
</ModalComponent>
);
};
interface IStudioBatchAddModal {
isIdle: boolean;
onBatchAdd: (input: string) => void;
batchAddParents: boolean;
setBatchAddParents: (addParents: boolean) => void;
close: () => void;
}
const StudioBatchAddModal: React.FC<IStudioBatchAddModal> = ({
isIdle,
onBatchAdd,
batchAddParents,
setBatchAddParents,
close,
}) => {
const intl = useIntl();
const studioInput = useRef<HTMLTextAreaElement | null>(null);
return (
<ModalComponent
show
icon={faStar}
header={intl.formatMessage({
id: "studio_tagger.add_new_studios",
})}
accept={{
text: intl.formatMessage({
id: "studio_tagger.add_new_studios",
}),
onClick: () => {
if (studioInput.current) {
onBatchAdd(studioInput.current.value);
} else {
close();
}
},
}}
cancel={{
text: intl.formatMessage({ id: "actions.cancel" }),
variant: "danger",
onClick: () => close(),
}}
disabled={!isIdle}
>
<Form.Control
className="text-input"
as="textarea"
ref={studioInput}
placeholder={intl.formatMessage({
id: "studio_tagger.studio_names_separated_by_comma",
})}
rows={6}
/>
<Form.Text>
<FormattedMessage id="studio_tagger.any_names_entered_will_be_queried" />
</Form.Text>
<div className="mt-2">
<Form.Check
id="add-parent"
checked={batchAddParents}
label={intl.formatMessage({
id: "studio_tagger.create_or_tag_parent_studios",
})}
onChange={() => setBatchAddParents(!batchAddParents)}
/>
</div>
</ModalComponent>
);
};
interface IStudioTaggerListProps {
studios: GQL.StudioDataFragment[];
selectedEndpoint: { endpoint: string; index: number };
isIdle: boolean;
config: ITaggerConfig;
stashBoxes?: GQL.StashBox[];
onBatchAdd: (studioInput: string, createParent: boolean) => void;
onBatchUpdate: (
ids: string[] | undefined,
refresh: boolean,
createParent: boolean
) => void;
}
const StudioTaggerList: React.FC<IStudioTaggerListProps> = ({
studios,
selectedEndpoint,
isIdle,
config,
stashBoxes,
onBatchAdd,
onBatchUpdate,
}) => {
const intl = useIntl();
const [loading, setLoading] = useState(false);
const [searchResults, setSearchResults] = useState<
Record<string, GQL.ScrapedStudioDataFragment[]>
>({});
const [searchErrors, setSearchErrors] = useState<
Record<string, string | undefined>
>({});
const [taggedStudios, setTaggedStudios] = useState<
Record<string, Partial<GQL.SlimStudioDataFragment>>
>({});
const [queries, setQueries] = useState<Record<string, string>>({});
const [showBatchAdd, setShowBatchAdd] = useState(false);
const [showBatchUpdate, setShowBatchUpdate] = useState(false);
const [batchAddParents, setBatchAddParents] = useState(
config.createParentStudios || false
);
const [error, setError] = useState<
Record<string, { message?: string; details?: string } | undefined>
>({});
const [loadingUpdate, setLoadingUpdate] = useState<string | undefined>();
const [modalStudio, setModalStudio] = useState<
GQL.ScrapedStudioDataFragment | undefined
>();
const doBoxSearch = (studioID: string, searchVal: string) => {
stashBoxStudioQuery(searchVal, selectedEndpoint.index)
.then((queryData) => {
const s = queryData.data?.scrapeSingleStudio ?? [];
setSearchResults({
...searchResults,
[studioID]: s,
});
setSearchErrors({
...searchErrors,
[studioID]: undefined,
});
setLoading(false);
})
.catch(() => {
setLoading(false);
// Destructure to remove existing result
const { [studioID]: unassign, ...results } = searchResults;
setSearchResults(results);
setSearchErrors({
...searchErrors,
[studioID]: intl.formatMessage({
id: "studio_tagger.network_error",
}),
});
});
setLoading(true);
};
const doBoxUpdate = (
studioID: string,
stashID: string,
endpointIndex: number
) => {
setLoadingUpdate(stashID);
setError({
...error,
[studioID]: undefined,
});
stashBoxStudioQuery(stashID, endpointIndex)
.then((queryData) => {
const data = queryData.data?.scrapeSingleStudio ?? [];
if (data.length > 0) {
setModalStudio({
...data[0],
stored_id: studioID,
});
}
})
.finally(() => setLoadingUpdate(undefined));
};
async function handleBatchAdd(input: string) {
onBatchAdd(input, batchAddParents);
setShowBatchAdd(false);
}
const handleBatchUpdate = (queryAll: boolean, refresh: boolean) => {
onBatchUpdate(
!queryAll ? studios.map((p) => p.id) : undefined,
refresh,
batchAddParents
);
setShowBatchUpdate(false);
};
const handleTaggedStudio = (
studio: Pick<GQL.SlimStudioDataFragment, "id"> &
Partial<Omit<GQL.SlimStudioDataFragment, "id">>
) => {
setTaggedStudios({
...taggedStudios,
[studio.id]: studio,
});
};
const [createStudio] = useStudioCreate();
const updateStudio = useUpdateStudio();
function handleSaveError(studioID: string, name: string, message: string) {
setError({
...error,
[studioID]: {
message: intl.formatMessage(
{ id: "studio_tagger.failed_to_save_studio" },
{ studio: modalStudio?.name }
),
details:
message === "UNIQUE constraint failed: studios.checksum"
? intl.formatMessage({
id: "studio_tagger.name_already_exists",
})
: message,
},
});
}
const handleStudioUpdate = async (
input: GQL.StudioCreateInput,
parentInput?: GQL.StudioCreateInput
) => {
setModalStudio(undefined);
const studioID = modalStudio?.stored_id;
if (studioID) {
if (parentInput) {
try {
// if parent id is set, then update the existing studio
if (input.parent_id) {
const parentUpdateData: GQL.StudioUpdateInput = {
...parentInput,
id: input.parent_id,
};
await updateStudio(parentUpdateData);
} else {
const parentRes = await createStudio({
variables: { input: parentInput },
});
input.parent_id = parentRes.data?.studioCreate?.id;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
handleSaveError(studioID, parentInput.name, e.message ?? "");
}
}
const updateData: GQL.StudioUpdateInput = {
...input,
id: studioID,
};
const res = await updateStudio(updateData);
if (!res.data?.studioUpdate)
handleSaveError(
studioID,
modalStudio?.name ?? "",
res?.errors?.[0]?.message ?? ""
);
}
};
const renderStudios = () =>
studios.map((studio) => {
const isTagged = taggedStudios[studio.id];
const stashID = studio.stash_ids.find((s) => {
return s.endpoint === selectedEndpoint.endpoint;
});
let mainContent;
if (!isTagged && stashID !== undefined) {
mainContent = (
<div className="text-left">
<h5 className="text-bold">
<FormattedMessage id="studio_tagger.studio_already_tagged" />
</h5>
</div>
);
} else if (!isTagged && !stashID) {
mainContent = (
<InputGroup>
<Form.Control
className="text-input"
defaultValue={studio.name ?? ""}
onChange={(e) =>
setQueries({
...queries,
[studio.id]: e.currentTarget.value,
})
}
onKeyPress={(e: React.KeyboardEvent<HTMLInputElement>) =>
e.key === "Enter" &&
doBoxSearch(studio.id, queries[studio.id] ?? studio.name ?? "")
}
/>
<InputGroup.Append>
<Button
disabled={loading}
onClick={() =>
doBoxSearch(
studio.id,
queries[studio.id] ?? studio.name ?? ""
)
}
>
<FormattedMessage id="actions.search" />
</Button>
</InputGroup.Append>
</InputGroup>
);
} else if (isTagged) {
mainContent = (
<div className="d-flex flex-column text-left">
<h5>
<FormattedMessage id="studio_tagger.studio_successfully_tagged" />
</h5>
</div>
);
}
let subContent;
if (stashID !== undefined) {
const base = stashID.endpoint.match(/https?:\/\/.*?\//)?.[0];
const link = base ? (
<a
className="small d-block"
href={`${base}studios/${stashID.stash_id}`}
target="_blank"
rel="noopener noreferrer"
>
{stashID.stash_id}
</a>
) : (
<div className="small">{stashID.stash_id}</div>
);
const endpointIndex =
stashBoxes?.findIndex((box) => box.endpoint === stashID.endpoint) ??
-1;
subContent = (
<div key={studio.id}>
<InputGroup className="StudioTagger-box-link">
<InputGroup.Text>{link}</InputGroup.Text>
<InputGroup.Append>
{endpointIndex !== -1 && (
<Button
onClick={() =>
doBoxUpdate(studio.id, stashID.stash_id, endpointIndex)
}
disabled={!!loadingUpdate}
>
{loadingUpdate === stashID.stash_id ? (
<LoadingIndicator inline small message="" />
) : (
<FormattedMessage id="actions.refresh" />
)}
</Button>
)}
</InputGroup.Append>
</InputGroup>
{error[studio.id] && (
<div className="text-danger mt-1">
<strong>
<span className="mr-2">Error:</span>
{error[studio.id]?.message}
</strong>
<div>{error[studio.id]?.details}</div>
</div>
)}
</div>
);
} else if (searchErrors[studio.id]) {
subContent = (
<div className="text-danger font-weight-bold">
{searchErrors[studio.id]}
</div>
);
} else if (searchResults[studio.id]?.length === 0) {
subContent = (
<div className="text-danger font-weight-bold">
<FormattedMessage id="studio_tagger.no_results_found" />
</div>
);
}
let searchResult;
if (searchResults[studio.id]?.length > 0 && !isTagged) {
searchResult = (
<StashSearchResult
key={studio.id}
stashboxStudios={searchResults[studio.id]}
studio={studio}
endpoint={selectedEndpoint.endpoint}
onStudioTagged={handleTaggedStudio}
excludedStudioFields={config.excludedStudioFields ?? []}
/>
);
}
return (
<div key={studio.id} className={`${CLASSNAME}-studio`}>
{modalStudio && (
<StudioModal
closeModal={() => setModalStudio(undefined)}
modalVisible={modalStudio.stored_id === studio.id}
studio={modalStudio}
handleStudioCreate={handleStudioUpdate}
excludedStudioFields={config.excludedStudioFields}
icon={faTags}
header={intl.formatMessage({
id: "studio_tagger.update_studio",
})}
endpoint={selectedEndpoint.endpoint}
/>
)}
<div className={`${CLASSNAME}-details`}>
<div></div>
<div>
<Card className="studio-card">
<img src={studio.image_path ?? ""} alt="" />
</Card>
</div>
<div className={`${CLASSNAME}-details-text`}>
<Link
to={`/studios/${studio.id}`}
className={`${CLASSNAME}-header`}
>
<h2>{studio.name}</h2>
</Link>
{mainContent}
<div className="sub-content text-left">{subContent}</div>
{searchResult}
</div>
</div>
</div>
);
});
return (
<Card>
{showBatchUpdate && (
<StudioBatchUpdateModal
close={() => setShowBatchUpdate(false)}
isIdle={isIdle}
selectedEndpoint={selectedEndpoint}
studios={studios}
onBatchUpdate={handleBatchUpdate}
batchAddParents={batchAddParents}
setBatchAddParents={setBatchAddParents}
/>
)}
{showBatchAdd && (
<StudioBatchAddModal
close={() => setShowBatchAdd(false)}
isIdle={isIdle}
onBatchAdd={handleBatchAdd}
batchAddParents={batchAddParents}
setBatchAddParents={setBatchAddParents}
/>
)}
<div className="ml-auto mb-3">
<Button onClick={() => setShowBatchAdd(true)}>
<FormattedMessage id="studio_tagger.batch_add_studios" />
</Button>
<Button className="ml-3" onClick={() => setShowBatchUpdate(true)}>
<FormattedMessage id="studio_tagger.batch_update_studios" />
</Button>
</div>
<div className={CLASSNAME}>{renderStudios()}</div>
</Card>
);
};
interface ITaggerProps {
studios: GQL.StudioDataFragment[];
}
export const StudioTagger: React.FC<ITaggerProps> = ({ studios }) => {
const jobsSubscribe = useJobsSubscribe();
const intl = useIntl();
const { configuration: stashConfig } = React.useContext(ConfigurationContext);
const [{ data: config }, setConfig] = useLocalForage<ITaggerConfig>(
LOCAL_FORAGE_KEY,
initialConfig
);
const [showConfig, setShowConfig] = useState(false);
const [showManual, setShowManual] = useState(false);
const [batchJobID, setBatchJobID] = useState<string | undefined | null>();
const [batchJob, setBatchJob] = useState<JobFragment | undefined>();
// monitor batch operation
useEffect(() => {
if (!jobsSubscribe.data) {
return;
}
const event = jobsSubscribe.data.jobsSubscribe;
if (event.job.id !== batchJobID) {
return;
}
if (event.type !== GQL.JobStatusUpdateType.Remove) {
setBatchJob(event.job);
} else {
setBatchJob(undefined);
setBatchJobID(undefined);
// Once the studio batch is complete, refresh all local studio data
const ac = getClient();
evictQueries(ac.cache, studioMutationImpactedQueries);
}
}, [jobsSubscribe, batchJobID]);
if (!config) return <LoadingIndicator />;
const savedEndpointIndex =
stashConfig?.general.stashBoxes.findIndex(
(s) => s.endpoint === config.selectedEndpoint
) ?? -1;
const selectedEndpointIndex =
savedEndpointIndex === -1 && stashConfig?.general.stashBoxes.length
? 0
: savedEndpointIndex;
const selectedEndpoint =
stashConfig?.general.stashBoxes[selectedEndpointIndex];
async function batchAdd(studioInput: string, createParent: boolean) {
if (studioInput && selectedEndpoint) {
const names = studioInput
.split(",")
.map((n) => n.trim())
.filter((n) => n.length > 0);
if (names.length > 0) {
const ret = await mutateStashBoxBatchStudioTag({
names: names,
endpoint: selectedEndpointIndex,
refresh: false,
exclude_fields: config?.excludedStudioFields ?? [],
createParent: createParent,
});
setBatchJobID(ret.data?.stashBoxBatchStudioTag);
}
}
}
async function batchUpdate(
ids: string[] | undefined,
refresh: boolean,
createParent: boolean
) {
if (selectedEndpoint) {
const ret = await mutateStashBoxBatchStudioTag({
ids: ids,
endpoint: selectedEndpointIndex,
refresh,
exclude_fields: config?.excludedStudioFields ?? [],
createParent: createParent,
});
setBatchJobID(ret.data?.stashBoxBatchStudioTag);
}
}
// const progress =
// jobStatus.data?.metadataUpdate.status ===
// "Stash-Box Studio Batch Operation" &&
// jobStatus.data.metadataUpdate.progress >= 0
// ? jobStatus.data.metadataUpdate.progress * 100
// : null;
function renderStatus() {
if (batchJob) {
const progress =
batchJob.progress !== undefined && batchJob.progress !== null
? batchJob.progress * 100
: undefined;
return (
<Form.Group className="px-4">
<h5>
<FormattedMessage id="studio_tagger.status_tagging_studios" />
</h5>
{progress !== undefined && (
<ProgressBar
animated
now={progress}
label={`${progress.toFixed(0)}%`}
/>
)}
</Form.Group>
);
}
if (batchJobID !== undefined) {
return (
<Form.Group className="px-4">
<h5>
<FormattedMessage id="studio_tagger.status_tagging_job_queued" />
</h5>
</Form.Group>
);
}
}
const showHideConfigId = showConfig
? "actions.hide_configuration"
: "actions.show_configuration";
return (
<>
<Manual
show={showManual}
onClose={() => setShowManual(false)}
defaultActiveTab="Tagger.md"
/>
{renderStatus()}
<div className="tagger-container mx-md-auto">
{selectedEndpointIndex !== -1 && selectedEndpoint ? (
<>
<div className="row mb-2 no-gutters">
<Button onClick={() => setShowConfig(!showConfig)} variant="link">
{intl.formatMessage({ id: showHideConfigId })}
</Button>
<Button
className="ml-auto"
onClick={() => setShowManual(true)}
title={intl.formatMessage({ id: "help" })}
variant="link"
>
<FormattedMessage id="help" />
</Button>
</div>
<StudioConfig
config={config}
setConfig={setConfig}
show={showConfig}
/>
<StudioTaggerList
studios={studios}
selectedEndpoint={{
endpoint: selectedEndpoint.endpoint,
index: selectedEndpointIndex,
}}
isIdle={batchJobID === undefined}
config={config}
stashBoxes={stashConfig?.general.stashBoxes}
onBatchAdd={batchAdd}
onBatchUpdate={batchUpdate}
/>
</>
) : (
<div className="my-4">
<h3 className="text-center mt-4">
<FormattedMessage id="studio_tagger.to_use_the_studio_tagger" />
</h3>
<h5 className="text-center">
Please see{" "}
<HashLink
to="/settings?tab=metadata-providers#stash-boxes"
scroll={(el) =>
el.scrollIntoView({ behavior: "smooth", block: "center" })
}
>
Settings.
</HashLink>
</h5>
</div>
)}
</div>
</>
);
};

View file

@ -227,6 +227,128 @@
}
}
.studio-create-modal {
font-size: 1.2rem;
max-width: 800px;
.image-selection {
text-align: center;
.studio-image {
height: 85%;
position: relative;
&-exclude {
position: absolute;
right: 20px;
top: 10px;
}
}
img {
max-height: 100%;
max-width: 100%;
}
}
.LoadingIndicator {
height: 100%;
}
&-field {
margin-bottom: 5px;
.btn {
margin-right: 5px;
}
.fa-icon {
width: 12px;
}
}
}
.StudioTagger {
display: flex;
flex-wrap: wrap;
justify-content: center;
max-width: 1600px;
&-header {
color: white;
&:hover {
color: white;
}
}
&-studio {
background-color: #495b68;
border-radius: 3px;
display: flex;
margin: 1rem;
max-width: 100%;
padding: 1rem;
.studio-card {
box-shadow: none;
flex-shrink: 0;
margin: 0;
padding: 0;
img {
background-color: #495b68;
max-height: 150px;
object-fit: contain;
vertical-align: middle;
width: 100%;
}
}
}
&-details {
//flex-grow: 1;
display: flex;
flex-direction: column;
justify-content: space-between;
margin: 0.5rem;
width: 24rem;
}
&-details-image {
vertical-align: bottom;
}
&-details-text {
vertical-align: bottom;
}
&-studio-search {
display: flex;
flex-wrap: wrap;
&-item {
align-items: center;
display: flex;
overflow: hidden;
text-align: left;
}
}
&-thumb {
height: 40px;
margin-right: 10px;
}
&-box-link {
margin-bottom: 5px;
.input-group-text {
font-family: monospace;
}
}
}
.FieldSelect {
.fa-icon {
width: 12px;

View file

@ -19,7 +19,10 @@ export const getClient = () => client;
// Evicts cached results for the given queries.
// Will also call a cache GC afterwards.
function evictQueries(cache: ApolloCache<unknown>, queries: DocumentNode[]) {
export function evictQueries(
cache: ApolloCache<unknown>,
queries: DocumentNode[]
) {
const fields: Modifiers = {};
for (const query of queries) {
const { selections } = getQueryDefinition(query).selectionSet;
@ -111,7 +114,7 @@ function deleteObject(
/// Object queries
export const useFindScene = (id: string) => {
const skip = id === "new";
const skip = id === "new" || id === "";
return GQL.useFindSceneQuery({ variables: { id }, skip });
};
@ -172,7 +175,7 @@ export const queryFindImages = (filter: ListFilterModel) =>
});
export const useFindMovie = (id: string) => {
const skip = id === "new";
const skip = id === "new" || id === "";
return GQL.useFindMovieQuery({ variables: { id }, skip });
};
@ -217,7 +220,7 @@ export const queryFindSceneMarkers = (filter: ListFilterModel) =>
export const useMarkerStrings = () => GQL.useMarkerStringsQuery();
export const useFindGallery = (id: string) => {
const skip = id === "new";
const skip = id === "new" || id === "";
return GQL.useFindGalleryQuery({ variables: { id }, skip });
};
@ -240,7 +243,7 @@ export const queryFindGalleries = (filter: ListFilterModel) =>
});
export const useFindPerformer = (id: string) => {
const skip = id === "new";
const skip = id === "new" || id === "";
return GQL.useFindPerformerQuery({ variables: { id }, skip });
};
@ -272,7 +275,7 @@ export const useAllPerformersForFilter = () =>
GQL.useAllPerformersForFilterQuery();
export const useFindStudio = (id: string) => {
const skip = id === "new";
const skip = id === "new" || id === "";
return GQL.useFindStudioQuery({ variables: { id }, skip });
};
@ -303,7 +306,7 @@ export const queryFindStudios = (filter: ListFilterModel) =>
export const useAllStudiosForFilter = () => GQL.useAllStudiosForFilterQuery();
export const useFindTag = (id: string) => {
const skip = id === "new";
const skip = id === "new" || id === "";
return GQL.useFindTagQuery({ variables: { id }, skip });
};
@ -1475,7 +1478,7 @@ const studioMutationImpactedTypeFields = {
Studio: ["child_studios"],
};
const studioMutationImpactedQueries = [
export const studioMutationImpactedQueries = [
GQL.FindScenesDocument, // filter by studio
GQL.FindImagesDocument, // filter by studio
GQL.FindMoviesDocument, // filter by studio
@ -1868,16 +1871,42 @@ export const stashBoxPerformerQuery = (
query: searchVal,
},
},
fetchPolicy: "network-only",
});
export const stashBoxStudioQuery = (
query: string | null,
stashBoxIndex: number
) =>
client.query<GQL.ScrapeSingleStudioQuery>({
query: GQL.ScrapeSingleStudioDocument,
variables: {
source: {
stash_box_index: stashBoxIndex,
},
input: {
query: query,
},
},
fetchPolicy: "network-only",
});
export const mutateStashBoxBatchPerformerTag = (
input: GQL.StashBoxBatchPerformerTagInput
input: GQL.StashBoxBatchTagInput
) =>
client.mutate<GQL.StashBoxBatchPerformerTagMutation>({
mutation: GQL.StashBoxBatchPerformerTagDocument,
variables: { input },
});
export const mutateStashBoxBatchStudioTag = (
input: GQL.StashBoxBatchTagInput
) =>
client.mutate<GQL.StashBoxBatchStudioTagMutation>({
mutation: GQL.StashBoxBatchStudioTagDocument,
variables: { input },
});
export const useListMovieScrapers = () => GQL.useListMovieScrapersQuery();
export const queryScrapeMovieURL = (url: string) =>

View file

@ -8,6 +8,7 @@
"allow_temporarily": "Allow temporarily",
"anonymise": "Anonymise",
"apply": "Apply",
"assign_stashid_to_parent_studio": "Assign Stash ID to existing parent studio and update metadata",
"auto_tag": "Auto Tag",
"backup": "Backup",
"browse_for_image": "Browse for image…",
@ -24,6 +25,7 @@
"create_chapters": "Create Chapter",
"create_entity": "Create {entityType}",
"create_marker": "Create Marker",
"create_parent_studio": "Create parent studio",
"created_entity": "Created {entity_type}: {entity_name}",
"customise": "Customise",
"delete": "Delete",
@ -1040,6 +1042,7 @@
"previous": "Previous"
},
"parent_of": "Parent of {children}",
"parent_studio": "Parent Studio",
"parent_studios": "Parent Studios",
"parent_tag_count": "Parent Tag Count",
"parent_tags": "Parent Tags",
@ -1221,6 +1224,7 @@
"stashbox": {
"go_review_draft": "Go to {endpoint_name} to review draft.",
"selected_stash_box": "Selected Stash-Box endpoint",
"source": "Stash-Box Source",
"submission_failed": "Submission failed",
"submission_successful": "Submission successful",
"submit_update": "Already exists in {endpoint_name}"
@ -1237,7 +1241,46 @@
},
"status": "Status: {statusText}",
"studio": "Studio",
"studio_and_parent": "Studio & Parent",
"studio_depth": "Levels (empty for all)",
"studio_tagger": {
"add_new_studios": "Add New Studios",
"any_names_entered_will_be_queried": "Any names entered will be queried from the remote Stash-Box instance and added if found. Only exact matches will be considered a match.",
"batch_add_studios": "Batch Add Studios",
"batch_update_studios": "Batch Update Studios",
"config": {
"active_stash-box_instance": "Active stash-box instance:",
"create_parent_desc": "Create missing parent studios, or tag and update data/image for existing parent studios with exact name matches",
"create_parent_label": "Create parent studios",
"edit_excluded_fields": "Edit Excluded Fields",
"excluded_fields": "Excluded fields:",
"no_fields_are_excluded": "No fields are excluded",
"no_instances_found": "No instances found",
"these_fields_will_not_be_changed_when_updating_studios": "These fields will not be changed when updating studios."
},
"create_or_tag_parent_studios": "Create missing or tag existing parent studios",
"current_page": "Current page",
"failed_to_save_studio": "Failed to save studio \"{studio}\"",
"name_already_exists": "Name already exists",
"network_error": "Network Error",
"no_results_found": "No results found.",
"number_of_studios_will_be_processed": "{studio_count} studios will be processed",
"studio_already_tagged": "Studio already tagged",
"studio_names_separated_by_comma": "Studio names separated by comma",
"studio_selection": "Studio selection",
"studio_successfully_tagged": "Studio successfully tagged",
"query_all_studios_in_the_database": "All studios in the database",
"refresh_tagged_studios": "Refresh tagged studios",
"refreshing_will_update_the_data": "Refreshing will update the data of any tagged studios from the stash-box instance.",
"status_tagging_job_queued": "Status: Tagging job queued",
"status_tagging_studios": "Status: Tagging studios",
"tag_status": "Tag Status",
"to_use_the_studio_tagger": "To use the studio tagger a stash-box instance needs to be configured.",
"untagged_studios": "Untagged studios",
"update_studio": "Update Studio",
"update_studios": "Update Studios",
"updating_untagged_studios_description": "Updating untagged studios will try to match any studios that lack a stashid and update the metadata."
},
"studios": "Studios",
"sub_tag_count": "Sub-Tag Count",
"sub_tag_of": "Sub-tag of {parent}",

View file

@ -30,7 +30,7 @@ const sortByOptions = ["name", "random", "rating"]
},
]);
const displayModeOptions = [DisplayMode.Grid];
const displayModeOptions = [DisplayMode.Grid, DisplayMode.Tagger];
const criterionOptions = [
createMandatoryStringCriterionOption("name"),
createStringCriterionOption("details"),

View file

@ -30,3 +30,15 @@ export function withoutTypename<T extends ITypename>(
{} as Omit<T, "__typename">
);
}
// excludeFields removes fields from data that are in the excluded object
export function excludeFields(
data: { [index: string]: unknown },
excluded: Record<string, boolean>
) {
Object.keys(data).forEach((k) => {
if (excluded[k] || !data[k]) {
data[k] = undefined;
}
});
}

15
vendor/github.com/gofrs/uuid/.gitignore generated vendored Normal file
View file

@ -0,0 +1,15 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, build with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# binary bundle generated by go-fuzz
uuid-fuzz.zip

20
vendor/github.com/gofrs/uuid/LICENSE generated vendored Normal file
View file

@ -0,0 +1,20 @@
Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

117
vendor/github.com/gofrs/uuid/README.md generated vendored Normal file
View file

@ -0,0 +1,117 @@
# UUID
[![License](https://img.shields.io/github/license/gofrs/uuid.svg)](https://github.com/gofrs/uuid/blob/master/LICENSE)
[![Build Status](https://travis-ci.org/gofrs/uuid.svg?branch=master)](https://travis-ci.org/gofrs/uuid)
[![GoDoc](http://godoc.org/github.com/gofrs/uuid?status.svg)](http://godoc.org/github.com/gofrs/uuid)
[![Coverage Status](https://codecov.io/gh/gofrs/uuid/branch/master/graphs/badge.svg?branch=master)](https://codecov.io/gh/gofrs/uuid/)
[![Go Report Card](https://goreportcard.com/badge/github.com/gofrs/uuid)](https://goreportcard.com/report/github.com/gofrs/uuid)
Package uuid provides a pure Go implementation of Universally Unique Identifiers
(UUID) variant as defined in RFC-4122. This package supports both the creation
and parsing of UUIDs in different formats.
This package supports the following UUID versions:
* Version 1, based on timestamp and MAC address (RFC-4122)
* Version 3, based on MD5 hashing of a named value (RFC-4122)
* Version 4, based on random numbers (RFC-4122)
* Version 5, based on SHA-1 hashing of a named value (RFC-4122)
This package also supports experimental Universally Unique Identifier implementations based on a
[draft RFC](https://www.ietf.org/archive/id/draft-peabody-dispatch-new-uuid-format-04.html) that updates RFC-4122
* Version 6, a k-sortable id based on timestamp, and field-compatible with v1 (draft-peabody-dispatch-new-uuid-format, RFC-4122)
* Version 7, a k-sortable id based on timestamp (draft-peabody-dispatch-new-uuid-format, RFC-4122)
The v6 and v7 IDs are **not** considered a part of the stable API, and may be subject to behavior or API changes as part of minor releases
to this package. They will be updated as the draft RFC changes, and will become stable if and when the draft RFC is accepted.
## Project History
This project was originally forked from the
[github.com/satori/go.uuid](https://github.com/satori/go.uuid) repository after
it appeared to be no longer maintained, while exhibiting [critical
flaws](https://github.com/satori/go.uuid/issues/73). We have decided to take
over this project to ensure it receives regular maintenance for the benefit of
the larger Go community.
We'd like to thank Maxim Bublis for his hard work on the original iteration of
the package.
## License
This source code of this package is released under the MIT License. Please see
the [LICENSE](https://github.com/gofrs/uuid/blob/master/LICENSE) for the full
content of the license.
## Recommended Package Version
We recommend using v2.0.0+ of this package, as versions prior to 2.0.0 were
created before our fork of the original package and have some known
deficiencies.
## Installation
It is recommended to use a package manager like `dep` that understands tagged
releases of a package, as well as semantic versioning.
If you are unable to make use of a dependency manager with your project, you can
use the `go get` command to download it directly:
```Shell
$ go get github.com/gofrs/uuid
```
## Requirements
Due to subtests not being supported in older versions of Go, this package is
only regularly tested against Go 1.7+. This package may work perfectly fine with
Go 1.2+, but support for these older versions is not actively maintained.
## Go 1.11 Modules
As of v3.2.0, this repository no longer adopts Go modules, and v3.2.0 no longer has a `go.mod` file. As a result, v3.2.0 also drops support for the `github.com/gofrs/uuid/v3` import path. Only module-based consumers are impacted. With the v3.2.0 release, _all_ gofrs/uuid consumers should use the `github.com/gofrs/uuid` import path.
An existing module-based consumer will continue to be able to build using the `github.com/gofrs/uuid/v3` import path using any valid consumer `go.mod` that worked prior to the publishing of v3.2.0, but any module-based consumer should start using the `github.com/gofrs/uuid` import path when possible and _must_ use the `github.com/gofrs/uuid` import path prior to upgrading to v3.2.0.
Please refer to [Issue #61](https://github.com/gofrs/uuid/issues/61) and [Issue #66](https://github.com/gofrs/uuid/issues/66) for more details.
## Usage
Here is a quick overview of how to use this package. For more detailed
documentation, please see the [GoDoc Page](http://godoc.org/github.com/gofrs/uuid).
```go
package main
import (
"log"
"github.com/gofrs/uuid"
)
// Create a Version 4 UUID, panicking on error.
// Use this form to initialize package-level variables.
var u1 = uuid.Must(uuid.NewV4())
func main() {
// Create a Version 4 UUID.
u2, err := uuid.NewV4()
if err != nil {
log.Fatalf("failed to generate UUID: %v", err)
}
log.Printf("generated Version 4 UUID %v", u2)
// Parse a UUID from a string.
s := "6ba7b810-9dad-11d1-80b4-00c04fd430c8"
u3, err := uuid.FromString(s)
if err != nil {
log.Fatalf("failed to parse UUID %q: %v", s, err)
}
log.Printf("successfully parsed UUID %v", u3)
}
```
## References
* [RFC-4122](https://tools.ietf.org/html/rfc4122)
* [DCE 1.1: Authentication and Security Services](http://pubs.opengroup.org/onlinepubs/9696989899/chap5.htm#tagcjh_08_02_01_01)
* [New UUID Formats RFC Draft (Peabody) Rev 04](https://www.ietf.org/archive/id/draft-peabody-dispatch-new-uuid-format-04.html#)

234
vendor/github.com/gofrs/uuid/codec.go generated vendored Normal file
View file

@ -0,0 +1,234 @@
// Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package uuid
import (
"errors"
"fmt"
)
// FromBytes returns a UUID generated from the raw byte slice input.
// It will return an error if the slice isn't 16 bytes long.
func FromBytes(input []byte) (UUID, error) {
u := UUID{}
err := u.UnmarshalBinary(input)
return u, err
}
// FromBytesOrNil returns a UUID generated from the raw byte slice input.
// Same behavior as FromBytes(), but returns uuid.Nil instead of an error.
func FromBytesOrNil(input []byte) UUID {
uuid, err := FromBytes(input)
if err != nil {
return Nil
}
return uuid
}
var errInvalidFormat = errors.New("uuid: invalid UUID format")
func fromHexChar(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 255
}
// Parse parses the UUID stored in the string text. Parsing and supported
// formats are the same as UnmarshalText.
func (u *UUID) Parse(s string) error {
switch len(s) {
case 32: // hash
case 36: // canonical
case 34, 38:
if s[0] != '{' || s[len(s)-1] != '}' {
return fmt.Errorf("uuid: incorrect UUID format in string %q", s)
}
s = s[1 : len(s)-1]
case 41, 45:
if s[:9] != "urn:uuid:" {
return fmt.Errorf("uuid: incorrect UUID format in string %q", s[:9])
}
s = s[9:]
default:
return fmt.Errorf("uuid: incorrect UUID length %d in string %q", len(s), s)
}
// canonical
if len(s) == 36 {
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return fmt.Errorf("uuid: incorrect UUID format in string %q", s)
}
for i, x := range [16]byte{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34,
} {
v1 := fromHexChar(s[x])
v2 := fromHexChar(s[x+1])
if v1|v2 == 255 {
return errInvalidFormat
}
u[i] = (v1 << 4) | v2
}
return nil
}
// hash like
for i := 0; i < 32; i += 2 {
v1 := fromHexChar(s[i])
v2 := fromHexChar(s[i+1])
if v1|v2 == 255 {
return errInvalidFormat
}
u[i/2] = (v1 << 4) | v2
}
return nil
}
// FromString returns a UUID parsed from the input string.
// Input is expected in a form accepted by UnmarshalText.
func FromString(text string) (UUID, error) {
var u UUID
err := u.Parse(text)
return u, err
}
// FromStringOrNil returns a UUID parsed from the input string.
// Same behavior as FromString(), but returns uuid.Nil instead of an error.
func FromStringOrNil(input string) UUID {
uuid, err := FromString(input)
if err != nil {
return Nil
}
return uuid
}
// MarshalText implements the encoding.TextMarshaler interface.
// The encoding is the same as returned by the String() method.
func (u UUID) MarshalText() ([]byte, error) {
var buf [36]byte
encodeCanonical(buf[:], u)
return buf[:], nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// Following formats are supported:
//
// "6ba7b810-9dad-11d1-80b4-00c04fd430c8",
// "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}",
// "urn:uuid:6ba7b810-9dad-11d1-80b4-00c04fd430c8"
// "6ba7b8109dad11d180b400c04fd430c8"
// "{6ba7b8109dad11d180b400c04fd430c8}",
// "urn:uuid:6ba7b8109dad11d180b400c04fd430c8"
//
// ABNF for supported UUID text representation follows:
//
// URN := 'urn'
// UUID-NID := 'uuid'
//
// hexdig := '0' | '1' | '2' | '3' | '4' | '5' | '6' | '7' | '8' | '9' |
// 'a' | 'b' | 'c' | 'd' | 'e' | 'f' |
// 'A' | 'B' | 'C' | 'D' | 'E' | 'F'
//
// hexoct := hexdig hexdig
// 2hexoct := hexoct hexoct
// 4hexoct := 2hexoct 2hexoct
// 6hexoct := 4hexoct 2hexoct
// 12hexoct := 6hexoct 6hexoct
//
// hashlike := 12hexoct
// canonical := 4hexoct '-' 2hexoct '-' 2hexoct '-' 6hexoct
//
// plain := canonical | hashlike
// uuid := canonical | hashlike | braced | urn
//
// braced := '{' plain '}' | '{' hashlike '}'
// urn := URN ':' UUID-NID ':' plain
func (u *UUID) UnmarshalText(b []byte) error {
switch len(b) {
case 32: // hash
case 36: // canonical
case 34, 38:
if b[0] != '{' || b[len(b)-1] != '}' {
return fmt.Errorf("uuid: incorrect UUID format in string %q", b)
}
b = b[1 : len(b)-1]
case 41, 45:
if string(b[:9]) != "urn:uuid:" {
return fmt.Errorf("uuid: incorrect UUID format in string %q", b[:9])
}
b = b[9:]
default:
return fmt.Errorf("uuid: incorrect UUID length %d in string %q", len(b), b)
}
if len(b) == 36 {
if b[8] != '-' || b[13] != '-' || b[18] != '-' || b[23] != '-' {
return fmt.Errorf("uuid: incorrect UUID format in string %q", b)
}
for i, x := range [16]byte{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34,
} {
v1 := fromHexChar(b[x])
v2 := fromHexChar(b[x+1])
if v1|v2 == 255 {
return errInvalidFormat
}
u[i] = (v1 << 4) | v2
}
return nil
}
for i := 0; i < 32; i += 2 {
v1 := fromHexChar(b[i])
v2 := fromHexChar(b[i+1])
if v1|v2 == 255 {
return errInvalidFormat
}
u[i/2] = (v1 << 4) | v2
}
return nil
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (u UUID) MarshalBinary() ([]byte, error) {
return u.Bytes(), nil
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
// It will return an error if the slice isn't 16 bytes long.
func (u *UUID) UnmarshalBinary(data []byte) error {
if len(data) != Size {
return fmt.Errorf("uuid: UUID must be exactly 16 bytes long, got %d bytes", len(data))
}
copy(u[:], data)
return nil
}

48
vendor/github.com/gofrs/uuid/fuzz.go generated vendored Normal file
View file

@ -0,0 +1,48 @@
// Copyright (c) 2018 Andrei Tudor Călin <mail@acln.ro>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
//go:build gofuzz
// +build gofuzz
package uuid
// Fuzz implements a simple fuzz test for FromString / UnmarshalText.
//
// To run:
//
// $ go get github.com/dvyukov/go-fuzz/...
// $ cd $GOPATH/src/github.com/gofrs/uuid
// $ go-fuzz-build github.com/gofrs/uuid
// $ go-fuzz -bin=uuid-fuzz.zip -workdir=./testdata
//
// If you make significant changes to FromString / UnmarshalText and add
// new cases to fromStringTests (in codec_test.go), please run
//
// $ go test -seed_fuzz_corpus
//
// to seed the corpus with the new interesting inputs, then run the fuzzer.
func Fuzz(data []byte) int {
_, err := FromString(string(data))
if err != nil {
return 0
}
return 1
}

456
vendor/github.com/gofrs/uuid/generator.go generated vendored Normal file
View file

@ -0,0 +1,456 @@
// Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package uuid
import (
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"encoding/binary"
"fmt"
"hash"
"io"
"net"
"sync"
"time"
)
// Difference in 100-nanosecond intervals between
// UUID epoch (October 15, 1582) and Unix epoch (January 1, 1970).
const epochStart = 122192928000000000
// EpochFunc is the function type used to provide the current time.
type EpochFunc func() time.Time
// HWAddrFunc is the function type used to provide hardware (MAC) addresses.
type HWAddrFunc func() (net.HardwareAddr, error)
// DefaultGenerator is the default UUID Generator used by this package.
var DefaultGenerator Generator = NewGen()
// NewV1 returns a UUID based on the current timestamp and MAC address.
func NewV1() (UUID, error) {
return DefaultGenerator.NewV1()
}
// NewV3 returns a UUID based on the MD5 hash of the namespace UUID and name.
func NewV3(ns UUID, name string) UUID {
return DefaultGenerator.NewV3(ns, name)
}
// NewV4 returns a randomly generated UUID.
func NewV4() (UUID, error) {
return DefaultGenerator.NewV4()
}
// NewV5 returns a UUID based on SHA-1 hash of the namespace UUID and name.
func NewV5(ns UUID, name string) UUID {
return DefaultGenerator.NewV5(ns, name)
}
// NewV6 returns a k-sortable UUID based on a timestamp and 48 bits of
// pseudorandom data. The timestamp in a V6 UUID is the same as V1, with the bit
// order being adjusted to allow the UUID to be k-sortable.
//
// This is implemented based on revision 03 of the Peabody UUID draft, and may
// be subject to change pending further revisions. Until the final specification
// revision is finished, changes required to implement updates to the spec will
// not be considered a breaking change. They will happen as a minor version
// releases until the spec is final.
func NewV6() (UUID, error) {
return DefaultGenerator.NewV6()
}
// NewV7 returns a k-sortable UUID based on the current millisecond precision
// UNIX epoch and 74 bits of pseudorandom data. It supports single-node batch generation (multiple UUIDs in the same timestamp) with a Monotonic Random counter.
//
// This is implemented based on revision 04 of the Peabody UUID draft, and may
// be subject to change pending further revisions. Until the final specification
// revision is finished, changes required to implement updates to the spec will
// not be considered a breaking change. They will happen as a minor version
// releases until the spec is final.
func NewV7() (UUID, error) {
return DefaultGenerator.NewV7()
}
// Generator provides an interface for generating UUIDs.
type Generator interface {
NewV1() (UUID, error)
NewV3(ns UUID, name string) UUID
NewV4() (UUID, error)
NewV5(ns UUID, name string) UUID
NewV6() (UUID, error)
NewV7() (UUID, error)
}
// Gen is a reference UUID generator based on the specifications laid out in
// RFC-4122 and DCE 1.1: Authentication and Security Services. This type
// satisfies the Generator interface as defined in this package.
//
// For consumers who are generating V1 UUIDs, but don't want to expose the MAC
// address of the node generating the UUIDs, the NewGenWithHWAF() function has been
// provided as a convenience. See the function's documentation for more info.
//
// The authors of this package do not feel that the majority of users will need
// to obfuscate their MAC address, and so we recommend using NewGen() to create
// a new generator.
type Gen struct {
clockSequenceOnce sync.Once
hardwareAddrOnce sync.Once
storageMutex sync.Mutex
rand io.Reader
epochFunc EpochFunc
hwAddrFunc HWAddrFunc
lastTime uint64
clockSequence uint16
hardwareAddr [6]byte
}
// GenOption is a function type that can be used to configure a Gen generator.
type GenOption func(*Gen)
// interface check -- build will fail if *Gen doesn't satisfy Generator
var _ Generator = (*Gen)(nil)
// NewGen returns a new instance of Gen with some default values set. Most
// people should use this.
func NewGen() *Gen {
return NewGenWithHWAF(defaultHWAddrFunc)
}
// NewGenWithHWAF builds a new UUID generator with the HWAddrFunc provided. Most
// consumers should use NewGen() instead.
//
// This is used so that consumers can generate their own MAC addresses, for use
// in the generated UUIDs, if there is some concern about exposing the physical
// address of the machine generating the UUID.
//
// The Gen generator will only invoke the HWAddrFunc once, and cache that MAC
// address for all the future UUIDs generated by it. If you'd like to switch the
// MAC address being used, you'll need to create a new generator using this
// function.
func NewGenWithHWAF(hwaf HWAddrFunc) *Gen {
return NewGenWithOptions(WithHWAddrFunc(hwaf))
}
// NewGenWithOptions returns a new instance of Gen with the options provided.
// Most people should use NewGen() or NewGenWithHWAF() instead.
//
// To customize the generator, you can pass in one or more GenOption functions.
// For example:
//
// gen := NewGenWithOptions(
// WithHWAddrFunc(myHWAddrFunc),
// WithEpochFunc(myEpochFunc),
// WithRandomReader(myRandomReader),
// )
//
// NewGenWithOptions(WithHWAddrFunc(myHWAddrFunc)) is equivalent to calling
// NewGenWithHWAF(myHWAddrFunc)
// NewGenWithOptions() is equivalent to calling NewGen()
func NewGenWithOptions(opts ...GenOption) *Gen {
gen := &Gen{
epochFunc: time.Now,
hwAddrFunc: defaultHWAddrFunc,
rand: rand.Reader,
}
for _, opt := range opts {
opt(gen)
}
return gen
}
// WithHWAddrFunc is a GenOption that allows you to provide your own HWAddrFunc
// function.
// When this option is nil, the defaultHWAddrFunc is used.
func WithHWAddrFunc(hwaf HWAddrFunc) GenOption {
return func(gen *Gen) {
if hwaf == nil {
hwaf = defaultHWAddrFunc
}
gen.hwAddrFunc = hwaf
}
}
// WithEpochFunc is a GenOption that allows you to provide your own EpochFunc
// function.
// When this option is nil, time.Now is used.
func WithEpochFunc(epochf EpochFunc) GenOption {
return func(gen *Gen) {
if epochf == nil {
epochf = time.Now
}
gen.epochFunc = epochf
}
}
// WithRandomReader is a GenOption that allows you to provide your own random
// reader.
// When this option is nil, the default rand.Reader is used.
func WithRandomReader(reader io.Reader) GenOption {
return func(gen *Gen) {
if reader == nil {
reader = rand.Reader
}
gen.rand = reader
}
}
// NewV1 returns a UUID based on the current timestamp and MAC address.
func (g *Gen) NewV1() (UUID, error) {
u := UUID{}
timeNow, clockSeq, err := g.getClockSequence(false)
if err != nil {
return Nil, err
}
binary.BigEndian.PutUint32(u[0:], uint32(timeNow))
binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>32))
binary.BigEndian.PutUint16(u[6:], uint16(timeNow>>48))
binary.BigEndian.PutUint16(u[8:], clockSeq)
hardwareAddr, err := g.getHardwareAddr()
if err != nil {
return Nil, err
}
copy(u[10:], hardwareAddr)
u.SetVersion(V1)
u.SetVariant(VariantRFC4122)
return u, nil
}
// NewV3 returns a UUID based on the MD5 hash of the namespace UUID and name.
func (g *Gen) NewV3(ns UUID, name string) UUID {
u := newFromHash(md5.New(), ns, name)
u.SetVersion(V3)
u.SetVariant(VariantRFC4122)
return u
}
// NewV4 returns a randomly generated UUID.
func (g *Gen) NewV4() (UUID, error) {
u := UUID{}
if _, err := io.ReadFull(g.rand, u[:]); err != nil {
return Nil, err
}
u.SetVersion(V4)
u.SetVariant(VariantRFC4122)
return u, nil
}
// NewV5 returns a UUID based on SHA-1 hash of the namespace UUID and name.
func (g *Gen) NewV5(ns UUID, name string) UUID {
u := newFromHash(sha1.New(), ns, name)
u.SetVersion(V5)
u.SetVariant(VariantRFC4122)
return u
}
// NewV6 returns a k-sortable UUID based on a timestamp and 48 bits of
// pseudorandom data. The timestamp in a V6 UUID is the same as V1, with the bit
// order being adjusted to allow the UUID to be k-sortable.
//
// This is implemented based on revision 03 of the Peabody UUID draft, and may
// be subject to change pending further revisions. Until the final specification
// revision is finished, changes required to implement updates to the spec will
// not be considered a breaking change. They will happen as a minor version
// releases until the spec is final.
func (g *Gen) NewV6() (UUID, error) {
var u UUID
if _, err := io.ReadFull(g.rand, u[10:]); err != nil {
return Nil, err
}
timeNow, clockSeq, err := g.getClockSequence(false)
if err != nil {
return Nil, err
}
binary.BigEndian.PutUint32(u[0:], uint32(timeNow>>28)) // set time_high
binary.BigEndian.PutUint16(u[4:], uint16(timeNow>>12)) // set time_mid
binary.BigEndian.PutUint16(u[6:], uint16(timeNow&0xfff)) // set time_low (minus four version bits)
binary.BigEndian.PutUint16(u[8:], clockSeq&0x3fff) // set clk_seq_hi_res (minus two variant bits)
u.SetVersion(V6)
u.SetVariant(VariantRFC4122)
return u, nil
}
// getClockSequence returns the epoch and clock sequence for V1,V6 and V7 UUIDs.
//
// When useUnixTSMs is false, it uses the Coordinated Universal Time (UTC) as a count of 100-
//
// nanosecond intervals since 00:00:00.00, 15 October 1582 (the date of Gregorian reform to the Christian calendar).
func (g *Gen) getClockSequence(useUnixTSMs bool) (uint64, uint16, error) {
var err error
g.clockSequenceOnce.Do(func() {
buf := make([]byte, 2)
if _, err = io.ReadFull(g.rand, buf); err != nil {
return
}
g.clockSequence = binary.BigEndian.Uint16(buf)
})
if err != nil {
return 0, 0, err
}
g.storageMutex.Lock()
defer g.storageMutex.Unlock()
var timeNow uint64
if useUnixTSMs {
timeNow = uint64(g.epochFunc().UnixMilli())
} else {
timeNow = g.getEpoch()
}
// Clock didn't change since last UUID generation.
// Should increase clock sequence.
if timeNow <= g.lastTime {
g.clockSequence++
}
g.lastTime = timeNow
return timeNow, g.clockSequence, nil
}
// NewV7 returns a k-sortable UUID based on the current millisecond precision
// UNIX epoch and 74 bits of pseudorandom data.
//
// This is implemented based on revision 04 of the Peabody UUID draft, and may
// be subject to change pending further revisions. Until the final specification
// revision is finished, changes required to implement updates to the spec will
// not be considered a breaking change. They will happen as a minor version
// releases until the spec is final.
func (g *Gen) NewV7() (UUID, error) {
var u UUID
/* https://www.ietf.org/archive/id/draft-peabody-dispatch-new-uuid-format-04.html#name-uuid-version-7
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| unix_ts_ms |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| unix_ts_ms | ver | rand_a |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|var| rand_b |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| rand_b |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */
ms, clockSeq, err := g.getClockSequence(true)
if err != nil {
return Nil, err
}
//UUIDv7 features a 48 bit timestamp. First 32bit (4bytes) represents seconds since 1970, followed by 2 bytes for the ms granularity.
u[0] = byte(ms >> 40) //1-6 bytes: big-endian unsigned number of Unix epoch timestamp
u[1] = byte(ms >> 32)
u[2] = byte(ms >> 24)
u[3] = byte(ms >> 16)
u[4] = byte(ms >> 8)
u[5] = byte(ms)
//support batching by using a monotonic pseudo-random sequence
//The 6th byte contains the version and partially rand_a data.
//We will lose the most significant bites from the clockSeq (with SetVersion), but it is ok, we need the least significant that contains the counter to ensure the monotonic property
binary.BigEndian.PutUint16(u[6:8], clockSeq) // set rand_a with clock seq which is random and monotonic
//override first 4bits of u[6].
u.SetVersion(V7)
//set rand_b 64bits of pseudo-random bits (first 2 will be overridden)
if _, err = io.ReadFull(g.rand, u[8:16]); err != nil {
return Nil, err
}
//override first 2 bits of byte[8] for the variant
u.SetVariant(VariantRFC4122)
return u, nil
}
// Returns the hardware address.
func (g *Gen) getHardwareAddr() ([]byte, error) {
var err error
g.hardwareAddrOnce.Do(func() {
var hwAddr net.HardwareAddr
if hwAddr, err = g.hwAddrFunc(); err == nil {
copy(g.hardwareAddr[:], hwAddr)
return
}
// Initialize hardwareAddr randomly in case
// of real network interfaces absence.
if _, err = io.ReadFull(g.rand, g.hardwareAddr[:]); err != nil {
return
}
// Set multicast bit as recommended by RFC-4122
g.hardwareAddr[0] |= 0x01
})
if err != nil {
return []byte{}, err
}
return g.hardwareAddr[:], nil
}
// Returns the difference between UUID epoch (October 15, 1582)
// and current time in 100-nanosecond intervals.
func (g *Gen) getEpoch() uint64 {
return epochStart + uint64(g.epochFunc().UnixNano()/100)
}
// Returns the UUID based on the hashing of the namespace UUID and name.
func newFromHash(h hash.Hash, ns UUID, name string) UUID {
u := UUID{}
h.Write(ns[:])
h.Write([]byte(name))
copy(u[:], h.Sum(nil))
return u
}
var netInterfaces = net.Interfaces
// Returns the hardware address.
func defaultHWAddrFunc() (net.HardwareAddr, error) {
ifaces, err := netInterfaces()
if err != nil {
return []byte{}, err
}
for _, iface := range ifaces {
if len(iface.HardwareAddr) >= 6 {
return iface.HardwareAddr, nil
}
}
return []byte{}, fmt.Errorf("uuid: no HW address found")
}

116
vendor/github.com/gofrs/uuid/sql.go generated vendored Normal file
View file

@ -0,0 +1,116 @@
// Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package uuid
import (
"database/sql"
"database/sql/driver"
"fmt"
)
var _ driver.Valuer = UUID{}
var _ sql.Scanner = (*UUID)(nil)
// Value implements the driver.Valuer interface.
func (u UUID) Value() (driver.Value, error) {
return u.String(), nil
}
// Scan implements the sql.Scanner interface.
// A 16-byte slice will be handled by UnmarshalBinary, while
// a longer byte slice or a string will be handled by UnmarshalText.
func (u *UUID) Scan(src interface{}) error {
switch src := src.(type) {
case UUID: // support gorm convert from UUID to NullUUID
*u = src
return nil
case []byte:
if len(src) == Size {
return u.UnmarshalBinary(src)
}
return u.UnmarshalText(src)
case string:
uu, err := FromString(src)
*u = uu
return err
}
return fmt.Errorf("uuid: cannot convert %T to UUID", src)
}
// NullUUID can be used with the standard sql package to represent a
// UUID value that can be NULL in the database.
type NullUUID struct {
UUID UUID
Valid bool
}
// Value implements the driver.Valuer interface.
func (u NullUUID) Value() (driver.Value, error) {
if !u.Valid {
return nil, nil
}
// Delegate to UUID Value function
return u.UUID.Value()
}
// Scan implements the sql.Scanner interface.
func (u *NullUUID) Scan(src interface{}) error {
if src == nil {
u.UUID, u.Valid = Nil, false
return nil
}
// Delegate to UUID Scan function
u.Valid = true
return u.UUID.Scan(src)
}
var nullJSON = []byte("null")
// MarshalJSON marshals the NullUUID as null or the nested UUID
func (u NullUUID) MarshalJSON() ([]byte, error) {
if !u.Valid {
return nullJSON, nil
}
var buf [38]byte
buf[0] = '"'
encodeCanonical(buf[1:37], u.UUID)
buf[37] = '"'
return buf[:], nil
}
// UnmarshalJSON unmarshals a NullUUID
func (u *NullUUID) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
u.UUID, u.Valid = Nil, false
return nil
}
if n := len(b); n >= 2 && b[0] == '"' {
b = b[1 : n-1]
}
err := u.UUID.UnmarshalText(b)
u.Valid = (err == nil)
return err
}

285
vendor/github.com/gofrs/uuid/uuid.go generated vendored Normal file
View file

@ -0,0 +1,285 @@
// Copyright (C) 2013-2018 by Maxim Bublis <b@codemonkey.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// Package uuid provides implementations of the Universally Unique Identifier
// (UUID), as specified in RFC-4122 and the Peabody RFC Draft (revision 03).
//
// RFC-4122[1] provides the specification for versions 1, 3, 4, and 5. The
// Peabody UUID RFC Draft[2] provides the specification for the new k-sortable
// UUIDs, versions 6 and 7.
//
// DCE 1.1[3] provides the specification for version 2, but version 2 support
// was removed from this package in v4 due to some concerns with the
// specification itself. Reading the spec, it seems that it would result in
// generating UUIDs that aren't very unique. In having read the spec it seemed
// that our implementation did not meet the spec. It also seems to be at-odds
// with RFC 4122, meaning we would need quite a bit of special code to support
// it. Lastly, there were no Version 2 implementations that we could find to
// ensure we were understanding the specification correctly.
//
// [1] https://tools.ietf.org/html/rfc4122
// [2] https://datatracker.ietf.org/doc/html/draft-peabody-dispatch-new-uuid-format-03
// [3] http://pubs.opengroup.org/onlinepubs/9696989899/chap5.htm#tagcjh_08_02_01_01
package uuid
import (
"encoding/binary"
"encoding/hex"
"fmt"
"time"
)
// Size of a UUID in bytes.
const Size = 16
// UUID is an array type to represent the value of a UUID, as defined in RFC-4122.
type UUID [Size]byte
// UUID versions.
const (
_ byte = iota
V1 // Version 1 (date-time and MAC address)
_ // Version 2 (date-time and MAC address, DCE security version) [removed]
V3 // Version 3 (namespace name-based)
V4 // Version 4 (random)
V5 // Version 5 (namespace name-based)
V6 // Version 6 (k-sortable timestamp and random data, field-compatible with v1) [peabody draft]
V7 // Version 7 (k-sortable timestamp and random data) [peabody draft]
_ // Version 8 (k-sortable timestamp, meant for custom implementations) [peabody draft] [not implemented]
)
// UUID layout variants.
const (
VariantNCS byte = iota
VariantRFC4122
VariantMicrosoft
VariantFuture
)
// UUID DCE domains.
const (
DomainPerson = iota
DomainGroup
DomainOrg
)
// Timestamp is the count of 100-nanosecond intervals since 00:00:00.00,
// 15 October 1582 within a V1 UUID. This type has no meaning for other
// UUID versions since they don't have an embedded timestamp.
type Timestamp uint64
const _100nsPerSecond = 10000000
// Time returns the UTC time.Time representation of a Timestamp
func (t Timestamp) Time() (time.Time, error) {
secs := uint64(t) / _100nsPerSecond
nsecs := 100 * (uint64(t) % _100nsPerSecond)
return time.Unix(int64(secs)-(epochStart/_100nsPerSecond), int64(nsecs)), nil
}
// TimestampFromV1 returns the Timestamp embedded within a V1 UUID.
// Returns an error if the UUID is any version other than 1.
func TimestampFromV1(u UUID) (Timestamp, error) {
if u.Version() != 1 {
err := fmt.Errorf("uuid: %s is version %d, not version 1", u, u.Version())
return 0, err
}
low := binary.BigEndian.Uint32(u[0:4])
mid := binary.BigEndian.Uint16(u[4:6])
hi := binary.BigEndian.Uint16(u[6:8]) & 0xfff
return Timestamp(uint64(low) + (uint64(mid) << 32) + (uint64(hi) << 48)), nil
}
// TimestampFromV6 returns the Timestamp embedded within a V6 UUID. This
// function returns an error if the UUID is any version other than 6.
//
// This is implemented based on revision 03 of the Peabody UUID draft, and may
// be subject to change pending further revisions. Until the final specification
// revision is finished, changes required to implement updates to the spec will
// not be considered a breaking change. They will happen as a minor version
// releases until the spec is final.
func TimestampFromV6(u UUID) (Timestamp, error) {
if u.Version() != 6 {
return 0, fmt.Errorf("uuid: %s is version %d, not version 6", u, u.Version())
}
hi := binary.BigEndian.Uint32(u[0:4])
mid := binary.BigEndian.Uint16(u[4:6])
low := binary.BigEndian.Uint16(u[6:8]) & 0xfff
return Timestamp(uint64(low) + (uint64(mid) << 12) + (uint64(hi) << 28)), nil
}
// Nil is the nil UUID, as specified in RFC-4122, that has all 128 bits set to
// zero.
var Nil = UUID{}
// Predefined namespace UUIDs.
var (
NamespaceDNS = Must(FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8"))
NamespaceURL = Must(FromString("6ba7b811-9dad-11d1-80b4-00c04fd430c8"))
NamespaceOID = Must(FromString("6ba7b812-9dad-11d1-80b4-00c04fd430c8"))
NamespaceX500 = Must(FromString("6ba7b814-9dad-11d1-80b4-00c04fd430c8"))
)
// IsNil returns if the UUID is equal to the nil UUID
func (u UUID) IsNil() bool {
return u == Nil
}
// Version returns the algorithm version used to generate the UUID.
func (u UUID) Version() byte {
return u[6] >> 4
}
// Variant returns the UUID layout variant.
func (u UUID) Variant() byte {
switch {
case (u[8] >> 7) == 0x00:
return VariantNCS
case (u[8] >> 6) == 0x02:
return VariantRFC4122
case (u[8] >> 5) == 0x06:
return VariantMicrosoft
case (u[8] >> 5) == 0x07:
fallthrough
default:
return VariantFuture
}
}
// Bytes returns a byte slice representation of the UUID.
func (u UUID) Bytes() []byte {
return u[:]
}
// encodeCanonical encodes the canonical RFC-4122 form of UUID u into the
// first 36 bytes dst.
func encodeCanonical(dst []byte, u UUID) {
const hextable = "0123456789abcdef"
dst[8] = '-'
dst[13] = '-'
dst[18] = '-'
dst[23] = '-'
for i, x := range [16]byte{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34,
} {
c := u[i]
dst[x] = hextable[c>>4]
dst[x+1] = hextable[c&0x0f]
}
}
// String returns a canonical RFC-4122 string representation of the UUID:
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx.
func (u UUID) String() string {
var buf [36]byte
encodeCanonical(buf[:], u)
return string(buf[:])
}
// Format implements fmt.Formatter for UUID values.
//
// The behavior is as follows:
// The 'x' and 'X' verbs output only the hex digits of the UUID, using a-f for 'x' and A-F for 'X'.
// The 'v', '+v', 's' and 'q' verbs return the canonical RFC-4122 string representation.
// The 'S' verb returns the RFC-4122 format, but with capital hex digits.
// The '#v' verb returns the "Go syntax" representation, which is a 16 byte array initializer.
// All other verbs not handled directly by the fmt package (like '%p') are unsupported and will return
// "%!verb(uuid.UUID=value)" as recommended by the fmt package.
func (u UUID) Format(f fmt.State, c rune) {
if c == 'v' && f.Flag('#') {
fmt.Fprintf(f, "%#v", [Size]byte(u))
return
}
switch c {
case 'x', 'X':
b := make([]byte, 32)
hex.Encode(b, u[:])
if c == 'X' {
toUpperHex(b)
}
_, _ = f.Write(b)
case 'v', 's', 'S':
b, _ := u.MarshalText()
if c == 'S' {
toUpperHex(b)
}
_, _ = f.Write(b)
case 'q':
b := make([]byte, 38)
b[0] = '"'
encodeCanonical(b[1:], u)
b[37] = '"'
_, _ = f.Write(b)
default:
// invalid/unsupported format verb
fmt.Fprintf(f, "%%!%c(uuid.UUID=%s)", c, u.String())
}
}
func toUpperHex(b []byte) {
for i, c := range b {
if 'a' <= c && c <= 'f' {
b[i] = c - ('a' - 'A')
}
}
}
// SetVersion sets the version bits.
func (u *UUID) SetVersion(v byte) {
u[6] = (u[6] & 0x0f) | (v << 4)
}
// SetVariant sets the variant bits.
func (u *UUID) SetVariant(v byte) {
switch v {
case VariantNCS:
u[8] = (u[8]&(0xff>>1) | (0x00 << 7))
case VariantRFC4122:
u[8] = (u[8]&(0xff>>2) | (0x02 << 6))
case VariantMicrosoft:
u[8] = (u[8]&(0xff>>3) | (0x06 << 5))
case VariantFuture:
fallthrough
default:
u[8] = (u[8]&(0xff>>3) | (0x07 << 5))
}
}
// Must is a helper that wraps a call to a function returning (UUID, error)
// and panics if the error is non-nil. It is intended for use in variable
// initializations such as
//
// var packageUUID = uuid.Must(uuid.FromString("123e4567-e89b-12d3-a456-426655440000"))
func Must(u UUID, err error) UUID {
if err != nil {
panic(err)
}
return u
}

3
vendor/modules.txt vendored
View file

@ -176,6 +176,9 @@ github.com/gobwas/pool/pbytes
## explicit; go 1.15
github.com/gobwas/ws
github.com/gobwas/ws/wsutil
# github.com/gofrs/uuid v4.4.0+incompatible
## explicit
github.com/gofrs/uuid
# github.com/golang-jwt/jwt/v4 v4.0.0
## explicit; go 1.15
github.com/golang-jwt/jwt/v4