mirror of
https://github.com/stashapp/stash.git
synced 2026-02-07 16:05:47 +01:00
Implement config file based user management
This commit is contained in:
parent
2a2351fcdc
commit
aa7107a242
27 changed files with 751 additions and 213 deletions
|
|
@ -126,7 +126,7 @@ models:
|
|||
ScraperSource:
|
||||
model: github.com/stashapp/stash/pkg/scraper.Source
|
||||
RoleEnum:
|
||||
model: github.com/stashapp/stash/pkg/user.RoleEnum
|
||||
model: github.com/stashapp/stash/pkg/models.RoleEnum
|
||||
IdentifySourceInput:
|
||||
model: github.com/stashapp/stash/internal/identify.Source
|
||||
IdentifyFieldOptionsInput:
|
||||
|
|
|
|||
|
|
@ -269,6 +269,7 @@ type Query {
|
|||
allStudios: [Studio!]! @deprecated(reason: "Use findStudios instead")
|
||||
allMovies: [Movie!]! @deprecated(reason: "Use findGroups instead")
|
||||
|
||||
users: [User!]! @hasRole(role: ADMIN)
|
||||
"""Returns currently authenticated user"""
|
||||
me: User
|
||||
|
||||
|
|
@ -593,6 +594,7 @@ type Mutation {
|
|||
userCreate(input: UserCreateInput!): User @hasRole(role: ADMIN)
|
||||
userUpdate(input: UserUpdateInput!): User @hasRole(role: ADMIN)
|
||||
userDestroy(input: UserDestroyInput!): Boolean! @hasRole(role: ADMIN)
|
||||
changeUserPassword(input: ChangeUserPasswordInput!): Boolean! @hasRole(role: ADMIN)
|
||||
changePassword(input: UserChangePasswordInput!): Boolean!
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,10 @@ directive @isUserOwner on FIELD_DEFINITION
|
|||
|
||||
type User {
|
||||
name: String!
|
||||
"""Should not be visible to other users"""
|
||||
"""
|
||||
If the user has no roles, they are considered locked and cannot log in.
|
||||
Should not be visible to other users
|
||||
"""
|
||||
roles: [RoleEnum!] @isUserOwner
|
||||
"""Should not be visible to other users"""
|
||||
api_key: String @isUserOwner
|
||||
|
|
@ -23,19 +26,22 @@ input UserCreateInput {
|
|||
}
|
||||
|
||||
input UserUpdateInput {
|
||||
name: String
|
||||
"""Password in plain text"""
|
||||
password: String
|
||||
roles: [RoleEnum!]
|
||||
existingName: String!
|
||||
name: String!
|
||||
roles: [RoleEnum!]!
|
||||
}
|
||||
|
||||
input UserDestroyInput {
|
||||
id: ID!
|
||||
name: String!
|
||||
}
|
||||
|
||||
input UserChangePasswordInput {
|
||||
"""Password in plain text"""
|
||||
existing_password: String
|
||||
new_password: String!
|
||||
reset_key: String
|
||||
existingPassword: String!
|
||||
newPassword: String!
|
||||
}
|
||||
|
||||
input ChangeUserPasswordInput {
|
||||
name: String!
|
||||
newPassword: String!
|
||||
}
|
||||
|
|
@ -12,8 +12,8 @@ import (
|
|||
"github.com/stashapp/stash/internal/manager"
|
||||
"github.com/stashapp/stash/internal/manager/config"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
"github.com/stashapp/stash/pkg/user"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -32,16 +32,17 @@ func allowUnauthenticated(r *http.Request) bool {
|
|||
}
|
||||
|
||||
type UserGetter interface {
|
||||
GetUser(ctx context.Context, username string) (*user.User, error)
|
||||
GetUser(ctx context.Context, username string) (*models.User, error)
|
||||
}
|
||||
|
||||
func authenticateHandler(g UserGetter) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c := config.GetInstance()
|
||||
s := c.UserStore
|
||||
|
||||
// error if external access tripwire activated
|
||||
if accessErr := session.CheckExternalAccessTripwire(c); accessErr != nil {
|
||||
if accessErr := session.CheckExternalAccessTripwire(s, c); accessErr != nil {
|
||||
http.Error(w, tripwireActivatedErrMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
|
@ -59,7 +60,9 @@ func authenticateHandler(g UserGetter) func(http.Handler) http.Handler {
|
|||
return
|
||||
}
|
||||
|
||||
if err := session.CheckAllowPublicWithoutAuth(c, r); err != nil {
|
||||
ctx := r.Context()
|
||||
|
||||
if err := session.CheckAllowPublicWithoutAuth(s, c, r); err != nil {
|
||||
var accessErr session.ExternalAccessError
|
||||
if errors.As(err, &accessErr) {
|
||||
session.LogExternalAccessError(accessErr)
|
||||
|
|
@ -77,11 +80,23 @@ func authenticateHandler(g UserGetter) func(http.Handler) http.Handler {
|
|||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
var u *models.User
|
||||
if userID != "" {
|
||||
u, err = g.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
// if we can't get the user object, we just return a forbidden error
|
||||
logger.Errorf("Error getting user object: %v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if u == nil {
|
||||
logger.Errorf("[User] cookie user %q not found", userID)
|
||||
}
|
||||
}
|
||||
|
||||
if hc, _ := c.HasCredentials(ctx); hc {
|
||||
if hc := s.LoginRequired(ctx); hc {
|
||||
// authentication is required
|
||||
if userID == "" && !allowUnauthenticated(r) {
|
||||
if u == nil && !allowUnauthenticated(r) {
|
||||
// if graphql or a non-webpage was requested, we just return a forbidden error
|
||||
ext := path.Ext(r.URL.Path)
|
||||
if r.URL.Path == gqlEndpoint || (ext != "" && ext != ".html") {
|
||||
|
|
@ -108,16 +123,8 @@ func authenticateHandler(g UserGetter) func(http.Handler) http.Handler {
|
|||
}
|
||||
}
|
||||
|
||||
if userID != "" {
|
||||
if u != nil {
|
||||
// set the user object in the context
|
||||
u, err := g.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
// if we can't get the user object, we just return a forbidden error
|
||||
logger.Errorf("Error getting user object: %v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ctx = session.SetCurrentUser(ctx, *u)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
"github.com/stashapp/stash/pkg/user"
|
||||
)
|
||||
|
||||
func HasRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role user.RoleEnum) (interface{}, error) {
|
||||
func HasRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role models.RoleEnum) (interface{}, error) {
|
||||
currentUser := session.GetCurrentUser(ctx)
|
||||
|
||||
// if there is no current user, this is an anonymous request
|
||||
|
|
@ -17,7 +17,29 @@ func HasRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolve
|
|||
return next(ctx)
|
||||
}
|
||||
|
||||
if currentUser != nil && !user.IsRole(currentUser.Roles, role) {
|
||||
if currentUser != nil && !currentUser.Roles.HasRole(role) {
|
||||
return nil, session.ErrUnauthorized
|
||||
}
|
||||
|
||||
return next(ctx)
|
||||
}
|
||||
|
||||
func IsUserOwnerDirective(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) {
|
||||
currentUser := session.GetCurrentUser(ctx)
|
||||
|
||||
// if there is no current user, this is an anonymous request
|
||||
// we should not end up here unless there are no credentials required
|
||||
if currentUser == nil {
|
||||
return next(ctx)
|
||||
}
|
||||
|
||||
// get the user from the object
|
||||
userObj, ok := obj.(*models.User)
|
||||
if !ok {
|
||||
return nil, session.ErrUnauthorized
|
||||
}
|
||||
|
||||
if currentUser.Username != userObj.Username {
|
||||
return nil, session.ErrUnauthorized
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ type Resolver struct {
|
|||
imageService manager.ImageService
|
||||
galleryService manager.GalleryService
|
||||
groupService manager.GroupService
|
||||
userService manager.UserService
|
||||
|
||||
hookExecutor hookExecutor
|
||||
}
|
||||
|
|
@ -110,6 +111,9 @@ func (r *Resolver) Plugin() PluginResolver {
|
|||
func (r *Resolver) ConfigResult() ConfigResultResolver {
|
||||
return &configResultResolver{r}
|
||||
}
|
||||
func (r *Resolver) User() UserResolver {
|
||||
return &userResolver{r}
|
||||
}
|
||||
|
||||
type mutationResolver struct{ *Resolver }
|
||||
type queryResolver struct{ *Resolver }
|
||||
|
|
@ -136,6 +140,7 @@ type folderResolver struct{ *Resolver }
|
|||
type savedFilterResolver struct{ *Resolver }
|
||||
type pluginResolver struct{ *Resolver }
|
||||
type configResultResolver struct{ *Resolver }
|
||||
type userResolver struct{ *Resolver }
|
||||
|
||||
func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
|
||||
return r.repository.WithTxn(ctx, fn)
|
||||
|
|
|
|||
23
internal/api/resolver_model_user.go
Normal file
23
internal/api/resolver_model_user.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
||||
func (r *userResolver) Name(ctx context.Context, obj *models.User) (string, error) {
|
||||
return obj.Username, nil
|
||||
}
|
||||
|
||||
func (r *userResolver) Roles(ctx context.Context, obj *models.User) ([]models.RoleEnum, error) {
|
||||
ret := make([]models.RoleEnum, len(obj.Roles))
|
||||
for i, role := range obj.Roles {
|
||||
ret[i] = models.RoleEnum(role)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (r *userResolver) APIKey(ctx context.Context, obj *models.User) (*string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
@ -320,7 +320,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen
|
|||
if input.Password != nil {
|
||||
// bit of a hack - check if the passed in password is the same as the stored hash
|
||||
// and only set if they are different
|
||||
currentPWHash, _ := c.GetPasswordHash(ctx, c.GetUsername())
|
||||
currentPWHash := c.GetPasswordHash()
|
||||
|
||||
if *input.Password != currentPWHash {
|
||||
if *input.Password == "" {
|
||||
|
|
|
|||
62
internal/api/resolver_mutation_user.go
Normal file
62
internal/api/resolver_mutation_user.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
)
|
||||
|
||||
func (r *mutationResolver) UserCreate(ctx context.Context, input UserCreateInput) (*models.User, error) {
|
||||
err := r.userService.CreateUser(ctx, models.User{
|
||||
Username: input.Name,
|
||||
Roles: models.Roles(input.Roles),
|
||||
}, input.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r.userService.GetUser(ctx, input.Name)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) UserUpdate(ctx context.Context, input UserUpdateInput) (*models.User, error) {
|
||||
err := r.userService.UpdateUser(ctx, input.ExistingName, models.User{
|
||||
Username: input.Name,
|
||||
Roles: models.Roles(input.Roles),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r.userService.GetUser(ctx, input.Name)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) UserDestroy(ctx context.Context, input UserDestroyInput) (bool, error) {
|
||||
err := r.userService.DeleteUser(ctx, input.Name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) ChangePassword(ctx context.Context, input UserChangePasswordInput) (bool, error) {
|
||||
// get current user
|
||||
u := session.GetCurrentUser(ctx)
|
||||
|
||||
err := r.userService.ChangePassword(ctx, u.Username, input.ExistingPassword, input.NewPassword)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) ChangeUserPassword(ctx context.Context, input ChangeUserPasswordInput) (bool, error) {
|
||||
err := r.userService.ChangeUserPassword(ctx, input.Name, input.NewPassword)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
|
@ -78,9 +78,6 @@ func makeConfigGeneralResult() *ConfigGeneralResult {
|
|||
|
||||
customPerformerImageLocation := config.GetCustomPerformerImageLocation()
|
||||
|
||||
username := config.GetUsername()
|
||||
pwHash, _ := config.GetPasswordHash(context.Background(), username)
|
||||
|
||||
return &ConfigGeneralResult{
|
||||
Stashes: config.GetStashPaths(),
|
||||
DatabasePath: config.GetDatabasePath(),
|
||||
|
|
@ -113,7 +110,7 @@ func makeConfigGeneralResult() *ConfigGeneralResult {
|
|||
GalleryCoverRegex: config.GetGalleryCoverRegex(),
|
||||
APIKey: config.GetAPIKey(),
|
||||
Username: config.GetUsername(),
|
||||
Password: pwHash,
|
||||
Password: config.GetPasswordHash(),
|
||||
MaxSessionAge: config.GetMaxSessionAge(),
|
||||
LogFile: &logFile,
|
||||
LogOut: config.GetLogOut(),
|
||||
|
|
|
|||
17
internal/api/resolver_query_user.go
Normal file
17
internal/api/resolver_query_user.go
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
)
|
||||
|
||||
func (r *queryResolver) Users(ctx context.Context) ([]*models.User, error) {
|
||||
return r.userService.AllUsers(ctx)
|
||||
}
|
||||
|
||||
func (r *queryResolver) Me(ctx context.Context) (*models.User, error) {
|
||||
// get current user
|
||||
return session.GetCurrentUser(ctx), nil
|
||||
}
|
||||
|
|
@ -164,19 +164,22 @@ func Initialize() (*Server, error) {
|
|||
imageService := mgr.ImageService
|
||||
galleryService := mgr.GalleryService
|
||||
groupService := mgr.GroupService
|
||||
userService := mgr.UserService
|
||||
resolver := &Resolver{
|
||||
repository: repo,
|
||||
sceneService: sceneService,
|
||||
imageService: imageService,
|
||||
galleryService: galleryService,
|
||||
groupService: groupService,
|
||||
userService: userService,
|
||||
hookExecutor: pluginCache,
|
||||
}
|
||||
|
||||
gqlCfg := Config{
|
||||
Resolvers: resolver,
|
||||
Directives: DirectiveRoot{
|
||||
HasRole: HasRoleDirective,
|
||||
HasRole: HasRoleDirective,
|
||||
IsUserOwner: IsUserOwnerDirective,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -236,9 +239,11 @@ func Initialize() (*Server, error) {
|
|||
|
||||
staticLoginUI := statigz.FileServer(ui.LoginUIBox.(fs.ReadDirFS))
|
||||
|
||||
r.Get(loginEndpoint, handleLogin())
|
||||
r.Post(loginEndpoint, handleLoginPost())
|
||||
r.Get(logoutEndpoint, handleLogout())
|
||||
sessionStore := mgr.SessionStore
|
||||
|
||||
r.Get(loginEndpoint, handleLogin(userService))
|
||||
r.Post(loginEndpoint, handleLoginPost(sessionStore))
|
||||
r.Get(logoutEndpoint, handleLogout(sessionStore))
|
||||
r.Get(loginLocaleEndpoint, handleLoginLocale(cfg))
|
||||
r.HandleFunc(loginEndpoint+"/*", func(w http.ResponseWriter, r *http.Request) {
|
||||
r.URL.Path = strings.TrimPrefix(r.URL.Path, loginEndpoint)
|
||||
|
|
|
|||
|
|
@ -103,11 +103,11 @@ func getLoginLocale(lang string) ([]byte, error) {
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func handleLogin() http.HandlerFunc {
|
||||
func handleLogin(s manager.UserService) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
returnURL := r.URL.Query().Get(returnURLParam)
|
||||
|
||||
if hc, _ := config.GetInstance().HasCredentials(r.Context()); !hc {
|
||||
if hc := s.LoginRequired(r.Context()); !hc {
|
||||
if returnURL != "" {
|
||||
http.Redirect(w, r, returnURL, http.StatusFound)
|
||||
} else {
|
||||
|
|
@ -121,9 +121,9 @@ func handleLogin() http.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func handleLoginPost() http.HandlerFunc {
|
||||
func handleLoginPost(s *session.Store) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
err := manager.GetInstance().SessionStore.Login(w, r)
|
||||
err := s.Login(w, r)
|
||||
if err != nil {
|
||||
// always log the error
|
||||
logger.Errorf("Error logging in: %v from IP: %s", err, r.RemoteAddr)
|
||||
|
|
@ -146,7 +146,7 @@ func handleLoginPost() http.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func handleLogout() http.HandlerFunc {
|
||||
func handleLogout(s *session.Store) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := manager.GetInstance().SessionStore.Logout(w, r); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
|
@ -155,7 +155,7 @@ func handleLogout() http.HandlerFunc {
|
|||
|
||||
// redirect to the login page if credentials are required
|
||||
prefix := getProxyPrefix(r)
|
||||
if hc, _ := config.GetInstance().HasCredentials(r.Context()); hc {
|
||||
if hc := s.LoginRequired(r.Context()); hc {
|
||||
http.Redirect(w, r, prefix+loginEndpoint, http.StatusFound)
|
||||
} else {
|
||||
http.Redirect(w, r, prefix+"/", http.StatusFound)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
|
|
@ -15,8 +14,6 @@ import (
|
|||
"sync"
|
||||
// "github.com/sasha-s/go-deadlock" // if you have deadlock issues
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/knadh/koanf/parsers/yaml"
|
||||
"github.com/knadh/koanf/providers/file"
|
||||
"github.com/knadh/koanf/v2"
|
||||
|
|
@ -28,7 +25,6 @@ import (
|
|||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/models/paths"
|
||||
"github.com/stashapp/stash/pkg/sliceutil"
|
||||
"github.com/stashapp/stash/pkg/user"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
|
|
@ -41,9 +37,8 @@ const (
|
|||
BlobsPath = "blobs_path"
|
||||
Downloads = "downloads"
|
||||
ApiKey = "api_key"
|
||||
Username = "username"
|
||||
Password = "password"
|
||||
MaxSessionAge = "max_session_age"
|
||||
|
||||
MaxSessionAge = "max_session_age"
|
||||
|
||||
// SFWContentMode mode config key
|
||||
SFWContentMode = "sfw_content_mode"
|
||||
|
|
@ -334,6 +329,9 @@ type Config struct {
|
|||
// configUpdates chan int
|
||||
certFile string
|
||||
keyFile string
|
||||
|
||||
UserStore *UserStore
|
||||
|
||||
sync.RWMutex
|
||||
// deadlock.RWMutex // for deadlock testing/issues
|
||||
}
|
||||
|
|
@ -431,6 +429,9 @@ func (i *Config) SetInterface(key string, value interface{}) {
|
|||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
i.setInterfaceNoLock(key, value)
|
||||
}
|
||||
func (i *Config) setInterfaceNoLock(key string, value interface{}) {
|
||||
i.set(key, value)
|
||||
}
|
||||
|
||||
|
|
@ -480,6 +481,10 @@ func (i *Config) Write() error {
|
|||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
return i.writeNoLock()
|
||||
}
|
||||
|
||||
func (i *Config) writeNoLock() error {
|
||||
data, err := i.marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -1078,78 +1083,6 @@ func (i *Config) GetAPIKey() string {
|
|||
return i.getString(ApiKey)
|
||||
}
|
||||
|
||||
func (i *Config) GetUsername() string {
|
||||
return i.getString(Username)
|
||||
}
|
||||
|
||||
func (i *Config) GetUser(ctx context.Context, username string) (*user.User, error) {
|
||||
// TODO - temp
|
||||
if username == "read" {
|
||||
return &user.User{
|
||||
Username: username,
|
||||
Roles: []user.RoleEnum{user.RoleEnumRead},
|
||||
}, nil
|
||||
}
|
||||
|
||||
u := i.GetUsername()
|
||||
if u != username {
|
||||
return nil, user.ErrUserNotFound
|
||||
}
|
||||
|
||||
return &user.User{
|
||||
Username: u,
|
||||
Roles: []user.RoleEnum{user.RoleEnumAdmin},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *Config) GetPasswordHash(ctx context.Context, username string) (string, error) {
|
||||
u := i.GetUsername()
|
||||
if u != username {
|
||||
return "", user.ErrUserNotFound
|
||||
}
|
||||
|
||||
return i.getString(Password), nil
|
||||
}
|
||||
|
||||
func (i *Config) GetCredentials() (string, string) {
|
||||
if hc, _ := i.HasCredentials(context.Background()); hc {
|
||||
return i.getString(Username), i.getString(Password)
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func (i *Config) HasCredentials(ctx context.Context) (bool, error) {
|
||||
username := i.getString(Username)
|
||||
pwHash := i.getString(Password)
|
||||
|
||||
return username != "" && pwHash != "", nil
|
||||
}
|
||||
|
||||
func hashPassword(password string) string {
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost)
|
||||
|
||||
return string(hash)
|
||||
}
|
||||
|
||||
func (i *Config) ValidateCredentials(username string, password string) bool {
|
||||
if hc, _ := i.HasCredentials(context.Background()); !hc {
|
||||
// don't need to authenticate if no credentials saved
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO - temp
|
||||
if username == "read" {
|
||||
return password == "read"
|
||||
}
|
||||
|
||||
authUser, authPWHash := i.GetCredentials()
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(authPWHash), []byte(password))
|
||||
|
||||
return username == authUser && err == nil
|
||||
}
|
||||
|
||||
func stashBoxValidate(str string) bool {
|
||||
u, err := url.Parse(str)
|
||||
return err == nil && u.Scheme != "" && u.Host != "" && strings.HasSuffix(u.Path, "/graphql")
|
||||
|
|
|
|||
|
|
@ -22,8 +22,6 @@ func TestConcurrentConfigAccess(t *testing.T) {
|
|||
t.Errorf("Failure setting initial configuration in worker %v iteration %v: %v", wk, l, err)
|
||||
}
|
||||
|
||||
i.HasCredentials()
|
||||
i.ValidateCredentials("", "")
|
||||
i.GetConfigFile()
|
||||
i.GetConfigPath()
|
||||
i.GetDefaultDatabaseFilePath()
|
||||
|
|
@ -75,7 +73,6 @@ func TestConcurrentConfigAccess(t *testing.T) {
|
|||
i.SetInterface(ApiKey, i.GetAPIKey())
|
||||
i.SetInterface(Username, i.GetUsername())
|
||||
i.SetInterface(Password, i.GetPasswordHash())
|
||||
i.GetCredentials()
|
||||
i.SetInterface(MaxSessionAge, i.GetMaxSessionAge())
|
||||
i.SetInterface(CustomServedFolders, i.GetCustomServedFolders())
|
||||
i.SetInterface(LegacyCustomUILocation, i.GetUILocation())
|
||||
|
|
|
|||
|
|
@ -62,6 +62,9 @@ func Initialize() (*Config, error) {
|
|||
main: koanf.New("."),
|
||||
overrides: koanf.New("."),
|
||||
}
|
||||
cfg.UserStore = &UserStore{
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
cfg.initOverrides()
|
||||
|
||||
|
|
@ -96,6 +99,10 @@ func Initialize() (*Config, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if err := cfg.UserStore.loadUsers(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load users: %v", err)
|
||||
}
|
||||
|
||||
instance = cfg
|
||||
return instance, nil
|
||||
}
|
||||
|
|
|
|||
227
internal/manager/config/users.go
Normal file
227
internal/manager/config/users.go
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
Username = "username"
|
||||
Password = "password"
|
||||
Users = "users"
|
||||
Roles = "roles"
|
||||
)
|
||||
|
||||
type StoredUser struct {
|
||||
Username string `json:"username" koanf:"username"`
|
||||
PasswordHash string `json:"passwordhash" koanf:"passwordhash"`
|
||||
Roles []models.RoleEnum `json:"roles" koanf:"roles"`
|
||||
}
|
||||
|
||||
type UserStore struct {
|
||||
*Config
|
||||
|
||||
cachedUsers map[string]StoredUser
|
||||
}
|
||||
|
||||
func (s *Config) GetUsername() string {
|
||||
return s.getString(Username)
|
||||
}
|
||||
|
||||
func (i *Config) GetPasswordHash() string {
|
||||
return i.getString(Password)
|
||||
}
|
||||
|
||||
func (s *UserStore) legacyUser() *StoredUser {
|
||||
un := s.getString(Username)
|
||||
pwHash := s.getString(Password)
|
||||
|
||||
if un != "" && pwHash != "" {
|
||||
return &StoredUser{
|
||||
Username: un,
|
||||
PasswordHash: pwHash,
|
||||
Roles: []models.RoleEnum{models.RoleEnumAdmin},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserStore) loadUsers() error {
|
||||
// done outside lock to avoid deadlock
|
||||
legacyUser := s.legacyUser()
|
||||
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
var ret []*StoredUser
|
||||
err := s.unmarshalKey(Users, &ret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add legacy username
|
||||
if legacyUser != nil {
|
||||
ret = append(ret, legacyUser)
|
||||
}
|
||||
|
||||
s.cachedUsers = make(map[string]StoredUser)
|
||||
for _, u := range ret {
|
||||
s.cachedUsers[u.Username] = *u
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserStore) convertUser(su StoredUser) *models.User {
|
||||
return &models.User{
|
||||
Username: su.Username,
|
||||
Roles: su.Roles,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserStore) getUser(username string) *StoredUser {
|
||||
u, ok := s.cachedUsers[username]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &u
|
||||
}
|
||||
|
||||
func (s *UserStore) GetUser(ctx context.Context, username string) (*models.User, error) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
u := s.getUser(username)
|
||||
if u == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return s.convertUser(*u), nil
|
||||
}
|
||||
|
||||
func (s *UserStore) AllUsers(ctx context.Context) ([]*models.User, error) {
|
||||
var users []*models.User
|
||||
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
for _, su := range s.cachedUsers {
|
||||
users = append(users, s.convertUser(su))
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (s *UserStore) LoginRequired(ctx context.Context) bool {
|
||||
return len(s.cachedUsers) > 0
|
||||
}
|
||||
|
||||
func hashPassword(password string) string {
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost)
|
||||
|
||||
return string(hash)
|
||||
}
|
||||
|
||||
func (s *UserStore) ValidateCredentials(ctx context.Context, username string, password string) bool {
|
||||
u := s.getUser(username)
|
||||
if u == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
|
||||
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (s *UserStore) saveUsers() error {
|
||||
// convert to list
|
||||
users := make([]StoredUser, 0, len(s.cachedUsers))
|
||||
for _, u := range s.cachedUsers {
|
||||
users = append(users, u)
|
||||
}
|
||||
|
||||
s.setInterfaceNoLock(Users, users)
|
||||
return s.writeNoLock()
|
||||
}
|
||||
|
||||
func (s *UserStore) ChangeUserPassword(ctx context.Context, username string, newPassword string) error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
u := s.getUser(username)
|
||||
if u == nil {
|
||||
return fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
newHash := hashPassword(newPassword)
|
||||
|
||||
updatedUser := *u
|
||||
updatedUser.PasswordHash = newHash
|
||||
s.cachedUsers[username] = updatedUser
|
||||
|
||||
return s.saveUsers()
|
||||
}
|
||||
|
||||
func (s *UserStore) CreateUser(ctx context.Context, u models.User, password string) error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
existingUser := s.getUser(u.Username)
|
||||
if existingUser != nil {
|
||||
return fmt.Errorf("user already exists")
|
||||
}
|
||||
|
||||
newUser := StoredUser{
|
||||
Username: u.Username,
|
||||
PasswordHash: hashPassword(password),
|
||||
Roles: u.Roles,
|
||||
}
|
||||
|
||||
s.cachedUsers[u.Username] = newUser
|
||||
|
||||
return s.saveUsers()
|
||||
}
|
||||
|
||||
func (s *UserStore) ReplaceUser(ctx context.Context, username string, updated models.User) error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
existingUser := s.getUser(username)
|
||||
if existingUser == nil {
|
||||
return fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
updatedUser := StoredUser{
|
||||
Username: updated.Username,
|
||||
PasswordHash: existingUser.PasswordHash,
|
||||
Roles: updated.Roles,
|
||||
}
|
||||
|
||||
// if username changed, remove old entry
|
||||
if username != updated.Username {
|
||||
delete(s.cachedUsers, username)
|
||||
}
|
||||
|
||||
s.cachedUsers[updated.Username] = updatedUser
|
||||
|
||||
return s.saveUsers()
|
||||
}
|
||||
|
||||
func (s *UserStore) DeleteUser(ctx context.Context, username string) error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
existingUser := s.getUser(username)
|
||||
if existingUser == nil {
|
||||
return fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
delete(s.cachedUsers, username)
|
||||
|
||||
return s.saveUsers()
|
||||
}
|
||||
|
|
@ -110,8 +110,8 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) {
|
|||
scanSubs: &subscriptionManager{},
|
||||
}
|
||||
|
||||
instance.UserService = &user.Service{
|
||||
Store: cfg,
|
||||
mgr.UserService = &user.Service{
|
||||
Store: cfg.UserStore,
|
||||
}
|
||||
|
||||
if !cfg.IsNewSystem() {
|
||||
|
|
@ -135,7 +135,7 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) {
|
|||
|
||||
// create temporary session store - this will be re-initialised
|
||||
// after config is complete
|
||||
mgr.SessionStore = session.NewStore(cfg, cfg)
|
||||
mgr.SessionStore = session.NewStore(cfg, instance.UserService)
|
||||
|
||||
logger.Warnf("config file %snot found. Assuming new system...", cfgFile)
|
||||
}
|
||||
|
|
@ -194,7 +194,7 @@ func initJobManager(cfg *config.Config) *job.Manager {
|
|||
func (s *Manager) postInit(ctx context.Context) error {
|
||||
s.RefreshConfig()
|
||||
|
||||
s.SessionStore = session.NewStore(s.Config, s.Config)
|
||||
s.SessionStore = session.NewStore(s.Config, s.UserService)
|
||||
s.PluginCache.RegisterSessionStore(s.SessionStore)
|
||||
|
||||
s.RefreshPluginCache()
|
||||
|
|
@ -256,7 +256,7 @@ func (s *Manager) postInit(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (s *Manager) checkSecurityTripwire() {
|
||||
if err := session.CheckExternalAccessTripwire(s.Config); err != nil {
|
||||
if err := session.CheckExternalAccessTripwire(s.Config.UserStore, s.Config); err != nil {
|
||||
session.LogExternalAccessError(*err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import (
|
|||
"github.com/stashapp/stash/pkg/image"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/scene"
|
||||
"github.com/stashapp/stash/pkg/user"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
)
|
||||
|
||||
type SceneService interface {
|
||||
|
|
@ -49,5 +49,14 @@ type GroupService interface {
|
|||
}
|
||||
|
||||
type UserService interface {
|
||||
GetUser(ctx context.Context, username string) (*user.User, error)
|
||||
session.Authenticator
|
||||
AllUsers(ctx context.Context) ([]*models.User, error)
|
||||
GetUser(ctx context.Context, username string) (*models.User, error)
|
||||
LoginRequired(ctx context.Context) bool
|
||||
|
||||
CreateUser(ctx context.Context, u models.User, password string) error
|
||||
UpdateUser(ctx context.Context, username string, updated models.User) error
|
||||
ChangePassword(ctx context.Context, username, existingPassword, newPassword string) error
|
||||
ChangeUserPassword(ctx context.Context, username string, newPassword string) error
|
||||
DeleteUser(ctx context.Context, username string) error
|
||||
}
|
||||
|
|
|
|||
12
pkg/models/model_user.go
Normal file
12
pkg/models/model_user.go
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
package models
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Roles Roles
|
||||
}
|
||||
|
||||
type UserInput struct {
|
||||
Username string
|
||||
Roles Roles
|
||||
Password string
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package user
|
||||
package models
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
|
@ -58,15 +58,13 @@ func (e RoleEnum) MarshalGQL(w io.Writer) {
|
|||
fmt.Fprint(w, strconv.Quote(e.String()))
|
||||
}
|
||||
|
||||
func IsRole(assignedRoles []RoleEnum, requiredRole RoleEnum) bool {
|
||||
valid := false
|
||||
type Roles []RoleEnum
|
||||
|
||||
for _, role := range assignedRoles {
|
||||
if role.Implies(requiredRole) {
|
||||
valid = true
|
||||
break
|
||||
func (r Roles) HasRole(role RoleEnum) bool {
|
||||
for _, r := range r {
|
||||
if r.Implies(role) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return valid
|
||||
return false
|
||||
}
|
||||
|
|
@ -16,8 +16,8 @@ func (e ExternalAccessError) Error() string {
|
|||
return fmt.Sprintf("stash accessed from external IP %s", net.IP(e).String())
|
||||
}
|
||||
|
||||
func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error {
|
||||
if hc, _ := c.HasCredentials(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() {
|
||||
func CheckAllowPublicWithoutAuth(s CredentialStore, c ExternalAccessConfig, r *http.Request) error {
|
||||
if hc := s.LoginRequired(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() {
|
||||
requestIPString, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing remote host (%s): %w", r.RemoteAddr, err)
|
||||
|
|
@ -60,8 +60,8 @@ func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func CheckExternalAccessTripwire(c ExternalAccessConfig) *ExternalAccessError {
|
||||
if hc, _ := c.HasCredentials(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() {
|
||||
func CheckExternalAccessTripwire(s CredentialStore, c ExternalAccessConfig) *ExternalAccessError {
|
||||
if hc := s.LoginRequired(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() {
|
||||
if remoteIP := c.GetSecurityTripwireAccessedFromPublicInternet(); remoteIP != "" {
|
||||
err := ExternalAccessError(net.ParseIP(remoteIP))
|
||||
return &err
|
||||
|
|
|
|||
|
|
@ -3,12 +3,15 @@ package session
|
|||
import "context"
|
||||
|
||||
type ExternalAccessConfig interface {
|
||||
HasCredentials(ctx context.Context) (bool, error)
|
||||
GetDangerousAllowPublicWithoutAuth() bool
|
||||
GetSecurityTripwireAccessedFromPublicInternet() string
|
||||
IsNewSystem() bool
|
||||
}
|
||||
|
||||
type CredentialStore interface {
|
||||
LoginRequired(ctx context.Context) bool
|
||||
}
|
||||
|
||||
type SessionConfig interface {
|
||||
GetUsername() string
|
||||
GetAPIKey() string
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/user"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
||||
type key int
|
||||
|
|
@ -46,7 +46,8 @@ func (e InvalidCredentialsError) Error() string {
|
|||
var ErrUnauthorized = errors.New("unauthorized")
|
||||
|
||||
type Authenticator interface {
|
||||
ValidateCredentials(username string, password string) bool
|
||||
LoginRequired(ctx context.Context) bool
|
||||
ValidateCredentials(ctx context.Context, username string, password string) error
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
|
|
@ -68,6 +69,10 @@ func NewStore(c SessionConfig, a Authenticator) *Store {
|
|||
return ret
|
||||
}
|
||||
|
||||
func (s *Store) LoginRequired(ctx context.Context) bool {
|
||||
return s.authenticator.LoginRequired(ctx)
|
||||
}
|
||||
|
||||
func (s *Store) Login(w http.ResponseWriter, r *http.Request) error {
|
||||
// ignore error - we want a new session regardless
|
||||
newSession, _ := s.sessionStore.Get(r, cookieName)
|
||||
|
|
@ -76,16 +81,16 @@ func (s *Store) Login(w http.ResponseWriter, r *http.Request) error {
|
|||
password := r.FormValue(passwordFormKey)
|
||||
|
||||
// authenticate the user
|
||||
if !s.authenticator.ValidateCredentials(username, password) {
|
||||
err := s.authenticator.ValidateCredentials(r.Context(), username, password)
|
||||
if err != nil {
|
||||
return &InvalidCredentialsError{Username: username}
|
||||
}
|
||||
|
||||
// since we only have one user, don't leak the name
|
||||
logger.Info("User logged in")
|
||||
logger.Infof("User %s logged in", username)
|
||||
|
||||
newSession.Values[userIDKey] = username
|
||||
|
||||
err := newSession.Save(r, w)
|
||||
err = newSession.Save(r, w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -99,6 +104,8 @@ func (s *Store) Logout(w http.ResponseWriter, r *http.Request) error {
|
|||
return err
|
||||
}
|
||||
|
||||
userID, _ := session.Values[userIDKey].(string)
|
||||
|
||||
delete(session.Values, userIDKey)
|
||||
session.Options.MaxAge = -1
|
||||
|
||||
|
|
@ -107,8 +114,7 @@ func (s *Store) Logout(w http.ResponseWriter, r *http.Request) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// since we only have one user, don't leak the name
|
||||
logger.Infof("User logged out")
|
||||
logger.Infof("User %s logged out", userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -138,15 +144,15 @@ func (s *Store) GetSessionUserID(w http.ResponseWriter, r *http.Request) (string
|
|||
return "", nil
|
||||
}
|
||||
|
||||
func SetCurrentUser(ctx context.Context, u user.User) context.Context {
|
||||
func SetCurrentUser(ctx context.Context, u models.User) context.Context {
|
||||
return context.WithValue(ctx, contextUser, u)
|
||||
}
|
||||
|
||||
// GetCurrentUser gets the current user id from the provided context
|
||||
func GetCurrentUser(ctx context.Context) *user.User {
|
||||
func GetCurrentUser(ctx context.Context) *models.User {
|
||||
userCtxVal := ctx.Value(contextUser)
|
||||
if userCtxVal != nil {
|
||||
currentUser := userCtxVal.(user.User)
|
||||
currentUser := userCtxVal.(models.User)
|
||||
return ¤tUser
|
||||
}
|
||||
|
||||
|
|
@ -164,6 +170,7 @@ func (s *Store) Authenticate(w http.ResponseWriter, r *http.Request) (userID str
|
|||
apiKey = r.URL.Query().Get(ApiKeyParameter)
|
||||
}
|
||||
|
||||
// FIXME - handle this
|
||||
if apiKey != "" {
|
||||
// match against configured API and set userID to the
|
||||
// configured username. In future, we'll want to
|
||||
|
|
|
|||
|
|
@ -1,44 +1 @@
|
|||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPasswordMismatch = fmt.Errorf("password mismatch")
|
||||
ErrUserNotFound = fmt.Errorf("user not found")
|
||||
)
|
||||
|
||||
func (s *Service) ValidateCredentials(ctx context.Context, username string, password string) error {
|
||||
hc, err := s.Store.HasCredentials(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking if credentials exist: %w", err)
|
||||
}
|
||||
|
||||
if !hc {
|
||||
// don't need to authenticate if no credentials saved
|
||||
return nil
|
||||
}
|
||||
|
||||
authPWHash, err := s.Store.GetPasswordHash(ctx, username)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("error getting password hash: %w", err)
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(authPWHash), []byte(password)); err != nil {
|
||||
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
||||
return ErrPasswordMismatch
|
||||
}
|
||||
|
||||
return fmt.Errorf("error comparing password hash: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,18 +2,266 @@ package user
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
GetUser(ctx context.Context, username string) (*User, error)
|
||||
HasCredentials(ctx context.Context) (bool, error)
|
||||
GetPasswordHash(ctx context.Context, username string) (string, error)
|
||||
var (
|
||||
ErrUserNotExist = errors.New("user not found")
|
||||
ErrEmptyUsername = errors.New("empty username")
|
||||
ErrUsernameHasWhitespace = errors.New("username has leading or trailing whitespace")
|
||||
ErrDeleteLastAdminUser = errors.New("final admin user cannot be deleted")
|
||||
ErrInternalError = errors.New("internal error")
|
||||
ErrAccessDenied = errors.New("access denied")
|
||||
ErrCurrentPasswordIncorrect = errors.New("current password incorrect")
|
||||
ErrUserAlreadyExists = errors.New("user with that username already exists")
|
||||
)
|
||||
|
||||
type UserSource interface {
|
||||
AllUsers(ctx context.Context) ([]*models.User, error)
|
||||
GetUser(ctx context.Context, username string) (*models.User, error)
|
||||
ValidateCredentials(ctx context.Context, username string, password string) bool
|
||||
|
||||
CreateUser(ctx context.Context, u models.User, password string) error
|
||||
ReplaceUser(ctx context.Context, username string, updated models.User) error
|
||||
ChangeUserPassword(ctx context.Context, username string, newPassword string) error
|
||||
DeleteUser(ctx context.Context, username string) error
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
Store Store
|
||||
Store UserSource
|
||||
}
|
||||
|
||||
func (s *Service) GetUser(ctx context.Context, username string) (*User, error) {
|
||||
func (s *Service) LoginRequired(ctx context.Context) bool {
|
||||
u, _ := s.Store.AllUsers(ctx)
|
||||
return len(u) > 0
|
||||
}
|
||||
|
||||
func (s *Service) GetUser(ctx context.Context, username string) (*models.User, error) {
|
||||
return s.Store.GetUser(ctx, username)
|
||||
}
|
||||
|
||||
func (s *Service) AllUsers(ctx context.Context) ([]*models.User, error) {
|
||||
return s.Store.AllUsers(ctx)
|
||||
}
|
||||
|
||||
func (s *Service) ValidateCredentials(ctx context.Context, username string, password string) error {
|
||||
// ensure user is not locked
|
||||
u, err := s.GetUser(ctx, username)
|
||||
if err != nil {
|
||||
logger.Errorf("error getting user for credential validation: %v", err)
|
||||
return ErrInternalError
|
||||
}
|
||||
|
||||
if u == nil {
|
||||
logger.Infof("[login attempt] user %s not found during credential validation", username)
|
||||
return ErrAccessDenied
|
||||
}
|
||||
|
||||
if len(u.Roles) == 0 {
|
||||
logger.Infof("[login attempt] user %s is locked", username)
|
||||
return ErrAccessDenied
|
||||
}
|
||||
|
||||
if !s.Store.ValidateCredentials(ctx, username, password) {
|
||||
logger.Infof("[login attempt] invalid credentials for user %s", username)
|
||||
return ErrAccessDenied
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateUsername(username string) error {
|
||||
if username == "" {
|
||||
return ErrEmptyUsername
|
||||
}
|
||||
|
||||
// username must not have leading or trailing whitespace
|
||||
trimmed := strings.TrimSpace(username)
|
||||
|
||||
if trimmed != username {
|
||||
return ErrUsernameHasWhitespace
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validatePassword(password string) error {
|
||||
if password == "" {
|
||||
return errors.New("password cannot be empty")
|
||||
}
|
||||
|
||||
// add more password validation as needed
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) CreateUser(ctx context.Context, u models.User, password string) error {
|
||||
// validate input
|
||||
// ensure username is valid
|
||||
if err := s.validateUsername(u.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check if user exists
|
||||
existingUser, err := s.GetUser(ctx, u.Username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking existing users: %w", err)
|
||||
}
|
||||
|
||||
if existingUser != nil {
|
||||
return ErrUserAlreadyExists
|
||||
}
|
||||
|
||||
// validate password
|
||||
if err := s.validatePassword(password); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if this is the first user, make them an admin
|
||||
users, err := s.AllUsers(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting existing users: %w", err)
|
||||
}
|
||||
|
||||
if len(users) == 0 && !u.Roles.HasRole(models.RoleEnumAdmin) {
|
||||
return errors.New("the first user must be an admin")
|
||||
}
|
||||
|
||||
// create user in store
|
||||
if err := s.Store.CreateUser(ctx, u, password); err != nil {
|
||||
return fmt.Errorf("error creating user: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof("[user] created %q", u.Username)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) UpdateUser(ctx context.Context, username string, updated models.User) error {
|
||||
// validate input
|
||||
// check if user exists
|
||||
existingUser, err := s.GetUser(ctx, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting existing user: %w", err)
|
||||
}
|
||||
|
||||
if existingUser == nil {
|
||||
return ErrUserNotExist
|
||||
}
|
||||
|
||||
existingRoles := existingUser.Roles
|
||||
|
||||
// ensure username is valid
|
||||
if username != updated.Username {
|
||||
if err := s.validateUsername(updated.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure new username doesn't already exist
|
||||
otherUser, err := s.GetUser(ctx, updated.Username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking existing user: %w", err)
|
||||
}
|
||||
|
||||
if otherUser != nil {
|
||||
return ErrUserAlreadyExists
|
||||
}
|
||||
}
|
||||
|
||||
// update user in store
|
||||
if err := s.Store.ReplaceUser(ctx, username, updated); err != nil {
|
||||
return fmt.Errorf("error updating user: %w", err)
|
||||
}
|
||||
|
||||
if username != updated.Username {
|
||||
logger.Infof("[user] updated name %q -> %q", username, updated.Username)
|
||||
}
|
||||
|
||||
if !slices.Equal(existingRoles, updated.Roles) {
|
||||
logger.Infof("[user] updated roles for user %q", updated.Username)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) ChangePassword(ctx context.Context, username, currentPassword, newPassword string) error {
|
||||
// validate current credentials
|
||||
if err := s.ValidateCredentials(ctx, username, currentPassword); err != nil {
|
||||
logger.Infof("[user] failed password change attempt for %q: incorrect current password", username)
|
||||
return ErrCurrentPasswordIncorrect
|
||||
}
|
||||
|
||||
return s.ChangeUserPassword(ctx, username, newPassword)
|
||||
}
|
||||
|
||||
func (s *Service) ChangeUserPassword(ctx context.Context, username, newPassword string) error {
|
||||
// check if user exists
|
||||
existingUser, err := s.GetUser(ctx, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting existing user: %w", err)
|
||||
}
|
||||
|
||||
if existingUser == nil {
|
||||
return ErrUserNotExist
|
||||
}
|
||||
|
||||
// validate new password
|
||||
if err := s.validatePassword(newPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// change password in store
|
||||
if err := s.Store.ChangeUserPassword(ctx, username, newPassword); err != nil {
|
||||
return fmt.Errorf("error changing user password: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof("[user] changed password for %q", username)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) DeleteUser(ctx context.Context, username string) error {
|
||||
// check if user exists
|
||||
existingUser, err := s.GetUser(ctx, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting existing user: %w", err)
|
||||
}
|
||||
|
||||
if existingUser == nil {
|
||||
return ErrUserNotExist
|
||||
}
|
||||
|
||||
// don't allow deleting last admin user
|
||||
if existingUser.Roles.HasRole(models.RoleEnumAdmin) {
|
||||
users, err := s.AllUsers(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting all users: %w", err)
|
||||
}
|
||||
|
||||
hasAdmin := false
|
||||
for _, u := range users {
|
||||
if u.Username != username && u.Roles.HasRole(models.RoleEnumAdmin) {
|
||||
hasAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAdmin {
|
||||
return ErrDeleteLastAdminUser
|
||||
}
|
||||
}
|
||||
|
||||
// delete user from store
|
||||
if err := s.Store.DeleteUser(ctx, username); err != nil {
|
||||
return fmt.Errorf("error deleting user: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof("[user] deleted %q", username)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
package user
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Roles []RoleEnum
|
||||
}
|
||||
Loading…
Reference in a new issue