Add UserService

This commit is contained in:
WithoutPants 2023-05-29 10:09:08 +10:00
parent 51d7bf272e
commit d906a66fec
16 changed files with 150 additions and 24 deletions

View file

@ -1,6 +1,7 @@
package api
import (
"context"
"errors"
"net"
"net/http"
@ -12,6 +13,7 @@ import (
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/session"
"github.com/stashapp/stash/pkg/user"
)
const (
@ -29,7 +31,11 @@ func allowUnauthenticated(r *http.Request) bool {
return strings.HasPrefix(r.URL.Path, loginEndpoint) || r.URL.Path == logoutEndpoint || r.URL.Path == "/css" || strings.HasPrefix(r.URL.Path, "/assets")
}
func authenticateHandler() func(http.Handler) http.Handler {
type UserGetter interface {
GetUser(ctx context.Context, username string) (*user.User, error)
}
func authenticateHandler(g UserGetter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c := config.GetInstance()
@ -73,7 +79,7 @@ func authenticateHandler() func(http.Handler) http.Handler {
ctx := r.Context()
if c.HasCredentials() {
if hc, _ := c.HasCredentials(ctx); hc {
// authentication is required
if userID == "" && !allowUnauthenticated(r) {
// if graphql or a non-webpage was requested, we just return a forbidden error
@ -102,7 +108,18 @@ func authenticateHandler() func(http.Handler) http.Handler {
}
}
ctx = session.SetCurrentUserID(ctx, userID)
if userID != "" {
// set the user object in the context
u, err := g.GetUser(ctx, userID)
if err != nil {
// if we can't get the user object, we just return a forbidden error
logger.Errorf("Error getting user object: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
ctx = session.SetCurrentUser(ctx, *u)
}
r = r.WithContext(ctx)

View file

@ -320,7 +320,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen
if input.Password != nil {
// bit of a hack - check if the passed in password is the same as the stored hash
// and only set if they are different
currentPWHash := c.GetPasswordHash()
currentPWHash, _ := c.GetPasswordHash(ctx, c.GetUsername())
if *input.Password != currentPWHash {
if *input.Password == "" {

View file

@ -78,6 +78,9 @@ func makeConfigGeneralResult() *ConfigGeneralResult {
customPerformerImageLocation := config.GetCustomPerformerImageLocation()
username := config.GetUsername()
pwHash, _ := config.GetPasswordHash(context.Background(), username)
return &ConfigGeneralResult{
Stashes: config.GetStashPaths(),
DatabasePath: config.GetDatabasePath(),
@ -110,7 +113,7 @@ func makeConfigGeneralResult() *ConfigGeneralResult {
GalleryCoverRegex: config.GetGalleryCoverRegex(),
APIKey: config.GetAPIKey(),
Username: config.GetUsername(),
Password: config.GetPasswordHash(),
Password: pwHash,
MaxSessionAge: config.GetMaxSessionAge(),
LogFile: &logFile,
LogOut: config.GetLogOut(),

View file

@ -122,9 +122,11 @@ func Initialize() (*Server, error) {
manager: mgr,
}
userStore := manager.GetInstance().UserService
r.Use(middleware.Heartbeat("/healthz"))
r.Use(cors.AllowAll().Handler)
r.Use(authenticateHandler())
r.Use(authenticateHandler(userStore))
visitedPluginHandler := mgr.SessionStore.VisitedPluginHandler()
r.Use(visitedPluginHandler)

View file

@ -107,7 +107,7 @@ func handleLogin() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
returnURL := r.URL.Query().Get(returnURLParam)
if !config.GetInstance().HasCredentials() {
if hc, _ := config.GetInstance().HasCredentials(r.Context()); !hc {
if returnURL != "" {
http.Redirect(w, r, returnURL, http.StatusFound)
} else {
@ -155,7 +155,7 @@ func handleLogout() http.HandlerFunc {
// redirect to the login page if credentials are required
prefix := getProxyPrefix(r)
if config.GetInstance().HasCredentials() {
if hc, _ := config.GetInstance().HasCredentials(r.Context()); hc {
http.Redirect(w, r, prefix+loginEndpoint, http.StatusFound)
} else {
http.Redirect(w, r, prefix+"/", http.StatusFound)

View file

@ -1,6 +1,7 @@
package config
import (
"context"
"fmt"
"net/url"
"os"
@ -27,6 +28,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/paths"
"github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/user"
"github.com/stashapp/stash/pkg/utils"
)
@ -1080,23 +1082,39 @@ func (i *Config) GetUsername() string {
return i.getString(Username)
}
func (i *Config) GetPasswordHash() string {
return i.getString(Password)
func (i *Config) GetUser(ctx context.Context, username string) (*user.User, error) {
u := i.GetUsername()
if u != username {
return nil, user.ErrUserNotFound
}
return &user.User{
Username: u,
}, nil
}
func (i *Config) GetPasswordHash(ctx context.Context, username string) (string, error) {
u := i.GetUsername()
if u != username {
return "", user.ErrUserNotFound
}
return i.getString(Password), nil
}
func (i *Config) GetCredentials() (string, string) {
if i.HasCredentials() {
if hc, _ := i.HasCredentials(context.Background()); hc {
return i.getString(Username), i.getString(Password)
}
return "", ""
}
func (i *Config) HasCredentials() bool {
func (i *Config) HasCredentials(ctx context.Context) (bool, error) {
username := i.getString(Username)
pwHash := i.getString(Password)
return username != "" && pwHash != ""
return username != "" && pwHash != "", nil
}
func hashPassword(password string) string {
@ -1106,7 +1124,7 @@ func hashPassword(password string) string {
}
func (i *Config) ValidateCredentials(username string, password string) bool {
if !i.HasCredentials() {
if hc, _ := i.HasCredentials(context.Background()); !hc {
// don't need to authenticate if no credentials saved
return true
}

View file

@ -27,6 +27,7 @@ import (
"github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/session"
"github.com/stashapp/stash/pkg/sqlite"
"github.com/stashapp/stash/pkg/user"
"github.com/stashapp/stash/pkg/utils"
"github.com/stashapp/stash/ui"
)
@ -109,6 +110,10 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) {
scanSubs: &subscriptionManager{},
}
instance.UserService = &user.Service{
Store: cfg,
}
if !cfg.IsNewSystem() {
logger.Infof("using config file: %s", cfg.GetConfigFile())
@ -189,7 +194,7 @@ func initJobManager(cfg *config.Config) *job.Manager {
func (s *Manager) postInit(ctx context.Context) error {
s.RefreshConfig()
s.SessionStore = session.NewStore(s.Config)
s.SessionStore = session.NewStore(s.Config, s.Config)
s.PluginCache.RegisterSessionStore(s.SessionStore)
s.RefreshPluginCache()

View file

@ -67,6 +67,7 @@ type Manager struct {
ImageService ImageService
GalleryService GalleryService
GroupService GroupService
UserService UserService
scanSubs *subscriptionManager
}

View file

@ -7,6 +7,7 @@ import (
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/user"
)
type SceneService interface {
@ -46,3 +47,7 @@ type GroupService interface {
RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error
ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error
}
type UserService interface {
GetUser(ctx context.Context, username string) (*user.User, error)
}

View file

@ -1,6 +1,7 @@
package session
import (
"context"
"fmt"
"net"
"net/http"
@ -16,7 +17,7 @@ func (e ExternalAccessError) Error() string {
}
func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error {
if !c.HasCredentials() && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() {
if hc, _ := c.HasCredentials(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() && !c.IsNewSystem() {
requestIPString, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return fmt.Errorf("error parsing remote host (%s): %w", r.RemoteAddr, err)
@ -60,7 +61,7 @@ func CheckAllowPublicWithoutAuth(c ExternalAccessConfig, r *http.Request) error
}
func CheckExternalAccessTripwire(c ExternalAccessConfig) *ExternalAccessError {
if !c.HasCredentials() && !c.GetDangerousAllowPublicWithoutAuth() {
if hc, _ := c.HasCredentials(context.Background()); !hc && !c.GetDangerousAllowPublicWithoutAuth() {
if remoteIP := c.GetSecurityTripwireAccessedFromPublicInternet(); remoteIP != "" {
err := ExternalAccessError(net.ParseIP(remoteIP))
return &err

View file

@ -1,7 +1,9 @@
package session
import "context"
type ExternalAccessConfig interface {
HasCredentials() bool
HasCredentials(ctx context.Context) (bool, error)
GetDangerousAllowPublicWithoutAuth() bool
GetSecurityTripwireAccessedFromPublicInternet() string
IsNewSystem() bool

View file

@ -61,7 +61,7 @@ func setVisitedPluginHooks(ctx context.Context, visitedPlugins []VisitedPluginHo
}
func (s *Store) MakePluginCookie(ctx context.Context) *http.Cookie {
currentUser := GetCurrentUserID(ctx)
currentUser := GetCurrentUser(ctx)
visitedPlugins := GetVisitedPluginHooks(ctx)
session := sessions.NewSession(s.sessionStore, cookieName)

View file

@ -8,6 +8,7 @@ import (
"github.com/gorilla/sessions"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/user"
)
type key int
@ -137,15 +138,15 @@ func (s *Store) GetSessionUserID(w http.ResponseWriter, r *http.Request) (string
return "", nil
}
func SetCurrentUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, contextUser, userID)
func SetCurrentUser(ctx context.Context, u user.User) context.Context {
return context.WithValue(ctx, contextUser, u)
}
// GetCurrentUserID gets the current user id from the provided context
func GetCurrentUserID(ctx context.Context) *string {
// GetCurrentUser gets the current user id from the provided context
func GetCurrentUser(ctx context.Context) *user.User {
userCtxVal := ctx.Value(contextUser)
if userCtxVal != nil {
currentUser := userCtxVal.(string)
currentUser := userCtxVal.(user.User)
return &currentUser
}

44
pkg/user/authenticate.go Normal file
View file

@ -0,0 +1,44 @@
package user
import (
"context"
"errors"
"fmt"
"golang.org/x/crypto/bcrypt"
)
var (
ErrPasswordMismatch = fmt.Errorf("password mismatch")
ErrUserNotFound = fmt.Errorf("user not found")
)
func (s *Service) ValidateCredentials(ctx context.Context, username string, password string) error {
hc, err := s.Store.HasCredentials(ctx)
if err != nil {
return fmt.Errorf("error checking if credentials exist: %w", err)
}
if !hc {
// don't need to authenticate if no credentials saved
return nil
}
authPWHash, err := s.Store.GetPasswordHash(ctx, username)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return err
}
return fmt.Errorf("error getting password hash: %w", err)
}
if err := bcrypt.CompareHashAndPassword([]byte(authPWHash), []byte(password)); err != nil {
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
return ErrPasswordMismatch
}
return fmt.Errorf("error comparing password hash: %w", err)
}
return nil
}

19
pkg/user/service.go Normal file
View file

@ -0,0 +1,19 @@
package user
import (
"context"
)
type Store interface {
GetUser(ctx context.Context, username string) (*User, error)
HasCredentials(ctx context.Context) (bool, error)
GetPasswordHash(ctx context.Context, username string) (string, error)
}
type Service struct {
Store Store
}
func (s *Service) GetUser(ctx context.Context, username string) (*User, error) {
return s.Store.GetUser(ctx, username)
}

8
pkg/user/user.go Normal file
View file

@ -0,0 +1,8 @@
package user
type RoleEnum string
type User struct {
Username string
Roles []RoleEnum
}