mirror of
https://github.com/stashapp/stash.git
synced 2026-02-07 16:05:47 +01:00
Add UserService
This commit is contained in:
parent
51d7bf272e
commit
d906a66fec
16 changed files with 150 additions and 24 deletions
|
|
@ -1,6 +1,7 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
|
|
@ -12,6 +13,7 @@ import (
|
|||
"github.com/stashapp/stash/internal/manager/config"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
"github.com/stashapp/stash/pkg/user"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -29,7 +31,11 @@ func allowUnauthenticated(r *http.Request) bool {
|
|||
return strings.HasPrefix(r.URL.Path, loginEndpoint) || r.URL.Path == logoutEndpoint || r.URL.Path == "/css" || strings.HasPrefix(r.URL.Path, "/assets")
|
||||
}
|
||||
|
||||
func authenticateHandler() func(http.Handler) http.Handler {
|
||||
type UserGetter interface {
|
||||
GetUser(ctx context.Context, username string) (*user.User, error)
|
||||
}
|
||||
|
||||
func authenticateHandler(g UserGetter) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c := config.GetInstance()
|
||||
|
|
@ -73,7 +79,7 @@ func authenticateHandler() func(http.Handler) http.Handler {
|
|||
|
||||
ctx := r.Context()
|
||||
|
||||
if c.HasCredentials() {
|
||||
if hc, _ := c.HasCredentials(ctx); hc {
|
||||
// authentication is required
|
||||
if userID == "" && !allowUnauthenticated(r) {
|
||||
// if graphql or a non-webpage was requested, we just return a forbidden error
|
||||
|
|
@ -102,7 +108,18 @@ func authenticateHandler() func(http.Handler) http.Handler {
|
|||
}
|
||||
}
|
||||
|
||||
ctx = session.SetCurrentUserID(ctx, userID)
|
||||
if userID != "" {
|
||||
// set the user object in the context
|
||||
u, err := g.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
// if we can't get the user object, we just return a forbidden error
|
||||
logger.Errorf("Error getting user object: %v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ctx = session.SetCurrentUser(ctx, *u)
|
||||
}
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
|
|
|
|||
|
|
@ -320,7 +320,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen
|
|||
if input.Password != nil {
|
||||
// bit of a hack - check if the passed in password is the same as the stored hash
|
||||
// and only set if they are different
|
||||
currentPWHash := c.GetPasswordHash()
|
||||
currentPWHash, _ := c.GetPasswordHash(ctx, c.GetUsername())
|
||||
|
||||
if *input.Password != currentPWHash {
|
||||
if *input.Password == "" {
|
||||
|
|
|
|||
|
|
@ -78,6 +78,9 @@ func makeConfigGeneralResult() *ConfigGeneralResult {
|
|||
|
||||
customPerformerImageLocation := config.GetCustomPerformerImageLocation()
|
||||
|
||||
username := config.GetUsername()
|
||||
pwHash, _ := config.GetPasswordHash(context.Background(), username)
|
||||
|
||||
return &ConfigGeneralResult{
|
||||
Stashes: config.GetStashPaths(),
|
||||
DatabasePath: config.GetDatabasePath(),
|
||||
|
|
@ -110,7 +113,7 @@ func makeConfigGeneralResult() *ConfigGeneralResult {
|
|||
GalleryCoverRegex: config.GetGalleryCoverRegex(),
|
||||
APIKey: config.GetAPIKey(),
|
||||
Username: config.GetUsername(),
|
||||
Password: config.GetPasswordHash(),
|
||||
Password: pwHash,
|
||||
MaxSessionAge: config.GetMaxSessionAge(),
|
||||
LogFile: &logFile,
|
||||
LogOut: config.GetLogOut(),
|
||||
|
|
|
|||
|
|
@ -122,9 +122,11 @@ func Initialize() (*Server, error) {
|
|||
manager: mgr,
|
||||
}
|
||||
|
||||
userStore := manager.GetInstance().UserService
|
||||
|
||||
r.Use(middleware.Heartbeat("/healthz"))
|
||||
r.Use(cors.AllowAll().Handler)
|
||||
r.Use(authenticateHandler())
|
||||
r.Use(authenticateHandler(userStore))
|
||||
visitedPluginHandler := mgr.SessionStore.VisitedPluginHandler()
|
||||
r.Use(visitedPluginHandler)
|
||||
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ func handleLogin() http.HandlerFunc {
|
|||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
returnURL := r.URL.Query().Get(returnURLParam)
|
||||
|
||||
if !config.GetInstance().HasCredentials() {
|
||||
if hc, _ := config.GetInstance().HasCredentials(r.Context()); !hc {
|
||||
if returnURL != "" {
|
||||
http.Redirect(w, r, returnURL, http.StatusFound)
|
||||
} else {
|
||||
|
|
@ -155,7 +155,7 @@ func handleLogout() http.HandlerFunc {
|
|||
|
||||
// redirect to the login page if credentials are required
|
||||
prefix := getProxyPrefix(r)
|
||||
if config.GetInstance().HasCredentials() {
|
||||
if hc, _ := config.GetInstance().HasCredentials(r.Context()); hc {
|
||||
http.Redirect(w, r, prefix+loginEndpoint, http.StatusFound)
|
||||
} else {
|
||||
http.Redirect(w, r, prefix+"/", http.StatusFound)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
|
|
@ -27,6 +28,7 @@ import (
|
|||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/models/paths"
|
||||
"github.com/stashapp/stash/pkg/sliceutil"
|
||||
"github.com/stashapp/stash/pkg/user"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
|
|
@ -1080,23 +1082,39 @@ func (i *Config) GetUsername() string {
|
|||
return i.getString(Username)
|
||||
}
|
||||
|
||||
func (i *Config) GetPasswordHash() string {
|
||||
return i.getString(Password)
|
||||
func (i *Config) GetUser(ctx context.Context, username string) (*user.User, error) {
|
||||
u := i.GetUsername()
|
||||
if u != username {
|
||||
return nil, user.ErrUserNotFound
|
||||
}
|
||||
|
||||
return &user.User{
|
||||
Username: u,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *Config) GetPasswordHash(ctx context.Context, username string) (string, error) {
|
||||
u := i.GetUsername()
|
||||
if u != username {
|
||||
return "", user.ErrUserNotFound
|
||||
}
|
||||
|
||||
return i.getString(Password), nil
|
||||
}
|
||||
|
||||
func (i *Config) GetCredentials() (string, string) {
|
||||
if i.HasCredentials() {
|
||||
if hc, _ := i.HasCredentials(context.Background()); hc {
|
||||
return i.getString(Username), i.getString(Password)
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func (i *Config) HasCredentials() bool {
|
||||
func (i *Config) HasCredentials(ctx context.Context) (bool, error) {
|
||||
username := i.getString(Username)
|
||||
pwHash := i.getString(Password)
|
||||
|
||||
return username != "" && pwHash != ""
|
||||
return username != "" && pwHash != "", nil
|
||||
}
|
||||
|
||||
func hashPassword(password string) string {
|
||||
|
|
@ -1106,7 +1124,7 @@ func hashPassword(password string) string {
|
|||
}
|
||||
|
||||
func (i *Config) ValidateCredentials(username string, password string) bool {
|
||||
if !i.HasCredentials() {
|
||||
if hc, _ := i.HasCredentials(context.Background()); !hc {
|
||||
// don't need to authenticate if no credentials saved
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/stashapp/stash/pkg/scraper"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
"github.com/stashapp/stash/pkg/sqlite"
|
||||
"github.com/stashapp/stash/pkg/user"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
"github.com/stashapp/stash/ui"
|
||||
)
|
||||
|
|
@ -109,6 +110,10 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) {
|
|||
scanSubs: &subscriptionManager{},
|
||||
}
|
||||
|
||||
instance.UserService = &user.Service{
|
||||
Store: cfg,
|
||||
}
|
||||
|
||||
if !cfg.IsNewSystem() {
|
||||
logger.Infof("using config file: %s", cfg.GetConfigFile())
|
||||
|
||||
|
|
@ -189,7 +194,7 @@ func initJobManager(cfg *config.Config) *job.Manager {
|
|||
func (s *Manager) postInit(ctx context.Context) error {
|
||||
s.RefreshConfig()
|
||||
|
||||
s.SessionStore = session.NewStore(s.Config)
|
||||
s.SessionStore = session.NewStore(s.Config, s.Config)
|
||||
s.PluginCache.RegisterSessionStore(s.SessionStore)
|
||||
|
||||
s.RefreshPluginCache()
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ type Manager struct {
|
|||
ImageService ImageService
|
||||
GalleryService GalleryService
|
||||
GroupService GroupService
|
||||
UserService UserService
|
||||
|
||||
scanSubs *subscriptionManager
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/stashapp/stash/pkg/image"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/scene"
|
||||
"github.com/stashapp/stash/pkg/user"
|
||||
)
|
||||
|
||||
type SceneService interface {
|
||||
|
|
@ -46,3 +47,7 @@ type GroupService interface {
|
|||
RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error
|
||||
ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error
|
||||
}
|
||||
|
||||
type UserService interface {
|
||||
GetUser(ctx context.Context, username string) (*user.User, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
|
@ -16,7 +17,7 @@ func (e ExternalAccessError) Error() string {
|
|||
}
|
||||
|
||||
func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error {
|
||||
if !c.HasCredentials() && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() {
|
||||
if hc, _ := c.HasCredentials(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() {
|
||||
requestIPString, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing remote host (%s): %w", r.RemoteAddr, err)
|
||||
|
|
@ -60,7 +61,7 @@ func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error
|
|||
}
|
||||
|
||||
func CheckExternalAccessTripwire(c ExternalAccessConfig) *ExternalAccessError {
|
||||
if !c.HasCredentials() && !c.GetDangerousAllowPublicWithoutAuth() {
|
||||
if hc, _ := c.HasCredentials(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() {
|
||||
if remoteIP := c.GetSecurityTripwireAccessedFromPublicInternet(); remoteIP != "" {
|
||||
err := ExternalAccessError(net.ParseIP(remoteIP))
|
||||
return &err
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
package session
|
||||
|
||||
import "context"
|
||||
|
||||
type ExternalAccessConfig interface {
|
||||
HasCredentials() bool
|
||||
HasCredentials(ctx context.Context) (bool, error)
|
||||
GetDangerousAllowPublicWithoutAuth() bool
|
||||
GetSecurityTripwireAccessedFromPublicInternet() string
|
||||
IsNewSystem() bool
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ func setVisitedPluginHooks(ctx context.Context, visitedPlugins []VisitedPluginHo
|
|||
}
|
||||
|
||||
func (s *Store) MakePluginCookie(ctx context.Context) *http.Cookie {
|
||||
currentUser := GetCurrentUserID(ctx)
|
||||
currentUser := GetCurrentUser(ctx)
|
||||
visitedPlugins := GetVisitedPluginHooks(ctx)
|
||||
|
||||
session := sessions.NewSession(s.sessionStore, cookieName)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/user"
|
||||
)
|
||||
|
||||
type key int
|
||||
|
|
@ -137,15 +138,15 @@ func (s *Store) GetSessionUserID(w http.ResponseWriter, r *http.Request) (string
|
|||
return "", nil
|
||||
}
|
||||
|
||||
func SetCurrentUserID(ctx context.Context, userID string) context.Context {
|
||||
return context.WithValue(ctx, contextUser, userID)
|
||||
func SetCurrentUser(ctx context.Context, u user.User) context.Context {
|
||||
return context.WithValue(ctx, contextUser, u)
|
||||
}
|
||||
|
||||
// GetCurrentUserID gets the current user id from the provided context
|
||||
func GetCurrentUserID(ctx context.Context) *string {
|
||||
// GetCurrentUser gets the current user id from the provided context
|
||||
func GetCurrentUser(ctx context.Context) *user.User {
|
||||
userCtxVal := ctx.Value(contextUser)
|
||||
if userCtxVal != nil {
|
||||
currentUser := userCtxVal.(string)
|
||||
currentUser := userCtxVal.(user.User)
|
||||
return ¤tUser
|
||||
}
|
||||
|
||||
|
|
|
|||
44
pkg/user/authenticate.go
Normal file
44
pkg/user/authenticate.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPasswordMismatch = fmt.Errorf("password mismatch")
|
||||
ErrUserNotFound = fmt.Errorf("user not found")
|
||||
)
|
||||
|
||||
func (s *Service) ValidateCredentials(ctx context.Context, username string, password string) error {
|
||||
hc, err := s.Store.HasCredentials(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking if credentials exist: %w", err)
|
||||
}
|
||||
|
||||
if !hc {
|
||||
// don't need to authenticate if no credentials saved
|
||||
return nil
|
||||
}
|
||||
|
||||
authPWHash, err := s.Store.GetPasswordHash(ctx, username)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("error getting password hash: %w", err)
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(authPWHash), []byte(password)); err != nil {
|
||||
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
||||
return ErrPasswordMismatch
|
||||
}
|
||||
|
||||
return fmt.Errorf("error comparing password hash: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
19
pkg/user/service.go
Normal file
19
pkg/user/service.go
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
GetUser(ctx context.Context, username string) (*User, error)
|
||||
HasCredentials(ctx context.Context) (bool, error)
|
||||
GetPasswordHash(ctx context.Context, username string) (string, error)
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
Store Store
|
||||
}
|
||||
|
||||
func (s *Service) GetUser(ctx context.Context, username string) (*User, error) {
|
||||
return s.Store.GetUser(ctx, username)
|
||||
}
|
||||
8
pkg/user/user.go
Normal file
8
pkg/user/user.go
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
package user
|
||||
|
||||
type RoleEnum string
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Roles []RoleEnum
|
||||
}
|
||||
Loading…
Reference in a new issue