diff --git a/internal/api/authentication.go b/internal/api/authentication.go index 6ad7117a1..1099a5702 100644 --- a/internal/api/authentication.go +++ b/internal/api/authentication.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "net" "net/http" @@ -12,6 +13,7 @@ import ( "github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/session" + "github.com/stashapp/stash/pkg/user" ) const ( @@ -29,7 +31,11 @@ func allowUnauthenticated(r *http.Request) bool { return strings.HasPrefix(r.URL.Path, loginEndpoint) || r.URL.Path == logoutEndpoint || r.URL.Path == "/css" || strings.HasPrefix(r.URL.Path, "/assets") } -func authenticateHandler() func(http.Handler) http.Handler { +type UserGetter interface { + GetUser(ctx context.Context, username string) (*user.User, error) +} + +func authenticateHandler(g UserGetter) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c := config.GetInstance() @@ -73,7 +79,7 @@ func authenticateHandler() func(http.Handler) http.Handler { ctx := r.Context() - if c.HasCredentials() { + if hc, _ := c.HasCredentials(ctx); hc { // authentication is required if userID == "" && !allowUnauthenticated(r) { // if graphql or a non-webpage was requested, we just return a forbidden error @@ -102,7 +108,18 @@ func authenticateHandler() func(http.Handler) http.Handler { } } - ctx = session.SetCurrentUserID(ctx, userID) + if userID != "" { + // set the user object in the context + u, err := g.GetUser(ctx, userID) + if err != nil { + // if we can't get the user object, we just return a forbidden error + logger.Errorf("Error getting user object: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + ctx = session.SetCurrentUser(ctx, *u) + } r = r.WithContext(ctx) diff --git a/internal/api/resolver_mutation_configure.go b/internal/api/resolver_mutation_configure.go index 23b61c208..67ace9552 100644 --- a/internal/api/resolver_mutation_configure.go +++ b/internal/api/resolver_mutation_configure.go @@ -320,7 +320,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen if input.Password != nil { // bit of a hack - check if the passed in password is the same as the stored hash // and only set if they are different - currentPWHash := c.GetPasswordHash() + currentPWHash, _ := c.GetPasswordHash(ctx, c.GetUsername()) if *input.Password != currentPWHash { if *input.Password == "" { diff --git a/internal/api/resolver_query_configuration.go b/internal/api/resolver_query_configuration.go index bc76212eb..11727ae36 100644 --- a/internal/api/resolver_query_configuration.go +++ b/internal/api/resolver_query_configuration.go @@ -78,6 +78,9 @@ func makeConfigGeneralResult() *ConfigGeneralResult { customPerformerImageLocation := config.GetCustomPerformerImageLocation() + username := config.GetUsername() + pwHash, _ := config.GetPasswordHash(context.Background(), username) + return &ConfigGeneralResult{ Stashes: config.GetStashPaths(), DatabasePath: config.GetDatabasePath(), @@ -110,7 +113,7 @@ func makeConfigGeneralResult() *ConfigGeneralResult { GalleryCoverRegex: config.GetGalleryCoverRegex(), APIKey: config.GetAPIKey(), Username: config.GetUsername(), - Password: config.GetPasswordHash(), + Password: pwHash, MaxSessionAge: config.GetMaxSessionAge(), LogFile: &logFile, LogOut: config.GetLogOut(), diff --git a/internal/api/server.go b/internal/api/server.go index a7516da52..9537891c3 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -122,9 +122,11 @@ func Initialize() (*Server, error) { manager: mgr, } + userStore := manager.GetInstance().UserService + r.Use(middleware.Heartbeat("/healthz")) r.Use(cors.AllowAll().Handler) - r.Use(authenticateHandler()) + r.Use(authenticateHandler(userStore)) visitedPluginHandler := mgr.SessionStore.VisitedPluginHandler() r.Use(visitedPluginHandler) diff --git a/internal/api/session.go b/internal/api/session.go index 5918cdd9b..f779fe992 100644 --- a/internal/api/session.go +++ b/internal/api/session.go @@ -107,7 +107,7 @@ func handleLogin() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { returnURL := r.URL.Query().Get(returnURLParam) - if !config.GetInstance().HasCredentials() { + if hc, _ := config.GetInstance().HasCredentials(r.Context()); !hc { if returnURL != "" { http.Redirect(w, r, returnURL, http.StatusFound) } else { @@ -155,7 +155,7 @@ func handleLogout() http.HandlerFunc { // redirect to the login page if credentials are required prefix := getProxyPrefix(r) - if config.GetInstance().HasCredentials() { + if hc, _ := config.GetInstance().HasCredentials(r.Context()); hc { http.Redirect(w, r, prefix+loginEndpoint, http.StatusFound) } else { http.Redirect(w, r, prefix+"/", http.StatusFound) diff --git a/internal/manager/config/config.go b/internal/manager/config/config.go index bb99bdcfc..78f84254a 100644 --- a/internal/manager/config/config.go +++ b/internal/manager/config/config.go @@ -1,6 +1,7 @@ package config import ( + "context" "fmt" "net/url" "os" @@ -27,6 +28,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/paths" "github.com/stashapp/stash/pkg/sliceutil" + "github.com/stashapp/stash/pkg/user" "github.com/stashapp/stash/pkg/utils" ) @@ -1080,23 +1082,39 @@ func (i *Config) GetUsername() string { return i.getString(Username) } -func (i *Config) GetPasswordHash() string { - return i.getString(Password) +func (i *Config) GetUser(ctx context.Context, username string) (*user.User, error) { + u := i.GetUsername() + if u != username { + return nil, user.ErrUserNotFound + } + + return &user.User{ + Username: u, + }, nil +} + +func (i *Config) GetPasswordHash(ctx context.Context, username string) (string, error) { + u := i.GetUsername() + if u != username { + return "", user.ErrUserNotFound + } + + return i.getString(Password), nil } func (i *Config) GetCredentials() (string, string) { - if i.HasCredentials() { + if hc, _ := i.HasCredentials(context.Background()); hc { return i.getString(Username), i.getString(Password) } return "", "" } -func (i *Config) HasCredentials() bool { +func (i *Config) HasCredentials(ctx context.Context) (bool, error) { username := i.getString(Username) pwHash := i.getString(Password) - return username != "" && pwHash != "" + return username != "" && pwHash != "", nil } func hashPassword(password string) string { @@ -1106,7 +1124,7 @@ func hashPassword(password string) string { } func (i *Config) ValidateCredentials(username string, password string) bool { - if !i.HasCredentials() { + if hc, _ := i.HasCredentials(context.Background()); !hc { // don't need to authenticate if no credentials saved return true } diff --git a/internal/manager/init.go b/internal/manager/init.go index 4423b2254..2f0f60a42 100644 --- a/internal/manager/init.go +++ b/internal/manager/init.go @@ -27,6 +27,7 @@ import ( "github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/session" "github.com/stashapp/stash/pkg/sqlite" + "github.com/stashapp/stash/pkg/user" "github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/ui" ) @@ -109,6 +110,10 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) { scanSubs: &subscriptionManager{}, } + instance.UserService = &user.Service{ + Store: cfg, + } + if !cfg.IsNewSystem() { logger.Infof("using config file: %s", cfg.GetConfigFile()) @@ -189,7 +194,7 @@ func initJobManager(cfg *config.Config) *job.Manager { func (s *Manager) postInit(ctx context.Context) error { s.RefreshConfig() - s.SessionStore = session.NewStore(s.Config) + s.SessionStore = session.NewStore(s.Config, s.Config) s.PluginCache.RegisterSessionStore(s.SessionStore) s.RefreshPluginCache() diff --git a/internal/manager/manager.go b/internal/manager/manager.go index f4f3fa636..8b261d889 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -67,6 +67,7 @@ type Manager struct { ImageService ImageService GalleryService GalleryService GroupService GroupService + UserService UserService scanSubs *subscriptionManager } diff --git a/internal/manager/repository.go b/internal/manager/repository.go index e51e737ee..d017dc23b 100644 --- a/internal/manager/repository.go +++ b/internal/manager/repository.go @@ -7,6 +7,7 @@ import ( "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" + "github.com/stashapp/stash/pkg/user" ) type SceneService interface { @@ -46,3 +47,7 @@ type GroupService interface { RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error } + +type UserService interface { + GetUser(ctx context.Context, username string) (*user.User, error) +} diff --git a/pkg/session/authentication.go b/pkg/session/authentication.go index 95c41baa5..8fb39c099 100644 --- a/pkg/session/authentication.go +++ b/pkg/session/authentication.go @@ -1,6 +1,7 @@ package session import ( + "context" "fmt" "net" "net/http" @@ -16,7 +17,7 @@ func (e ExternalAccessError) Error() string { } func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error { - if !c.HasCredentials() && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() { + if hc, _ := c.HasCredentials(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() { requestIPString, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return fmt.Errorf("error parsing remote host (%s): %w", r.RemoteAddr, err) @@ -60,7 +61,7 @@ func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error } func CheckExternalAccessTripwire(c ExternalAccessConfig) *ExternalAccessError { - if !c.HasCredentials() && !c.GetDangerousAllowPublicWithoutAuth() { + if hc, _ := c.HasCredentials(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() { if remoteIP := c.GetSecurityTripwireAccessedFromPublicInternet(); remoteIP != "" { err := ExternalAccessError(net.ParseIP(remoteIP)) return &err diff --git a/pkg/session/config.go b/pkg/session/config.go index ecdc36bd3..372fc3a6a 100644 --- a/pkg/session/config.go +++ b/pkg/session/config.go @@ -1,7 +1,9 @@ package session +import "context" + type ExternalAccessConfig interface { - HasCredentials() bool + HasCredentials(ctx context.Context) (bool, error) GetDangerousAllowPublicWithoutAuth() bool GetSecurityTripwireAccessedFromPublicInternet() string IsNewSystem() bool diff --git a/pkg/session/plugin.go b/pkg/session/plugin.go index 7a57ca4b5..22d988072 100644 --- a/pkg/session/plugin.go +++ b/pkg/session/plugin.go @@ -61,7 +61,7 @@ func setVisitedPluginHooks(ctx context.Context, visitedPlugins []VisitedPluginHo } func (s *Store) MakePluginCookie(ctx context.Context) *http.Cookie { - currentUser := GetCurrentUserID(ctx) + currentUser := GetCurrentUser(ctx) visitedPlugins := GetVisitedPluginHooks(ctx) session := sessions.NewSession(s.sessionStore, cookieName) diff --git a/pkg/session/session.go b/pkg/session/session.go index a211a1524..997a71198 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -8,6 +8,7 @@ import ( "github.com/gorilla/sessions" "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/user" ) type key int @@ -137,15 +138,15 @@ func (s *Store) GetSessionUserID(w http.ResponseWriter, r *http.Request) (string return "", nil } -func SetCurrentUserID(ctx context.Context, userID string) context.Context { - return context.WithValue(ctx, contextUser, userID) +func SetCurrentUser(ctx context.Context, u user.User) context.Context { + return context.WithValue(ctx, contextUser, u) } -// GetCurrentUserID gets the current user id from the provided context -func GetCurrentUserID(ctx context.Context) *string { +// GetCurrentUser gets the current user id from the provided context +func GetCurrentUser(ctx context.Context) *user.User { userCtxVal := ctx.Value(contextUser) if userCtxVal != nil { - currentUser := userCtxVal.(string) + currentUser := userCtxVal.(user.User) return ¤tUser } diff --git a/pkg/user/authenticate.go b/pkg/user/authenticate.go new file mode 100644 index 000000000..d3aa48f4a --- /dev/null +++ b/pkg/user/authenticate.go @@ -0,0 +1,44 @@ +package user + +import ( + "context" + "errors" + "fmt" + + "golang.org/x/crypto/bcrypt" +) + +var ( + ErrPasswordMismatch = fmt.Errorf("password mismatch") + ErrUserNotFound = fmt.Errorf("user not found") +) + +func (s *Service) ValidateCredentials(ctx context.Context, username string, password string) error { + hc, err := s.Store.HasCredentials(ctx) + if err != nil { + return fmt.Errorf("error checking if credentials exist: %w", err) + } + + if !hc { + // don't need to authenticate if no credentials saved + return nil + } + + authPWHash, err := s.Store.GetPasswordHash(ctx, username) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return err + } + return fmt.Errorf("error getting password hash: %w", err) + } + + if err := bcrypt.CompareHashAndPassword([]byte(authPWHash), []byte(password)); err != nil { + if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { + return ErrPasswordMismatch + } + + return fmt.Errorf("error comparing password hash: %w", err) + } + + return nil +} diff --git a/pkg/user/service.go b/pkg/user/service.go new file mode 100644 index 000000000..d918fed0c --- /dev/null +++ b/pkg/user/service.go @@ -0,0 +1,19 @@ +package user + +import ( + "context" +) + +type Store interface { + GetUser(ctx context.Context, username string) (*User, error) + HasCredentials(ctx context.Context) (bool, error) + GetPasswordHash(ctx context.Context, username string) (string, error) +} + +type Service struct { + Store Store +} + +func (s *Service) GetUser(ctx context.Context, username string) (*User, error) { + return s.Store.GetUser(ctx, username) +} diff --git a/pkg/user/user.go b/pkg/user/user.go new file mode 100644 index 000000000..aea6dbaeb --- /dev/null +++ b/pkg/user/user.go @@ -0,0 +1,8 @@ +package user + +type RoleEnum string + +type User struct { + Username string + Roles []RoleEnum +}