diff --git a/gqlgen.yml b/gqlgen.yml index b949d44dc..981eb1d21 100644 --- a/gqlgen.yml +++ b/gqlgen.yml @@ -125,6 +125,8 @@ models: model: github.com/stashapp/stash/internal/identify.FieldStrategy ScraperSource: model: github.com/stashapp/stash/pkg/scraper.Source + RoleEnum: + model: github.com/stashapp/stash/pkg/models.RoleEnum IdentifySourceInput: model: github.com/stashapp/stash/internal/identify.Source IdentifyFieldOptionsInput: diff --git a/graphql/schema/schema.graphql b/graphql/schema/schema.graphql index 7fda85b24..79e48cd42 100644 --- a/graphql/schema/schema.graphql +++ b/graphql/schema/schema.graphql @@ -269,7 +269,11 @@ type Query { allStudios: [Studio!]! @deprecated(reason: "Use findStudios instead") allMovies: [Movie!]! @deprecated(reason: "Use findGroups instead") - # Get everything with minimal metadata + users: [User!]! @hasRole(role: ADMIN) + """ + Returns currently authenticated user + """ + me: User # Version version: Version! @@ -279,137 +283,169 @@ type Query { } type Mutation { - setup(input: SetupInput!): Boolean! + setup(input: SetupInput!): Boolean! @hasRole(role: ADMIN) "Migrates the schema to the required version. Returns the job ID" - migrate(input: MigrateInput!): ID! + migrate(input: MigrateInput!): ID! @hasRole(role: ADMIN) "Downloads and installs ffmpeg and ffprobe binaries into the configuration directory. Returns the job ID." - downloadFFMpeg: ID! + downloadFFMpeg: ID! @hasRole(role: ADMIN) - sceneCreate(input: SceneCreateInput!): Scene - sceneUpdate(input: SceneUpdateInput!): Scene - sceneMerge(input: SceneMergeInput!): Scene - bulkSceneUpdate(input: BulkSceneUpdateInput!): [Scene!] - sceneDestroy(input: SceneDestroyInput!): Boolean! - scenesDestroy(input: ScenesDestroyInput!): Boolean! - scenesUpdate(input: [SceneUpdateInput!]!): [Scene] + sceneCreate(input: SceneCreateInput!): Scene @hasRole(role: MODIFY) + sceneUpdate(input: SceneUpdateInput!): Scene @hasRole(role: MODIFY) + sceneMerge(input: SceneMergeInput!): Scene @hasRole(role: MODIFY) + bulkSceneUpdate(input: BulkSceneUpdateInput!): [Scene!] @hasRole(role: MODIFY) + sceneDestroy(input: SceneDestroyInput!): Boolean! @hasRole(role: MODIFY) + scenesDestroy(input: ScenesDestroyInput!): Boolean! @hasRole(role: MODIFY) + scenesUpdate(input: [SceneUpdateInput!]!): [Scene] @hasRole(role: MODIFY) "Increments the o-counter for a scene. Returns the new value" - sceneIncrementO(id: ID!): Int! @deprecated(reason: "Use sceneAddO instead") + sceneIncrementO(id: ID!): Int! + @deprecated(reason: "Use sceneAddO instead") + @hasRole(role: MODIFY) "Decrements the o-counter for a scene. Returns the new value" - sceneDecrementO(id: ID!): Int! @deprecated(reason: "Use sceneRemoveO instead") + sceneDecrementO(id: ID!): Int! + @deprecated(reason: "Use sceneRemoveO instead") + @hasRole(role: MODIFY) "Increments the o-counter for a scene. Uses the current time if none provided." sceneAddO(id: ID!, times: [Timestamp!]): HistoryMutationResult! + @hasRole(role: MODIFY) "Decrements the o-counter for a scene, removing the last recorded time if specific time not provided. Returns the new value" sceneDeleteO(id: ID!, times: [Timestamp!]): HistoryMutationResult! + @hasRole(role: MODIFY) "Resets the o-counter for a scene to 0. Returns the new value" - sceneResetO(id: ID!): Int! + sceneResetO(id: ID!): Int! @hasRole(role: MODIFY) "Sets the resume time point (if provided) and adds the provided duration to the scene's play duration" sceneSaveActivity(id: ID!, resume_time: Float, playDuration: Float): Boolean! + @hasRole(role: MODIFY) "Resets the resume time point and play duration" sceneResetActivity( id: ID! reset_resume: Boolean reset_duration: Boolean - ): Boolean! + ): Boolean! @hasRole(role: MODIFY) "Increments the play count for the scene. Returns the new play count value." sceneIncrementPlayCount(id: ID!): Int! @deprecated(reason: "Use sceneAddPlay instead") + @hasRole(role: MODIFY) "Increments the play count for the scene. Uses the current time if none provided." sceneAddPlay(id: ID!, times: [Timestamp!]): HistoryMutationResult! + @hasRole(role: MODIFY) "Decrements the play count for the scene, removing the specific times or the last recorded time if not provided." sceneDeletePlay(id: ID!, times: [Timestamp!]): HistoryMutationResult! + @hasRole(role: MODIFY) "Resets the play count for a scene to 0. Returns the new play count value." - sceneResetPlayCount(id: ID!): Int! + sceneResetPlayCount(id: ID!): Int! @hasRole(role: MODIFY) "Generates screenshot at specified time in seconds. Leave empty to generate default screenshot" - sceneGenerateScreenshot(id: ID!, at: Float): String! + sceneGenerateScreenshot(id: ID!, at: Float): String! @hasRole(role: ADMIN) sceneMarkerCreate(input: SceneMarkerCreateInput!): SceneMarker + @hasRole(role: MODIFY) sceneMarkerUpdate(input: SceneMarkerUpdateInput!): SceneMarker + @hasRole(role: MODIFY) bulkSceneMarkerUpdate(input: BulkSceneMarkerUpdateInput!): [SceneMarker!] - sceneMarkerDestroy(id: ID!): Boolean! - sceneMarkersDestroy(ids: [ID!]!): Boolean! + @hasRole(role: MODIFY) + sceneMarkerDestroy(id: ID!): Boolean! @hasRole(role: MODIFY) + sceneMarkersDestroy(ids: [ID!]!): Boolean! @hasRole(role: MODIFY) - sceneAssignFile(input: AssignSceneFileInput!): Boolean! + sceneAssignFile(input: AssignSceneFileInput!): Boolean! @hasRole(role: MODIFY) - imageUpdate(input: ImageUpdateInput!): Image - bulkImageUpdate(input: BulkImageUpdateInput!): [Image!] - imageDestroy(input: ImageDestroyInput!): Boolean! - imagesDestroy(input: ImagesDestroyInput!): Boolean! - imagesUpdate(input: [ImageUpdateInput!]!): [Image] + imageUpdate(input: ImageUpdateInput!): Image @hasRole(role: MODIFY) + bulkImageUpdate(input: BulkImageUpdateInput!): [Image!] @hasRole(role: MODIFY) + imageDestroy(input: ImageDestroyInput!): Boolean! @hasRole(role: MODIFY) + imagesDestroy(input: ImagesDestroyInput!): Boolean! @hasRole(role: MODIFY) + imagesUpdate(input: [ImageUpdateInput!]!): [Image] @hasRole(role: MODIFY) "Increments the o-counter for an image. Returns the new value" - imageIncrementO(id: ID!): Int! + imageIncrementO(id: ID!): Int! @hasRole(role: MODIFY) "Decrements the o-counter for an image. Returns the new value" - imageDecrementO(id: ID!): Int! + imageDecrementO(id: ID!): Int! @hasRole(role: MODIFY) "Resets the o-counter for a image to 0. Returns the new value" - imageResetO(id: ID!): Int! + imageResetO(id: ID!): Int! @hasRole(role: MODIFY) - galleryCreate(input: GalleryCreateInput!): Gallery - galleryUpdate(input: GalleryUpdateInput!): Gallery + galleryCreate(input: GalleryCreateInput!): Gallery @hasRole(role: MODIFY) + galleryUpdate(input: GalleryUpdateInput!): Gallery @hasRole(role: MODIFY) bulkGalleryUpdate(input: BulkGalleryUpdateInput!): [Gallery!] - galleryDestroy(input: GalleryDestroyInput!): Boolean! + @hasRole(role: MODIFY) + galleryDestroy(input: GalleryDestroyInput!): Boolean! @hasRole(role: MODIFY) galleriesUpdate(input: [GalleryUpdateInput!]!): [Gallery] + @hasRole(role: MODIFY) - addGalleryImages(input: GalleryAddInput!): Boolean! + addGalleryImages(input: GalleryAddInput!): Boolean! @hasRole(role: MODIFY) removeGalleryImages(input: GalleryRemoveInput!): Boolean! - setGalleryCover(input: GallerySetCoverInput!): Boolean! + @hasRole(role: MODIFY) + setGalleryCover(input: GallerySetCoverInput!): Boolean! @hasRole(role: MODIFY) resetGalleryCover(input: GalleryResetCoverInput!): Boolean! + @hasRole(role: MODIFY) galleryChapterCreate(input: GalleryChapterCreateInput!): GalleryChapter + @hasRole(role: MODIFY) galleryChapterUpdate(input: GalleryChapterUpdateInput!): GalleryChapter - galleryChapterDestroy(id: ID!): Boolean! + @hasRole(role: MODIFY) + galleryChapterDestroy(id: ID!): Boolean! @hasRole(role: MODIFY) performerCreate(input: PerformerCreateInput!): Performer + @hasRole(role: MODIFY) performerUpdate(input: PerformerUpdateInput!): Performer + @hasRole(role: MODIFY) performerDestroy(input: PerformerDestroyInput!): Boolean! - performersDestroy(ids: [ID!]!): Boolean! + @hasRole(role: MODIFY) + performersDestroy(ids: [ID!]!): Boolean! @hasRole(role: MODIFY) bulkPerformerUpdate(input: BulkPerformerUpdateInput!): [Performer!] - performerMerge(input: PerformerMergeInput!): Performer! + @hasRole(role: MODIFY) + performerMerge(input: PerformerMergeInput!): Performer! @hasRole(role: MODIFY) - studioCreate(input: StudioCreateInput!): Studio - studioUpdate(input: StudioUpdateInput!): Studio - studioDestroy(input: StudioDestroyInput!): Boolean! - studiosDestroy(ids: [ID!]!): Boolean! + studioCreate(input: StudioCreateInput!): Studio @hasRole(role: MODIFY) + studioUpdate(input: StudioUpdateInput!): Studio @hasRole(role: MODIFY) + studioDestroy(input: StudioDestroyInput!): Boolean! @hasRole(role: MODIFY) + studiosDestroy(ids: [ID!]!): Boolean! @hasRole(role: MODIFY) bulkStudioUpdate(input: BulkStudioUpdateInput!): [Studio!] + @hasRole(role: MODIFY) movieCreate(input: MovieCreateInput!): Movie @deprecated(reason: "Use groupCreate instead") + @hasRole(role: MODIFY) movieUpdate(input: MovieUpdateInput!): Movie @deprecated(reason: "Use groupUpdate instead") + @hasRole(role: MODIFY) movieDestroy(input: MovieDestroyInput!): Boolean! @deprecated(reason: "Use groupDestroy instead") + @hasRole(role: MODIFY) moviesDestroy(ids: [ID!]!): Boolean! @deprecated(reason: "Use groupsDestroy instead") + @hasRole(role: MODIFY) bulkMovieUpdate(input: BulkMovieUpdateInput!): [Movie!] @deprecated(reason: "Use bulkGroupUpdate instead") + @hasRole(role: MODIFY) - groupCreate(input: GroupCreateInput!): Group - groupUpdate(input: GroupUpdateInput!): Group - groupDestroy(input: GroupDestroyInput!): Boolean! - groupsDestroy(ids: [ID!]!): Boolean! - bulkGroupUpdate(input: BulkGroupUpdateInput!): [Group!] + groupCreate(input: GroupCreateInput!): Group @hasRole(role: MODIFY) + groupUpdate(input: GroupUpdateInput!): Group @hasRole(role: MODIFY) + groupDestroy(input: GroupDestroyInput!): Boolean! @hasRole(role: MODIFY) + groupsDestroy(ids: [ID!]!): Boolean! @hasRole(role: MODIFY) + bulkGroupUpdate(input: BulkGroupUpdateInput!): [Group!] @hasRole(role: MODIFY) addGroupSubGroups(input: GroupSubGroupAddInput!): Boolean! + @hasRole(role: MODIFY) removeGroupSubGroups(input: GroupSubGroupRemoveInput!): Boolean! + @hasRole(role: MODIFY) "Reorder sub groups within a group. Returns true if successful." reorderSubGroups(input: ReorderSubGroupsInput!): Boolean! + @hasRole(role: MODIFY) - tagCreate(input: TagCreateInput!): Tag - tagUpdate(input: TagUpdateInput!): Tag - tagDestroy(input: TagDestroyInput!): Boolean! - tagsDestroy(ids: [ID!]!): Boolean! - tagsMerge(input: TagsMergeInput!): Tag - bulkTagUpdate(input: BulkTagUpdateInput!): [Tag!] + tagCreate(input: TagCreateInput!): Tag @hasRole(role: MODIFY) + tagUpdate(input: TagUpdateInput!): Tag @hasRole(role: MODIFY) + tagDestroy(input: TagDestroyInput!): Boolean! @hasRole(role: MODIFY) + tagsDestroy(ids: [ID!]!): Boolean! @hasRole(role: MODIFY) + tagsMerge(input: TagsMergeInput!): Tag @hasRole(role: MODIFY) + bulkTagUpdate(input: BulkTagUpdateInput!): [Tag!] @hasRole(role: MODIFY) """ Moves the given files to the given destination. Returns true if successful. @@ -420,90 +456,98 @@ type Mutation { matches one of the media extensions. Creates folder hierarchy if needed. """ - moveFiles(input: MoveFilesInput!): Boolean! - deleteFiles(ids: [ID!]!): Boolean! + moveFiles(input: MoveFilesInput!): Boolean! @hasRole(role: MODIFY) + deleteFiles(ids: [ID!]!): Boolean! @hasRole(role: MODIFY) "Deletes file entries from the database without deleting the files from the filesystem" - destroyFiles(ids: [ID!]!): Boolean! + destroyFiles(ids: [ID!]!): Boolean! @hasRole(role: MODIFY) fileSetFingerprints(input: FileSetFingerprintsInput!): Boolean! # Saved filters - saveFilter(input: SaveFilterInput!): SavedFilter! + saveFilter(input: SaveFilterInput!): SavedFilter! @hasRole(role: MODIFY) destroySavedFilter(input: DestroyFilterInput!): Boolean! + @hasRole(role: MODIFY) setDefaultFilter(input: SetDefaultFilterInput!): Boolean! @deprecated(reason: "now uses UI config") + @hasRole(role: MODIFY) "Change general configuration options" configureGeneral(input: ConfigGeneralInput!): ConfigGeneralResult! + @hasRole(role: ADMIN) configureInterface(input: ConfigInterfaceInput!): ConfigInterfaceResult! + @hasRole(role: ADMIN) configureDLNA(input: ConfigDLNAInput!): ConfigDLNAResult! + @hasRole(role: ADMIN) configureScraping(input: ConfigScrapingInput!): ConfigScrapingResult! + @hasRole(role: ADMIN) configureDefaults( input: ConfigDefaultSettingsInput! - ): ConfigDefaultSettingsResult! + ): ConfigDefaultSettingsResult! @hasRole(role: ADMIN) "overwrites the entire plugin configuration for the given plugin" - configurePlugin(plugin_id: ID!, input: Map!): Map! + configurePlugin(plugin_id: ID!, input: Map!): Map! @hasRole(role: ADMIN) """ overwrites the UI configuration if input is provided, then the entire UI configuration is replaced if partial is provided, then the partial UI configuration is merged into the existing UI configuration """ - configureUI(input: Map, partial: Map): Map! + configureUI(input: Map, partial: Map): Map! @hasRole(role: ADMIN) """ sets a single UI key value key is a dot separated path to the value """ - configureUISetting(key: String!, value: Any): Map! + configureUISetting(key: String!, value: Any): Map! @hasRole(role: ADMIN) - "Generate and set (or clear) API key" + "Generate and set (or clear) API key for the current user" generateAPIKey(input: GenerateAPIKeyInput!): String! "Returns a link to download the result" - exportObjects(input: ExportObjectsInput!): String + exportObjects(input: ExportObjectsInput!): String @hasRole(role: ADMIN) "Performs an incremental import. Returns the job ID" - importObjects(input: ImportObjectsInput!): ID! + importObjects(input: ImportObjectsInput!): ID! @hasRole(role: ADMIN) "Start an full import. Completely wipes the database and imports from the metadata directory. Returns the job ID" - metadataImport: ID! + metadataImport: ID! @hasRole(role: ADMIN) "Start a full export. Outputs to the metadata directory. Returns the job ID" - metadataExport: ID! + metadataExport: ID! @hasRole(role: ADMIN) "Start a scan. Returns the job ID" - metadataScan(input: ScanMetadataInput!): ID! + metadataScan(input: ScanMetadataInput!): ID! @hasRole(role: ADMIN) "Start generating content. Returns the job ID" - metadataGenerate(input: GenerateMetadataInput!): ID! + metadataGenerate(input: GenerateMetadataInput!): ID! @hasRole(role: ADMIN) "Start auto-tagging. Returns the job ID" - metadataAutoTag(input: AutoTagMetadataInput!): ID! + metadataAutoTag(input: AutoTagMetadataInput!): ID! @hasRole(role: ADMIN) "Clean metadata. Returns the job ID" - metadataClean(input: CleanMetadataInput!): ID! + metadataClean(input: CleanMetadataInput!): ID! @hasRole(role: ADMIN) "Clean generated files. Returns the job ID" - metadataCleanGenerated(input: CleanGeneratedInput!): ID! + metadataCleanGenerated(input: CleanGeneratedInput!): ID! @hasRole(role: ADMIN) "Identifies scenes using scrapers. Returns the job ID" - metadataIdentify(input: IdentifyMetadataInput!): ID! + metadataIdentify(input: IdentifyMetadataInput!): ID! @hasRole(role: ADMIN) "Migrate generated files for the current hash naming" - migrateHashNaming: ID! + migrateHashNaming: ID! @hasRole(role: ADMIN) "Migrates legacy scene screenshot files into the blob storage" migrateSceneScreenshots(input: MigrateSceneScreenshotsInput!): ID! + @hasRole(role: ADMIN) "Migrates blobs from the old storage system to the current one" - migrateBlobs(input: MigrateBlobsInput!): ID! + migrateBlobs(input: MigrateBlobsInput!): ID! @hasRole(role: ADMIN) "Anonymise the database in a separate file. Optionally returns a link to download the database file" anonymiseDatabase(input: AnonymiseDatabaseInput!): String + @hasRole(role: ADMIN) "Optimises the database. Returns the job ID" - optimiseDatabase: ID! + optimiseDatabase: ID! @hasRole(role: ADMIN) "Reload scrapers" - reloadScrapers: Boolean! + reloadScrapers: Boolean! @hasRole(role: ADMIN) """ Enable/disable plugins - enabledMap is a map of plugin IDs to enabled booleans. Plugins not in the map are not affected. """ - setPluginsEnabled(enabledMap: BoolMap!): Boolean! + setPluginsEnabled(enabledMap: BoolMap!): Boolean! @hasRole(role: ADMIN) """ Run a plugin task. @@ -520,15 +564,15 @@ type Mutation { description: String args: [PluginArgInput!] @deprecated(reason: "Use args_map instead") args_map: Map - ): ID! + ): ID! @hasRole(role: MODIFY) """ Runs a plugin operation. The operation is run immediately and does not use the job queue. Returns a map of the result. """ - runPluginOperation(plugin_id: ID!, args: Map): Any + runPluginOperation(plugin_id: ID!, args: Map): Any @hasRole(role: MODIFY) - reloadPlugins: Boolean! + reloadPlugins: Boolean! @hasRole(role: ADMIN) """ Installs the given packages. @@ -537,6 +581,7 @@ type Mutation { Returns the job ID """ installPackages(type: PackageType!, packages: [PackageSpecInput!]!): ID! + @hasRole(role: ADMIN) """ Updates the given packages. If a package is not installed, it will not be installed. @@ -546,48 +591,62 @@ type Mutation { Returns the job ID. """ updatePackages(type: PackageType!, packages: [PackageSpecInput!]): ID! + @hasRole(role: ADMIN) """ Uninstalls the given packages. If an error occurs when uninstalling a package, the job will continue to uninstall the remaining packages. Returns the job ID """ uninstallPackages(type: PackageType!, packages: [PackageSpecInput!]!): ID! + @hasRole(role: ADMIN) - stopJob(job_id: ID!): Boolean! - stopAllJobs: Boolean! + stopJob(job_id: ID!): Boolean! @hasRole(role: ADMIN) + stopAllJobs: Boolean! @hasRole(role: ADMIN) "Submit fingerprints to stash-box instance" submitStashBoxFingerprints( input: StashBoxFingerprintSubmissionInput! - ): Boolean! + ): Boolean! @hasRole(role: MODIFY) "Submit scene as draft to stash-box instance" submitStashBoxSceneDraft(input: StashBoxDraftSubmissionInput!): ID + @hasRole(role: MODIFY) "Submit performer as draft to stash-box instance" submitStashBoxPerformerDraft(input: StashBoxDraftSubmissionInput!): ID + @hasRole(role: MODIFY) "Backup the database. Optionally returns a link to download the database file" - backupDatabase(input: BackupDatabaseInput!): String + backupDatabase(input: BackupDatabaseInput!): String @hasRole(role: ADMIN) "DANGEROUS: Execute an arbitrary SQL statement that returns rows." - querySQL(sql: String!, args: [Any]): SQLQueryResult! + querySQL(sql: String!, args: [Any]): SQLQueryResult! @hasRole(role: ADMIN) "DANGEROUS: Execute an arbitrary SQL statement without returning any rows." - execSQL(sql: String!, args: [Any]): SQLExecResult! + execSQL(sql: String!, args: [Any]): SQLExecResult! @hasRole(role: ADMIN) "Run batch performer tag task. Returns the job ID." stashBoxBatchPerformerTag(input: StashBoxBatchTagInput!): String! + @hasRole(role: ADMIN) "Run batch studio tag task. Returns the job ID." stashBoxBatchStudioTag(input: StashBoxBatchTagInput!): String! + @hasRole(role: ADMIN) "Enables DLNA for an optional duration. Has no effect if DLNA is enabled by default" - enableDLNA(input: EnableDLNAInput!): Boolean! + enableDLNA(input: EnableDLNAInput!): Boolean! @hasRole(role: ADMIN) "Disables DLNA for an optional duration. Has no effect if DLNA is disabled by default" - disableDLNA(input: DisableDLNAInput!): Boolean! + disableDLNA(input: DisableDLNAInput!): Boolean! @hasRole(role: ADMIN) "Enables an IP address for DLNA for an optional duration" - addTempDLNAIP(input: AddTempDLNAIPInput!): Boolean! + addTempDLNAIP(input: AddTempDLNAIPInput!): Boolean! @hasRole(role: ADMIN) "Removes an IP address from the temporary DLNA whitelist" removeTempDLNAIP(input: RemoveTempDLNAIPInput!): Boolean! + @hasRole(role: ADMIN) + + userCreate(input: UserCreateInput!): User @hasRole(role: ADMIN) + userUpdate(input: UserUpdateInput!): User @hasRole(role: ADMIN) + userDestroy(input: UserDestroyInput!): Boolean! @hasRole(role: ADMIN) + changeUserPassword(input: ChangeUserPasswordInput!): Boolean! + @hasRole(role: ADMIN) + changePassword(input: UserChangePasswordInput!): Boolean! } type Subscription { diff --git a/graphql/schema/types/user.graphql b/graphql/schema/types/user.graphql new file mode 100644 index 000000000..7b5829e37 --- /dev/null +++ b/graphql/schema/types/user.graphql @@ -0,0 +1,53 @@ +enum RoleEnum { + ADMIN + READ + MODIFY +} + +directive @hasRole(role: RoleEnum!) on FIELD_DEFINITION +directive @isUserOwner on FIELD_DEFINITION + +type User { + name: String! + """ + If the user has no roles, they are considered locked and cannot log in. + Should not be visible to other users + """ + roles: [RoleEnum!] @isUserOwner + """ + Should not be visible to other users + """ + api_key: String @isUserOwner +} + +input UserCreateInput { + name: String! + """ + Password in plain text + """ + password: String! + roles: [RoleEnum!]! +} + +input UserUpdateInput { + existingName: String! + name: String! + roles: [RoleEnum!]! +} + +input UserDestroyInput { + name: String! +} + +input UserChangePasswordInput { + """ + Password in plain text + """ + existingPassword: String! + newPassword: String! +} + +input ChangeUserPasswordInput { + name: String! + newPassword: String! +} diff --git a/internal/api/authentication.go b/internal/api/authentication.go index 6ad7117a1..0ff2bcca1 100644 --- a/internal/api/authentication.go +++ b/internal/api/authentication.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "net" "net/http" @@ -11,7 +12,9 @@ import ( "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/session" + "github.com/stashapp/stash/pkg/user" ) const ( @@ -29,21 +32,47 @@ func allowUnauthenticated(r *http.Request) bool { return strings.HasPrefix(r.URL.Path, loginEndpoint) || r.URL.Path == logoutEndpoint || r.URL.Path == "/css" || strings.HasPrefix(r.URL.Path, "/assets") } -func authenticateHandler() func(http.Handler) http.Handler { +type UserAuthenticator interface { + AuthenticateByAPIKey(ctx context.Context, apiKey string) (*models.User, error) + AuthenticateUserByID(ctx context.Context, username string) (*models.User, error) +} + +func authenticateHandler(g UserAuthenticator) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c := config.GetInstance() + s := c.UserStore // error if external access tripwire activated - if accessErr := session.CheckExternalAccessTripwire(c); accessErr != nil { + if accessErr := session.CheckExternalAccessTripwire(s, c); accessErr != nil { http.Error(w, tripwireActivatedErrMsg, http.StatusForbidden) return } - userID, err := manager.GetInstance().SessionStore.Authenticate(w, r) + // try to authenticate using api key first + var u *models.User + var err error + ctx := r.Context() + + apiKey := session.GetRequestApiKey(r) + if apiKey != "" { + u, err = g.AuthenticateByAPIKey(ctx, apiKey) + } else { + userID, getErr := manager.GetInstance().SessionStore.GetSessionUserID(w, r) + if getErr != nil { + logger.Errorf("error getting session user ID: %v", getErr) + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + + if userID != "" { + u, err = g.AuthenticateUserByID(ctx, userID) + } + } + if err != nil { - if !errors.Is(err, session.ErrUnauthorized) { - http.Error(w, err.Error(), http.StatusInternalServerError) + if errors.Is(err, user.ErrInternalError) { + http.Error(w, "internal server error", http.StatusInternalServerError) return } @@ -53,7 +82,7 @@ func authenticateHandler() func(http.Handler) http.Handler { return } - if err := session.CheckAllowPublicWithoutAuth(c, r); err != nil { + if err := session.CheckAllowPublicWithoutAuth(s, c, r); err != nil { var accessErr session.ExternalAccessError if errors.As(err, &accessErr) { session.LogExternalAccessError(accessErr) @@ -71,11 +100,9 @@ func authenticateHandler() func(http.Handler) http.Handler { return } - ctx := r.Context() - - if c.HasCredentials() { + if hc := s.LoginRequired(ctx); hc { // authentication is required - if userID == "" && !allowUnauthenticated(r) { + if u == nil && !allowUnauthenticated(r) { // if graphql or a non-webpage was requested, we just return a forbidden error ext := path.Ext(r.URL.Path) if r.URL.Path == gqlEndpoint || (ext != "" && ext != ".html") { @@ -102,7 +129,10 @@ func authenticateHandler() func(http.Handler) http.Handler { } } - ctx = session.SetCurrentUserID(ctx, userID) + if u != nil { + // set the user object in the context + ctx = session.SetCurrentUser(ctx, *u) + } r = r.WithContext(ctx) diff --git a/internal/api/directives.go b/internal/api/directives.go new file mode 100644 index 000000000..646f194af --- /dev/null +++ b/internal/api/directives.go @@ -0,0 +1,48 @@ +package api + +import ( + "context" + + "github.com/99designs/gqlgen/graphql" + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/session" +) + +func HasRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role models.RoleEnum) (interface{}, error) { + currentUser := session.GetCurrentUser(ctx) + + // if there is no current user, this is an anonymous request + // we should not end up here unless there are no credentials required + if currentUser == nil { + return next(ctx) + } + + if !currentUser.Roles.HasRole(role) { + return nil, session.ErrUnauthorized + } + + return next(ctx) +} + +func IsUserOwnerDirective(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) { + currentUser := session.GetCurrentUser(ctx) + + // if there is no current user, this is an anonymous request + // we should not end up here unless there are no credentials required + if currentUser == nil { + return next(ctx) + } + + // get the user from the object + userObj, ok := obj.(*models.User) + if !ok { + return nil, session.ErrUnauthorized + } + + // allow admin access + if !currentUser.Roles.HasRole(models.RoleEnumAdmin) && currentUser.Username != userObj.Username { + return nil, session.ErrUnauthorized + } + + return next(ctx) +} diff --git a/internal/api/resolver.go b/internal/api/resolver.go index 061d0e1a9..509a2f170 100644 --- a/internal/api/resolver.go +++ b/internal/api/resolver.go @@ -37,6 +37,7 @@ type Resolver struct { imageService manager.ImageService galleryService manager.GalleryService groupService manager.GroupService + userService manager.UserService hookExecutor hookExecutor } @@ -110,6 +111,9 @@ func (r *Resolver) Plugin() PluginResolver { func (r *Resolver) ConfigResult() ConfigResultResolver { return &configResultResolver{r} } +func (r *Resolver) User() UserResolver { + return &userResolver{r} +} type mutationResolver struct{ *Resolver } type queryResolver struct{ *Resolver } @@ -136,6 +140,7 @@ type folderResolver struct{ *Resolver } type savedFilterResolver struct{ *Resolver } type pluginResolver struct{ *Resolver } type configResultResolver struct{ *Resolver } +type userResolver struct{ *Resolver } func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error { return r.repository.WithTxn(ctx, fn) diff --git a/internal/api/resolver_model_user.go b/internal/api/resolver_model_user.go new file mode 100644 index 000000000..2c6cc944c --- /dev/null +++ b/internal/api/resolver_model_user.go @@ -0,0 +1,19 @@ +package api + +import ( + "context" + + "github.com/stashapp/stash/pkg/models" +) + +func (r *userResolver) Name(ctx context.Context, obj *models.User) (string, error) { + return obj.Username, nil +} + +func (r *userResolver) Roles(ctx context.Context, obj *models.User) ([]models.RoleEnum, error) { + ret := make([]models.RoleEnum, len(obj.Roles)) + for i, role := range obj.Roles { + ret[i] = models.RoleEnum(role) + } + return ret, nil +} diff --git a/internal/api/resolver_mutation_configure.go b/internal/api/resolver_mutation_configure.go index 23b61c208..1db1c5587 100644 --- a/internal/api/resolver_mutation_configure.go +++ b/internal/api/resolver_mutation_configure.go @@ -637,29 +637,6 @@ func (r *mutationResolver) ConfigureDefaults(ctx context.Context, input ConfigDe return makeConfigDefaultsResult(), nil } -func (r *mutationResolver) GenerateAPIKey(ctx context.Context, input GenerateAPIKeyInput) (string, error) { - c := config.GetInstance() - - var newAPIKey string - if input.Clear == nil || !*input.Clear { - username := c.GetUsername() - if username != "" { - var err error - newAPIKey, err = manager.GenerateAPIKey(username) - if err != nil { - return "", err - } - } - } - - c.SetString(config.ApiKey, newAPIKey) - if err := c.Write(); err != nil { - return newAPIKey, err - } - - return newAPIKey, nil -} - func (r *mutationResolver) ConfigureUI(ctx context.Context, input map[string]interface{}, partial map[string]interface{}) (map[string]interface{}, error) { c := config.GetInstance() diff --git a/internal/api/resolver_mutation_user.go b/internal/api/resolver_mutation_user.go new file mode 100644 index 000000000..dbec308a1 --- /dev/null +++ b/internal/api/resolver_mutation_user.go @@ -0,0 +1,86 @@ +package api + +import ( + "context" + "fmt" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/session" +) + +func (r *mutationResolver) UserCreate(ctx context.Context, input UserCreateInput) (*models.User, error) { + err := r.userService.CreateUser(ctx, models.User{ + Username: input.Name, + Roles: models.Roles(input.Roles), + }, input.Password) + if err != nil { + return nil, err + } + + return r.userService.GetUser(ctx, input.Name) +} + +func (r *mutationResolver) UserUpdate(ctx context.Context, input UserUpdateInput) (*models.User, error) { + err := r.userService.UpdateUser(ctx, input.ExistingName, models.User{ + Username: input.Name, + Roles: models.Roles(input.Roles), + }) + if err != nil { + return nil, err + } + + return r.userService.GetUser(ctx, input.Name) +} + +func (r *mutationResolver) UserDestroy(ctx context.Context, input UserDestroyInput) (bool, error) { + err := r.userService.DeleteUser(ctx, input.Name) + if err != nil { + return false, err + } + + return true, nil +} + +func (r *mutationResolver) ChangePassword(ctx context.Context, input UserChangePasswordInput) (bool, error) { + // get current user + u := session.GetCurrentUser(ctx) + + err := r.userService.ChangePassword(ctx, u.Username, input.ExistingPassword, input.NewPassword) + if err != nil { + return false, err + } + + return true, nil +} + +func (r *mutationResolver) ChangeUserPassword(ctx context.Context, input ChangeUserPasswordInput) (bool, error) { + err := r.userService.ChangeUserPassword(ctx, input.Name, input.NewPassword) + if err != nil { + return false, err + } + + return true, nil +} + +func (r *mutationResolver) GenerateAPIKey(ctx context.Context, input GenerateAPIKeyInput) (string, error) { + u := session.GetCurrentUser(ctx) + + if u == nil { + return "", fmt.Errorf("no current user in context") + } + + if input.Clear != nil && *input.Clear { + err := r.userService.ClearAPIKey(ctx, u.Username) + if err != nil { + return "", err + } + return "", nil + } + + newAPIKey, err := r.userService.GenerateAPIKey(ctx, u.Username) + if err != nil { + return "", err + } + + return newAPIKey, nil +} diff --git a/internal/api/resolver_query_user.go b/internal/api/resolver_query_user.go new file mode 100644 index 000000000..813e595cd --- /dev/null +++ b/internal/api/resolver_query_user.go @@ -0,0 +1,17 @@ +package api + +import ( + "context" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/session" +) + +func (r *queryResolver) Users(ctx context.Context) ([]*models.User, error) { + return r.userService.AllUsers(ctx) +} + +func (r *queryResolver) Me(ctx context.Context) (*models.User, error) { + // get current user + return session.GetCurrentUser(ctx), nil +} diff --git a/internal/api/server.go b/internal/api/server.go index a7516da52..dbc20e346 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -122,9 +122,11 @@ func Initialize() (*Server, error) { manager: mgr, } + userStore := manager.GetInstance().UserService + r.Use(middleware.Heartbeat("/healthz")) r.Use(cors.AllowAll().Handler) - r.Use(authenticateHandler()) + r.Use(authenticateHandler(userStore)) visitedPluginHandler := mgr.SessionStore.VisitedPluginHandler() r.Use(visitedPluginHandler) @@ -162,16 +164,26 @@ func Initialize() (*Server, error) { imageService := mgr.ImageService galleryService := mgr.GalleryService groupService := mgr.GroupService + userService := mgr.UserService resolver := &Resolver{ repository: repo, sceneService: sceneService, imageService: imageService, galleryService: galleryService, groupService: groupService, + userService: userService, hookExecutor: pluginCache, } - gqlSrv := gqlHandler.New(NewExecutableSchema(Config{Resolvers: resolver})) + gqlCfg := Config{ + Resolvers: resolver, + Directives: DirectiveRoot{ + HasRole: HasRoleDirective, + IsUserOwner: IsUserOwnerDirective, + }, + } + + gqlSrv := gqlHandler.New(NewExecutableSchema(gqlCfg)) gqlSrv.SetRecoverFunc(recoverFunc) gqlSrv.AddTransport(gqlTransport.Websocket{ Upgrader: websocket.Upgrader{ @@ -227,9 +239,11 @@ func Initialize() (*Server, error) { staticLoginUI := statigz.FileServer(ui.LoginUIBox.(fs.ReadDirFS)) - r.Get(loginEndpoint, handleLogin()) - r.Post(loginEndpoint, handleLoginPost()) - r.Get(logoutEndpoint, handleLogout()) + sessionStore := mgr.SessionStore + + r.Get(loginEndpoint, handleLogin(userService)) + r.Post(loginEndpoint, handleLoginPost(sessionStore)) + r.Get(logoutEndpoint, handleLogout(sessionStore)) r.Get(loginLocaleEndpoint, handleLoginLocale(cfg)) r.HandleFunc(loginEndpoint+"/*", func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.TrimPrefix(r.URL.Path, loginEndpoint) diff --git a/internal/api/session.go b/internal/api/session.go index 5918cdd9b..e66e00b0d 100644 --- a/internal/api/session.go +++ b/internal/api/session.go @@ -103,11 +103,11 @@ func getLoginLocale(lang string) ([]byte, error) { return data, nil } -func handleLogin() http.HandlerFunc { +func handleLogin(s manager.UserService) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { returnURL := r.URL.Query().Get(returnURLParam) - if !config.GetInstance().HasCredentials() { + if hc := s.LoginRequired(r.Context()); !hc { if returnURL != "" { http.Redirect(w, r, returnURL, http.StatusFound) } else { @@ -121,9 +121,9 @@ func handleLogin() http.HandlerFunc { } } -func handleLoginPost() http.HandlerFunc { +func handleLoginPost(s *session.Store) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - err := manager.GetInstance().SessionStore.Login(w, r) + err := s.Login(w, r) if err != nil { // always log the error logger.Errorf("Error logging in: %v from IP: %s", err, r.RemoteAddr) @@ -146,7 +146,7 @@ func handleLoginPost() http.HandlerFunc { } } -func handleLogout() http.HandlerFunc { +func handleLogout(s *session.Store) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if err := manager.GetInstance().SessionStore.Logout(w, r); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -155,7 +155,7 @@ func handleLogout() http.HandlerFunc { // redirect to the login page if credentials are required prefix := getProxyPrefix(r) - if config.GetInstance().HasCredentials() { + if hc := s.LoginRequired(r.Context()); hc { http.Redirect(w, r, prefix+loginEndpoint, http.StatusFound) } else { http.Redirect(w, r, prefix+"/", http.StatusFound) diff --git a/internal/manager/config/config.go b/internal/manager/config/config.go index bb99bdcfc..befc30fb2 100644 --- a/internal/manager/config/config.go +++ b/internal/manager/config/config.go @@ -14,8 +14,6 @@ import ( "sync" // "github.com/sasha-s/go-deadlock" // if you have deadlock issues - "golang.org/x/crypto/bcrypt" - "github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/providers/file" "github.com/knadh/koanf/v2" @@ -39,9 +37,8 @@ const ( BlobsPath = "blobs_path" Downloads = "downloads" ApiKey = "api_key" - Username = "username" - Password = "password" - MaxSessionAge = "max_session_age" + + MaxSessionAge = "max_session_age" // SFWContentMode mode config key SFWContentMode = "sfw_content_mode" @@ -332,6 +329,9 @@ type Config struct { // configUpdates chan int certFile string keyFile string + + UserStore *UserStore + sync.RWMutex // deadlock.RWMutex // for deadlock testing/issues } @@ -345,99 +345,102 @@ func GetInstance() *Config { return instance } -func (i *Config) load(f string) error { - if err := i.main.Load(file.Provider(f), yaml.Parser()); err != nil { +func (s *Config) load(f string) error { + if err := s.main.Load(file.Provider(f), yaml.Parser()); err != nil { return err } - i.filePath = f + s.filePath = f return nil } -func (i *Config) IsNewSystem() bool { - return i.isNewSystem +func (s *Config) IsNewSystem() bool { + return s.isNewSystem } -func (i *Config) SetConfigFile(fn string) { - i.Lock() - defer i.Unlock() - i.filePath = fn +func (s *Config) SetConfigFile(fn string) { + s.Lock() + defer s.Unlock() + s.filePath = fn } -func (i *Config) InitTLS() { - configDirectory := i.GetConfigPath() +func (s *Config) InitTLS() { + configDirectory := s.GetConfigPath() tlsPaths := []string{ configDirectory, paths.GetStashHomeDirectory(), } - i.certFile = i.getString(sslCertPath) - if i.certFile == "" { + s.certFile = s.getString(sslCertPath) + if s.certFile == "" { // Look for default file - i.certFile = fsutil.FindInPaths(tlsPaths, "stash.crt") + s.certFile = fsutil.FindInPaths(tlsPaths, "stash.crt") } - i.keyFile = i.getString(sslKeyPath) - if i.keyFile == "" { + s.keyFile = s.getString(sslKeyPath) + if s.keyFile == "" { // Look for default file - i.keyFile = fsutil.FindInPaths(tlsPaths, "stash.key") + s.keyFile = fsutil.FindInPaths(tlsPaths, "stash.key") } } -func (i *Config) GetTLSFiles() (certFile, keyFile string) { - return i.certFile, i.keyFile +func (s *Config) GetTLSFiles() (certFile, keyFile string) { + return s.certFile, s.keyFile } -func (i *Config) HasTLSConfig() bool { - certFile, keyFile := i.GetTLSFiles() +func (s *Config) HasTLSConfig() bool { + certFile, keyFile := s.GetTLSFiles() return certFile != "" && keyFile != "" } -func (i *Config) GetNoBrowser() bool { - return i.getBool(NoBrowser) +func (s *Config) GetNoBrowser() bool { + return s.getBool(NoBrowser) } -func (i *Config) GetNotificationsEnabled() bool { - return i.getBool(NotificationsEnabled) +func (s *Config) GetNotificationsEnabled() bool { + return s.getBool(NotificationsEnabled) } // GetShowOneTimeMovedNotification shows whether a small notification to inform the user that Stash // will no longer show a terminal window, and instead will be available in the tray, should be shown. // It is true when an existing system is started after upgrading, and set to false forever after it is shown. -func (i *Config) GetShowOneTimeMovedNotification() bool { - return i.getBool(ShowOneTimeMovedNotification) +func (s *Config) GetShowOneTimeMovedNotification() bool { + return s.getBool(ShowOneTimeMovedNotification) } // these methods are intended to ensure type safety (ie no primitive pointers) -func (i *Config) SetBool(key string, value bool) { - i.SetInterface(key, value) +func (s *Config) SetBool(key string, value bool) { + s.SetInterface(key, value) } -func (i *Config) SetString(key string, value string) { - i.SetInterface(key, value) +func (s *Config) SetString(key string, value string) { + s.SetInterface(key, value) } -func (i *Config) SetInt(key string, value int) { - i.SetInterface(key, value) +func (s *Config) SetInt(key string, value int) { + s.SetInterface(key, value) } -func (i *Config) SetFloat(key string, value float64) { - i.SetInterface(key, value) +func (s *Config) SetFloat(key string, value float64) { + s.SetInterface(key, value) } -func (i *Config) SetInterface(key string, value interface{}) { - i.Lock() - defer i.Unlock() +func (s *Config) SetInterface(key string, value interface{}) { + s.Lock() + defer s.Unlock() - i.set(key, value) + s.setInterfaceNoLock(key, value) +} +func (s *Config) setInterfaceNoLock(key string, value interface{}) { + s.set(key, value) } -func (i *Config) set(key string, value interface{}) { +func (s *Config) set(key string, value interface{}) { // assumes lock held // default behaviour for Set is to merge the value // we want to replace it - i.main.Delete(key) + s.main.Delete(key) if value == nil { return @@ -449,52 +452,56 @@ func (i *Config) set(key string, value interface{}) { return } - _ = i.main.Set(key, value) + _ = s.main.Set(key, value) } -func (i *Config) SetDefault(key string, value interface{}) { - i.Lock() - defer i.Unlock() +func (s *Config) SetDefault(key string, value interface{}) { + s.Lock() + defer s.Unlock() - i.setDefault(key, value) + s.setDefault(key, value) } -func (i *Config) setDefault(key string, value interface{}) { - if !i.main.Exists(key) { - i.set(key, value) +func (s *Config) setDefault(key string, value interface{}) { + if !s.main.Exists(key) { + s.set(key, value) } } -func (i *Config) SetPassword(value string) { +func (s *Config) SetPassword(value string) { // if blank, don't bother hashing; we want it to be blank if value == "" { - i.SetString(Password, "") + s.SetString(Password, "") } else { - i.SetString(Password, hashPassword(value)) + s.SetString(Password, hashPassword(value)) } } -func (i *Config) Write() error { - i.Lock() - defer i.Unlock() +func (s *Config) Write() error { + s.Lock() + defer s.Unlock() - data, err := i.marshal() + return s.writeNoLock() +} + +func (s *Config) writeNoLock() error { + data, err := s.marshal() if err != nil { return err } - return os.WriteFile(i.filePath, data, 0640) + return os.WriteFile(s.filePath, data, 0640) } -func (i *Config) Marshal() ([]byte, error) { - i.RLock() - defer i.RUnlock() +func (s *Config) Marshal() ([]byte, error) { + s.RLock() + defer s.RUnlock() - return i.marshal() + return s.marshal() } -func (i *Config) marshal() ([]byte, error) { - return i.main.Marshal(yaml.Parser()) +func (s *Config) marshal() ([]byte, error) { + return s.main.Marshal(yaml.Parser()) } // FileEnvSet returns true if the configuration file environment parameter @@ -504,23 +511,23 @@ func FileEnvSet() bool { } // GetConfigFile returns the full path to the used configuration file. -func (i *Config) GetConfigFile() string { - i.RLock() - defer i.RUnlock() - return i.filePath +func (s *Config) GetConfigFile() string { + s.RLock() + defer s.RUnlock() + return s.filePath } // GetConfigPath returns the path of the directory containing the used // configuration file. -func (i *Config) GetConfigPath() string { - return filepath.Dir(i.GetConfigFile()) +func (s *Config) GetConfigPath() string { + return filepath.Dir(s.GetConfigFile()) } // GetConfigPathAbs returns the path of the directory containing the used // configuration file, resolved to an absolute path. Returns the return value // of GetConfigPath if the path cannot be made into an absolute path. -func (i *Config) GetConfigPathAbs() string { - p := filepath.Dir(i.GetConfigFile()) +func (s *Config) GetConfigPathAbs() string { + p := filepath.Dir(s.GetConfigFile()) ret, _ := filepath.Abs(p) if ret == "" { @@ -532,17 +539,17 @@ func (i *Config) GetConfigPathAbs() string { // GetDefaultDatabaseFilePath returns the default database filename, // which is located in the same directory as the config file. -func (i *Config) GetDefaultDatabaseFilePath() string { - return filepath.Join(i.GetConfigPath(), "stash-go.sqlite") +func (s *Config) GetDefaultDatabaseFilePath() string { + return filepath.Join(s.GetConfigPath(), "stash-go.sqlite") } // forKey returns the Koanf instance that should be used to get the provided // key. Returns the overrides instance if the key exists there, otherwise it // returns the main instance. Assumes read lock held. -func (i *Config) forKey(key string) *koanf.Koanf { - v := i.main - if i.overrides.Exists(key) { - v = i.overrides +func (s *Config) forKey(key string) *koanf.Koanf { + v := s.main + if s.overrides.Exists(key) { + v = s.overrides } return v @@ -550,8 +557,8 @@ func (i *Config) forKey(key string) *koanf.Koanf { // viper returns the viper instance that has the key set. Returns nil // if no instance has the key. Assumes read lock held. -func (i *Config) with(key string) *koanf.Koanf { - v := i.forKey(key) +func (s *Config) with(key string) *koanf.Koanf { + v := s.forKey(key) if v.Exists(key) { return v @@ -560,75 +567,75 @@ func (i *Config) with(key string) *koanf.Koanf { return nil } -func (i *Config) HasOverride(key string) bool { - i.RLock() - defer i.RUnlock() +func (s *Config) HasOverride(key string) bool { + s.RLock() + defer s.RUnlock() - return i.overrides.Exists(key) + return s.overrides.Exists(key) } // These functions wrap the equivalent viper functions, checking the override // instance first, then the main instance. -func (i *Config) unmarshalKey(key string, rawVal interface{}) error { - i.RLock() - defer i.RUnlock() +func (s *Config) unmarshalKey(key string, rawVal interface{}) error { + s.RLock() + defer s.RUnlock() - return i.forKey(key).Unmarshal(key, rawVal) + return s.forKey(key).Unmarshal(key, rawVal) } -func (i *Config) getStringSlice(key string) []string { - i.RLock() - defer i.RUnlock() +func (s *Config) getStringSlice(key string) []string { + s.RLock() + defer s.RUnlock() - return i.forKey(key).Strings(key) + return s.forKey(key).Strings(key) } -func (i *Config) getString(key string) string { - i.RLock() - defer i.RUnlock() +func (s *Config) getString(key string) string { + s.RLock() + defer s.RUnlock() - return i.forKey(key).String(key) + return s.forKey(key).String(key) } -func (i *Config) getBool(key string) bool { - i.RLock() - defer i.RUnlock() +func (s *Config) getBool(key string) bool { + s.RLock() + defer s.RUnlock() - return i.forKey(key).Bool(key) + return s.forKey(key).Bool(key) } -func (i *Config) getBoolDefault(key string, def bool) bool { - i.RLock() - defer i.RUnlock() +func (s *Config) getBoolDefault(key string, def bool) bool { + s.RLock() + defer s.RUnlock() ret := def - v := i.forKey(key) + v := s.forKey(key) if v.Exists(key) { ret = v.Bool(key) } return ret } -func (i *Config) getInt(key string) int { - i.RLock() - defer i.RUnlock() +func (s *Config) getInt(key string) int { + s.RLock() + defer s.RUnlock() - return i.forKey(key).Int(key) + return s.forKey(key).Int(key) } -func (i *Config) getFloat64(key string) float64 { - i.RLock() - defer i.RUnlock() +func (s *Config) getFloat64(key string) float64 { + s.RLock() + defer s.RUnlock() - return i.forKey(key).Float64(key) + return s.forKey(key).Float64(key) } -func (i *Config) getStringMapString(key string) map[string]string { - i.RLock() - defer i.RUnlock() +func (s *Config) getStringMapString(key string) map[string]string { + s.RLock() + defer s.RUnlock() - ret := i.forKey(key).StringMap(key) + ret := s.forKey(key).StringMap(key) // GetStringMapString returns an empty map regardless of whether the // key exists or not. @@ -641,24 +648,24 @@ func (i *Config) getStringMapString(key string) map[string]string { // GetSFW returns true if SFW mode is enabled. // Default performer images are changed to more agnostic images when enabled. -func (i *Config) GetSFWContentMode() bool { - i.RLock() - defer i.RUnlock() - return i.getBool(SFWContentMode) +func (s *Config) GetSFWContentMode() bool { + s.RLock() + defer s.RUnlock() + return s.getBool(SFWContentMode) } // GetStashPaths returns the configured stash library paths. // Works opposite to the usual case - it will return the override // value only if the main value is not set. -func (i *Config) GetStashPaths() StashConfigs { - i.RLock() - defer i.RUnlock() +func (s *Config) GetStashPaths() StashConfigs { + s.RLock() + defer s.RUnlock() var ret StashConfigs - v := i.main + v := s.main if !v.Exists(Stash) { - v = i.overrides + v = s.overrides } if err := v.Unmarshal(Stash, &ret); err != nil || len(ret) == 0 { @@ -676,26 +683,26 @@ func (i *Config) GetStashPaths() StashConfigs { return ret } -func (i *Config) GetCachePath() string { - return i.getString(Cache) +func (s *Config) GetCachePath() string { + return s.getString(Cache) } -func (i *Config) GetGeneratedPath() string { - return i.getString(Generated) +func (s *Config) GetGeneratedPath() string { + return s.getString(Generated) } -func (i *Config) GetBlobsPath() string { - return i.getString(BlobsPath) +func (s *Config) GetBlobsPath() string { + return s.getString(BlobsPath) } // GetExtraBlobsPaths returns extra blobs paths. // For developer/advanced use only. -func (i *Config) GetExtraBlobsPaths() []string { - return i.getStringSlice(ExtraBlobsPaths) +func (s *Config) GetExtraBlobsPaths() []string { + return s.getStringSlice(ExtraBlobsPaths) } -func (i *Config) GetBlobsStorage() BlobsStorageType { - ret := BlobsStorageType(i.getString(BlobsStorage)) +func (s *Config) GetBlobsStorage() BlobsStorageType { + ret := BlobsStorageType(s.getString(BlobsStorage)) if !ret.IsValid() { // default to database storage @@ -706,23 +713,23 @@ func (i *Config) GetBlobsStorage() BlobsStorageType { return ret } -func (i *Config) GetMetadataPath() string { - return i.getString(Metadata) +func (s *Config) GetMetadataPath() string { + return s.getString(Metadata) } -func (i *Config) GetDatabasePath() string { - return i.getString(Database) +func (s *Config) GetDatabasePath() string { + return s.getString(Database) } -func (i *Config) GetBackupDirectoryPath() string { - return i.getString(BackupDirectoryPath) +func (s *Config) GetBackupDirectoryPath() string { + return s.getString(BackupDirectoryPath) } -func (i *Config) GetBackupDirectoryPathOrDefault() string { - ret := i.GetBackupDirectoryPath() +func (s *Config) GetBackupDirectoryPathOrDefault() string { + ret := s.GetBackupDirectoryPath() if ret == "" { // #4915 - default to the same directory as the database - return filepath.Dir(i.GetDatabasePath()) + return filepath.Dir(s.GetDatabasePath()) } return ret @@ -730,69 +737,69 @@ func (i *Config) GetBackupDirectoryPathOrDefault() string { // GetFFMpegPath returns the path to the FFMpeg executable. // If empty, stash will attempt to resolve it from the path. -func (i *Config) GetFFMpegPath() string { - return i.getString(FFMpegPath) +func (s *Config) GetFFMpegPath() string { + return s.getString(FFMpegPath) } // GetFFProbePath returns the path to the FFProbe executable. // If empty, stash will attempt to resolve it from the path. -func (i *Config) GetFFProbePath() string { - return i.getString(FFProbePath) +func (s *Config) GetFFProbePath() string { + return s.getString(FFProbePath) } -func (i *Config) GetJWTSignKey() []byte { - return []byte(i.getString(JWTSignKey)) +func (s *Config) GetJWTSignKey() []byte { + return []byte(s.getString(JWTSignKey)) } -func (i *Config) GetSessionStoreKey() []byte { - return []byte(i.getString(SessionStoreKey)) +func (s *Config) GetSessionStoreKey() []byte { + return []byte(s.getString(SessionStoreKey)) } -func (i *Config) GetDefaultScrapersPath() string { +func (s *Config) GetDefaultScrapersPath() string { // default to the same directory as the config file - fn := filepath.Join(i.GetConfigPath(), "scrapers") + fn := filepath.Join(s.GetConfigPath(), "scrapers") return fn } -func (i *Config) GetExcludes() []string { - return i.getStringSlice(Exclude) +func (s *Config) GetExcludes() []string { + return s.getStringSlice(Exclude) } -func (i *Config) GetImageExcludes() []string { - return i.getStringSlice(ImageExclude) +func (s *Config) GetImageExcludes() []string { + return s.getStringSlice(ImageExclude) } -func (i *Config) GetVideoExtensions() []string { - ret := i.getStringSlice(VideoExtensions) +func (s *Config) GetVideoExtensions() []string { + ret := s.getStringSlice(VideoExtensions) if len(ret) == 0 { ret = defaultVideoExtensions } return ret } -func (i *Config) GetImageExtensions() []string { - ret := i.getStringSlice(ImageExtensions) +func (s *Config) GetImageExtensions() []string { + ret := s.getStringSlice(ImageExtensions) if len(ret) == 0 { ret = defaultImageExtensions } return ret } -func (i *Config) GetGalleryExtensions() []string { - ret := i.getStringSlice(GalleryExtensions) +func (s *Config) GetGalleryExtensions() []string { + ret := s.getStringSlice(GalleryExtensions) if len(ret) == 0 { ret = defaultGalleryExtensions } return ret } -func (i *Config) GetCreateGalleriesFromFolders() bool { - return i.getBool(CreateGalleriesFromFolders) +func (s *Config) GetCreateGalleriesFromFolders() bool { + return s.getBool(CreateGalleriesFromFolders) } -func (i *Config) GetLanguage() string { - ret := i.getString(Language) +func (s *Config) GetLanguage() string { + ret := s.getString(Language) // default to English if ret == "" { @@ -804,14 +811,14 @@ func (i *Config) GetLanguage() string { // IsCalculateMD5 returns true if MD5 checksums should be generated for // scene video files. -func (i *Config) IsCalculateMD5() bool { - return i.getBool(CalculateMD5) +func (s *Config) IsCalculateMD5() bool { + return s.getBool(CalculateMD5) } // GetVideoFileNamingAlgorithm returns what hash algorithm should be used for // naming generated scene video files. -func (i *Config) GetVideoFileNamingAlgorithm() models.HashAlgorithm { - ret := i.getString(VideoFileNamingAlgorithm) +func (s *Config) GetVideoFileNamingAlgorithm() models.HashAlgorithm { + ret := s.getString(VideoFileNamingAlgorithm) // default to oshash if ret == "" { @@ -821,12 +828,12 @@ func (i *Config) GetVideoFileNamingAlgorithm() models.HashAlgorithm { return models.HashAlgorithm(ret) } -func (i *Config) GetSequentialScanning() bool { - return i.getBool(SequentialScanning) +func (s *Config) GetSequentialScanning() bool { + return s.getBool(SequentialScanning) } -func (i *Config) GetGalleryCoverRegex() string { - var regexString = i.getString(GalleryCoverRegex) +func (s *Config) GetGalleryCoverRegex() string { + var regexString = s.getString(GalleryCoverRegex) _, err := regexp.Compile(regexString) if err != nil { @@ -837,57 +844,57 @@ func (i *Config) GetGalleryCoverRegex() string { return regexString } -func (i *Config) GetScrapersPath() string { - return i.getString(ScrapersPath) +func (s *Config) GetScrapersPath() string { + return s.getString(ScrapersPath) } -func (i *Config) GetScraperUserAgent() string { - return i.getString(ScraperUserAgent) +func (s *Config) GetScraperUserAgent() string { + return s.getString(ScraperUserAgent) } // GetScraperCDPPath gets the path to the Chrome executable or remote address // to an instance of Chrome. -func (i *Config) GetScraperCDPPath() string { - return i.getString(ScraperCDPPath) +func (s *Config) GetScraperCDPPath() string { + return s.getString(ScraperCDPPath) } // GetScraperCertCheck returns true if the scraper should check for insecure // certificates when fetching an image or a page. -func (i *Config) GetScraperCertCheck() bool { - return i.getBoolDefault(ScraperCertCheck, true) +func (s *Config) GetScraperCertCheck() bool { + return s.getBoolDefault(ScraperCertCheck, true) } -func (i *Config) GetScraperExcludeTagPatterns() []string { - return i.getStringSlice(ScraperExcludeTagPatterns) +func (s *Config) GetScraperExcludeTagPatterns() []string { + return s.getStringSlice(ScraperExcludeTagPatterns) } -func (i *Config) GetStashBoxes() []*models.StashBox { +func (s *Config) GetStashBoxes() []*models.StashBox { var boxes []*models.StashBox - if err := i.unmarshalKey(StashBoxes, &boxes); err != nil { + if err := s.unmarshalKey(StashBoxes, &boxes); err != nil { logger.Warnf("error in unmarshalkey: %v", err) } return boxes } -func (i *Config) GetDefaultPluginsPath() string { +func (s *Config) GetDefaultPluginsPath() string { // default to the same directory as the config file - fn := filepath.Join(i.GetConfigPath(), "plugins") + fn := filepath.Join(s.GetConfigPath(), "plugins") return fn } -func (i *Config) GetPluginsPath() string { - return i.getString(PluginsPath) +func (s *Config) GetPluginsPath() string { + return s.getString(PluginsPath) } -func (i *Config) GetAllPluginConfiguration() map[string]map[string]interface{} { - i.RLock() - defer i.RUnlock() +func (s *Config) GetAllPluginConfiguration() map[string]map[string]interface{} { + s.RLock() + defer s.RUnlock() ret := make(map[string]map[string]interface{}) - v := i.forKey(PluginsSetting) + v := s.forKey(PluginsSetting) sub := v.Cut(PluginsSetting) if sub == nil { @@ -901,36 +908,36 @@ func (i *Config) GetAllPluginConfiguration() map[string]map[string]interface{} { return ret } -func (i *Config) GetPluginConfiguration(pluginID string) map[string]interface{} { - i.RLock() - defer i.RUnlock() +func (s *Config) GetPluginConfiguration(pluginID string) map[string]interface{} { + s.RLock() + defer s.RUnlock() key := PluginsSettingPrefix + pluginID - return i.forKey(key).Cut(key).Raw() + return s.forKey(key).Cut(key).Raw() } // SetPluginConfiguration sets the configuration for a plugin. // It will overwrite any existing configuration. -func (i *Config) SetPluginConfiguration(pluginID string, v map[string]interface{}) { - i.Lock() - defer i.Unlock() +func (s *Config) SetPluginConfiguration(pluginID string, v map[string]interface{}) { + s.Lock() + defer s.Unlock() key := PluginsSettingPrefix + pluginID - i.set(key, v) + s.set(key, v) } -func (i *Config) GetDisabledPlugins() []string { - return i.getStringSlice(DisabledPlugins) +func (s *Config) GetDisabledPlugins() []string { + return s.getStringSlice(DisabledPlugins) } -func (i *Config) GetPythonPath() string { - return i.getString(PythonPath) +func (s *Config) GetPythonPath() string { + return s.getString(PythonPath) } -func (i *Config) GetHost() string { - ret := i.getString(Host) +func (s *Config) GetHost() string { + ret := s.getString(Host) if ret == "" { ret = hostDefault } @@ -938,8 +945,8 @@ func (i *Config) GetHost() string { return ret } -func (i *Config) GetPort() int { - ret := i.getInt(Port) +func (s *Config) GetPort() int { + ret := s.getInt(Port) if ret == 0 { ret = portDefault } @@ -947,41 +954,41 @@ func (i *Config) GetPort() int { return ret } -func (i *Config) GetThemeColor() string { - return i.getString(ThemeColor) +func (s *Config) GetThemeColor() string { + return s.getString(ThemeColor) } -func (i *Config) GetExternalHost() string { - return i.getString(ExternalHost) +func (s *Config) GetExternalHost() string { + return s.getString(ExternalHost) } // GetPreviewSegmentDuration returns the duration of a single segment in a // scene preview file, in seconds. -func (i *Config) GetPreviewSegmentDuration() float64 { - return i.getFloat64(PreviewSegmentDuration) +func (s *Config) GetPreviewSegmentDuration() float64 { + return s.getFloat64(PreviewSegmentDuration) } // GetParallelTasks returns the number of parallel tasks that should be started // by scan or generate task. -func (i *Config) GetParallelTasks() int { - return i.getInt(ParallelTasks) +func (s *Config) GetParallelTasks() int { + return s.getInt(ParallelTasks) } -func (i *Config) GetParallelTasksWithAutoDetection() int { - parallelTasks := i.getInt(ParallelTasks) +func (s *Config) GetParallelTasksWithAutoDetection() int { + parallelTasks := s.getInt(ParallelTasks) if parallelTasks <= 0 { parallelTasks = (runtime.NumCPU() / 4) + 1 } return parallelTasks } -func (i *Config) GetPreviewAudio() bool { - return i.getBool(PreviewAudio) +func (s *Config) GetPreviewAudio() bool { + return s.getBool(PreviewAudio) } // GetPreviewSegments returns the amount of segments in a scene preview file. -func (i *Config) GetPreviewSegments() int { - return i.getInt(PreviewSegments) +func (s *Config) GetPreviewSegments() int { + return s.getInt(PreviewSegments) } // GetPreviewExcludeStart returns the configuration setting string for @@ -990,8 +997,8 @@ func (i *Config) GetPreviewSegments() int { // of seconds to exclude from the start of the video before it is included // in the preview. If the value is suffixed with a '%' character (for example // '2%'), then it is interpreted as a proportion of the total video duration. -func (i *Config) GetPreviewExcludeStart() string { - return i.getString(PreviewExcludeStart) +func (s *Config) GetPreviewExcludeStart() string { + return s.getString(PreviewExcludeStart) } // GetPreviewExcludeEnd returns the configuration setting string for @@ -999,14 +1006,14 @@ func (i *Config) GetPreviewExcludeStart() string { // is interpreted as the amount of seconds to exclude from the end of the video // when generating previews. If the value is suffixed with a '%' character, // then it is interpreted as a proportion of the total video duration. -func (i *Config) GetPreviewExcludeEnd() string { - return i.getString(PreviewExcludeEnd) +func (s *Config) GetPreviewExcludeEnd() string { + return s.getString(PreviewExcludeEnd) } // GetPreviewPreset returns the preset when generating previews. Defaults to // Slow. -func (i *Config) GetPreviewPreset() models.PreviewPreset { - ret := i.getString(PreviewPreset) +func (s *Config) GetPreviewPreset() models.PreviewPreset { + ret := s.getString(PreviewPreset) // default to slow if ret == "" { @@ -1016,12 +1023,12 @@ func (i *Config) GetPreviewPreset() models.PreviewPreset { return models.PreviewPreset(ret) } -func (i *Config) GetTranscodeHardwareAcceleration() bool { - return i.getBool(TranscodeHardwareAcceleration) +func (s *Config) GetTranscodeHardwareAcceleration() bool { + return s.getBool(TranscodeHardwareAcceleration) } -func (i *Config) GetMaxTranscodeSize() models.StreamingResolutionEnum { - ret := i.getString(MaxTranscodeSize) +func (s *Config) GetMaxTranscodeSize() models.StreamingResolutionEnum { + ret := s.getString(MaxTranscodeSize) // default to original if ret == "" { @@ -1031,8 +1038,8 @@ func (i *Config) GetMaxTranscodeSize() models.StreamingResolutionEnum { return models.StreamingResolutionEnum(ret) } -func (i *Config) GetMaxStreamingTranscodeSize() models.StreamingResolutionEnum { - ret := i.getString(MaxStreamingTranscodeSize) +func (s *Config) GetMaxStreamingTranscodeSize() models.StreamingResolutionEnum { + ret := s.getString(MaxStreamingTranscodeSize) // default to original if ret == "" { @@ -1042,80 +1049,38 @@ func (i *Config) GetMaxStreamingTranscodeSize() models.StreamingResolutionEnum { return models.StreamingResolutionEnum(ret) } -func (i *Config) GetTranscodeInputArgs() []string { - return i.getStringSlice(TranscodeInputArgs) +func (s *Config) GetTranscodeInputArgs() []string { + return s.getStringSlice(TranscodeInputArgs) } -func (i *Config) GetTranscodeOutputArgs() []string { - return i.getStringSlice(TranscodeOutputArgs) +func (s *Config) GetTranscodeOutputArgs() []string { + return s.getStringSlice(TranscodeOutputArgs) } -func (i *Config) GetLiveTranscodeInputArgs() []string { - return i.getStringSlice(LiveTranscodeInputArgs) +func (s *Config) GetLiveTranscodeInputArgs() []string { + return s.getStringSlice(LiveTranscodeInputArgs) } -func (i *Config) GetLiveTranscodeOutputArgs() []string { - return i.getStringSlice(LiveTranscodeOutputArgs) +func (s *Config) GetLiveTranscodeOutputArgs() []string { + return s.getStringSlice(LiveTranscodeOutputArgs) } -func (i *Config) GetDrawFunscriptHeatmapRange() bool { - return i.getBoolDefault(DrawFunscriptHeatmapRange, drawFunscriptHeatmapRangeDefault) +func (s *Config) GetDrawFunscriptHeatmapRange() bool { + return s.getBoolDefault(DrawFunscriptHeatmapRange, drawFunscriptHeatmapRangeDefault) } // IsWriteImageThumbnails returns true if image thumbnails should be written // to disk after generating on the fly. -func (i *Config) IsWriteImageThumbnails() bool { - return i.getBool(WriteImageThumbnails) +func (s *Config) IsWriteImageThumbnails() bool { + return s.getBool(WriteImageThumbnails) } -func (i *Config) IsCreateImageClipsFromVideos() bool { - return i.getBool(CreateImageClipsFromVideos) +func (s *Config) IsCreateImageClipsFromVideos() bool { + return s.getBool(CreateImageClipsFromVideos) } -func (i *Config) GetAPIKey() string { - return i.getString(ApiKey) -} - -func (i *Config) GetUsername() string { - return i.getString(Username) -} - -func (i *Config) GetPasswordHash() string { - return i.getString(Password) -} - -func (i *Config) GetCredentials() (string, string) { - if i.HasCredentials() { - return i.getString(Username), i.getString(Password) - } - - return "", "" -} - -func (i *Config) HasCredentials() bool { - username := i.getString(Username) - pwHash := i.getString(Password) - - return username != "" && pwHash != "" -} - -func hashPassword(password string) string { - hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost) - - return string(hash) -} - -func (i *Config) ValidateCredentials(username string, password string) bool { - if !i.HasCredentials() { - // don't need to authenticate if no credentials saved - return true - } - - authUser, authPWHash := i.GetCredentials() - - err := bcrypt.CompareHashAndPassword([]byte(authPWHash), []byte(password)) - - return username == authUser && err == nil +func (s *Config) GetAPIKey() string { + return s.getString(ApiKey) } func stashBoxValidate(str string) bool { @@ -1130,7 +1095,7 @@ type StashBoxInput struct { MaxRequestsPerMinute int `json:"max_requests_per_minute"` } -func (i *Config) ValidateStashBoxes(boxes []*StashBoxInput) error { +func (s *Config) ValidateStashBoxes(boxes []*StashBoxInput) error { isMulti := len(boxes) > 1 for _, box := range boxes { @@ -1157,12 +1122,12 @@ func (i *Config) ValidateStashBoxes(boxes []*StashBoxInput) error { // GetMaxSessionAge gets the maximum age for session cookies, in seconds. // Session cookie expiry times are refreshed every request. -func (i *Config) GetMaxSessionAge() int { - i.RLock() - defer i.RUnlock() +func (s *Config) GetMaxSessionAge() int { + s.RLock() + defer s.RUnlock() ret := DefaultMaxSessionAge - v := i.forKey(MaxSessionAge) + v := s.forKey(MaxSessionAge) if v.Exists(MaxSessionAge) { ret = v.Int(MaxSessionAge) } @@ -1172,55 +1137,55 @@ func (i *Config) GetMaxSessionAge() int { // GetCustomServedFolders gets the map of custom paths to their applicable // filesystem locations -func (i *Config) GetCustomServedFolders() utils.URLMap { - return i.getStringMapString(CustomServedFolders) +func (s *Config) GetCustomServedFolders() utils.URLMap { + return s.getStringMapString(CustomServedFolders) } -func (i *Config) GetUILocation() string { - if ret := i.getString(UILocation); ret != "" { +func (s *Config) GetUILocation() string { + if ret := s.getString(UILocation); ret != "" { return ret } - return i.getString(LegacyCustomUILocation) + return s.getString(LegacyCustomUILocation) } // Interface options -func (i *Config) GetMenuItems() []string { - i.RLock() - defer i.RUnlock() - v := i.forKey(MenuItems) +func (s *Config) GetMenuItems() []string { + s.RLock() + defer s.RUnlock() + v := s.forKey(MenuItems) if v.Exists(MenuItems) { return v.Strings(MenuItems) } return defaultMenuItems } -func (i *Config) GetSoundOnPreview() bool { - return i.getBool(SoundOnPreview) +func (s *Config) GetSoundOnPreview() bool { + return s.getBool(SoundOnPreview) } -func (i *Config) GetWallShowTitle() bool { - i.RLock() - defer i.RUnlock() +func (s *Config) GetWallShowTitle() bool { + s.RLock() + defer s.RUnlock() ret := defaultWallShowTitle - v := i.forKey(WallShowTitle) + v := s.forKey(WallShowTitle) if v.Exists(WallShowTitle) { ret = v.Bool(WallShowTitle) } return ret } -func (i *Config) GetCustomPerformerImageLocation() string { - return i.getString(CustomPerformerImageLocation) +func (s *Config) GetCustomPerformerImageLocation() string { + return s.getString(CustomPerformerImageLocation) } -func (i *Config) GetWallPlayback() string { - i.RLock() - defer i.RUnlock() +func (s *Config) GetWallPlayback() string { + s.RLock() + defer s.RUnlock() ret := defaultWallPlayback - v := i.forKey(WallPlayback) + v := s.forKey(WallPlayback) if v.Exists(WallPlayback) { ret = v.String(WallPlayback) } @@ -1228,40 +1193,40 @@ func (i *Config) GetWallPlayback() string { return ret } -func (i *Config) GetShowScrubber() bool { - return i.getBoolDefault(ShowScrubber, showScrubberDefault) +func (s *Config) GetShowScrubber() bool { + return s.getBoolDefault(ShowScrubber, showScrubberDefault) } -func (i *Config) GetMaximumLoopDuration() int { - return i.getInt(MaximumLoopDuration) +func (s *Config) GetMaximumLoopDuration() int { + return s.getInt(MaximumLoopDuration) } -func (i *Config) GetAutostartVideo() bool { - return i.getBool(AutostartVideo) +func (s *Config) GetAutostartVideo() bool { + return s.getBool(AutostartVideo) } -func (i *Config) GetAutostartVideoOnPlaySelected() bool { - return i.getBoolDefault(AutostartVideoOnPlaySelected, autostartVideoOnPlaySelectedDefault) +func (s *Config) GetAutostartVideoOnPlaySelected() bool { + return s.getBoolDefault(AutostartVideoOnPlaySelected, autostartVideoOnPlaySelectedDefault) } -func (i *Config) GetContinuePlaylistDefault() bool { - return i.getBool(ContinuePlaylistDefault) +func (s *Config) GetContinuePlaylistDefault() bool { + return s.getBool(ContinuePlaylistDefault) } -func (i *Config) GetShowStudioAsText() bool { - return i.getBool(ShowStudioAsText) +func (s *Config) GetShowStudioAsText() bool { + return s.getBool(ShowStudioAsText) } -func (i *Config) getSlideshowDelay() int { +func (s *Config) getSlideshowDelay() int { // assume have lock ret := defaultImageLightboxSlideshowDelay - v := i.forKey(ImageLightboxSlideshowDelay) + v := s.forKey(ImageLightboxSlideshowDelay) if v.Exists(ImageLightboxSlideshowDelay) { ret = v.Int(ImageLightboxSlideshowDelay) } else { // fallback to old location - v := i.forKey(legacyImageLightboxSlideshowDelay) + v := s.forKey(legacyImageLightboxSlideshowDelay) if v.Exists(legacyImageLightboxSlideshowDelay) { ret = v.Int(legacyImageLightboxSlideshowDelay) } @@ -1270,36 +1235,36 @@ func (i *Config) getSlideshowDelay() int { return ret } -func (i *Config) GetImageLightboxOptions() ConfigImageLightboxResult { - i.RLock() - defer i.RUnlock() +func (s *Config) GetImageLightboxOptions() ConfigImageLightboxResult { + s.RLock() + defer s.RUnlock() - delay := i.getSlideshowDelay() + delay := s.getSlideshowDelay() ret := ConfigImageLightboxResult{ SlideshowDelay: &delay, } - if v := i.with(ImageLightboxDisplayModeKey); v != nil { + if v := s.with(ImageLightboxDisplayModeKey); v != nil { mode := ImageLightboxDisplayMode(v.String(ImageLightboxDisplayModeKey)) ret.DisplayMode = &mode } - if v := i.with(ImageLightboxScaleUp); v != nil { + if v := s.with(ImageLightboxScaleUp); v != nil { value := v.Bool(ImageLightboxScaleUp) ret.ScaleUp = &value } - if v := i.with(ImageLightboxResetZoomOnNav); v != nil { + if v := s.with(ImageLightboxResetZoomOnNav); v != nil { value := v.Bool(ImageLightboxResetZoomOnNav) ret.ResetZoomOnNav = &value } - if v := i.with(ImageLightboxScrollModeKey); v != nil { + if v := s.with(ImageLightboxScrollModeKey); v != nil { mode := ImageLightboxScrollMode(v.String(ImageLightboxScrollModeKey)) ret.ScrollMode = &mode } - if v := i.with(ImageLightboxScrollAttemptsBeforeChange); v != nil { + if v := s.with(ImageLightboxScrollAttemptsBeforeChange); v != nil { ret.ScrollAttemptsBeforeChange = v.Int(ImageLightboxScrollAttemptsBeforeChange) } - if v := i.with(ImageLightboxDisableAnimation); v != nil { + if v := s.with(ImageLightboxDisableAnimation); v != nil { value := v.Bool(ImageLightboxDisableAnimation) ret.DisableAnimation = &value } @@ -1307,27 +1272,27 @@ func (i *Config) GetImageLightboxOptions() ConfigImageLightboxResult { return ret } -func (i *Config) GetDisableDropdownCreate() *ConfigDisableDropdownCreate { +func (s *Config) GetDisableDropdownCreate() *ConfigDisableDropdownCreate { return &ConfigDisableDropdownCreate{ - Performer: i.getBool(DisableDropdownCreatePerformer), - Studio: i.getBool(DisableDropdownCreateStudio), - Tag: i.getBool(DisableDropdownCreateTag), - Movie: i.getBool(DisableDropdownCreateMovie), - Gallery: i.getBool(DisableDropdownCreateGallery), + Performer: s.getBool(DisableDropdownCreatePerformer), + Studio: s.getBool(DisableDropdownCreateStudio), + Tag: s.getBool(DisableDropdownCreateTag), + Movie: s.getBool(DisableDropdownCreateMovie), + Gallery: s.getBool(DisableDropdownCreateGallery), } } -func (i *Config) GetUIConfiguration() map[string]interface{} { - i.RLock() - defer i.RUnlock() +func (s *Config) GetUIConfiguration() map[string]interface{} { + s.RLock() + defer s.RUnlock() - return i.forKey(UI).Cut(UI).Raw() + return s.forKey(UI).Cut(UI).Raw() } // GetMinimumPlayPercent returns the minimum percentage of a video that must be // watched before incrementing the play count. Returns 0 if not configured. -func (i *Config) GetMinimumPlayPercent() int { - uiConfig := i.GetUIConfiguration() +func (s *Config) GetMinimumPlayPercent() int { + uiConfig := s.GetUIConfiguration() if uiConfig == nil { return 0 } @@ -1344,16 +1309,16 @@ func (i *Config) GetMinimumPlayPercent() int { return 0 } -func (i *Config) SetUIConfiguration(v map[string]interface{}) { - i.Lock() - defer i.Unlock() +func (s *Config) SetUIConfiguration(v map[string]interface{}) { + s.Lock() + defer s.Unlock() - i.set(UI, v) + s.set(UI, v) } -func (i *Config) GetCSSPath() string { +func (s *Config) GetCSSPath() string { // use custom.css in the same directory as the config file - configFileUsed := i.GetConfigFile() + configFileUsed := s.GetConfigFile() configDir := filepath.Dir(configFileUsed) fn := filepath.Join(configDir, "custom.css") @@ -1361,8 +1326,8 @@ func (i *Config) GetCSSPath() string { return fn } -func (i *Config) GetCSS() string { - fn := i.GetCSSPath() +func (s *Config) GetCSS() string { + fn := s.GetCSSPath() exists, _ := fsutil.FileExists(fn) if !exists { @@ -1378,10 +1343,10 @@ func (i *Config) GetCSS() string { return string(buf) } -func (i *Config) SetCSS(css string) { - fn := i.GetCSSPath() - i.Lock() - defer i.Unlock() +func (s *Config) SetCSS(css string) { + fn := s.GetCSSPath() + s.Lock() + defer s.Unlock() buf := []byte(css) @@ -1390,13 +1355,13 @@ func (i *Config) SetCSS(css string) { } } -func (i *Config) GetCSSEnabled() bool { - return i.getBool(CSSEnabled) +func (s *Config) GetCSSEnabled() bool { + return s.getBool(CSSEnabled) } -func (i *Config) GetJavascriptPath() string { +func (s *Config) GetJavascriptPath() string { // use custom.js in the same directory as the config file - configFileUsed := i.GetConfigFile() + configFileUsed := s.GetConfigFile() configDir := filepath.Dir(configFileUsed) fn := filepath.Join(configDir, "custom.js") @@ -1404,8 +1369,8 @@ func (i *Config) GetJavascriptPath() string { return fn } -func (i *Config) GetJavascript() string { - fn := i.GetJavascriptPath() +func (s *Config) GetJavascript() string { + fn := s.GetJavascriptPath() exists, _ := fsutil.FileExists(fn) if !exists { @@ -1421,10 +1386,10 @@ func (i *Config) GetJavascript() string { return string(buf) } -func (i *Config) SetJavascript(javascript string) { - fn := i.GetJavascriptPath() - i.Lock() - defer i.Unlock() +func (s *Config) SetJavascript(javascript string) { + fn := s.GetJavascriptPath() + s.Lock() + defer s.Unlock() buf := []byte(javascript) @@ -1433,13 +1398,13 @@ func (i *Config) SetJavascript(javascript string) { } } -func (i *Config) GetJavascriptEnabled() bool { - return i.getBool(JavascriptEnabled) +func (s *Config) GetJavascriptEnabled() bool { + return s.getBool(JavascriptEnabled) } -func (i *Config) GetCustomLocalesPath() string { +func (s *Config) GetCustomLocalesPath() string { // use custom-locales.json in the same directory as the config file - configFileUsed := i.GetConfigFile() + configFileUsed := s.GetConfigFile() configDir := filepath.Dir(configFileUsed) fn := filepath.Join(configDir, "custom-locales.json") @@ -1447,8 +1412,8 @@ func (i *Config) GetCustomLocalesPath() string { return fn } -func (i *Config) GetCustomLocales() string { - fn := i.GetCustomLocalesPath() +func (s *Config) GetCustomLocales() string { + fn := s.GetCustomLocalesPath() exists, _ := fsutil.FileExists(fn) if !exists { @@ -1464,10 +1429,10 @@ func (i *Config) GetCustomLocales() string { return string(buf) } -func (i *Config) SetCustomLocales(customLocales string) { - fn := i.GetCustomLocalesPath() - i.Lock() - defer i.Unlock() +func (s *Config) SetCustomLocales(customLocales string) { + fn := s.GetCustomLocalesPath() + s.Lock() + defer s.Unlock() buf := []byte(customLocales) @@ -1476,52 +1441,52 @@ func (i *Config) SetCustomLocales(customLocales string) { } } -func (i *Config) GetCustomLocalesEnabled() bool { - return i.getBool(CustomLocalesEnabled) +func (s *Config) GetCustomLocalesEnabled() bool { + return s.getBool(CustomLocalesEnabled) } // GetDisableCustomizations returns true if all customizations (plugins, custom CSS, // custom JavaScript, and custom locales) should be disabled. This is useful for // troubleshooting issues without permanently disabling individual customizations. -func (i *Config) GetDisableCustomizations() bool { - return i.getBool(DisableCustomizations) +func (s *Config) GetDisableCustomizations() bool { + return s.getBool(DisableCustomizations) } -func (i *Config) GetHandyKey() string { - return i.getString(HandyKey) +func (s *Config) GetHandyKey() string { + return s.getString(HandyKey) } -func (i *Config) GetFunscriptOffset() int { - return i.getInt(FunscriptOffset) +func (s *Config) GetFunscriptOffset() int { + return s.getInt(FunscriptOffset) } -func (i *Config) GetUseStashHostedFunscript() bool { - return i.getBoolDefault(UseStashHostedFunscript, useStashHostedFunscriptDefault) +func (s *Config) GetUseStashHostedFunscript() bool { + return s.getBoolDefault(UseStashHostedFunscript, useStashHostedFunscriptDefault) } -func (i *Config) GetDeleteFileDefault() bool { - return i.getBool(DeleteFileDefault) +func (s *Config) GetDeleteFileDefault() bool { + return s.getBool(DeleteFileDefault) } -func (i *Config) GetDeleteGeneratedDefault() bool { - return i.getBoolDefault(DeleteGeneratedDefault, deleteGeneratedDefaultDefault) +func (s *Config) GetDeleteGeneratedDefault() bool { + return s.getBoolDefault(DeleteGeneratedDefault, deleteGeneratedDefaultDefault) } -func (i *Config) GetDeleteTrashPath() string { - return i.getString(DeleteTrashPath) +func (s *Config) GetDeleteTrashPath() string { + return s.getString(DeleteTrashPath) } -func (i *Config) SetDeleteTrashPath(value string) { - i.SetString(DeleteTrashPath, value) +func (s *Config) SetDeleteTrashPath(value string) { + s.SetString(DeleteTrashPath, value) } // GetDefaultIdentifySettings returns the default Identify task settings. // Returns nil if the settings could not be unmarshalled, or if it // has not been set. -func (i *Config) GetDefaultIdentifySettings() *identify.Options { - i.RLock() - defer i.RUnlock() - v := i.forKey(DefaultIdentifySettings) +func (s *Config) GetDefaultIdentifySettings() *identify.Options { + s.RLock() + defer s.RUnlock() + v := s.forKey(DefaultIdentifySettings) if v.Exists(DefaultIdentifySettings) && v.Get(DefaultIdentifySettings) != nil { var ret identify.Options @@ -1538,10 +1503,10 @@ func (i *Config) GetDefaultIdentifySettings() *identify.Options { // GetDefaultScanSettings returns the default Scan task settings. // Returns nil if the settings could not be unmarshalled, or if it // has not been set. -func (i *Config) GetDefaultScanSettings() *ScanMetadataOptions { - i.RLock() - defer i.RUnlock() - v := i.forKey(DefaultScanSettings) +func (s *Config) GetDefaultScanSettings() *ScanMetadataOptions { + s.RLock() + defer s.RUnlock() + v := s.forKey(DefaultScanSettings) if v.Exists(DefaultScanSettings) && v.Get(DefaultScanSettings) != nil { var ret ScanMetadataOptions @@ -1557,10 +1522,10 @@ func (i *Config) GetDefaultScanSettings() *ScanMetadataOptions { // GetDefaultAutoTagSettings returns the default Scan task settings. // Returns nil if the settings could not be unmarshalled, or if it // has not been set. -func (i *Config) GetDefaultAutoTagSettings() *AutoTagMetadataOptions { - i.RLock() - defer i.RUnlock() - v := i.forKey(DefaultAutoTagSettings) +func (s *Config) GetDefaultAutoTagSettings() *AutoTagMetadataOptions { + s.RLock() + defer s.RUnlock() + v := s.forKey(DefaultAutoTagSettings) if v.Exists(DefaultAutoTagSettings) { var ret AutoTagMetadataOptions @@ -1576,10 +1541,10 @@ func (i *Config) GetDefaultAutoTagSettings() *AutoTagMetadataOptions { // GetDefaultGenerateSettings returns the default Scan task settings. // Returns nil if the settings could not be unmarshalled, or if it // has not been set. -func (i *Config) GetDefaultGenerateSettings() *models.GenerateMetadataOptions { - i.RLock() - defer i.RUnlock() - v := i.forKey(DefaultGenerateSettings) +func (s *Config) GetDefaultGenerateSettings() *models.GenerateMetadataOptions { + s.RLock() + defer s.RUnlock() + v := s.forKey(DefaultGenerateSettings) if v.Exists(DefaultGenerateSettings) { var ret models.GenerateMetadataOptions @@ -1594,44 +1559,44 @@ func (i *Config) GetDefaultGenerateSettings() *models.GenerateMetadataOptions { // GetDangerousAllowPublicWithoutAuth determines if the security feature is enabled. // See https://discourse.stashapp.cc/t/-/1658 -func (i *Config) GetDangerousAllowPublicWithoutAuth() bool { - return i.getBool(dangerousAllowPublicWithoutAuth) +func (s *Config) GetDangerousAllowPublicWithoutAuth() bool { + return s.getBool(dangerousAllowPublicWithoutAuth) } // GetSecurityTripwireAccessedFromPublicInternet returns a public IP address if stash // has been accessed from the public internet, with no auth enabled, and // DangerousAllowPublicWithoutAuth disabled. Returns an empty string otherwise. -func (i *Config) GetSecurityTripwireAccessedFromPublicInternet() string { - return i.getString(SecurityTripwireAccessedFromPublicInternet) +func (s *Config) GetSecurityTripwireAccessedFromPublicInternet() string { + return s.getString(SecurityTripwireAccessedFromPublicInternet) } // GetDLNAServerName returns the visible name of the DLNA server. If empty, // "stash" will be used. -func (i *Config) GetDLNAServerName() string { - return i.getString(DLNAServerName) +func (s *Config) GetDLNAServerName() string { + return s.getString(DLNAServerName) } // GetDLNADefaultEnabled returns true if the DLNA is enabled by default. -func (i *Config) GetDLNADefaultEnabled() bool { - return i.getBool(DLNADefaultEnabled) +func (s *Config) GetDLNADefaultEnabled() bool { + return s.getBool(DLNADefaultEnabled) } // GetDLNADefaultIPWhitelist returns a list of IP addresses/wildcards that // are allowed to use the DLNA service. -func (i *Config) GetDLNADefaultIPWhitelist() []string { - return i.getStringSlice(DLNADefaultIPWhitelist) +func (s *Config) GetDLNADefaultIPWhitelist() []string { + return s.getStringSlice(DLNADefaultIPWhitelist) } // GetDLNAInterfaces returns a list of interface names to expose DLNA on. If // empty, runs on all interfaces. -func (i *Config) GetDLNAInterfaces() []string { - return i.getStringSlice(DLNAInterfaces) +func (s *Config) GetDLNAInterfaces() []string { + return s.getStringSlice(DLNAInterfaces) } // GetDLNAPort returns the port to run the DLNA server on. If empty, 1338 // will be used. -func (i *Config) GetDLNAPort() int { - ret := i.getInt(DLNAPort) +func (s *Config) GetDLNAPort() int { + ret := s.getInt(DLNAPort) if ret == 0 { ret = DLNAPortDefault } @@ -1639,15 +1604,15 @@ func (i *Config) GetDLNAPort() int { } // GetDLNAPortAsString returns the port to run the DLNA server on as a string. -func (i *Config) GetDLNAPortAsString() string { - return ":" + strconv.Itoa(i.GetDLNAPort()) +func (s *Config) GetDLNAPortAsString() string { + return ":" + strconv.Itoa(s.GetDLNAPort()) } // GetDLNAActivityTrackingEnabled returns true if DLNA activity tracking is enabled. // This uses the same "trackActivity" UI setting that controls frontend play history tracking. // When enabled, scenes played via DLNA will have their play count and duration tracked. -func (i *Config) GetDLNAActivityTrackingEnabled() bool { - uiConfig := i.GetUIConfiguration() +func (s *Config) GetDLNAActivityTrackingEnabled() bool { + uiConfig := s.GetUIConfiguration() if uiConfig == nil { return true // Default to enabled } @@ -1661,8 +1626,8 @@ func (i *Config) GetDLNAActivityTrackingEnabled() bool { // GetVideoSortOrder returns the sort order to display videos. If // empty, videos will be sorted by titles. -func (i *Config) GetVideoSortOrder() string { - ret := i.getString(DLNAVideoSortOrder) +func (s *Config) GetVideoSortOrder() string { + ret := s.getString(DLNAVideoSortOrder) if ret == "" { ret = dlnaVideoSortOrderDefault } @@ -1672,21 +1637,21 @@ func (i *Config) GetVideoSortOrder() string { // GetLogFile returns the filename of the file to output logs to. // An empty string means that file logging will be disabled. -func (i *Config) GetLogFile() string { - return i.getString(LogFile) +func (s *Config) GetLogFile() string { + return s.getString(LogFile) } // GetLogOut returns true if logging should be output to the terminal // in addition to writing to a log file. Logging will be output to the // terminal if file logging is disabled. Defaults to true. -func (i *Config) GetLogOut() bool { - return i.getBoolDefault(LogOut, defaultLogOut) +func (s *Config) GetLogOut() bool { + return s.getBoolDefault(LogOut, defaultLogOut) } // GetLogLevel returns the lowest log level to write to the log. // Should be one of "Debug", "Info", "Warning", "Error" -func (i *Config) GetLogLevel() string { - value := i.getString(LogLevel) +func (s *Config) GetLogLevel() string { + value := s.getString(LogLevel) if value != "Debug" && value != "Info" && value != "Warning" && value != "Error" && value != "Trace" { value = defaultLogLevel } @@ -1696,13 +1661,13 @@ func (i *Config) GetLogLevel() string { // GetLogAccess returns true if http requests should be logged to the terminal. // HTTP requests are not logged to the log file. Defaults to true. -func (i *Config) GetLogAccess() bool { - return i.getBoolDefault(LogAccess, defaultLogAccess) +func (s *Config) GetLogAccess() bool { + return s.getBoolDefault(LogAccess, defaultLogAccess) } // GetLogFileMaxSize returns the maximum size of the log file in megabytes for lumberjack to rotate -func (i *Config) GetLogFileMaxSize() int { - value := i.getInt(LogFileMaxSize) +func (s *Config) GetLogFileMaxSize() int { + value := s.getInt(LogFileMaxSize) if value < 0 { value = defaultLogFileMaxSize } @@ -1711,12 +1676,12 @@ func (i *Config) GetLogFileMaxSize() int { } // Max allowed graphql upload size in megabytes -func (i *Config) GetMaxUploadSize() int64 { - i.RLock() - defer i.RUnlock() +func (s *Config) GetMaxUploadSize() int64 { + s.RLock() + defer s.RUnlock() ret := int64(1024) - v := i.forKey(MaxUploadSize) + v := s.forKey(MaxUploadSize) if v.Exists(MaxUploadSize) { ret = v.Int64(MaxUploadSize) } @@ -1724,10 +1689,10 @@ func (i *Config) GetMaxUploadSize() int64 { } // GetProxy returns the url of a http proxy to be used for all outgoing http calls. -func (i *Config) GetProxy() string { +func (s *Config) GetProxy() string { // Validate format reg := regexp.MustCompile(`^((?:socks5h?|https?):\/\/)(([\P{Cc}]+):([\P{Cc}]+)@)?(([a-zA-Z0-9][a-zA-Z0-9.-]*)(:[0-9]{1,5})?)`) - proxy := i.getString(Proxy) + proxy := s.getString(Proxy) if proxy != "" && reg.MatchString(proxy) { logger.Debug("Proxy is valid, using it") return proxy @@ -1739,34 +1704,34 @@ func (i *Config) GetProxy() string { } // GetProxy returns the url of a http proxy to be used for all outgoing http calls. -func (i *Config) GetNoProxy() string { +func (s *Config) GetNoProxy() string { // NoProxy does not require validation, it is validated by the native Go library sufficiently - return i.getString(NoProxy) + return s.getString(NoProxy) } // ActivatePublicAccessTripwire sets the security_tripwire_accessed_from_public_internet // config field to the provided IP address to indicate that stash has been accessed // from this public IP without authentication. -func (i *Config) ActivatePublicAccessTripwire(requestIP string) error { - i.SetString(SecurityTripwireAccessedFromPublicInternet, requestIP) - return i.Write() +func (s *Config) ActivatePublicAccessTripwire(requestIP string) error { + s.SetString(SecurityTripwireAccessedFromPublicInternet, requestIP) + return s.Write() } -func (i *Config) getPackageSources(key string) []*models.PackageSource { +func (s *Config) getPackageSources(key string) []*models.PackageSource { var sources []*models.PackageSource - if err := i.unmarshalKey(key, &sources); err != nil { + if err := s.unmarshalKey(key, &sources); err != nil { logger.Warnf("error in unmarshalkey: %v", err) } return sources } -func (i *Config) GetPluginPackageSources() []*models.PackageSource { - return i.getPackageSources(PluginPackageSources) +func (s *Config) GetPluginPackageSources() []*models.PackageSource { + return s.getPackageSources(PluginPackageSources) } -func (i *Config) GetScraperPackageSources() []*models.PackageSource { - return i.getPackageSources(ScraperPackageSources) +func (s *Config) GetScraperPackageSources() []*models.PackageSource { + return s.getPackageSources(ScraperPackageSources) } type packagePathGetter struct { @@ -1795,21 +1760,21 @@ func (g packagePathGetter) GetSourcePath(srcURL string) string { return "" } -func (i *Config) GetPluginPackagePathGetter() packagePathGetter { +func (s *Config) GetPluginPackagePathGetter() packagePathGetter { return packagePathGetter{ - getterFn: i.GetPluginPackageSources, + getterFn: s.GetPluginPackageSources, } } -func (i *Config) GetScraperPackagePathGetter() packagePathGetter { +func (s *Config) GetScraperPackagePathGetter() packagePathGetter { return packagePathGetter{ - getterFn: i.GetScraperPackageSources, + getterFn: s.GetScraperPackageSources, } } -func (i *Config) Validate() error { - i.RLock() - defer i.RUnlock() +func (s *Config) Validate() error { + s.RLock() + defer s.RUnlock() mandatoryPaths := []string{ Database, Generated, @@ -1818,7 +1783,7 @@ func (i *Config) Validate() error { var missingFields []string for _, p := range mandatoryPaths { - if !i.forKey(p).Exists(p) || i.forKey(p).String(p) == "" { + if !s.forKey(p).Exists(p) || s.forKey(p).String(p) == "" { missingFields = append(missingFields, p) } } @@ -1829,7 +1794,7 @@ func (i *Config) Validate() error { } } - if i.GetBlobsStorage() == BlobStorageTypeFilesystem && i.forKey(BlobsPath).String(BlobsPath) == "" { + if s.GetBlobsStorage() == BlobStorageTypeFilesystem && s.forKey(BlobsPath).String(BlobsPath) == "" { return MissingConfigError{ missingFields: []string{BlobsPath}, } @@ -1838,63 +1803,63 @@ func (i *Config) Validate() error { return nil } -func (i *Config) setDefaultValues() { +func (s *Config) setDefaultValues() { // read data before write lock scope - defaultDatabaseFilePath := i.GetDefaultDatabaseFilePath() - defaultScrapersPath := i.GetDefaultScrapersPath() - defaultPluginsPath := i.GetDefaultPluginsPath() + defaultDatabaseFilePath := s.GetDefaultDatabaseFilePath() + defaultScrapersPath := s.GetDefaultScrapersPath() + defaultPluginsPath := s.GetDefaultPluginsPath() - i.Lock() - defer i.Unlock() + s.Lock() + defer s.Unlock() // set the default host and port so that these are written to the config // file - i.setDefault(Host, hostDefault) - i.setDefault(Port, portDefault) + s.setDefault(Host, hostDefault) + s.setDefault(Port, portDefault) - i.setDefault(ParallelTasks, parallelTasksDefault) - i.setDefault(SequentialScanning, SequentialScanningDefault) - i.setDefault(PreviewSegmentDuration, previewSegmentDurationDefault) - i.setDefault(PreviewSegments, previewSegmentsDefault) - i.setDefault(PreviewExcludeStart, previewExcludeStartDefault) - i.setDefault(PreviewExcludeEnd, previewExcludeEndDefault) - i.setDefault(PreviewAudio, previewAudioDefault) - i.setDefault(SoundOnPreview, false) + s.setDefault(ParallelTasks, parallelTasksDefault) + s.setDefault(SequentialScanning, SequentialScanningDefault) + s.setDefault(PreviewSegmentDuration, previewSegmentDurationDefault) + s.setDefault(PreviewSegments, previewSegmentsDefault) + s.setDefault(PreviewExcludeStart, previewExcludeStartDefault) + s.setDefault(PreviewExcludeEnd, previewExcludeEndDefault) + s.setDefault(PreviewAudio, previewAudioDefault) + s.setDefault(SoundOnPreview, false) - i.setDefault(ThemeColor, DefaultThemeColor) + s.setDefault(ThemeColor, DefaultThemeColor) - i.setDefault(WriteImageThumbnails, writeImageThumbnailsDefault) - i.setDefault(CreateImageClipsFromVideos, createImageClipsFromVideosDefault) + s.setDefault(WriteImageThumbnails, writeImageThumbnailsDefault) + s.setDefault(CreateImageClipsFromVideos, createImageClipsFromVideosDefault) - i.setDefault(Database, defaultDatabaseFilePath) + s.setDefault(Database, defaultDatabaseFilePath) - i.setDefault(dangerousAllowPublicWithoutAuth, dangerousAllowPublicWithoutAuthDefault) - i.setDefault(SecurityTripwireAccessedFromPublicInternet, securityTripwireAccessedFromPublicInternetDefault) + s.setDefault(dangerousAllowPublicWithoutAuth, dangerousAllowPublicWithoutAuthDefault) + s.setDefault(SecurityTripwireAccessedFromPublicInternet, securityTripwireAccessedFromPublicInternetDefault) // Set generated to the metadata path for backwards compat - i.setDefault(Generated, i.main.String(Metadata)) + s.setDefault(Generated, s.main.String(Metadata)) - i.setDefault(NoBrowser, NoBrowserDefault) - i.setDefault(NotificationsEnabled, NotificationsEnabledDefault) - i.setDefault(ShowOneTimeMovedNotification, ShowOneTimeMovedNotificationDefault) + s.setDefault(NoBrowser, NoBrowserDefault) + s.setDefault(NotificationsEnabled, NotificationsEnabledDefault) + s.setDefault(ShowOneTimeMovedNotification, ShowOneTimeMovedNotificationDefault) // Set default scrapers and plugins paths - i.setDefault(ScrapersPath, defaultScrapersPath) - i.setDefault(PluginsPath, defaultPluginsPath) + s.setDefault(ScrapersPath, defaultScrapersPath) + s.setDefault(PluginsPath, defaultPluginsPath) // Set default gallery cover regex - i.setDefault(GalleryCoverRegex, galleryCoverRegexDefault) + s.setDefault(GalleryCoverRegex, galleryCoverRegexDefault) // Set NoProxy default - i.setDefault(NoProxy, noProxyDefault) + s.setDefault(NoProxy, noProxyDefault) // set default package sources - i.setDefault(PluginPackageSources, []map[string]string{{ + s.setDefault(PluginPackageSources, []map[string]string{{ "name": sourceDefaultName, "url": pluginPackageSourcesDefault, "localpath": sourceDefaultPath, }}) - i.setDefault(ScraperPackageSources, []map[string]string{{ + s.setDefault(ScraperPackageSources, []map[string]string{{ "name": sourceDefaultName, "url": scraperPackageSourcesDefault, "localpath": sourceDefaultPath, @@ -1904,50 +1869,50 @@ func (i *Config) setDefaultValues() { // setExistingSystemDefaults sets config options that are new and unset in an existing install, // but should have a separate default than for brand-new systems, to maintain behavior. // The config file will not be written. -func (i *Config) setExistingSystemDefaults() { - i.Lock() - defer i.Unlock() - if !i.isNewSystem { +func (s *Config) setExistingSystemDefaults() { + s.Lock() + defer s.Unlock() + if !s.isNewSystem { // Existing systems as of the introduction of auto-browser open should retain existing // behavior and not start the browser automatically. - if !i.main.Exists(NoBrowser) { - i.set(NoBrowser, true) + if !s.main.Exists(NoBrowser) { + s.set(NoBrowser, true) } // Existing systems as of the introduction of the taskbar should inform users. - if !i.main.Exists(ShowOneTimeMovedNotification) { - i.set(ShowOneTimeMovedNotification, true) + if !s.main.Exists(ShowOneTimeMovedNotification) { + s.set(ShowOneTimeMovedNotification, true) } } } // SetInitialConfig fills in missing required config fields. The config file will not be written. -func (i *Config) SetInitialConfig() error { +func (s *Config) SetInitialConfig() error { // generate some api keys const apiKeyLength = 32 - if string(i.GetJWTSignKey()) == "" { + if string(s.GetJWTSignKey()) == "" { signKey, err := hash.GenerateRandomKey(apiKeyLength) if err != nil { return fmt.Errorf("error generating JWTSignKey: %w", err) } - i.SetString(JWTSignKey, signKey) + s.SetString(JWTSignKey, signKey) } - if string(i.GetSessionStoreKey()) == "" { + if string(s.GetSessionStoreKey()) == "" { sessionStoreKey, err := hash.GenerateRandomKey(apiKeyLength) if err != nil { return fmt.Errorf("error generating session store key: %w", err) } - i.SetString(SessionStoreKey, sessionStoreKey) + s.SetString(SessionStoreKey, sessionStoreKey) } - i.setDefaultValues() + s.setDefaultValues() return nil } -func (i *Config) FinalizeSetup() { - i.isNewSystem = false +func (s *Config) FinalizeSetup() { + s.isNewSystem = false // i.configUpdates <- 0 } diff --git a/internal/manager/config/config_concurrency_test.go b/internal/manager/config/config_concurrency_test.go index fd9b067c7..55774be61 100644 --- a/internal/manager/config/config_concurrency_test.go +++ b/internal/manager/config/config_concurrency_test.go @@ -22,8 +22,6 @@ func TestConcurrentConfigAccess(t *testing.T) { t.Errorf("Failure setting initial configuration in worker %v iteration %v: %v", wk, l, err) } - i.HasCredentials() - i.ValidateCredentials("", "") i.GetConfigFile() i.GetConfigPath() i.GetDefaultDatabaseFilePath() @@ -75,7 +73,6 @@ func TestConcurrentConfigAccess(t *testing.T) { i.SetInterface(ApiKey, i.GetAPIKey()) i.SetInterface(Username, i.GetUsername()) i.SetInterface(Password, i.GetPasswordHash()) - i.GetCredentials() i.SetInterface(MaxSessionAge, i.GetMaxSessionAge()) i.SetInterface(CustomServedFolders, i.GetCustomServedFolders()) i.SetInterface(LegacyCustomUILocation, i.GetUILocation()) diff --git a/internal/manager/config/init.go b/internal/manager/config/init.go index 840b50b70..2dadda32e 100644 --- a/internal/manager/config/init.go +++ b/internal/manager/config/init.go @@ -62,6 +62,9 @@ func Initialize() (*Config, error) { main: koanf.New("."), overrides: koanf.New("."), } + cfg.UserStore = &UserStore{ + Config: cfg, + } cfg.initOverrides() @@ -96,6 +99,10 @@ func Initialize() (*Config, error) { } } + if err := cfg.UserStore.loadUsers(); err != nil { + return nil, fmt.Errorf("failed to load users: %v", err) + } + instance = cfg return instance, nil } @@ -110,8 +117,8 @@ func InitializeEmpty() *Config { return instance } -func (i *Config) loadFromCommandLine() { - v := i.overrides +func (s *Config) loadFromCommandLine() { + v := s.overrides if err := v.Load(posflag.ProviderWithFlag(pflag.CommandLine, ".", v, func(f *pflag.Flag) (string, interface{}) { // ignore flags that have not been changed @@ -125,8 +132,8 @@ func (i *Config) loadFromCommandLine() { } } -func (i *Config) loadFromEnv() { - v := i.overrides +func (s *Config) loadFromEnv() { + v := s.overrides if err := v.Load(env.ProviderWithValue("STASH_", ".", func(key, value string) (string, interface{}) { key = strings.ToLower(strings.TrimPrefix(key, "STASH_")) @@ -140,12 +147,12 @@ func (i *Config) loadFromEnv() { } } -func (i *Config) initOverrides() { - i.loadFromCommandLine() - i.loadFromEnv() +func (s *Config) initOverrides() { + s.loadFromCommandLine() + s.loadFromEnv() } -func (i *Config) initConfig() error { +func (s *Config) initConfig() error { configFile := "" envConfigFile := os.Getenv("STASH_CONFIG_FILE") @@ -158,8 +165,8 @@ func (i *Config) initConfig() error { if configFile != "" { // if file does not exist, assume it is a new system if exists, _ := fsutil.FileExists(configFile); !exists { - i.isNewSystem = true - i.SetConfigFile(configFile) + s.isNewSystem = true + s.SetConfigFile(configFile) // ensure we can write to the file if err := fsutil.Touch(configFile); err != nil { @@ -172,15 +179,15 @@ func (i *Config) initConfig() error { return nil } else { // load from provided config file - if err := i.loadFirstFromFiles([]string{configFile}); err != nil { + if err := s.loadFirstFromFiles([]string{configFile}); err != nil { return err } } } else { // load from default locations - if err := i.loadFirstFromFiles(defaultConfigLocations); err != nil { + if err := s.loadFirstFromFiles(defaultConfigLocations); err != nil { if errors.Is(err, errConfigNotFound) { - i.isNewSystem = true + s.isNewSystem = true return nil } @@ -191,10 +198,10 @@ func (i *Config) initConfig() error { return nil } -func (i *Config) loadFirstFromFiles(f []string) error { +func (s *Config) loadFirstFromFiles(f []string) error { for _, ff := range f { if exists, _ := fsutil.FileExists(ff); exists { - return i.load(ff) + return s.load(ff) } } diff --git a/internal/manager/config/users.go b/internal/manager/config/users.go new file mode 100644 index 000000000..c3d081a71 --- /dev/null +++ b/internal/manager/config/users.go @@ -0,0 +1,252 @@ +package config + +import ( + "context" + "fmt" + + "github.com/stashapp/stash/pkg/models" + "golang.org/x/crypto/bcrypt" +) + +const ( + Username = "username" + Password = "password" + Users = "users" + Roles = "roles" +) + +type StoredUser struct { + Username string `json:"username" koanf:"username"` + PasswordHash string `json:"passwordhash" koanf:"passwordhash"` + Roles []models.RoleEnum `json:"roles" koanf:"roles"` + ApiKey string `json:"api_key" koanf:"api_key"` +} + +type UserStore struct { + *Config + + cachedUsers map[string]StoredUser +} + +func (s *Config) GetUsername() string { + return s.getString(Username) +} + +func (s *Config) GetPasswordHash() string { + return s.getString(Password) +} + +func (s *UserStore) legacyUser() *StoredUser { + un := s.getString(Username) + pwHash := s.getString(Password) + apiKey := s.getString(ApiKey) + + if un != "" && pwHash != "" { + return &StoredUser{ + Username: un, + PasswordHash: pwHash, + Roles: []models.RoleEnum{models.RoleEnumAdmin}, + ApiKey: apiKey, + } + } + + return nil +} + +func (s *UserStore) loadUsers() error { + // done outside lock to avoid deadlock + legacyUser := s.legacyUser() + + s.RLock() + defer s.RUnlock() + + var ret []*StoredUser + err := s.unmarshalKey(Users, &ret) + if err != nil { + return err + } + + // add legacy username + if legacyUser != nil { + ret = append(ret, legacyUser) + } + + s.cachedUsers = make(map[string]StoredUser) + for _, u := range ret { + s.cachedUsers[u.Username] = *u + } + + return nil +} + +func (s *UserStore) convertUser(su StoredUser) *models.User { + return &models.User{ + Username: su.Username, + Roles: su.Roles, + ApiKey: su.ApiKey, + } +} + +func (s *UserStore) getUser(username string) *StoredUser { + u, ok := s.cachedUsers[username] + if !ok { + return nil + } + + return &u +} + +func (s *UserStore) GetUser(ctx context.Context, username string) (*models.User, error) { + s.RLock() + defer s.RUnlock() + + u := s.getUser(username) + if u == nil { + return nil, nil + } + + return s.convertUser(*u), nil +} + +func (s *UserStore) AllUsers(ctx context.Context) ([]*models.User, error) { + var users []*models.User + + s.RLock() + defer s.RUnlock() + + for _, su := range s.cachedUsers { + users = append(users, s.convertUser(su)) + } + + return users, nil +} + +func (s *UserStore) LoginRequired(ctx context.Context) bool { + return len(s.cachedUsers) > 0 +} + +func hashPassword(password string) string { + hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost) + + return string(hash) +} + +func (s *UserStore) ValidateCredentials(ctx context.Context, username string, password string) bool { + u := s.getUser(username) + if u == nil { + return false + } + + err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) + + return err == nil +} + +func (s *UserStore) saveUsers() error { + // convert to list + users := make([]StoredUser, 0, len(s.cachedUsers)) + for _, u := range s.cachedUsers { + users = append(users, u) + } + + s.setInterfaceNoLock(Users, users) + return s.writeNoLock() +} + +func (s *UserStore) ChangeUserPassword(ctx context.Context, username string, newPassword string) error { + s.Lock() + defer s.Unlock() + + u := s.getUser(username) + if u == nil { + return fmt.Errorf("user not found") + } + + newHash := hashPassword(newPassword) + + updatedUser := *u + updatedUser.PasswordHash = newHash + s.cachedUsers[username] = updatedUser + + return s.saveUsers() +} + +func (s *UserStore) ChangeUserAPIKey(ctx context.Context, username string, newAPIKey string) error { + s.Lock() + defer s.Unlock() + + u := s.getUser(username) + if u == nil { + return fmt.Errorf("user not found") + } + + updatedUser := *u + updatedUser.ApiKey = newAPIKey + s.cachedUsers[username] = updatedUser + + return s.saveUsers() +} + +func (s *UserStore) CreateUser(ctx context.Context, u models.User, password string) error { + s.Lock() + defer s.Unlock() + + existingUser := s.getUser(u.Username) + if existingUser != nil { + return fmt.Errorf("user already exists") + } + + newUser := StoredUser{ + Username: u.Username, + PasswordHash: hashPassword(password), + Roles: u.Roles, + ApiKey: u.ApiKey, + } + + s.cachedUsers[u.Username] = newUser + + return s.saveUsers() +} + +// ReplaceUser replaces an existing user with updated information. +// ApiKey is ignored and not changed by this method. +func (s *UserStore) ReplaceUser(ctx context.Context, username string, updated models.User) error { + s.Lock() + defer s.Unlock() + + existingUser := s.getUser(username) + if existingUser == nil { + return fmt.Errorf("user not found") + } + + updatedUser := StoredUser{ + Username: updated.Username, + PasswordHash: existingUser.PasswordHash, + Roles: updated.Roles, + // don't allow changing apikey with this method + ApiKey: existingUser.ApiKey, + } + + // if username changed, remove old entry + if username != updated.Username { + delete(s.cachedUsers, username) + } + + s.cachedUsers[updated.Username] = updatedUser + + return s.saveUsers() +} + +func (s *UserStore) DeleteUser(ctx context.Context, username string) error { + s.Lock() + defer s.Unlock() + + existingUser := s.getUser(username) + if existingUser == nil { + return fmt.Errorf("user not found") + } + + delete(s.cachedUsers, username) + + return s.saveUsers() +} diff --git a/internal/manager/init.go b/internal/manager/init.go index b4af5eab7..e1cfb3bbf 100644 --- a/internal/manager/init.go +++ b/internal/manager/init.go @@ -27,6 +27,7 @@ import ( "github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/session" "github.com/stashapp/stash/pkg/sqlite" + "github.com/stashapp/stash/pkg/user" "github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/ui" ) @@ -109,6 +110,10 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) { scanSubs: &subscriptionManager{}, } + mgr.UserService = &user.Service{ + Store: cfg.UserStore, + } + if !cfg.IsNewSystem() { logger.Infof("using config file: %s", cfg.GetConfigFile()) @@ -130,7 +135,7 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) { // create temporary session store - this will be re-initialised // after config is complete - mgr.SessionStore = session.NewStore(cfg) + mgr.SessionStore = session.NewStore(cfg, instance.UserService) logger.Warnf("config file %snot found. Assuming new system...", cfgFile) } @@ -189,7 +194,7 @@ func initJobManager(cfg *config.Config) *job.Manager { func (s *Manager) postInit(ctx context.Context) error { s.RefreshConfig() - s.SessionStore = session.NewStore(s.Config) + s.SessionStore = session.NewStore(s.Config, s.UserService) s.PluginCache.RegisterSessionStore(s.SessionStore) s.RefreshPluginCache() @@ -251,7 +256,7 @@ func (s *Manager) postInit(ctx context.Context) error { } func (s *Manager) checkSecurityTripwire() { - if err := session.CheckExternalAccessTripwire(s.Config); err != nil { + if err := session.CheckExternalAccessTripwire(s.Config.UserStore, s.Config); err != nil { session.LogExternalAccessError(*err) } } diff --git a/internal/manager/manager.go b/internal/manager/manager.go index f4f3fa636..8b261d889 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -67,6 +67,7 @@ type Manager struct { ImageService ImageService GalleryService GalleryService GroupService GroupService + UserService UserService scanSubs *subscriptionManager } diff --git a/internal/manager/repository.go b/internal/manager/repository.go index e51e737ee..866a79787 100644 --- a/internal/manager/repository.go +++ b/internal/manager/repository.go @@ -7,6 +7,7 @@ import ( "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" + "github.com/stashapp/stash/pkg/session" ) type SceneService interface { @@ -46,3 +47,20 @@ type GroupService interface { RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error } + +type UserService interface { + session.Authenticator + AllUsers(ctx context.Context) ([]*models.User, error) + GetUser(ctx context.Context, username string) (*models.User, error) + LoginRequired(ctx context.Context) bool + AuthenticateByAPIKey(ctx context.Context, apiKey string) (*models.User, error) + AuthenticateUserByID(ctx context.Context, username string) (*models.User, error) + + CreateUser(ctx context.Context, u models.User, password string) error + UpdateUser(ctx context.Context, username string, updated models.User) error + ChangePassword(ctx context.Context, username, existingPassword, newPassword string) error + ChangeUserPassword(ctx context.Context, username string, newPassword string) error + GenerateAPIKey(ctx context.Context, username string) (string, error) + ClearAPIKey(ctx context.Context, username string) error + DeleteUser(ctx context.Context, username string) error +} diff --git a/pkg/models/model_user.go b/pkg/models/model_user.go new file mode 100644 index 000000000..fb683656c --- /dev/null +++ b/pkg/models/model_user.go @@ -0,0 +1,13 @@ +package models + +type User struct { + Username string + Roles Roles + ApiKey string +} + +type UserInput struct { + Username string + Roles Roles + Password string +} diff --git a/pkg/models/role.go b/pkg/models/role.go new file mode 100644 index 000000000..52a162a63 --- /dev/null +++ b/pkg/models/role.go @@ -0,0 +1,70 @@ +package models + +import ( + "fmt" + "io" + "strconv" +) + +type RoleEnum string + +const ( + RoleEnumAdmin RoleEnum = "ADMIN" + RoleEnumRead RoleEnum = "READ" + RoleEnumModify RoleEnum = "MODIFY" +) + +func (e RoleEnum) Implies(other RoleEnum) bool { + // admin has all roles + if e == RoleEnumAdmin { + return true + } + + // until we add a NONE value, all values imply read + if e.IsValid() && other == RoleEnumRead { + return true + } + + // all others only imply themselves + return e == other +} + +func (e RoleEnum) IsValid() bool { + switch e { + case RoleEnumRead, RoleEnumModify, RoleEnumAdmin: + return true + } + return false +} + +func (e RoleEnum) String() string { + return string(e) +} + +func (e *RoleEnum) UnmarshalGQL(v interface{}) error { + str, ok := v.(string) + if !ok { + return fmt.Errorf("enums must be strings") + } + + *e = RoleEnum(str) + if !e.IsValid() { + return fmt.Errorf("%s is not a valid RoleEnum", str) + } + return nil +} + +func (e RoleEnum) MarshalGQL(w io.Writer) { + fmt.Fprint(w, strconv.Quote(e.String())) +} + +type Roles []RoleEnum + +func (r Roles) HasRole(role RoleEnum) bool { + for _, r := range r { + if r.Implies(role) { + return true + } + } + return false +} diff --git a/pkg/session/authentication.go b/pkg/session/authentication.go index 95c41baa5..eb6878957 100644 --- a/pkg/session/authentication.go +++ b/pkg/session/authentication.go @@ -1,6 +1,7 @@ package session import ( + "context" "fmt" "net" "net/http" @@ -15,8 +16,8 @@ func (e ExternalAccessError) Error() string { return fmt.Sprintf("stash accessed from external IP %s", net.IP(e).String()) } -func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error { - if !c.HasCredentials() && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() { +func CheckAllowPublicWithoutAuth(s CredentialStore, c ExternalAccessConfig, r *http.Request) error { + if hc := s.LoginRequired(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() { requestIPString, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return fmt.Errorf("error parsing remote host (%s): %w", r.RemoteAddr, err) @@ -59,8 +60,8 @@ func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error return nil } -func CheckExternalAccessTripwire(c ExternalAccessConfig) *ExternalAccessError { - if !c.HasCredentials() && !c.GetDangerousAllowPublicWithoutAuth() { +func CheckExternalAccessTripwire(s CredentialStore, c ExternalAccessConfig) *ExternalAccessError { + if hc := s.LoginRequired(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() { if remoteIP := c.GetSecurityTripwireAccessedFromPublicInternet(); remoteIP != "" { err := ExternalAccessError(net.ParseIP(remoteIP)) return &err diff --git a/pkg/session/authentication_test.go b/pkg/session/authentication_test.go index ac6383f24..005e7680f 100644 --- a/pkg/session/authentication_test.go +++ b/pkg/session/authentication_test.go @@ -1,6 +1,7 @@ package session import ( + "context" "errors" "net/http" "testing" @@ -13,7 +14,7 @@ type config struct { securityTripwireAccessedFromPublicInternet string } -func (c *config) HasCredentials() bool { +func (c *config) LoginRequired(ctx context.Context) bool { return c.username != "" && c.password != "" } @@ -34,7 +35,7 @@ func TestCheckAllowPublicWithoutAuth(t *testing.T) { doTest := func(caseIndex int, r *http.Request, expectedErr interface{}) { t.Helper() - err := CheckAllowPublicWithoutAuth(c, r) + err := CheckAllowPublicWithoutAuth(c, c, r) if expectedErr == nil && err == nil { return @@ -120,7 +121,7 @@ func TestCheckAllowPublicWithoutAuth(t *testing.T) { RemoteAddr: remoteAddr, } - err := CheckAllowPublicWithoutAuth(c, r) + err := CheckAllowPublicWithoutAuth(c, c, r) if err == nil { t.Errorf("[%s]: expected error", remoteAddr) continue @@ -137,7 +138,7 @@ func TestCheckAllowPublicWithoutAuth(t *testing.T) { c.username = "admin" c.password = "admin" - if err := CheckAllowPublicWithoutAuth(c, r); err != nil { + if err := CheckAllowPublicWithoutAuth(c, c, r); err != nil { t.Errorf("unexpected error: %v", err) } @@ -146,7 +147,7 @@ func TestCheckAllowPublicWithoutAuth(t *testing.T) { c.dangerousAllowPublicWithoutAuth = true - if err := CheckAllowPublicWithoutAuth(c, r); err != nil { + if err := CheckAllowPublicWithoutAuth(c, c, r); err != nil { t.Errorf("unexpected error: %v", err) } } @@ -160,7 +161,7 @@ func TestCheckExternalAccessTripwire(t *testing.T) { c.username = "admin" c.password = "admin" - if err := CheckExternalAccessTripwire(c); err != nil { + if err := CheckExternalAccessTripwire(c, c); err != nil { t.Errorf("unexpected error %v", err) } @@ -170,19 +171,19 @@ func TestCheckExternalAccessTripwire(t *testing.T) { // HACK - this key isn't publically exposed c.dangerousAllowPublicWithoutAuth = true - if err := CheckExternalAccessTripwire(c); err != nil { + if err := CheckExternalAccessTripwire(c, c); err != nil { t.Errorf("unexpected error %v", err) } c.dangerousAllowPublicWithoutAuth = false - if err := CheckExternalAccessTripwire(c); err == nil { + if err := CheckExternalAccessTripwire(c, c); err == nil { t.Errorf("expected error %v", ExternalAccessError("4.4.4.4")) } c.securityTripwireAccessedFromPublicInternet = "" - if err := CheckExternalAccessTripwire(c); err != nil { + if err := CheckExternalAccessTripwire(c, c); err != nil { t.Errorf("unexpected error %v", err) } } diff --git a/pkg/session/config.go b/pkg/session/config.go index 0bd584c51..98a92cd97 100644 --- a/pkg/session/config.go +++ b/pkg/session/config.go @@ -1,17 +1,21 @@ package session +import "context" + type ExternalAccessConfig interface { - HasCredentials() bool GetDangerousAllowPublicWithoutAuth() bool GetSecurityTripwireAccessedFromPublicInternet() string IsNewSystem() bool } +type CredentialStore interface { + LoginRequired(ctx context.Context) bool +} + type SessionConfig interface { GetUsername() string GetAPIKey() string GetSessionStoreKey() []byte GetMaxSessionAge() int - ValidateCredentials(username string, password string) bool } diff --git a/pkg/session/plugin.go b/pkg/session/plugin.go index 7a57ca4b5..22d988072 100644 --- a/pkg/session/plugin.go +++ b/pkg/session/plugin.go @@ -61,7 +61,7 @@ func setVisitedPluginHooks(ctx context.Context, visitedPlugins []VisitedPluginHo } func (s *Store) MakePluginCookie(ctx context.Context) *http.Cookie { - currentUser := GetCurrentUserID(ctx) + currentUser := GetCurrentUser(ctx) visitedPlugins := GetVisitedPluginHooks(ctx) session := sessions.NewSession(s.sessionStore, cookieName) diff --git a/pkg/session/session.go b/pkg/session/session.go index 66cb39e09..dae7ea44f 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -8,6 +8,7 @@ import ( "github.com/gorilla/sessions" "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/models" ) type key int @@ -44,15 +45,22 @@ func (e InvalidCredentialsError) Error() string { var ErrUnauthorized = errors.New("unauthorized") -type Store struct { - sessionStore *sessions.CookieStore - config SessionConfig +type Authenticator interface { + LoginRequired(ctx context.Context) bool + ValidateCredentials(ctx context.Context, username string, password string) error } -func NewStore(c SessionConfig) *Store { +type Store struct { + sessionStore *sessions.CookieStore + authenticator Authenticator + config SessionConfig +} + +func NewStore(c SessionConfig, a Authenticator) *Store { ret := &Store{ - sessionStore: sessions.NewCookieStore(c.GetSessionStoreKey()), - config: c, + sessionStore: sessions.NewCookieStore(c.GetSessionStoreKey()), + config: c, + authenticator: a, } ret.sessionStore.MaxAge(c.GetMaxSessionAge()) @@ -61,6 +69,10 @@ func NewStore(c SessionConfig) *Store { return ret } +func (s *Store) LoginRequired(ctx context.Context) bool { + return s.authenticator.LoginRequired(ctx) +} + func (s *Store) Login(w http.ResponseWriter, r *http.Request) error { // ignore error - we want a new session regardless newSession, _ := s.sessionStore.Get(r, cookieName) @@ -69,16 +81,16 @@ func (s *Store) Login(w http.ResponseWriter, r *http.Request) error { password := r.FormValue(passwordFormKey) // authenticate the user - if !s.config.ValidateCredentials(username, password) { + err := s.authenticator.ValidateCredentials(r.Context(), username, password) + if err != nil { return &InvalidCredentialsError{Username: username} } - // since we only have one user, don't leak the name - logger.Info("User logged in") + logger.Infof("User %s logged in", username) newSession.Values[userIDKey] = username - err := newSession.Save(r, w) + err = newSession.Save(r, w) if err != nil { return err } @@ -92,6 +104,8 @@ func (s *Store) Logout(w http.ResponseWriter, r *http.Request) error { return err } + userID, _ := session.Values[userIDKey].(string) + delete(session.Values, userIDKey) session.Options.MaxAge = -1 @@ -100,8 +114,7 @@ func (s *Store) Logout(w http.ResponseWriter, r *http.Request) error { return err } - // since we only have one user, don't leak the name - logger.Infof("User logged out") + logger.Infof("User %s logged out", userID) return nil } @@ -131,25 +144,22 @@ func (s *Store) GetSessionUserID(w http.ResponseWriter, r *http.Request) (string return "", nil } -func SetCurrentUserID(ctx context.Context, userID string) context.Context { - return context.WithValue(ctx, contextUser, userID) +func SetCurrentUser(ctx context.Context, u models.User) context.Context { + return context.WithValue(ctx, contextUser, u) } -// GetCurrentUserID gets the current user id from the provided context -func GetCurrentUserID(ctx context.Context) *string { +// GetCurrentUser gets the current user id from the provided context +func GetCurrentUser(ctx context.Context) *models.User { userCtxVal := ctx.Value(contextUser) if userCtxVal != nil { - currentUser := userCtxVal.(string) + currentUser := userCtxVal.(models.User) return ¤tUser } return nil } -func (s *Store) Authenticate(w http.ResponseWriter, r *http.Request) (userID string, err error) { - c := s.config - - // translate api key into current user, if present +func GetRequestApiKey(r *http.Request) string { apiKey := r.Header.Get(ApiKeyHeader) // try getting the api key as a query parameter @@ -157,23 +167,5 @@ func (s *Store) Authenticate(w http.ResponseWriter, r *http.Request) (userID str apiKey = r.URL.Query().Get(ApiKeyParameter) } - if apiKey != "" { - // match against configured API and set userID to the - // configured username. In future, we'll want to - // get the username from the key. - if c.GetAPIKey() != apiKey { - return "", ErrUnauthorized - } - - userID = c.GetUsername() - } else { - // handle session - userID, err = s.GetSessionUserID(w, r) - } - - if err != nil { - return "", err - } - - return + return apiKey } diff --git a/internal/manager/apikey.go b/pkg/user/apikey.go similarity index 93% rename from internal/manager/apikey.go rename to pkg/user/apikey.go index 7bd3126fa..26aaae1db 100644 --- a/internal/manager/apikey.go +++ b/pkg/user/apikey.go @@ -1,4 +1,4 @@ -package manager +package user import ( "errors" @@ -17,7 +17,7 @@ type APIKeyClaims struct { jwt.RegisteredClaims } -func GenerateAPIKey(userID string) (string, error) { +func generateAPIKey(userID string) (string, error) { claims := &APIKeyClaims{ UserID: userID, RegisteredClaims: jwt.RegisteredClaims{ diff --git a/pkg/user/authenticate.go b/pkg/user/authenticate.go new file mode 100644 index 000000000..a00006b65 --- /dev/null +++ b/pkg/user/authenticate.go @@ -0,0 +1 @@ +package user diff --git a/pkg/user/service.go b/pkg/user/service.go new file mode 100644 index 000000000..fbfa12fd0 --- /dev/null +++ b/pkg/user/service.go @@ -0,0 +1,397 @@ +package user + +import ( + "context" + "errors" + "fmt" + "slices" + "strings" + + "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/models" +) + +var ( + ErrUserNotExist = errors.New("user not found") + ErrEmptyUsername = errors.New("empty username") + ErrUsernameHasWhitespace = errors.New("username has leading or trailing whitespace") + ErrDeleteLastAdminUser = errors.New("final admin user cannot be deleted") + ErrRemoveLastAdminRole = errors.New("final admin role cannot be removed") + ErrInternalError = errors.New("internal error") + ErrAccessDenied = errors.New("access denied") + ErrCurrentPasswordIncorrect = errors.New("current password incorrect") + ErrUserAlreadyExists = errors.New("user with that username already exists") +) + +type UserSource interface { + AllUsers(ctx context.Context) ([]*models.User, error) + GetUser(ctx context.Context, username string) (*models.User, error) + ValidateCredentials(ctx context.Context, username string, password string) bool + + CreateUser(ctx context.Context, u models.User, password string) error + ReplaceUser(ctx context.Context, username string, updated models.User) error + ChangeUserPassword(ctx context.Context, username string, newPassword string) error + ChangeUserAPIKey(ctx context.Context, username string, newAPIKey string) error + DeleteUser(ctx context.Context, username string) error +} + +type Service struct { + Store UserSource +} + +func (s *Service) LoginRequired(ctx context.Context) bool { + u, _ := s.Store.AllUsers(ctx) + return len(u) > 0 +} + +func (s *Service) GetUser(ctx context.Context, username string) (*models.User, error) { + return s.Store.GetUser(ctx, username) +} + +func (s *Service) AllUsers(ctx context.Context) ([]*models.User, error) { + return s.Store.AllUsers(ctx) +} + +func userIsLocked(u *models.User) bool { + return len(u.Roles) == 0 +} + +func (s *Service) ValidateCredentials(ctx context.Context, username string, password string) error { + // ensure user is not locked + u, err := s.GetUser(ctx, username) + if err != nil { + logger.Errorf("error getting user for credential validation: %v", err) + return ErrInternalError + } + + if u == nil { + logger.Infof("[login attempt] user %s not found during credential validation", username) + return ErrAccessDenied + } + + if userIsLocked(u) { + logger.Infof("[login attempt] user %s is locked", username) + return ErrAccessDenied + } + + if !s.Store.ValidateCredentials(ctx, username, password) { + logger.Infof("[login attempt] invalid credentials for user %s", username) + return ErrAccessDenied + } + return nil +} + +// AuthenticateUserByID authenticates a user by their username and returns the user object if successful. +// This is used for session-based authentication. +// It will return an error if the user does not exist or if the user is locked. +func (s *Service) AuthenticateUserByID(ctx context.Context, username string) (*models.User, error) { + u, err := s.GetUser(ctx, username) + if err != nil { + logger.Errorf("error getting user for authentication: %v", err) + return nil, ErrInternalError + } + + if u == nil { + logger.Infof("[authentication] user %s not found", username) + return nil, ErrAccessDenied + } + + if userIsLocked(u) { + logger.Infof("[authentication] user %s is locked", username) + return nil, ErrAccessDenied + } + + return u, nil +} + +func (s *Service) AuthenticateByAPIKey(ctx context.Context, apiKey string) (*models.User, error) { + username, err := GetUserIDFromAPIKey(apiKey) + if err != nil { + logger.Errorf("error getting user ID from api key: %v", err) + return nil, ErrInternalError + } + + user, err := s.GetUser(ctx, username) + if err != nil { + logger.Errorf("error getting user by username: %v", err) + return nil, ErrInternalError + } + + if user == nil { + logger.Infof("[apikey authentication] user %s not found", username) + return nil, ErrAccessDenied + } + + if userIsLocked(user) { + logger.Infof("[apikey authentication] user %s is locked", username) + return nil, ErrAccessDenied + } + + // ensure apikey matches + if user.ApiKey != apiKey { + logger.Infof("[apikey authentication] invalid api key for user %s", username) + return nil, ErrAccessDenied + } + + return user, nil +} + +func (s *Service) validateUsername(username string) error { + if username == "" { + return ErrEmptyUsername + } + + // username must not have leading or trailing whitespace + trimmed := strings.TrimSpace(username) + + if trimmed != username { + return ErrUsernameHasWhitespace + } + + return nil +} + +func (s *Service) validatePassword(password string) error { + if password == "" { + return errors.New("password cannot be empty") + } + + // add more password validation as needed + + return nil +} + +func (s *Service) CreateUser(ctx context.Context, u models.User, password string) error { + // validate input + // ensure username is valid + if err := s.validateUsername(u.Username); err != nil { + return err + } + + // check if user exists + existingUser, err := s.GetUser(ctx, u.Username) + if err != nil { + return fmt.Errorf("error checking existing users: %w", err) + } + + if existingUser != nil { + return ErrUserAlreadyExists + } + + // validate password + if err := s.validatePassword(password); err != nil { + return err + } + + // if this is the first user, make them an admin + users, err := s.AllUsers(ctx) + if err != nil { + return fmt.Errorf("error getting existing users: %w", err) + } + + if len(users) == 0 && !u.Roles.HasRole(models.RoleEnumAdmin) { + return errors.New("the first user must be an admin") + } + + // create user in store + if err := s.Store.CreateUser(ctx, u, password); err != nil { + return fmt.Errorf("error creating user: %w", err) + } + + logger.Infof("[user] created %q", u.Username) + + return nil +} + +func (s *Service) UpdateUser(ctx context.Context, username string, updated models.User) error { + // validate input + // check if user exists + existingUser, err := s.GetUser(ctx, username) + if err != nil { + return fmt.Errorf("error getting existing user: %w", err) + } + + if existingUser == nil { + return ErrUserNotExist + } + + existingRoles := existingUser.Roles + + // ensure username is valid + if username != updated.Username { + if err := s.validateUsername(updated.Username); err != nil { + return err + } + + // ensure new username doesn't already exist + otherUser, err := s.GetUser(ctx, updated.Username) + if err != nil { + return fmt.Errorf("error checking existing user: %w", err) + } + + if otherUser != nil { + return ErrUserAlreadyExists + } + } + + // validate roles + // don't allow removing admin from last admin user + if existingRoles.HasRole(models.RoleEnumAdmin) && !updated.Roles.HasRole(models.RoleEnumAdmin) { + users, err := s.AllUsers(ctx) + if err != nil { + return fmt.Errorf("error getting all users: %w", err) + } + + hasAdmin := false + for _, u := range users { + if u.Username != existingUser.Username && u.Roles.HasRole(models.RoleEnumAdmin) { + hasAdmin = true + break + } + } + + if !hasAdmin { + return ErrRemoveLastAdminRole + } + } + + // update user in store + if err := s.Store.ReplaceUser(ctx, username, updated); err != nil { + return fmt.Errorf("error updating user: %w", err) + } + + if username != updated.Username { + logger.Infof("[user] updated name %q -> %q", username, updated.Username) + } + + if !slices.Equal(existingRoles, updated.Roles) { + logger.Infof("[user] updated roles for user %q", updated.Username) + } + + return nil +} + +func (s *Service) ChangePassword(ctx context.Context, username, currentPassword, newPassword string) error { + // validate current credentials + if err := s.ValidateCredentials(ctx, username, currentPassword); err != nil { + logger.Infof("[user] failed password change attempt for %q: incorrect current password", username) + return ErrCurrentPasswordIncorrect + } + + return s.ChangeUserPassword(ctx, username, newPassword) +} + +func (s *Service) ChangeUserPassword(ctx context.Context, username, newPassword string) error { + // check if user exists + existingUser, err := s.GetUser(ctx, username) + if err != nil { + return fmt.Errorf("error getting existing user: %w", err) + } + + if existingUser == nil { + return ErrUserNotExist + } + + // validate new password + if err := s.validatePassword(newPassword); err != nil { + return err + } + + // change password in store + if err := s.Store.ChangeUserPassword(ctx, username, newPassword); err != nil { + return fmt.Errorf("error changing user password: %w", err) + } + + logger.Infof("[user] changed password for %q", username) + + return nil +} + +func (s *Service) GenerateAPIKey(ctx context.Context, username string) (string, error) { + // check if user exists + existingUser, err := s.GetUser(ctx, username) + if err != nil { + return "", fmt.Errorf("error getting existing user: %w", err) + } + + if existingUser == nil { + return "", ErrUserNotExist + } + + // generate new api key + newAPIKey, err := generateAPIKey(username) + if err != nil { + return "", fmt.Errorf("error generating api key: %w", err) + } + + if err := s.Store.ChangeUserAPIKey(ctx, username, newAPIKey); err != nil { + return "", fmt.Errorf("error updating user with new api key: %w", err) + } + + logger.Infof("[user] generated new API key for %q", username) + + return newAPIKey, nil +} + +func (s *Service) ClearAPIKey(ctx context.Context, username string) error { + // check if user exists + existingUser, err := s.GetUser(ctx, username) + if err != nil { + return fmt.Errorf("error getting existing user: %w", err) + } + + if existingUser == nil { + return ErrUserNotExist + } + + // clear api key + if err := s.Store.ChangeUserAPIKey(ctx, username, ""); err != nil { + return fmt.Errorf("error clearing user api key: %w", err) + } + + logger.Infof("[user] cleared API key for %q", username) + + return nil +} + +func (s *Service) DeleteUser(ctx context.Context, username string) error { + // check if user exists + existingUser, err := s.GetUser(ctx, username) + if err != nil { + return fmt.Errorf("error getting existing user: %w", err) + } + + if existingUser == nil { + return ErrUserNotExist + } + + // don't allow deleting last admin user unless it is the last user + if existingUser.Roles.HasRole(models.RoleEnumAdmin) { + users, err := s.AllUsers(ctx) + if err != nil { + return fmt.Errorf("error getting all users: %w", err) + } + + hasAdmin := false + for _, u := range users { + if u.Username != username && u.Roles.HasRole(models.RoleEnumAdmin) { + hasAdmin = true + break + } + } + + // allow deleting last admin if it is the only user + if !hasAdmin && len(users) > 1 { + return ErrDeleteLastAdminUser + } + } + + // delete user from store + if err := s.Store.DeleteUser(ctx, username); err != nil { + return fmt.Errorf("error deleting user: %w", err) + } + + logger.Infof("[user] deleted %q", username) + + return nil +}