diff --git a/gqlgen.yml b/gqlgen.yml index b23b36ca9..981eb1d21 100644 --- a/gqlgen.yml +++ b/gqlgen.yml @@ -126,7 +126,7 @@ models: ScraperSource: model: github.com/stashapp/stash/pkg/scraper.Source RoleEnum: - model: github.com/stashapp/stash/pkg/user.RoleEnum + model: github.com/stashapp/stash/pkg/models.RoleEnum IdentifySourceInput: model: github.com/stashapp/stash/internal/identify.Source IdentifyFieldOptionsInput: diff --git a/graphql/schema/schema.graphql b/graphql/schema/schema.graphql index 12ac5a472..7a175a637 100644 --- a/graphql/schema/schema.graphql +++ b/graphql/schema/schema.graphql @@ -269,6 +269,7 @@ type Query { allStudios: [Studio!]! @deprecated(reason: "Use findStudios instead") allMovies: [Movie!]! @deprecated(reason: "Use findGroups instead") + users: [User!]! @hasRole(role: ADMIN) """Returns currently authenticated user""" me: User @@ -593,6 +594,7 @@ type Mutation { userCreate(input: UserCreateInput!): User @hasRole(role: ADMIN) userUpdate(input: UserUpdateInput!): User @hasRole(role: ADMIN) userDestroy(input: UserDestroyInput!): Boolean! @hasRole(role: ADMIN) + changeUserPassword(input: ChangeUserPasswordInput!): Boolean! @hasRole(role: ADMIN) changePassword(input: UserChangePasswordInput!): Boolean! } diff --git a/graphql/schema/types/user.graphql b/graphql/schema/types/user.graphql index 6226f061b..92802ad89 100644 --- a/graphql/schema/types/user.graphql +++ b/graphql/schema/types/user.graphql @@ -9,7 +9,10 @@ directive @isUserOwner on FIELD_DEFINITION type User { name: String! - """Should not be visible to other users""" + """ + If the user has no roles, they are considered locked and cannot log in. + Should not be visible to other users + """ roles: [RoleEnum!] @isUserOwner """Should not be visible to other users""" api_key: String @isUserOwner @@ -23,19 +26,22 @@ input UserCreateInput { } input UserUpdateInput { - name: String - """Password in plain text""" - password: String - roles: [RoleEnum!] + existingName: String! + name: String! + roles: [RoleEnum!]! } input UserDestroyInput { - id: ID! + name: String! } input UserChangePasswordInput { """Password in plain text""" - existing_password: String - new_password: String! - reset_key: String + existingPassword: String! + newPassword: String! +} + +input ChangeUserPasswordInput { + name: String! + newPassword: String! } \ No newline at end of file diff --git a/internal/api/authentication.go b/internal/api/authentication.go index 1099a5702..aac3b2550 100644 --- a/internal/api/authentication.go +++ b/internal/api/authentication.go @@ -12,8 +12,8 @@ import ( "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/session" - "github.com/stashapp/stash/pkg/user" ) const ( @@ -32,16 +32,17 @@ func allowUnauthenticated(r *http.Request) bool { } type UserGetter interface { - GetUser(ctx context.Context, username string) (*user.User, error) + GetUser(ctx context.Context, username string) (*models.User, error) } func authenticateHandler(g UserGetter) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c := config.GetInstance() + s := c.UserStore // error if external access tripwire activated - if accessErr := session.CheckExternalAccessTripwire(c); accessErr != nil { + if accessErr := session.CheckExternalAccessTripwire(s, c); accessErr != nil { http.Error(w, tripwireActivatedErrMsg, http.StatusForbidden) return } @@ -59,7 +60,9 @@ func authenticateHandler(g UserGetter) func(http.Handler) http.Handler { return } - if err := session.CheckAllowPublicWithoutAuth(c, r); err != nil { + ctx := r.Context() + + if err := session.CheckAllowPublicWithoutAuth(s, c, r); err != nil { var accessErr session.ExternalAccessError if errors.As(err, &accessErr) { session.LogExternalAccessError(accessErr) @@ -77,11 +80,23 @@ func authenticateHandler(g UserGetter) func(http.Handler) http.Handler { return } - ctx := r.Context() + var u *models.User + if userID != "" { + u, err = g.GetUser(ctx, userID) + if err != nil { + // if we can't get the user object, we just return a forbidden error + logger.Errorf("Error getting user object: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if u == nil { + logger.Errorf("[User] cookie user %q not found", userID) + } + } - if hc, _ := c.HasCredentials(ctx); hc { + if hc := s.LoginRequired(ctx); hc { // authentication is required - if userID == "" && !allowUnauthenticated(r) { + if u == nil && !allowUnauthenticated(r) { // if graphql or a non-webpage was requested, we just return a forbidden error ext := path.Ext(r.URL.Path) if r.URL.Path == gqlEndpoint || (ext != "" && ext != ".html") { @@ -108,16 +123,8 @@ func authenticateHandler(g UserGetter) func(http.Handler) http.Handler { } } - if userID != "" { + if u != nil { // set the user object in the context - u, err := g.GetUser(ctx, userID) - if err != nil { - // if we can't get the user object, we just return a forbidden error - logger.Errorf("Error getting user object: %v", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - ctx = session.SetCurrentUser(ctx, *u) } diff --git a/internal/api/directives.go b/internal/api/directives.go index 9e839bb06..a5d8341e8 100644 --- a/internal/api/directives.go +++ b/internal/api/directives.go @@ -4,11 +4,11 @@ import ( "context" "github.com/99designs/gqlgen/graphql" + "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/session" - "github.com/stashapp/stash/pkg/user" ) -func HasRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role user.RoleEnum) (interface{}, error) { +func HasRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role models.RoleEnum) (interface{}, error) { currentUser := session.GetCurrentUser(ctx) // if there is no current user, this is an anonymous request @@ -17,7 +17,29 @@ func HasRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolve return next(ctx) } - if currentUser != nil && !user.IsRole(currentUser.Roles, role) { + if currentUser != nil && !currentUser.Roles.HasRole(role) { + return nil, session.ErrUnauthorized + } + + return next(ctx) +} + +func IsUserOwnerDirective(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) { + currentUser := session.GetCurrentUser(ctx) + + // if there is no current user, this is an anonymous request + // we should not end up here unless there are no credentials required + if currentUser == nil { + return next(ctx) + } + + // get the user from the object + userObj, ok := obj.(*models.User) + if !ok { + return nil, session.ErrUnauthorized + } + + if currentUser.Username != userObj.Username { return nil, session.ErrUnauthorized } diff --git a/internal/api/resolver.go b/internal/api/resolver.go index 061d0e1a9..509a2f170 100644 --- a/internal/api/resolver.go +++ b/internal/api/resolver.go @@ -37,6 +37,7 @@ type Resolver struct { imageService manager.ImageService galleryService manager.GalleryService groupService manager.GroupService + userService manager.UserService hookExecutor hookExecutor } @@ -110,6 +111,9 @@ func (r *Resolver) Plugin() PluginResolver { func (r *Resolver) ConfigResult() ConfigResultResolver { return &configResultResolver{r} } +func (r *Resolver) User() UserResolver { + return &userResolver{r} +} type mutationResolver struct{ *Resolver } type queryResolver struct{ *Resolver } @@ -136,6 +140,7 @@ type folderResolver struct{ *Resolver } type savedFilterResolver struct{ *Resolver } type pluginResolver struct{ *Resolver } type configResultResolver struct{ *Resolver } +type userResolver struct{ *Resolver } func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error { return r.repository.WithTxn(ctx, fn) diff --git a/internal/api/resolver_model_user.go b/internal/api/resolver_model_user.go new file mode 100644 index 000000000..11919cf64 --- /dev/null +++ b/internal/api/resolver_model_user.go @@ -0,0 +1,23 @@ +package api + +import ( + "context" + + "github.com/stashapp/stash/pkg/models" +) + +func (r *userResolver) Name(ctx context.Context, obj *models.User) (string, error) { + return obj.Username, nil +} + +func (r *userResolver) Roles(ctx context.Context, obj *models.User) ([]models.RoleEnum, error) { + ret := make([]models.RoleEnum, len(obj.Roles)) + for i, role := range obj.Roles { + ret[i] = models.RoleEnum(role) + } + return ret, nil +} + +func (r *userResolver) APIKey(ctx context.Context, obj *models.User) (*string, error) { + return nil, nil +} diff --git a/internal/api/resolver_mutation_configure.go b/internal/api/resolver_mutation_configure.go index 67ace9552..23b61c208 100644 --- a/internal/api/resolver_mutation_configure.go +++ b/internal/api/resolver_mutation_configure.go @@ -320,7 +320,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen if input.Password != nil { // bit of a hack - check if the passed in password is the same as the stored hash // and only set if they are different - currentPWHash, _ := c.GetPasswordHash(ctx, c.GetUsername()) + currentPWHash := c.GetPasswordHash() if *input.Password != currentPWHash { if *input.Password == "" { diff --git a/internal/api/resolver_mutation_user.go b/internal/api/resolver_mutation_user.go new file mode 100644 index 000000000..cc48ceef3 --- /dev/null +++ b/internal/api/resolver_mutation_user.go @@ -0,0 +1,62 @@ +package api + +import ( + "context" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/session" +) + +func (r *mutationResolver) UserCreate(ctx context.Context, input UserCreateInput) (*models.User, error) { + err := r.userService.CreateUser(ctx, models.User{ + Username: input.Name, + Roles: models.Roles(input.Roles), + }, input.Password) + if err != nil { + return nil, err + } + + return r.userService.GetUser(ctx, input.Name) +} + +func (r *mutationResolver) UserUpdate(ctx context.Context, input UserUpdateInput) (*models.User, error) { + err := r.userService.UpdateUser(ctx, input.ExistingName, models.User{ + Username: input.Name, + Roles: models.Roles(input.Roles), + }) + if err != nil { + return nil, err + } + + return r.userService.GetUser(ctx, input.Name) +} + +func (r *mutationResolver) UserDestroy(ctx context.Context, input UserDestroyInput) (bool, error) { + err := r.userService.DeleteUser(ctx, input.Name) + if err != nil { + return false, err + } + + return true, nil +} + +func (r *mutationResolver) ChangePassword(ctx context.Context, input UserChangePasswordInput) (bool, error) { + // get current user + u := session.GetCurrentUser(ctx) + + err := r.userService.ChangePassword(ctx, u.Username, input.ExistingPassword, input.NewPassword) + if err != nil { + return false, err + } + + return true, nil +} + +func (r *mutationResolver) ChangeUserPassword(ctx context.Context, input ChangeUserPasswordInput) (bool, error) { + err := r.userService.ChangeUserPassword(ctx, input.Name, input.NewPassword) + if err != nil { + return false, err + } + + return true, nil +} diff --git a/internal/api/resolver_query_configuration.go b/internal/api/resolver_query_configuration.go index 11727ae36..bc76212eb 100644 --- a/internal/api/resolver_query_configuration.go +++ b/internal/api/resolver_query_configuration.go @@ -78,9 +78,6 @@ func makeConfigGeneralResult() *ConfigGeneralResult { customPerformerImageLocation := config.GetCustomPerformerImageLocation() - username := config.GetUsername() - pwHash, _ := config.GetPasswordHash(context.Background(), username) - return &ConfigGeneralResult{ Stashes: config.GetStashPaths(), DatabasePath: config.GetDatabasePath(), @@ -113,7 +110,7 @@ func makeConfigGeneralResult() *ConfigGeneralResult { GalleryCoverRegex: config.GetGalleryCoverRegex(), APIKey: config.GetAPIKey(), Username: config.GetUsername(), - Password: pwHash, + Password: config.GetPasswordHash(), MaxSessionAge: config.GetMaxSessionAge(), LogFile: &logFile, LogOut: config.GetLogOut(), diff --git a/internal/api/resolver_query_user.go b/internal/api/resolver_query_user.go new file mode 100644 index 000000000..813e595cd --- /dev/null +++ b/internal/api/resolver_query_user.go @@ -0,0 +1,17 @@ +package api + +import ( + "context" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/session" +) + +func (r *queryResolver) Users(ctx context.Context) ([]*models.User, error) { + return r.userService.AllUsers(ctx) +} + +func (r *queryResolver) Me(ctx context.Context) (*models.User, error) { + // get current user + return session.GetCurrentUser(ctx), nil +} diff --git a/internal/api/server.go b/internal/api/server.go index 8ea1d700c..dbc20e346 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -164,19 +164,22 @@ func Initialize() (*Server, error) { imageService := mgr.ImageService galleryService := mgr.GalleryService groupService := mgr.GroupService + userService := mgr.UserService resolver := &Resolver{ repository: repo, sceneService: sceneService, imageService: imageService, galleryService: galleryService, groupService: groupService, + userService: userService, hookExecutor: pluginCache, } gqlCfg := Config{ Resolvers: resolver, Directives: DirectiveRoot{ - HasRole: HasRoleDirective, + HasRole: HasRoleDirective, + IsUserOwner: IsUserOwnerDirective, }, } @@ -236,9 +239,11 @@ func Initialize() (*Server, error) { staticLoginUI := statigz.FileServer(ui.LoginUIBox.(fs.ReadDirFS)) - r.Get(loginEndpoint, handleLogin()) - r.Post(loginEndpoint, handleLoginPost()) - r.Get(logoutEndpoint, handleLogout()) + sessionStore := mgr.SessionStore + + r.Get(loginEndpoint, handleLogin(userService)) + r.Post(loginEndpoint, handleLoginPost(sessionStore)) + r.Get(logoutEndpoint, handleLogout(sessionStore)) r.Get(loginLocaleEndpoint, handleLoginLocale(cfg)) r.HandleFunc(loginEndpoint+"/*", func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.TrimPrefix(r.URL.Path, loginEndpoint) diff --git a/internal/api/session.go b/internal/api/session.go index f779fe992..e66e00b0d 100644 --- a/internal/api/session.go +++ b/internal/api/session.go @@ -103,11 +103,11 @@ func getLoginLocale(lang string) ([]byte, error) { return data, nil } -func handleLogin() http.HandlerFunc { +func handleLogin(s manager.UserService) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { returnURL := r.URL.Query().Get(returnURLParam) - if hc, _ := config.GetInstance().HasCredentials(r.Context()); !hc { + if hc := s.LoginRequired(r.Context()); !hc { if returnURL != "" { http.Redirect(w, r, returnURL, http.StatusFound) } else { @@ -121,9 +121,9 @@ func handleLogin() http.HandlerFunc { } } -func handleLoginPost() http.HandlerFunc { +func handleLoginPost(s *session.Store) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - err := manager.GetInstance().SessionStore.Login(w, r) + err := s.Login(w, r) if err != nil { // always log the error logger.Errorf("Error logging in: %v from IP: %s", err, r.RemoteAddr) @@ -146,7 +146,7 @@ func handleLoginPost() http.HandlerFunc { } } -func handleLogout() http.HandlerFunc { +func handleLogout(s *session.Store) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if err := manager.GetInstance().SessionStore.Logout(w, r); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -155,7 +155,7 @@ func handleLogout() http.HandlerFunc { // redirect to the login page if credentials are required prefix := getProxyPrefix(r) - if hc, _ := config.GetInstance().HasCredentials(r.Context()); hc { + if hc := s.LoginRequired(r.Context()); hc { http.Redirect(w, r, prefix+loginEndpoint, http.StatusFound) } else { http.Redirect(w, r, prefix+"/", http.StatusFound) diff --git a/internal/manager/config/config.go b/internal/manager/config/config.go index a5b9da425..e1c35fd6b 100644 --- a/internal/manager/config/config.go +++ b/internal/manager/config/config.go @@ -1,7 +1,6 @@ package config import ( - "context" "fmt" "net/url" "os" @@ -15,8 +14,6 @@ import ( "sync" // "github.com/sasha-s/go-deadlock" // if you have deadlock issues - "golang.org/x/crypto/bcrypt" - "github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/providers/file" "github.com/knadh/koanf/v2" @@ -28,7 +25,6 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/paths" "github.com/stashapp/stash/pkg/sliceutil" - "github.com/stashapp/stash/pkg/user" "github.com/stashapp/stash/pkg/utils" ) @@ -41,9 +37,8 @@ const ( BlobsPath = "blobs_path" Downloads = "downloads" ApiKey = "api_key" - Username = "username" - Password = "password" - MaxSessionAge = "max_session_age" + + MaxSessionAge = "max_session_age" // SFWContentMode mode config key SFWContentMode = "sfw_content_mode" @@ -334,6 +329,9 @@ type Config struct { // configUpdates chan int certFile string keyFile string + + UserStore *UserStore + sync.RWMutex // deadlock.RWMutex // for deadlock testing/issues } @@ -431,6 +429,9 @@ func (i *Config) SetInterface(key string, value interface{}) { i.Lock() defer i.Unlock() + i.setInterfaceNoLock(key, value) +} +func (i *Config) setInterfaceNoLock(key string, value interface{}) { i.set(key, value) } @@ -480,6 +481,10 @@ func (i *Config) Write() error { i.Lock() defer i.Unlock() + return i.writeNoLock() +} + +func (i *Config) writeNoLock() error { data, err := i.marshal() if err != nil { return err @@ -1078,78 +1083,6 @@ func (i *Config) GetAPIKey() string { return i.getString(ApiKey) } -func (i *Config) GetUsername() string { - return i.getString(Username) -} - -func (i *Config) GetUser(ctx context.Context, username string) (*user.User, error) { - // TODO - temp - if username == "read" { - return &user.User{ - Username: username, - Roles: []user.RoleEnum{user.RoleEnumRead}, - }, nil - } - - u := i.GetUsername() - if u != username { - return nil, user.ErrUserNotFound - } - - return &user.User{ - Username: u, - Roles: []user.RoleEnum{user.RoleEnumAdmin}, - }, nil -} - -func (i *Config) GetPasswordHash(ctx context.Context, username string) (string, error) { - u := i.GetUsername() - if u != username { - return "", user.ErrUserNotFound - } - - return i.getString(Password), nil -} - -func (i *Config) GetCredentials() (string, string) { - if hc, _ := i.HasCredentials(context.Background()); hc { - return i.getString(Username), i.getString(Password) - } - - return "", "" -} - -func (i *Config) HasCredentials(ctx context.Context) (bool, error) { - username := i.getString(Username) - pwHash := i.getString(Password) - - return username != "" && pwHash != "", nil -} - -func hashPassword(password string) string { - hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost) - - return string(hash) -} - -func (i *Config) ValidateCredentials(username string, password string) bool { - if hc, _ := i.HasCredentials(context.Background()); !hc { - // don't need to authenticate if no credentials saved - return true - } - - // TODO - temp - if username == "read" { - return password == "read" - } - - authUser, authPWHash := i.GetCredentials() - - err := bcrypt.CompareHashAndPassword([]byte(authPWHash), []byte(password)) - - return username == authUser && err == nil -} - func stashBoxValidate(str string) bool { u, err := url.Parse(str) return err == nil && u.Scheme != "" && u.Host != "" && strings.HasSuffix(u.Path, "/graphql") diff --git a/internal/manager/config/config_concurrency_test.go b/internal/manager/config/config_concurrency_test.go index fd9b067c7..55774be61 100644 --- a/internal/manager/config/config_concurrency_test.go +++ b/internal/manager/config/config_concurrency_test.go @@ -22,8 +22,6 @@ func TestConcurrentConfigAccess(t *testing.T) { t.Errorf("Failure setting initial configuration in worker %v iteration %v: %v", wk, l, err) } - i.HasCredentials() - i.ValidateCredentials("", "") i.GetConfigFile() i.GetConfigPath() i.GetDefaultDatabaseFilePath() @@ -75,7 +73,6 @@ func TestConcurrentConfigAccess(t *testing.T) { i.SetInterface(ApiKey, i.GetAPIKey()) i.SetInterface(Username, i.GetUsername()) i.SetInterface(Password, i.GetPasswordHash()) - i.GetCredentials() i.SetInterface(MaxSessionAge, i.GetMaxSessionAge()) i.SetInterface(CustomServedFolders, i.GetCustomServedFolders()) i.SetInterface(LegacyCustomUILocation, i.GetUILocation()) diff --git a/internal/manager/config/init.go b/internal/manager/config/init.go index 840b50b70..e0dee9735 100644 --- a/internal/manager/config/init.go +++ b/internal/manager/config/init.go @@ -62,6 +62,9 @@ func Initialize() (*Config, error) { main: koanf.New("."), overrides: koanf.New("."), } + cfg.UserStore = &UserStore{ + Config: cfg, + } cfg.initOverrides() @@ -96,6 +99,10 @@ func Initialize() (*Config, error) { } } + if err := cfg.UserStore.loadUsers(); err != nil { + return nil, fmt.Errorf("failed to load users: %v", err) + } + instance = cfg return instance, nil } diff --git a/internal/manager/config/users.go b/internal/manager/config/users.go new file mode 100644 index 000000000..960d8ec4b --- /dev/null +++ b/internal/manager/config/users.go @@ -0,0 +1,227 @@ +package config + +import ( + "context" + "fmt" + + "github.com/stashapp/stash/pkg/models" + "golang.org/x/crypto/bcrypt" +) + +const ( + Username = "username" + Password = "password" + Users = "users" + Roles = "roles" +) + +type StoredUser struct { + Username string `json:"username" koanf:"username"` + PasswordHash string `json:"passwordhash" koanf:"passwordhash"` + Roles []models.RoleEnum `json:"roles" koanf:"roles"` +} + +type UserStore struct { + *Config + + cachedUsers map[string]StoredUser +} + +func (s *Config) GetUsername() string { + return s.getString(Username) +} + +func (i *Config) GetPasswordHash() string { + return i.getString(Password) +} + +func (s *UserStore) legacyUser() *StoredUser { + un := s.getString(Username) + pwHash := s.getString(Password) + + if un != "" && pwHash != "" { + return &StoredUser{ + Username: un, + PasswordHash: pwHash, + Roles: []models.RoleEnum{models.RoleEnumAdmin}, + } + } + + return nil +} + +func (s *UserStore) loadUsers() error { + // done outside lock to avoid deadlock + legacyUser := s.legacyUser() + + s.RLock() + defer s.RUnlock() + + var ret []*StoredUser + err := s.unmarshalKey(Users, &ret) + if err != nil { + return err + } + + // add legacy username + if legacyUser != nil { + ret = append(ret, legacyUser) + } + + s.cachedUsers = make(map[string]StoredUser) + for _, u := range ret { + s.cachedUsers[u.Username] = *u + } + + return nil +} + +func (s *UserStore) convertUser(su StoredUser) *models.User { + return &models.User{ + Username: su.Username, + Roles: su.Roles, + } +} + +func (s *UserStore) getUser(username string) *StoredUser { + u, ok := s.cachedUsers[username] + if !ok { + return nil + } + + return &u +} + +func (s *UserStore) GetUser(ctx context.Context, username string) (*models.User, error) { + s.RLock() + defer s.RUnlock() + + u := s.getUser(username) + if u == nil { + return nil, nil + } + + return s.convertUser(*u), nil +} + +func (s *UserStore) AllUsers(ctx context.Context) ([]*models.User, error) { + var users []*models.User + + s.RLock() + defer s.RUnlock() + + for _, su := range s.cachedUsers { + users = append(users, s.convertUser(su)) + } + + return users, nil +} + +func (s *UserStore) LoginRequired(ctx context.Context) bool { + return len(s.cachedUsers) > 0 +} + +func hashPassword(password string) string { + hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost) + + return string(hash) +} + +func (s *UserStore) ValidateCredentials(ctx context.Context, username string, password string) bool { + u := s.getUser(username) + if u == nil { + return false + } + + err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) + + return err == nil +} + +func (s *UserStore) saveUsers() error { + // convert to list + users := make([]StoredUser, 0, len(s.cachedUsers)) + for _, u := range s.cachedUsers { + users = append(users, u) + } + + s.setInterfaceNoLock(Users, users) + return s.writeNoLock() +} + +func (s *UserStore) ChangeUserPassword(ctx context.Context, username string, newPassword string) error { + s.Lock() + defer s.Unlock() + + u := s.getUser(username) + if u == nil { + return fmt.Errorf("user not found") + } + + newHash := hashPassword(newPassword) + + updatedUser := *u + updatedUser.PasswordHash = newHash + s.cachedUsers[username] = updatedUser + + return s.saveUsers() +} + +func (s *UserStore) CreateUser(ctx context.Context, u models.User, password string) error { + s.Lock() + defer s.Unlock() + + existingUser := s.getUser(u.Username) + if existingUser != nil { + return fmt.Errorf("user already exists") + } + + newUser := StoredUser{ + Username: u.Username, + PasswordHash: hashPassword(password), + Roles: u.Roles, + } + + s.cachedUsers[u.Username] = newUser + + return s.saveUsers() +} + +func (s *UserStore) ReplaceUser(ctx context.Context, username string, updated models.User) error { + s.Lock() + defer s.Unlock() + + existingUser := s.getUser(username) + if existingUser == nil { + return fmt.Errorf("user not found") + } + + updatedUser := StoredUser{ + Username: updated.Username, + PasswordHash: existingUser.PasswordHash, + Roles: updated.Roles, + } + + // if username changed, remove old entry + if username != updated.Username { + delete(s.cachedUsers, username) + } + + s.cachedUsers[updated.Username] = updatedUser + + return s.saveUsers() +} + +func (s *UserStore) DeleteUser(ctx context.Context, username string) error { + s.Lock() + defer s.Unlock() + + existingUser := s.getUser(username) + if existingUser == nil { + return fmt.Errorf("user not found") + } + + delete(s.cachedUsers, username) + + return s.saveUsers() +} diff --git a/internal/manager/init.go b/internal/manager/init.go index 2f0f60a42..e1cfb3bbf 100644 --- a/internal/manager/init.go +++ b/internal/manager/init.go @@ -110,8 +110,8 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) { scanSubs: &subscriptionManager{}, } - instance.UserService = &user.Service{ - Store: cfg, + mgr.UserService = &user.Service{ + Store: cfg.UserStore, } if !cfg.IsNewSystem() { @@ -135,7 +135,7 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) { // create temporary session store - this will be re-initialised // after config is complete - mgr.SessionStore = session.NewStore(cfg, cfg) + mgr.SessionStore = session.NewStore(cfg, instance.UserService) logger.Warnf("config file %snot found. Assuming new system...", cfgFile) } @@ -194,7 +194,7 @@ func initJobManager(cfg *config.Config) *job.Manager { func (s *Manager) postInit(ctx context.Context) error { s.RefreshConfig() - s.SessionStore = session.NewStore(s.Config, s.Config) + s.SessionStore = session.NewStore(s.Config, s.UserService) s.PluginCache.RegisterSessionStore(s.SessionStore) s.RefreshPluginCache() @@ -256,7 +256,7 @@ func (s *Manager) postInit(ctx context.Context) error { } func (s *Manager) checkSecurityTripwire() { - if err := session.CheckExternalAccessTripwire(s.Config); err != nil { + if err := session.CheckExternalAccessTripwire(s.Config.UserStore, s.Config); err != nil { session.LogExternalAccessError(*err) } } diff --git a/internal/manager/repository.go b/internal/manager/repository.go index d017dc23b..c71b51daa 100644 --- a/internal/manager/repository.go +++ b/internal/manager/repository.go @@ -7,7 +7,7 @@ import ( "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" - "github.com/stashapp/stash/pkg/user" + "github.com/stashapp/stash/pkg/session" ) type SceneService interface { @@ -49,5 +49,14 @@ type GroupService interface { } type UserService interface { - GetUser(ctx context.Context, username string) (*user.User, error) + session.Authenticator + AllUsers(ctx context.Context) ([]*models.User, error) + GetUser(ctx context.Context, username string) (*models.User, error) + LoginRequired(ctx context.Context) bool + + CreateUser(ctx context.Context, u models.User, password string) error + UpdateUser(ctx context.Context, username string, updated models.User) error + ChangePassword(ctx context.Context, username, existingPassword, newPassword string) error + ChangeUserPassword(ctx context.Context, username string, newPassword string) error + DeleteUser(ctx context.Context, username string) error } diff --git a/pkg/models/model_user.go b/pkg/models/model_user.go new file mode 100644 index 000000000..d885382db --- /dev/null +++ b/pkg/models/model_user.go @@ -0,0 +1,12 @@ +package models + +type User struct { + Username string + Roles Roles +} + +type UserInput struct { + Username string + Roles Roles + Password string +} diff --git a/pkg/user/role.go b/pkg/models/role.go similarity index 82% rename from pkg/user/role.go rename to pkg/models/role.go index e411cfdee..52a162a63 100644 --- a/pkg/user/role.go +++ b/pkg/models/role.go @@ -1,4 +1,4 @@ -package user +package models import ( "fmt" @@ -58,15 +58,13 @@ func (e RoleEnum) MarshalGQL(w io.Writer) { fmt.Fprint(w, strconv.Quote(e.String())) } -func IsRole(assignedRoles []RoleEnum, requiredRole RoleEnum) bool { - valid := false +type Roles []RoleEnum - for _, role := range assignedRoles { - if role.Implies(requiredRole) { - valid = true - break +func (r Roles) HasRole(role RoleEnum) bool { + for _, r := range r { + if r.Implies(role) { + return true } } - - return valid + return false } diff --git a/pkg/session/authentication.go b/pkg/session/authentication.go index 8fb39c099..eb6878957 100644 --- a/pkg/session/authentication.go +++ b/pkg/session/authentication.go @@ -16,8 +16,8 @@ func (e ExternalAccessError) Error() string { return fmt.Sprintf("stash accessed from external IP %s", net.IP(e).String()) } -func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error { - if hc, _ := c.HasCredentials(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() { +func CheckAllowPublicWithoutAuth(s CredentialStore, c ExternalAccessConfig, r *http.Request) error { + if hc := s.LoginRequired(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() { requestIPString, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return fmt.Errorf("error parsing remote host (%s): %w", r.RemoteAddr, err) @@ -60,8 +60,8 @@ func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error return nil } -func CheckExternalAccessTripwire(c ExternalAccessConfig) *ExternalAccessError { - if hc, _ := c.HasCredentials(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() { +func CheckExternalAccessTripwire(s CredentialStore, c ExternalAccessConfig) *ExternalAccessError { + if hc := s.LoginRequired(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() { if remoteIP := c.GetSecurityTripwireAccessedFromPublicInternet(); remoteIP != "" { err := ExternalAccessError(net.ParseIP(remoteIP)) return &err diff --git a/pkg/session/config.go b/pkg/session/config.go index 372fc3a6a..98a92cd97 100644 --- a/pkg/session/config.go +++ b/pkg/session/config.go @@ -3,12 +3,15 @@ package session import "context" type ExternalAccessConfig interface { - HasCredentials(ctx context.Context) (bool, error) GetDangerousAllowPublicWithoutAuth() bool GetSecurityTripwireAccessedFromPublicInternet() string IsNewSystem() bool } +type CredentialStore interface { + LoginRequired(ctx context.Context) bool +} + type SessionConfig interface { GetUsername() string GetAPIKey() string diff --git a/pkg/session/session.go b/pkg/session/session.go index 997a71198..50de79fbb 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -8,7 +8,7 @@ import ( "github.com/gorilla/sessions" "github.com/stashapp/stash/pkg/logger" - "github.com/stashapp/stash/pkg/user" + "github.com/stashapp/stash/pkg/models" ) type key int @@ -46,7 +46,8 @@ func (e InvalidCredentialsError) Error() string { var ErrUnauthorized = errors.New("unauthorized") type Authenticator interface { - ValidateCredentials(username string, password string) bool + LoginRequired(ctx context.Context) bool + ValidateCredentials(ctx context.Context, username string, password string) error } type Store struct { @@ -68,6 +69,10 @@ func NewStore(c SessionConfig, a Authenticator) *Store { return ret } +func (s *Store) LoginRequired(ctx context.Context) bool { + return s.authenticator.LoginRequired(ctx) +} + func (s *Store) Login(w http.ResponseWriter, r *http.Request) error { // ignore error - we want a new session regardless newSession, _ := s.sessionStore.Get(r, cookieName) @@ -76,16 +81,16 @@ func (s *Store) Login(w http.ResponseWriter, r *http.Request) error { password := r.FormValue(passwordFormKey) // authenticate the user - if !s.authenticator.ValidateCredentials(username, password) { + err := s.authenticator.ValidateCredentials(r.Context(), username, password) + if err != nil { return &InvalidCredentialsError{Username: username} } - // since we only have one user, don't leak the name - logger.Info("User logged in") + logger.Infof("User %s logged in", username) newSession.Values[userIDKey] = username - err := newSession.Save(r, w) + err = newSession.Save(r, w) if err != nil { return err } @@ -99,6 +104,8 @@ func (s *Store) Logout(w http.ResponseWriter, r *http.Request) error { return err } + userID, _ := session.Values[userIDKey].(string) + delete(session.Values, userIDKey) session.Options.MaxAge = -1 @@ -107,8 +114,7 @@ func (s *Store) Logout(w http.ResponseWriter, r *http.Request) error { return err } - // since we only have one user, don't leak the name - logger.Infof("User logged out") + logger.Infof("User %s logged out", userID) return nil } @@ -138,15 +144,15 @@ func (s *Store) GetSessionUserID(w http.ResponseWriter, r *http.Request) (string return "", nil } -func SetCurrentUser(ctx context.Context, u user.User) context.Context { +func SetCurrentUser(ctx context.Context, u models.User) context.Context { return context.WithValue(ctx, contextUser, u) } // GetCurrentUser gets the current user id from the provided context -func GetCurrentUser(ctx context.Context) *user.User { +func GetCurrentUser(ctx context.Context) *models.User { userCtxVal := ctx.Value(contextUser) if userCtxVal != nil { - currentUser := userCtxVal.(user.User) + currentUser := userCtxVal.(models.User) return ¤tUser } @@ -164,6 +170,7 @@ func (s *Store) Authenticate(w http.ResponseWriter, r *http.Request) (userID str apiKey = r.URL.Query().Get(ApiKeyParameter) } + // FIXME - handle this if apiKey != "" { // match against configured API and set userID to the // configured username. In future, we'll want to diff --git a/pkg/user/authenticate.go b/pkg/user/authenticate.go index d3aa48f4a..a00006b65 100644 --- a/pkg/user/authenticate.go +++ b/pkg/user/authenticate.go @@ -1,44 +1 @@ package user - -import ( - "context" - "errors" - "fmt" - - "golang.org/x/crypto/bcrypt" -) - -var ( - ErrPasswordMismatch = fmt.Errorf("password mismatch") - ErrUserNotFound = fmt.Errorf("user not found") -) - -func (s *Service) ValidateCredentials(ctx context.Context, username string, password string) error { - hc, err := s.Store.HasCredentials(ctx) - if err != nil { - return fmt.Errorf("error checking if credentials exist: %w", err) - } - - if !hc { - // don't need to authenticate if no credentials saved - return nil - } - - authPWHash, err := s.Store.GetPasswordHash(ctx, username) - if err != nil { - if errors.Is(err, ErrUserNotFound) { - return err - } - return fmt.Errorf("error getting password hash: %w", err) - } - - if err := bcrypt.CompareHashAndPassword([]byte(authPWHash), []byte(password)); err != nil { - if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { - return ErrPasswordMismatch - } - - return fmt.Errorf("error comparing password hash: %w", err) - } - - return nil -} diff --git a/pkg/user/service.go b/pkg/user/service.go index d918fed0c..295a11eb4 100644 --- a/pkg/user/service.go +++ b/pkg/user/service.go @@ -2,18 +2,266 @@ package user import ( "context" + "errors" + "fmt" + "slices" + "strings" + + "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/models" ) -type Store interface { - GetUser(ctx context.Context, username string) (*User, error) - HasCredentials(ctx context.Context) (bool, error) - GetPasswordHash(ctx context.Context, username string) (string, error) +var ( + ErrUserNotExist = errors.New("user not found") + ErrEmptyUsername = errors.New("empty username") + ErrUsernameHasWhitespace = errors.New("username has leading or trailing whitespace") + ErrDeleteLastAdminUser = errors.New("final admin user cannot be deleted") + ErrInternalError = errors.New("internal error") + ErrAccessDenied = errors.New("access denied") + ErrCurrentPasswordIncorrect = errors.New("current password incorrect") + ErrUserAlreadyExists = errors.New("user with that username already exists") +) + +type UserSource interface { + AllUsers(ctx context.Context) ([]*models.User, error) + GetUser(ctx context.Context, username string) (*models.User, error) + ValidateCredentials(ctx context.Context, username string, password string) bool + + CreateUser(ctx context.Context, u models.User, password string) error + ReplaceUser(ctx context.Context, username string, updated models.User) error + ChangeUserPassword(ctx context.Context, username string, newPassword string) error + DeleteUser(ctx context.Context, username string) error } type Service struct { - Store Store + Store UserSource } -func (s *Service) GetUser(ctx context.Context, username string) (*User, error) { +func (s *Service) LoginRequired(ctx context.Context) bool { + u, _ := s.Store.AllUsers(ctx) + return len(u) > 0 +} + +func (s *Service) GetUser(ctx context.Context, username string) (*models.User, error) { return s.Store.GetUser(ctx, username) } + +func (s *Service) AllUsers(ctx context.Context) ([]*models.User, error) { + return s.Store.AllUsers(ctx) +} + +func (s *Service) ValidateCredentials(ctx context.Context, username string, password string) error { + // ensure user is not locked + u, err := s.GetUser(ctx, username) + if err != nil { + logger.Errorf("error getting user for credential validation: %v", err) + return ErrInternalError + } + + if u == nil { + logger.Infof("[login attempt] user %s not found during credential validation", username) + return ErrAccessDenied + } + + if len(u.Roles) == 0 { + logger.Infof("[login attempt] user %s is locked", username) + return ErrAccessDenied + } + + if !s.Store.ValidateCredentials(ctx, username, password) { + logger.Infof("[login attempt] invalid credentials for user %s", username) + return ErrAccessDenied + } + return nil +} + +func (s *Service) validateUsername(username string) error { + if username == "" { + return ErrEmptyUsername + } + + // username must not have leading or trailing whitespace + trimmed := strings.TrimSpace(username) + + if trimmed != username { + return ErrUsernameHasWhitespace + } + + return nil +} + +func (s *Service) validatePassword(password string) error { + if password == "" { + return errors.New("password cannot be empty") + } + + // add more password validation as needed + + return nil +} + +func (s *Service) CreateUser(ctx context.Context, u models.User, password string) error { + // validate input + // ensure username is valid + if err := s.validateUsername(u.Username); err != nil { + return err + } + + // check if user exists + existingUser, err := s.GetUser(ctx, u.Username) + if err != nil { + return fmt.Errorf("error checking existing users: %w", err) + } + + if existingUser != nil { + return ErrUserAlreadyExists + } + + // validate password + if err := s.validatePassword(password); err != nil { + return err + } + + // if this is the first user, make them an admin + users, err := s.AllUsers(ctx) + if err != nil { + return fmt.Errorf("error getting existing users: %w", err) + } + + if len(users) == 0 && !u.Roles.HasRole(models.RoleEnumAdmin) { + return errors.New("the first user must be an admin") + } + + // create user in store + if err := s.Store.CreateUser(ctx, u, password); err != nil { + return fmt.Errorf("error creating user: %w", err) + } + + logger.Infof("[user] created %q", u.Username) + + return nil +} + +func (s *Service) UpdateUser(ctx context.Context, username string, updated models.User) error { + // validate input + // check if user exists + existingUser, err := s.GetUser(ctx, username) + if err != nil { + return fmt.Errorf("error getting existing user: %w", err) + } + + if existingUser == nil { + return ErrUserNotExist + } + + existingRoles := existingUser.Roles + + // ensure username is valid + if username != updated.Username { + if err := s.validateUsername(updated.Username); err != nil { + return err + } + + // ensure new username doesn't already exist + otherUser, err := s.GetUser(ctx, updated.Username) + if err != nil { + return fmt.Errorf("error checking existing user: %w", err) + } + + if otherUser != nil { + return ErrUserAlreadyExists + } + } + + // update user in store + if err := s.Store.ReplaceUser(ctx, username, updated); err != nil { + return fmt.Errorf("error updating user: %w", err) + } + + if username != updated.Username { + logger.Infof("[user] updated name %q -> %q", username, updated.Username) + } + + if !slices.Equal(existingRoles, updated.Roles) { + logger.Infof("[user] updated roles for user %q", updated.Username) + } + + return nil +} + +func (s *Service) ChangePassword(ctx context.Context, username, currentPassword, newPassword string) error { + // validate current credentials + if err := s.ValidateCredentials(ctx, username, currentPassword); err != nil { + logger.Infof("[user] failed password change attempt for %q: incorrect current password", username) + return ErrCurrentPasswordIncorrect + } + + return s.ChangeUserPassword(ctx, username, newPassword) +} + +func (s *Service) ChangeUserPassword(ctx context.Context, username, newPassword string) error { + // check if user exists + existingUser, err := s.GetUser(ctx, username) + if err != nil { + return fmt.Errorf("error getting existing user: %w", err) + } + + if existingUser == nil { + return ErrUserNotExist + } + + // validate new password + if err := s.validatePassword(newPassword); err != nil { + return err + } + + // change password in store + if err := s.Store.ChangeUserPassword(ctx, username, newPassword); err != nil { + return fmt.Errorf("error changing user password: %w", err) + } + + logger.Infof("[user] changed password for %q", username) + + return nil +} + +func (s *Service) DeleteUser(ctx context.Context, username string) error { + // check if user exists + existingUser, err := s.GetUser(ctx, username) + if err != nil { + return fmt.Errorf("error getting existing user: %w", err) + } + + if existingUser == nil { + return ErrUserNotExist + } + + // don't allow deleting last admin user + if existingUser.Roles.HasRole(models.RoleEnumAdmin) { + users, err := s.AllUsers(ctx) + if err != nil { + return fmt.Errorf("error getting all users: %w", err) + } + + hasAdmin := false + for _, u := range users { + if u.Username != username && u.Roles.HasRole(models.RoleEnumAdmin) { + hasAdmin = true + break + } + } + + if !hasAdmin { + return ErrDeleteLastAdminUser + } + } + + // delete user from store + if err := s.Store.DeleteUser(ctx, username); err != nil { + return fmt.Errorf("error deleting user: %w", err) + } + + logger.Infof("[user] deleted %q", username) + + return nil +} diff --git a/pkg/user/user.go b/pkg/user/user.go deleted file mode 100644 index 7971b5f47..000000000 --- a/pkg/user/user.go +++ /dev/null @@ -1,6 +0,0 @@ -package user - -type User struct { - Username string - Roles []RoleEnum -}