Implement config file based user management

This commit is contained in:
WithoutPants 2026-01-17 23:21:06 +11:00
parent 2a2351fcdc
commit aa7107a242
27 changed files with 751 additions and 213 deletions

View file

@ -126,7 +126,7 @@ models:
ScraperSource:
model: github.com/stashapp/stash/pkg/scraper.Source
RoleEnum:
model: github.com/stashapp/stash/pkg/user.RoleEnum
model: github.com/stashapp/stash/pkg/models.RoleEnum
IdentifySourceInput:
model: github.com/stashapp/stash/internal/identify.Source
IdentifyFieldOptionsInput:

View file

@ -269,6 +269,7 @@ type Query {
allStudios: [Studio!]! @deprecated(reason: "Use findStudios instead")
allMovies: [Movie!]! @deprecated(reason: "Use findGroups instead")
users: [User!]! @hasRole(role: ADMIN)
"""Returns currently authenticated user"""
me: User
@ -593,6 +594,7 @@ type Mutation {
userCreate(input: UserCreateInput!): User @hasRole(role: ADMIN)
userUpdate(input: UserUpdateInput!): User @hasRole(role: ADMIN)
userDestroy(input: UserDestroyInput!): Boolean! @hasRole(role: ADMIN)
changeUserPassword(input: ChangeUserPasswordInput!): Boolean! @hasRole(role: ADMIN)
changePassword(input: UserChangePasswordInput!): Boolean!
}

View file

@ -9,7 +9,10 @@ directive @isUserOwner on FIELD_DEFINITION
type User {
name: String!
"""Should not be visible to other users"""
"""
If the user has no roles, they are considered locked and cannot log in.
Should not be visible to other users
"""
roles: [RoleEnum!] @isUserOwner
"""Should not be visible to other users"""
api_key: String @isUserOwner
@ -23,19 +26,22 @@ input UserCreateInput {
}
input UserUpdateInput {
name: String
"""Password in plain text"""
password: String
roles: [RoleEnum!]
existingName: String!
name: String!
roles: [RoleEnum!]!
}
input UserDestroyInput {
id: ID!
name: String!
}
input UserChangePasswordInput {
"""Password in plain text"""
existing_password: String
new_password: String!
reset_key: String
existingPassword: String!
newPassword: String!
}
input ChangeUserPasswordInput {
name: String!
newPassword: String!
}

View file

@ -12,8 +12,8 @@ import (
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/session"
"github.com/stashapp/stash/pkg/user"
)
const (
@ -32,16 +32,17 @@ func allowUnauthenticated(r *http.Request) bool {
}
type UserGetter interface {
GetUser(ctx context.Context, username string) (*user.User, error)
GetUser(ctx context.Context, username string) (*models.User, error)
}
func authenticateHandler(g UserGetter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c := config.GetInstance()
s := c.UserStore
// error if external access tripwire activated
if accessErr := session.CheckExternalAccessTripwire(c); accessErr != nil {
if accessErr := session.CheckExternalAccessTripwire(s, c); accessErr != nil {
http.Error(w, tripwireActivatedErrMsg, http.StatusForbidden)
return
}
@ -59,7 +60,9 @@ func authenticateHandler(g UserGetter) func(http.Handler) http.Handler {
return
}
if err := session.CheckAllowPublicWithoutAuth(c, r); err != nil {
ctx := r.Context()
if err := session.CheckAllowPublicWithoutAuth(s, c, r); err != nil {
var accessErr session.ExternalAccessError
if errors.As(err, &accessErr) {
session.LogExternalAccessError(accessErr)
@ -77,11 +80,23 @@ func authenticateHandler(g UserGetter) func(http.Handler) http.Handler {
return
}
ctx := r.Context()
var u *models.User
if userID != "" {
u, err = g.GetUser(ctx, userID)
if err != nil {
// if we can't get the user object, we just return a forbidden error
logger.Errorf("Error getting user object: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if u == nil {
logger.Errorf("[User] cookie user %q not found", userID)
}
}
if hc, _ := c.HasCredentials(ctx); hc {
if hc := s.LoginRequired(ctx); hc {
// authentication is required
if userID == "" && !allowUnauthenticated(r) {
if u == nil && !allowUnauthenticated(r) {
// if graphql or a non-webpage was requested, we just return a forbidden error
ext := path.Ext(r.URL.Path)
if r.URL.Path == gqlEndpoint || (ext != "" && ext != ".html") {
@ -108,16 +123,8 @@ func authenticateHandler(g UserGetter) func(http.Handler) http.Handler {
}
}
if userID != "" {
if u != nil {
// set the user object in the context
u, err := g.GetUser(ctx, userID)
if err != nil {
// if we can't get the user object, we just return a forbidden error
logger.Errorf("Error getting user object: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
ctx = session.SetCurrentUser(ctx, *u)
}

View file

@ -4,11 +4,11 @@ import (
"context"
"github.com/99designs/gqlgen/graphql"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/session"
"github.com/stashapp/stash/pkg/user"
)
func HasRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role user.RoleEnum) (interface{}, error) {
func HasRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role models.RoleEnum) (interface{}, error) {
currentUser := session.GetCurrentUser(ctx)
// if there is no current user, this is an anonymous request
@ -17,7 +17,29 @@ func HasRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolve
return next(ctx)
}
if currentUser != nil && !user.IsRole(currentUser.Roles, role) {
if currentUser != nil && !currentUser.Roles.HasRole(role) {
return nil, session.ErrUnauthorized
}
return next(ctx)
}
func IsUserOwnerDirective(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) {
currentUser := session.GetCurrentUser(ctx)
// if there is no current user, this is an anonymous request
// we should not end up here unless there are no credentials required
if currentUser == nil {
return next(ctx)
}
// get the user from the object
userObj, ok := obj.(*models.User)
if !ok {
return nil, session.ErrUnauthorized
}
if currentUser.Username != userObj.Username {
return nil, session.ErrUnauthorized
}

View file

@ -37,6 +37,7 @@ type Resolver struct {
imageService manager.ImageService
galleryService manager.GalleryService
groupService manager.GroupService
userService manager.UserService
hookExecutor hookExecutor
}
@ -110,6 +111,9 @@ func (r *Resolver) Plugin() PluginResolver {
func (r *Resolver) ConfigResult() ConfigResultResolver {
return &configResultResolver{r}
}
func (r *Resolver) User() UserResolver {
return &userResolver{r}
}
type mutationResolver struct{ *Resolver }
type queryResolver struct{ *Resolver }
@ -136,6 +140,7 @@ type folderResolver struct{ *Resolver }
type savedFilterResolver struct{ *Resolver }
type pluginResolver struct{ *Resolver }
type configResultResolver struct{ *Resolver }
type userResolver struct{ *Resolver }
func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
return r.repository.WithTxn(ctx, fn)

View file

@ -0,0 +1,23 @@
package api
import (
"context"
"github.com/stashapp/stash/pkg/models"
)
func (r *userResolver) Name(ctx context.Context, obj *models.User) (string, error) {
return obj.Username, nil
}
func (r *userResolver) Roles(ctx context.Context, obj *models.User) ([]models.RoleEnum, error) {
ret := make([]models.RoleEnum, len(obj.Roles))
for i, role := range obj.Roles {
ret[i] = models.RoleEnum(role)
}
return ret, nil
}
func (r *userResolver) APIKey(ctx context.Context, obj *models.User) (*string, error) {
return nil, nil
}

View file

@ -320,7 +320,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen
if input.Password != nil {
// bit of a hack - check if the passed in password is the same as the stored hash
// and only set if they are different
currentPWHash, _ := c.GetPasswordHash(ctx, c.GetUsername())
currentPWHash := c.GetPasswordHash()
if *input.Password != currentPWHash {
if *input.Password == "" {

View file

@ -0,0 +1,62 @@
package api
import (
"context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/session"
)
func (r *mutationResolver) UserCreate(ctx context.Context, input UserCreateInput) (*models.User, error) {
err := r.userService.CreateUser(ctx, models.User{
Username: input.Name,
Roles: models.Roles(input.Roles),
}, input.Password)
if err != nil {
return nil, err
}
return r.userService.GetUser(ctx, input.Name)
}
func (r *mutationResolver) UserUpdate(ctx context.Context, input UserUpdateInput) (*models.User, error) {
err := r.userService.UpdateUser(ctx, input.ExistingName, models.User{
Username: input.Name,
Roles: models.Roles(input.Roles),
})
if err != nil {
return nil, err
}
return r.userService.GetUser(ctx, input.Name)
}
func (r *mutationResolver) UserDestroy(ctx context.Context, input UserDestroyInput) (bool, error) {
err := r.userService.DeleteUser(ctx, input.Name)
if err != nil {
return false, err
}
return true, nil
}
func (r *mutationResolver) ChangePassword(ctx context.Context, input UserChangePasswordInput) (bool, error) {
// get current user
u := session.GetCurrentUser(ctx)
err := r.userService.ChangePassword(ctx, u.Username, input.ExistingPassword, input.NewPassword)
if err != nil {
return false, err
}
return true, nil
}
func (r *mutationResolver) ChangeUserPassword(ctx context.Context, input ChangeUserPasswordInput) (bool, error) {
err := r.userService.ChangeUserPassword(ctx, input.Name, input.NewPassword)
if err != nil {
return false, err
}
return true, nil
}

View file

@ -78,9 +78,6 @@ func makeConfigGeneralResult() *ConfigGeneralResult {
customPerformerImageLocation := config.GetCustomPerformerImageLocation()
username := config.GetUsername()
pwHash, _ := config.GetPasswordHash(context.Background(), username)
return &ConfigGeneralResult{
Stashes: config.GetStashPaths(),
DatabasePath: config.GetDatabasePath(),
@ -113,7 +110,7 @@ func makeConfigGeneralResult() *ConfigGeneralResult {
GalleryCoverRegex: config.GetGalleryCoverRegex(),
APIKey: config.GetAPIKey(),
Username: config.GetUsername(),
Password: pwHash,
Password: config.GetPasswordHash(),
MaxSessionAge: config.GetMaxSessionAge(),
LogFile: &logFile,
LogOut: config.GetLogOut(),

View file

@ -0,0 +1,17 @@
package api
import (
"context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/session"
)
func (r *queryResolver) Users(ctx context.Context) ([]*models.User, error) {
return r.userService.AllUsers(ctx)
}
func (r *queryResolver) Me(ctx context.Context) (*models.User, error) {
// get current user
return session.GetCurrentUser(ctx), nil
}

View file

@ -164,19 +164,22 @@ func Initialize() (*Server, error) {
imageService := mgr.ImageService
galleryService := mgr.GalleryService
groupService := mgr.GroupService
userService := mgr.UserService
resolver := &Resolver{
repository: repo,
sceneService: sceneService,
imageService: imageService,
galleryService: galleryService,
groupService: groupService,
userService: userService,
hookExecutor: pluginCache,
}
gqlCfg := Config{
Resolvers: resolver,
Directives: DirectiveRoot{
HasRole: HasRoleDirective,
HasRole: HasRoleDirective,
IsUserOwner: IsUserOwnerDirective,
},
}
@ -236,9 +239,11 @@ func Initialize() (*Server, error) {
staticLoginUI := statigz.FileServer(ui.LoginUIBox.(fs.ReadDirFS))
r.Get(loginEndpoint, handleLogin())
r.Post(loginEndpoint, handleLoginPost())
r.Get(logoutEndpoint, handleLogout())
sessionStore := mgr.SessionStore
r.Get(loginEndpoint, handleLogin(userService))
r.Post(loginEndpoint, handleLoginPost(sessionStore))
r.Get(logoutEndpoint, handleLogout(sessionStore))
r.Get(loginLocaleEndpoint, handleLoginLocale(cfg))
r.HandleFunc(loginEndpoint+"/*", func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = strings.TrimPrefix(r.URL.Path, loginEndpoint)

View file

@ -103,11 +103,11 @@ func getLoginLocale(lang string) ([]byte, error) {
return data, nil
}
func handleLogin() http.HandlerFunc {
func handleLogin(s manager.UserService) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
returnURL := r.URL.Query().Get(returnURLParam)
if hc, _ := config.GetInstance().HasCredentials(r.Context()); !hc {
if hc := s.LoginRequired(r.Context()); !hc {
if returnURL != "" {
http.Redirect(w, r, returnURL, http.StatusFound)
} else {
@ -121,9 +121,9 @@ func handleLogin() http.HandlerFunc {
}
}
func handleLoginPost() http.HandlerFunc {
func handleLoginPost(s *session.Store) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
err := manager.GetInstance().SessionStore.Login(w, r)
err := s.Login(w, r)
if err != nil {
// always log the error
logger.Errorf("Error logging in: %v from IP: %s", err, r.RemoteAddr)
@ -146,7 +146,7 @@ func handleLoginPost() http.HandlerFunc {
}
}
func handleLogout() http.HandlerFunc {
func handleLogout(s *session.Store) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := manager.GetInstance().SessionStore.Logout(w, r); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -155,7 +155,7 @@ func handleLogout() http.HandlerFunc {
// redirect to the login page if credentials are required
prefix := getProxyPrefix(r)
if hc, _ := config.GetInstance().HasCredentials(r.Context()); hc {
if hc := s.LoginRequired(r.Context()); hc {
http.Redirect(w, r, prefix+loginEndpoint, http.StatusFound)
} else {
http.Redirect(w, r, prefix+"/", http.StatusFound)

View file

@ -1,7 +1,6 @@
package config
import (
"context"
"fmt"
"net/url"
"os"
@ -15,8 +14,6 @@ import (
"sync"
// "github.com/sasha-s/go-deadlock" // if you have deadlock issues
"golang.org/x/crypto/bcrypt"
"github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/file"
"github.com/knadh/koanf/v2"
@ -28,7 +25,6 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/paths"
"github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/user"
"github.com/stashapp/stash/pkg/utils"
)
@ -41,9 +37,8 @@ const (
BlobsPath = "blobs_path"
Downloads = "downloads"
ApiKey = "api_key"
Username = "username"
Password = "password"
MaxSessionAge = "max_session_age"
MaxSessionAge = "max_session_age"
// SFWContentMode mode config key
SFWContentMode = "sfw_content_mode"
@ -334,6 +329,9 @@ type Config struct {
// configUpdates chan int
certFile string
keyFile string
UserStore *UserStore
sync.RWMutex
// deadlock.RWMutex // for deadlock testing/issues
}
@ -431,6 +429,9 @@ func (i *Config) SetInterface(key string, value interface{}) {
i.Lock()
defer i.Unlock()
i.setInterfaceNoLock(key, value)
}
func (i *Config) setInterfaceNoLock(key string, value interface{}) {
i.set(key, value)
}
@ -480,6 +481,10 @@ func (i *Config) Write() error {
i.Lock()
defer i.Unlock()
return i.writeNoLock()
}
func (i *Config) writeNoLock() error {
data, err := i.marshal()
if err != nil {
return err
@ -1078,78 +1083,6 @@ func (i *Config) GetAPIKey() string {
return i.getString(ApiKey)
}
func (i *Config) GetUsername() string {
return i.getString(Username)
}
func (i *Config) GetUser(ctx context.Context, username string) (*user.User, error) {
// TODO - temp
if username == "read" {
return &user.User{
Username: username,
Roles: []user.RoleEnum{user.RoleEnumRead},
}, nil
}
u := i.GetUsername()
if u != username {
return nil, user.ErrUserNotFound
}
return &user.User{
Username: u,
Roles: []user.RoleEnum{user.RoleEnumAdmin},
}, nil
}
func (i *Config) GetPasswordHash(ctx context.Context, username string) (string, error) {
u := i.GetUsername()
if u != username {
return "", user.ErrUserNotFound
}
return i.getString(Password), nil
}
func (i *Config) GetCredentials() (string, string) {
if hc, _ := i.HasCredentials(context.Background()); hc {
return i.getString(Username), i.getString(Password)
}
return "", ""
}
func (i *Config) HasCredentials(ctx context.Context) (bool, error) {
username := i.getString(Username)
pwHash := i.getString(Password)
return username != "" && pwHash != "", nil
}
func hashPassword(password string) string {
hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost)
return string(hash)
}
func (i *Config) ValidateCredentials(username string, password string) bool {
if hc, _ := i.HasCredentials(context.Background()); !hc {
// don't need to authenticate if no credentials saved
return true
}
// TODO - temp
if username == "read" {
return password == "read"
}
authUser, authPWHash := i.GetCredentials()
err := bcrypt.CompareHashAndPassword([]byte(authPWHash), []byte(password))
return username == authUser && err == nil
}
func stashBoxValidate(str string) bool {
u, err := url.Parse(str)
return err == nil && u.Scheme != "" && u.Host != "" && strings.HasSuffix(u.Path, "/graphql")

View file

@ -22,8 +22,6 @@ func TestConcurrentConfigAccess(t *testing.T) {
t.Errorf("Failure setting initial configuration in worker %v iteration %v: %v", wk, l, err)
}
i.HasCredentials()
i.ValidateCredentials("", "")
i.GetConfigFile()
i.GetConfigPath()
i.GetDefaultDatabaseFilePath()
@ -75,7 +73,6 @@ func TestConcurrentConfigAccess(t *testing.T) {
i.SetInterface(ApiKey, i.GetAPIKey())
i.SetInterface(Username, i.GetUsername())
i.SetInterface(Password, i.GetPasswordHash())
i.GetCredentials()
i.SetInterface(MaxSessionAge, i.GetMaxSessionAge())
i.SetInterface(CustomServedFolders, i.GetCustomServedFolders())
i.SetInterface(LegacyCustomUILocation, i.GetUILocation())

View file

@ -62,6 +62,9 @@ func Initialize() (*Config, error) {
main: koanf.New("."),
overrides: koanf.New("."),
}
cfg.UserStore = &UserStore{
Config: cfg,
}
cfg.initOverrides()
@ -96,6 +99,10 @@ func Initialize() (*Config, error) {
}
}
if err := cfg.UserStore.loadUsers(); err != nil {
return nil, fmt.Errorf("failed to load users: %v", err)
}
instance = cfg
return instance, nil
}

View file

@ -0,0 +1,227 @@
package config
import (
"context"
"fmt"
"github.com/stashapp/stash/pkg/models"
"golang.org/x/crypto/bcrypt"
)
const (
Username = "username"
Password = "password"
Users = "users"
Roles = "roles"
)
type StoredUser struct {
Username string `json:"username" koanf:"username"`
PasswordHash string `json:"passwordhash" koanf:"passwordhash"`
Roles []models.RoleEnum `json:"roles" koanf:"roles"`
}
type UserStore struct {
*Config
cachedUsers map[string]StoredUser
}
func (s *Config) GetUsername() string {
return s.getString(Username)
}
func (i *Config) GetPasswordHash() string {
return i.getString(Password)
}
func (s *UserStore) legacyUser() *StoredUser {
un := s.getString(Username)
pwHash := s.getString(Password)
if un != "" && pwHash != "" {
return &StoredUser{
Username: un,
PasswordHash: pwHash,
Roles: []models.RoleEnum{models.RoleEnumAdmin},
}
}
return nil
}
func (s *UserStore) loadUsers() error {
// done outside lock to avoid deadlock
legacyUser := s.legacyUser()
s.RLock()
defer s.RUnlock()
var ret []*StoredUser
err := s.unmarshalKey(Users, &ret)
if err != nil {
return err
}
// add legacy username
if legacyUser != nil {
ret = append(ret, legacyUser)
}
s.cachedUsers = make(map[string]StoredUser)
for _, u := range ret {
s.cachedUsers[u.Username] = *u
}
return nil
}
func (s *UserStore) convertUser(su StoredUser) *models.User {
return &models.User{
Username: su.Username,
Roles: su.Roles,
}
}
func (s *UserStore) getUser(username string) *StoredUser {
u, ok := s.cachedUsers[username]
if !ok {
return nil
}
return &u
}
func (s *UserStore) GetUser(ctx context.Context, username string) (*models.User, error) {
s.RLock()
defer s.RUnlock()
u := s.getUser(username)
if u == nil {
return nil, nil
}
return s.convertUser(*u), nil
}
func (s *UserStore) AllUsers(ctx context.Context) ([]*models.User, error) {
var users []*models.User
s.RLock()
defer s.RUnlock()
for _, su := range s.cachedUsers {
users = append(users, s.convertUser(su))
}
return users, nil
}
func (s *UserStore) LoginRequired(ctx context.Context) bool {
return len(s.cachedUsers) > 0
}
func hashPassword(password string) string {
hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost)
return string(hash)
}
func (s *UserStore) ValidateCredentials(ctx context.Context, username string, password string) bool {
u := s.getUser(username)
if u == nil {
return false
}
err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
return err == nil
}
func (s *UserStore) saveUsers() error {
// convert to list
users := make([]StoredUser, 0, len(s.cachedUsers))
for _, u := range s.cachedUsers {
users = append(users, u)
}
s.setInterfaceNoLock(Users, users)
return s.writeNoLock()
}
func (s *UserStore) ChangeUserPassword(ctx context.Context, username string, newPassword string) error {
s.Lock()
defer s.Unlock()
u := s.getUser(username)
if u == nil {
return fmt.Errorf("user not found")
}
newHash := hashPassword(newPassword)
updatedUser := *u
updatedUser.PasswordHash = newHash
s.cachedUsers[username] = updatedUser
return s.saveUsers()
}
func (s *UserStore) CreateUser(ctx context.Context, u models.User, password string) error {
s.Lock()
defer s.Unlock()
existingUser := s.getUser(u.Username)
if existingUser != nil {
return fmt.Errorf("user already exists")
}
newUser := StoredUser{
Username: u.Username,
PasswordHash: hashPassword(password),
Roles: u.Roles,
}
s.cachedUsers[u.Username] = newUser
return s.saveUsers()
}
func (s *UserStore) ReplaceUser(ctx context.Context, username string, updated models.User) error {
s.Lock()
defer s.Unlock()
existingUser := s.getUser(username)
if existingUser == nil {
return fmt.Errorf("user not found")
}
updatedUser := StoredUser{
Username: updated.Username,
PasswordHash: existingUser.PasswordHash,
Roles: updated.Roles,
}
// if username changed, remove old entry
if username != updated.Username {
delete(s.cachedUsers, username)
}
s.cachedUsers[updated.Username] = updatedUser
return s.saveUsers()
}
func (s *UserStore) DeleteUser(ctx context.Context, username string) error {
s.Lock()
defer s.Unlock()
existingUser := s.getUser(username)
if existingUser == nil {
return fmt.Errorf("user not found")
}
delete(s.cachedUsers, username)
return s.saveUsers()
}

View file

@ -110,8 +110,8 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) {
scanSubs: &subscriptionManager{},
}
instance.UserService = &user.Service{
Store: cfg,
mgr.UserService = &user.Service{
Store: cfg.UserStore,
}
if !cfg.IsNewSystem() {
@ -135,7 +135,7 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) {
// create temporary session store - this will be re-initialised
// after config is complete
mgr.SessionStore = session.NewStore(cfg, cfg)
mgr.SessionStore = session.NewStore(cfg, instance.UserService)
logger.Warnf("config file %snot found. Assuming new system...", cfgFile)
}
@ -194,7 +194,7 @@ func initJobManager(cfg *config.Config) *job.Manager {
func (s *Manager) postInit(ctx context.Context) error {
s.RefreshConfig()
s.SessionStore = session.NewStore(s.Config, s.Config)
s.SessionStore = session.NewStore(s.Config, s.UserService)
s.PluginCache.RegisterSessionStore(s.SessionStore)
s.RefreshPluginCache()
@ -256,7 +256,7 @@ func (s *Manager) postInit(ctx context.Context) error {
}
func (s *Manager) checkSecurityTripwire() {
if err := session.CheckExternalAccessTripwire(s.Config); err != nil {
if err := session.CheckExternalAccessTripwire(s.Config.UserStore, s.Config); err != nil {
session.LogExternalAccessError(*err)
}
}

View file

@ -7,7 +7,7 @@ import (
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/user"
"github.com/stashapp/stash/pkg/session"
)
type SceneService interface {
@ -49,5 +49,14 @@ type GroupService interface {
}
type UserService interface {
GetUser(ctx context.Context, username string) (*user.User, error)
session.Authenticator
AllUsers(ctx context.Context) ([]*models.User, error)
GetUser(ctx context.Context, username string) (*models.User, error)
LoginRequired(ctx context.Context) bool
CreateUser(ctx context.Context, u models.User, password string) error
UpdateUser(ctx context.Context, username string, updated models.User) error
ChangePassword(ctx context.Context, username, existingPassword, newPassword string) error
ChangeUserPassword(ctx context.Context, username string, newPassword string) error
DeleteUser(ctx context.Context, username string) error
}

12
pkg/models/model_user.go Normal file
View file

@ -0,0 +1,12 @@
package models
type User struct {
Username string
Roles Roles
}
type UserInput struct {
Username string
Roles Roles
Password string
}

View file

@ -1,4 +1,4 @@
package user
package models
import (
"fmt"
@ -58,15 +58,13 @@ func (e RoleEnum) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
func IsRole(assignedRoles []RoleEnum, requiredRole RoleEnum) bool {
valid := false
type Roles []RoleEnum
for _, role := range assignedRoles {
if role.Implies(requiredRole) {
valid = true
break
func (r Roles) HasRole(role RoleEnum) bool {
for _, r := range r {
if r.Implies(role) {
return true
}
}
return valid
return false
}

View file

@ -16,8 +16,8 @@ func (e ExternalAccessError) Error() string {
return fmt.Sprintf("stash accessed from external IP %s", net.IP(e).String())
}
func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error {
if hc, _ := c.HasCredentials(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() {
func CheckAllowPublicWithoutAuth(s CredentialStore, c ExternalAccessConfig, r *http.Request) error {
if hc := s.LoginRequired(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() {
requestIPString, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return fmt.Errorf("error parsing remote host (%s): %w", r.RemoteAddr, err)
@ -60,8 +60,8 @@ func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error
return nil
}
func CheckExternalAccessTripwire(c ExternalAccessConfig) *ExternalAccessError {
if hc, _ := c.HasCredentials(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() {
func CheckExternalAccessTripwire(s CredentialStore, c ExternalAccessConfig) *ExternalAccessError {
if hc := s.LoginRequired(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() {
if remoteIP := c.GetSecurityTripwireAccessedFromPublicInternet(); remoteIP != "" {
err := ExternalAccessError(net.ParseIP(remoteIP))
return &err

View file

@ -3,12 +3,15 @@ package session
import "context"
type ExternalAccessConfig interface {
HasCredentials(ctx context.Context) (bool, error)
GetDangerousAllowPublicWithoutAuth() bool
GetSecurityTripwireAccessedFromPublicInternet() string
IsNewSystem() bool
}
type CredentialStore interface {
LoginRequired(ctx context.Context) bool
}
type SessionConfig interface {
GetUsername() string
GetAPIKey() string

View file

@ -8,7 +8,7 @@ import (
"github.com/gorilla/sessions"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/user"
"github.com/stashapp/stash/pkg/models"
)
type key int
@ -46,7 +46,8 @@ func (e InvalidCredentialsError) Error() string {
var ErrUnauthorized = errors.New("unauthorized")
type Authenticator interface {
ValidateCredentials(username string, password string) bool
LoginRequired(ctx context.Context) bool
ValidateCredentials(ctx context.Context, username string, password string) error
}
type Store struct {
@ -68,6 +69,10 @@ func NewStore(c SessionConfig, a Authenticator) *Store {
return ret
}
func (s *Store) LoginRequired(ctx context.Context) bool {
return s.authenticator.LoginRequired(ctx)
}
func (s *Store) Login(w http.ResponseWriter, r *http.Request) error {
// ignore error - we want a new session regardless
newSession, _ := s.sessionStore.Get(r, cookieName)
@ -76,16 +81,16 @@ func (s *Store) Login(w http.ResponseWriter, r *http.Request) error {
password := r.FormValue(passwordFormKey)
// authenticate the user
if !s.authenticator.ValidateCredentials(username, password) {
err := s.authenticator.ValidateCredentials(r.Context(), username, password)
if err != nil {
return &InvalidCredentialsError{Username: username}
}
// since we only have one user, don't leak the name
logger.Info("User logged in")
logger.Infof("User %s logged in", username)
newSession.Values[userIDKey] = username
err := newSession.Save(r, w)
err = newSession.Save(r, w)
if err != nil {
return err
}
@ -99,6 +104,8 @@ func (s *Store) Logout(w http.ResponseWriter, r *http.Request) error {
return err
}
userID, _ := session.Values[userIDKey].(string)
delete(session.Values, userIDKey)
session.Options.MaxAge = -1
@ -107,8 +114,7 @@ func (s *Store) Logout(w http.ResponseWriter, r *http.Request) error {
return err
}
// since we only have one user, don't leak the name
logger.Infof("User logged out")
logger.Infof("User %s logged out", userID)
return nil
}
@ -138,15 +144,15 @@ func (s *Store) GetSessionUserID(w http.ResponseWriter, r *http.Request) (string
return "", nil
}
func SetCurrentUser(ctx context.Context, u user.User) context.Context {
func SetCurrentUser(ctx context.Context, u models.User) context.Context {
return context.WithValue(ctx, contextUser, u)
}
// GetCurrentUser gets the current user id from the provided context
func GetCurrentUser(ctx context.Context) *user.User {
func GetCurrentUser(ctx context.Context) *models.User {
userCtxVal := ctx.Value(contextUser)
if userCtxVal != nil {
currentUser := userCtxVal.(user.User)
currentUser := userCtxVal.(models.User)
return &currentUser
}
@ -164,6 +170,7 @@ func (s *Store) Authenticate(w http.ResponseWriter, r *http.Request) (userID str
apiKey = r.URL.Query().Get(ApiKeyParameter)
}
// FIXME - handle this
if apiKey != "" {
// match against configured API and set userID to the
// configured username. In future, we'll want to

View file

@ -1,44 +1 @@
package user
import (
"context"
"errors"
"fmt"
"golang.org/x/crypto/bcrypt"
)
var (
ErrPasswordMismatch = fmt.Errorf("password mismatch")
ErrUserNotFound = fmt.Errorf("user not found")
)
func (s *Service) ValidateCredentials(ctx context.Context, username string, password string) error {
hc, err := s.Store.HasCredentials(ctx)
if err != nil {
return fmt.Errorf("error checking if credentials exist: %w", err)
}
if !hc {
// don't need to authenticate if no credentials saved
return nil
}
authPWHash, err := s.Store.GetPasswordHash(ctx, username)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return err
}
return fmt.Errorf("error getting password hash: %w", err)
}
if err := bcrypt.CompareHashAndPassword([]byte(authPWHash), []byte(password)); err != nil {
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
return ErrPasswordMismatch
}
return fmt.Errorf("error comparing password hash: %w", err)
}
return nil
}

View file

@ -2,18 +2,266 @@ package user
import (
"context"
"errors"
"fmt"
"slices"
"strings"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
)
type Store interface {
GetUser(ctx context.Context, username string) (*User, error)
HasCredentials(ctx context.Context) (bool, error)
GetPasswordHash(ctx context.Context, username string) (string, error)
var (
ErrUserNotExist = errors.New("user not found")
ErrEmptyUsername = errors.New("empty username")
ErrUsernameHasWhitespace = errors.New("username has leading or trailing whitespace")
ErrDeleteLastAdminUser = errors.New("final admin user cannot be deleted")
ErrInternalError = errors.New("internal error")
ErrAccessDenied = errors.New("access denied")
ErrCurrentPasswordIncorrect = errors.New("current password incorrect")
ErrUserAlreadyExists = errors.New("user with that username already exists")
)
type UserSource interface {
AllUsers(ctx context.Context) ([]*models.User, error)
GetUser(ctx context.Context, username string) (*models.User, error)
ValidateCredentials(ctx context.Context, username string, password string) bool
CreateUser(ctx context.Context, u models.User, password string) error
ReplaceUser(ctx context.Context, username string, updated models.User) error
ChangeUserPassword(ctx context.Context, username string, newPassword string) error
DeleteUser(ctx context.Context, username string) error
}
type Service struct {
Store Store
Store UserSource
}
func (s *Service) GetUser(ctx context.Context, username string) (*User, error) {
func (s *Service) LoginRequired(ctx context.Context) bool {
u, _ := s.Store.AllUsers(ctx)
return len(u) > 0
}
func (s *Service) GetUser(ctx context.Context, username string) (*models.User, error) {
return s.Store.GetUser(ctx, username)
}
func (s *Service) AllUsers(ctx context.Context) ([]*models.User, error) {
return s.Store.AllUsers(ctx)
}
func (s *Service) ValidateCredentials(ctx context.Context, username string, password string) error {
// ensure user is not locked
u, err := s.GetUser(ctx, username)
if err != nil {
logger.Errorf("error getting user for credential validation: %v", err)
return ErrInternalError
}
if u == nil {
logger.Infof("[login attempt] user %s not found during credential validation", username)
return ErrAccessDenied
}
if len(u.Roles) == 0 {
logger.Infof("[login attempt] user %s is locked", username)
return ErrAccessDenied
}
if !s.Store.ValidateCredentials(ctx, username, password) {
logger.Infof("[login attempt] invalid credentials for user %s", username)
return ErrAccessDenied
}
return nil
}
func (s *Service) validateUsername(username string) error {
if username == "" {
return ErrEmptyUsername
}
// username must not have leading or trailing whitespace
trimmed := strings.TrimSpace(username)
if trimmed != username {
return ErrUsernameHasWhitespace
}
return nil
}
func (s *Service) validatePassword(password string) error {
if password == "" {
return errors.New("password cannot be empty")
}
// add more password validation as needed
return nil
}
func (s *Service) CreateUser(ctx context.Context, u models.User, password string) error {
// validate input
// ensure username is valid
if err := s.validateUsername(u.Username); err != nil {
return err
}
// check if user exists
existingUser, err := s.GetUser(ctx, u.Username)
if err != nil {
return fmt.Errorf("error checking existing users: %w", err)
}
if existingUser != nil {
return ErrUserAlreadyExists
}
// validate password
if err := s.validatePassword(password); err != nil {
return err
}
// if this is the first user, make them an admin
users, err := s.AllUsers(ctx)
if err != nil {
return fmt.Errorf("error getting existing users: %w", err)
}
if len(users) == 0 && !u.Roles.HasRole(models.RoleEnumAdmin) {
return errors.New("the first user must be an admin")
}
// create user in store
if err := s.Store.CreateUser(ctx, u, password); err != nil {
return fmt.Errorf("error creating user: %w", err)
}
logger.Infof("[user] created %q", u.Username)
return nil
}
func (s *Service) UpdateUser(ctx context.Context, username string, updated models.User) error {
// validate input
// check if user exists
existingUser, err := s.GetUser(ctx, username)
if err != nil {
return fmt.Errorf("error getting existing user: %w", err)
}
if existingUser == nil {
return ErrUserNotExist
}
existingRoles := existingUser.Roles
// ensure username is valid
if username != updated.Username {
if err := s.validateUsername(updated.Username); err != nil {
return err
}
// ensure new username doesn't already exist
otherUser, err := s.GetUser(ctx, updated.Username)
if err != nil {
return fmt.Errorf("error checking existing user: %w", err)
}
if otherUser != nil {
return ErrUserAlreadyExists
}
}
// update user in store
if err := s.Store.ReplaceUser(ctx, username, updated); err != nil {
return fmt.Errorf("error updating user: %w", err)
}
if username != updated.Username {
logger.Infof("[user] updated name %q -> %q", username, updated.Username)
}
if !slices.Equal(existingRoles, updated.Roles) {
logger.Infof("[user] updated roles for user %q", updated.Username)
}
return nil
}
func (s *Service) ChangePassword(ctx context.Context, username, currentPassword, newPassword string) error {
// validate current credentials
if err := s.ValidateCredentials(ctx, username, currentPassword); err != nil {
logger.Infof("[user] failed password change attempt for %q: incorrect current password", username)
return ErrCurrentPasswordIncorrect
}
return s.ChangeUserPassword(ctx, username, newPassword)
}
func (s *Service) ChangeUserPassword(ctx context.Context, username, newPassword string) error {
// check if user exists
existingUser, err := s.GetUser(ctx, username)
if err != nil {
return fmt.Errorf("error getting existing user: %w", err)
}
if existingUser == nil {
return ErrUserNotExist
}
// validate new password
if err := s.validatePassword(newPassword); err != nil {
return err
}
// change password in store
if err := s.Store.ChangeUserPassword(ctx, username, newPassword); err != nil {
return fmt.Errorf("error changing user password: %w", err)
}
logger.Infof("[user] changed password for %q", username)
return nil
}
func (s *Service) DeleteUser(ctx context.Context, username string) error {
// check if user exists
existingUser, err := s.GetUser(ctx, username)
if err != nil {
return fmt.Errorf("error getting existing user: %w", err)
}
if existingUser == nil {
return ErrUserNotExist
}
// don't allow deleting last admin user
if existingUser.Roles.HasRole(models.RoleEnumAdmin) {
users, err := s.AllUsers(ctx)
if err != nil {
return fmt.Errorf("error getting all users: %w", err)
}
hasAdmin := false
for _, u := range users {
if u.Username != username && u.Roles.HasRole(models.RoleEnumAdmin) {
hasAdmin = true
break
}
}
if !hasAdmin {
return ErrDeleteLastAdminUser
}
}
// delete user from store
if err := s.Store.DeleteUser(ctx, username); err != nil {
return fmt.Errorf("error deleting user: %w", err)
}
logger.Infof("[user] deleted %q", username)
return nil
}

View file

@ -1,6 +0,0 @@
package user
type User struct {
Username string
Roles []RoleEnum
}