mirror of
https://github.com/stashapp/stash.git
synced 2026-04-11 17:40:57 +02:00
Merge c61058c302 into fd480c5a3e
This commit is contained in:
commit
07ab9a7339
8 changed files with 571 additions and 15 deletions
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/stashapp/stash/internal/manager/config"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
"github.com/stashapp/stash/pkg/signedurl"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -29,6 +30,46 @@ func allowUnauthenticated(r *http.Request) bool {
|
|||
return strings.HasPrefix(r.URL.Path, loginEndpoint) || r.URL.Path == logoutEndpoint || r.URL.Path == "/css" || strings.HasPrefix(r.URL.Path, "/assets")
|
||||
}
|
||||
|
||||
// authenticateSignedRequest checks if the request is a valid signed media request.
|
||||
// Returns the matched username and true if valid, or empty string and false otherwise.
|
||||
func authenticateSignedRequest(r *http.Request) (string, bool) {
|
||||
// Only apply to scene stream paths (used by AirPlay/Chromecast devices that can't pass cookies)
|
||||
if !strings.HasPrefix(r.URL.Path, "/scene/") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
c := config.GetInstance()
|
||||
|
||||
// Signed URLs are only relevant when credentials are configured
|
||||
if !c.HasCredentials() {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Check for signed URL parameters
|
||||
q := r.URL.Query()
|
||||
if q.Get(signedurl.CIDParam) == "" || q.Get(signedurl.ExpiresParam) == "" || q.Get(signedurl.SigParam) == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Extract the credential ID and look up the user's signing key.
|
||||
// We need the key before we can verify the signature, since in a
|
||||
// multi-user setup each user could have their own signing key.
|
||||
cid := q.Get(signedurl.CIDParam)
|
||||
username, secret, found := resolveCredentialID(c, cid)
|
||||
if !found {
|
||||
logger.Warnf("signed URL credential ID mismatch")
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Verify the signature using the user's signing key
|
||||
if _, err := signedurl.VerifyURL(r.URL.Path, q, secret); err != nil {
|
||||
logger.Warnf("signed URL verification failed: %v", err)
|
||||
return "", false
|
||||
}
|
||||
|
||||
return username, true
|
||||
}
|
||||
|
||||
func authenticateHandler() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -42,6 +83,15 @@ func authenticateHandler() func(http.Handler) http.Handler {
|
|||
|
||||
r = session.SetLocalRequest(r)
|
||||
|
||||
// Check for signed media requests
|
||||
if username, ok := authenticateSignedRequest(r); ok {
|
||||
ctx := r.Context()
|
||||
ctx = session.SetCurrentUserID(ctx, username)
|
||||
r = r.WithContext(ctx)
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := manager.GetInstance().SessionStore.Authenticate(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, session.ErrUnauthorized) {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ import (
|
|||
"github.com/stashapp/stash/internal/api/urlbuilders"
|
||||
"github.com/stashapp/stash/internal/manager"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
"github.com/stashapp/stash/pkg/signedurl"
|
||||
)
|
||||
|
||||
func convertVideoFile(f models.File) (*models.VideoFile, error) {
|
||||
|
|
@ -107,15 +109,38 @@ func (r *sceneResolver) Paths(ctx context.Context, obj *models.Scene) (*ScenePat
|
|||
baseURL, _ := ctx.Value(BaseURLCtxKey).(string)
|
||||
config := manager.GetInstance().Config
|
||||
builder := urlbuilders.NewSceneURLBuilder(baseURL, obj)
|
||||
|
||||
var streamPath string
|
||||
var captionBasePath string
|
||||
if config.HasCredentials() {
|
||||
userID := session.GetCurrentUserID(ctx)
|
||||
if userID == nil {
|
||||
return nil, fmt.Errorf("user ID not found")
|
||||
}
|
||||
|
||||
// Sign the stream prefix
|
||||
streamURL := builder.GetStreamURL("")
|
||||
streamURL.RawQuery = signedParams(config, *userID, signedurl.DerivePrefix(streamURL.Path)).Encode()
|
||||
streamPath = streamURL.String()
|
||||
|
||||
// Sign the caption prefix
|
||||
captionBase := builder.GetCaptionURL()
|
||||
captionBasePath = captionBase + "?" + signedParams(config, *userID, "/scene/"+builder.SceneID+"/caption").Encode()
|
||||
} else {
|
||||
apiKey := config.GetAPIKey()
|
||||
streamURL := builder.GetStreamURL(apiKey)
|
||||
streamPath = streamURL.String()
|
||||
captionBasePath = builder.GetCaptionURL()
|
||||
}
|
||||
|
||||
// Web-only formats: use unsigned URLs (rely on cookie authentication)
|
||||
screenshotPath := builder.GetScreenshotURL()
|
||||
previewPath := builder.GetStreamPreviewURL()
|
||||
streamPath := builder.GetStreamURL(config.GetAPIKey()).String()
|
||||
webpPath := builder.GetStreamPreviewImageURL()
|
||||
objHash := obj.GetHash(config.GetVideoFileNamingAlgorithm())
|
||||
vttPath := builder.GetSpriteVTTURL(objHash)
|
||||
spritePath := builder.GetSpriteURL(objHash)
|
||||
funscriptPath := builder.GetFunscriptURL()
|
||||
captionBasePath := builder.GetCaptionURL()
|
||||
interactiveHeatmap := builder.GetInteractiveHeatmapURL()
|
||||
|
||||
return &ScenePathsType{
|
||||
|
|
@ -294,9 +319,25 @@ func (r *sceneResolver) SceneStreams(ctx context.Context, obj *models.Scene) ([]
|
|||
|
||||
baseURL, _ := ctx.Value(BaseURLCtxKey).(string)
|
||||
builder := urlbuilders.NewSceneURLBuilder(baseURL, obj)
|
||||
apiKey := config.GetAPIKey()
|
||||
|
||||
return manager.GetSceneStreamPaths(obj, builder.GetStreamURL(apiKey), config.GetMaxStreamingTranscodeSize())
|
||||
// Build the base stream URL with signing params or apikey
|
||||
streamURL := builder.GetStreamURL("")
|
||||
if config.HasCredentials() {
|
||||
userID := session.GetCurrentUserID(ctx)
|
||||
if userID == nil {
|
||||
return nil, fmt.Errorf("user ID not found")
|
||||
}
|
||||
streamURL.RawQuery = signedParams(config, *userID, signedurl.DerivePrefix(streamURL.Path)).Encode()
|
||||
} else {
|
||||
apiKey := config.GetAPIKey()
|
||||
if apiKey != "" {
|
||||
v := streamURL.Query()
|
||||
v.Set("apikey", apiKey)
|
||||
streamURL.RawQuery = v.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
return manager.GetSceneStreamPaths(obj, streamURL, config.GetMaxStreamingTranscodeSize())
|
||||
}
|
||||
|
||||
func (r *sceneResolver) Interactive(ctx context.Context, obj *models.Scene) (bool, error) {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ import (
|
|||
"github.com/stashapp/stash/internal/api/urlbuilders"
|
||||
"github.com/stashapp/stash/internal/manager"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
"github.com/stashapp/stash/pkg/signedurl"
|
||||
)
|
||||
|
||||
func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*manager.SceneStreamEndpoint, error) {
|
||||
|
|
@ -39,7 +41,22 @@ func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*manage
|
|||
|
||||
baseURL, _ := ctx.Value(BaseURLCtxKey).(string)
|
||||
builder := urlbuilders.NewSceneURLBuilder(baseURL, scene)
|
||||
apiKey := config.GetAPIKey()
|
||||
|
||||
return manager.GetSceneStreamPaths(scene, builder.GetStreamURL(apiKey), config.GetMaxStreamingTranscodeSize())
|
||||
streamURL := builder.GetStreamURL("")
|
||||
if config.HasCredentials() {
|
||||
userID := session.GetCurrentUserID(ctx)
|
||||
if userID == nil {
|
||||
return nil, fmt.Errorf("user ID not found")
|
||||
}
|
||||
streamURL.RawQuery = signedParams(config, *userID, signedurl.DerivePrefix(streamURL.Path)).Encode()
|
||||
} else {
|
||||
apiKey := config.GetAPIKey()
|
||||
if apiKey != "" {
|
||||
v := streamURL.Query()
|
||||
v.Set("apikey", apiKey)
|
||||
streamURL.RawQuery = v.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
return manager.GetSceneStreamPaths(scene, streamURL, config.GetMaxStreamingTranscodeSize())
|
||||
}
|
||||
|
|
|
|||
32
internal/api/signed_url.go
Normal file
32
internal/api/signed_url.go
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/stashapp/stash/internal/manager/config"
|
||||
"github.com/stashapp/stash/pkg/signedurl"
|
||||
)
|
||||
|
||||
// userSigningKey returns the HMAC signing key for a given user.
|
||||
func userSigningKey(c *config.Config, _ string) []byte {
|
||||
return c.GetJWTSignKey()
|
||||
}
|
||||
|
||||
// signedParams generates signed URL query parameters for the given path prefix and user.
|
||||
func signedParams(c *config.Config, userID string, prefix string) url.Values {
|
||||
secret := userSigningKey(c, userID)
|
||||
cid := signedurl.GenerateCredentialID(secret, userID)
|
||||
expires := time.Now().Add(time.Duration(c.GetSignedURLExpiry()) * time.Second)
|
||||
return signedurl.SignPrefix(prefix, secret, cid, expires)
|
||||
}
|
||||
|
||||
// resolveCredentialID maps a credential ID back to a username and their signing key.
|
||||
func resolveCredentialID(c *config.Config, cid string) (string, []byte, bool) {
|
||||
username := c.GetUsername()
|
||||
secret := userSigningKey(c, username)
|
||||
if signedurl.GenerateCredentialID(secret, username) == cid {
|
||||
return username, secret, true
|
||||
}
|
||||
return "", nil, false
|
||||
}
|
||||
|
|
@ -43,6 +43,9 @@ const (
|
|||
Password = "password"
|
||||
MaxSessionAge = "max_session_age"
|
||||
|
||||
SignedURLExpiry = "signed_url_expiry"
|
||||
signedURLExpiryDefault = 60 * 60 * 24 // 24 hours in seconds
|
||||
|
||||
// SFWContentMode mode config key
|
||||
SFWContentMode = "sfw_content_mode"
|
||||
|
||||
|
|
@ -1229,6 +1232,21 @@ func (i *Config) GetMaxSessionAge() int {
|
|||
return ret
|
||||
}
|
||||
|
||||
// GetSignedURLExpiry gets the expiry time for signed URLs, in seconds.
|
||||
// Defaults to 24 hours to accommodate long video playback sessions.
|
||||
func (i *Config) GetSignedURLExpiry() int {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
|
||||
ret := signedURLExpiryDefault
|
||||
v := i.forKey(SignedURLExpiry)
|
||||
if v.Exists(SignedURLExpiry) {
|
||||
ret = v.Int(SignedURLExpiry)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetCustomServedFolders gets the map of custom paths to their applicable
|
||||
// filesystem locations
|
||||
func (i *Config) GetCustomServedFolders() utils.URLMap {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/stashapp/stash/pkg/fsutil"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/signedurl"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
|
||||
"github.com/zencoder/go-dash/v3/mpd"
|
||||
|
|
@ -433,26 +434,36 @@ func serveHLSManifest(sm *StreamManager, w http.ResponseWriter, r *http.Request,
|
|||
baseURL := prefix + baseUrl.String()
|
||||
|
||||
urlQuery := url.Values{}
|
||||
|
||||
// Forward auth params to segment URLs. API key takes precedence
|
||||
// over signed params since it is explicitly configured by the user.
|
||||
// TODO - this needs to be handled outside of this package
|
||||
apikey := r.URL.Query().Get(apiKeyParamKey)
|
||||
if apikey != "" {
|
||||
urlQuery.Set(apiKeyParamKey, apikey)
|
||||
} else {
|
||||
cid := r.URL.Query().Get(signedurl.CIDParam)
|
||||
expires := r.URL.Query().Get(signedurl.ExpiresParam)
|
||||
sig := r.URL.Query().Get(signedurl.SigParam)
|
||||
if cid != "" && expires != "" && sig != "" {
|
||||
urlQuery.Set(signedurl.CIDParam, cid)
|
||||
urlQuery.Set(signedurl.ExpiresParam, expires)
|
||||
urlQuery.Set(signedurl.SigParam, sig)
|
||||
}
|
||||
}
|
||||
|
||||
if resolution != "" {
|
||||
urlQuery.Set(resolutionParamKey, resolution)
|
||||
}
|
||||
|
||||
// TODO - this needs to be handled outside of this package
|
||||
if apikey != "" {
|
||||
urlQuery.Set(apiKeyParamKey, apikey)
|
||||
}
|
||||
|
||||
urlQueryString := ""
|
||||
segQuery := ""
|
||||
if len(urlQuery) > 0 {
|
||||
urlQueryString = "?" + urlQuery.Encode()
|
||||
segQuery = "?" + urlQuery.Encode()
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
fmt.Fprint(&buf, "#EXTM3U\n")
|
||||
|
||||
fmt.Fprint(&buf, "#EXT-X-VERSION:3\n")
|
||||
fmt.Fprint(&buf, "#EXT-X-MEDIA-SEQUENCE:0\n")
|
||||
fmt.Fprintf(&buf, "#EXT-X-TARGETDURATION:%d\n", segmentLength)
|
||||
|
|
@ -468,7 +479,7 @@ func serveHLSManifest(sm *StreamManager, w http.ResponseWriter, r *http.Request,
|
|||
}
|
||||
|
||||
fmt.Fprintf(&buf, "#EXTINF:%f,\n", thisLength)
|
||||
fmt.Fprintf(&buf, "%s/%d.ts%s\n", baseURL, segment, urlQueryString)
|
||||
fmt.Fprintf(&buf, "%s/%d.ts%s\n", baseURL, segment, segQuery)
|
||||
|
||||
leftover -= thisLength
|
||||
segment++
|
||||
|
|
@ -529,10 +540,21 @@ func serveDASHManifest(sm *StreamManager, w http.ResponseWriter, r *http.Request
|
|||
|
||||
urlQuery := url.Values{}
|
||||
|
||||
// Forward auth params to segment URLs. API key takes precedence
|
||||
// over signed params since it is explicitly configured by the user.
|
||||
// TODO - this needs to be handled outside of this package
|
||||
apikey := r.URL.Query().Get(apiKeyParamKey)
|
||||
if apikey != "" {
|
||||
urlQuery.Set(apiKeyParamKey, apikey)
|
||||
} else {
|
||||
cid := r.URL.Query().Get(signedurl.CIDParam)
|
||||
expires := r.URL.Query().Get(signedurl.ExpiresParam)
|
||||
sig := r.URL.Query().Get(signedurl.SigParam)
|
||||
if cid != "" && expires != "" && sig != "" {
|
||||
urlQuery.Set(signedurl.CIDParam, cid)
|
||||
urlQuery.Set(signedurl.ExpiresParam, expires)
|
||||
urlQuery.Set(signedurl.SigParam, sig)
|
||||
}
|
||||
}
|
||||
|
||||
maxTranscodeSize := sm.config.GetMaxStreamingTranscodeSize().GetMaxResolution()
|
||||
|
|
|
|||
105
pkg/signedurl/signedurl.go
Normal file
105
pkg/signedurl/signedurl.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
// Package signedurl provides HMAC-signed URLs for media requests from devices that cannot pass cookies (AirPlay, Chromecast).
|
||||
package signedurl
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CIDParam = "cid"
|
||||
ExpiresParam = "expires"
|
||||
SigParam = "signature"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingParams = errors.New("missing required signed URL parameters")
|
||||
ErrExpiredURL = errors.New("signed URL has expired")
|
||||
ErrInvalidSignature = errors.New("invalid signature")
|
||||
ErrInvalidURL = errors.New("invalid URL")
|
||||
)
|
||||
|
||||
// GenerateCredentialID produces an opaque, deterministic identifier for a user.
|
||||
// It is an HMAC-SHA256 of the username, truncated to 16 hex characters.
|
||||
func GenerateCredentialID(secret []byte, username string) string {
|
||||
h := hmac.New(sha256.New, secret)
|
||||
h.Write([]byte(username))
|
||||
return hex.EncodeToString(h.Sum(nil))[:16]
|
||||
}
|
||||
|
||||
// DerivePrefix extracts the signing prefix from a request path by taking the
|
||||
// first 3 segments and stripping the file extension from the 3rd.
|
||||
func DerivePrefix(path string) string {
|
||||
parts := strings.Split(strings.Trim(path, "/"), "/")
|
||||
if len(parts) < 3 {
|
||||
return "/" + strings.Join(parts, "/")
|
||||
}
|
||||
action := parts[2]
|
||||
if dotIdx := strings.IndexByte(action, '.'); dotIdx >= 0 {
|
||||
action = action[:dotIdx]
|
||||
}
|
||||
return "/" + parts[0] + "/" + parts[1] + "/" + action
|
||||
}
|
||||
|
||||
// makeSignString constructs the canonical string to sign.
|
||||
func makeSignString(prefix string, cid string, expires time.Time) string {
|
||||
return prefix + "?" + CIDParam + "=" + cid + "&" + ExpiresParam + "=" + strconv.FormatInt(expires.Unix(), 10)
|
||||
}
|
||||
|
||||
// SignPrefix signs a path prefix and returns url.Values containing
|
||||
// the cid, expires, and signature parameters. The caller appends
|
||||
// these to any URL whose path falls under the signed prefix.
|
||||
func SignPrefix(prefix string, secret []byte, cid string, expires time.Time) url.Values {
|
||||
signString := makeSignString(prefix, cid, expires)
|
||||
|
||||
h := hmac.New(sha256.New, secret)
|
||||
h.Write([]byte(signString))
|
||||
signature := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
params := make(url.Values)
|
||||
params.Set(CIDParam, cid)
|
||||
params.Set(ExpiresParam, strconv.FormatInt(expires.Unix(), 10))
|
||||
params.Set(SigParam, signature)
|
||||
return params
|
||||
}
|
||||
|
||||
// VerifyURL verifies a signed URL request. It derives the signing prefix
|
||||
// from the request path, checks expiry, and validates the HMAC signature.
|
||||
// Returns the credential ID on success.
|
||||
func VerifyURL(requestPath string, queryParams url.Values, secret []byte) (string, error) {
|
||||
cid := queryParams.Get(CIDParam)
|
||||
expiresStr := queryParams.Get(ExpiresParam)
|
||||
sig := queryParams.Get(SigParam)
|
||||
|
||||
if cid == "" || expiresStr == "" || sig == "" {
|
||||
return "", ErrMissingParams
|
||||
}
|
||||
|
||||
expires, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err != nil {
|
||||
return "", ErrInvalidURL
|
||||
}
|
||||
|
||||
if time.Now().Unix() > expires {
|
||||
return "", ErrExpiredURL
|
||||
}
|
||||
|
||||
prefix := DerivePrefix(requestPath)
|
||||
signString := makeSignString(prefix, cid, time.Unix(expires, 0))
|
||||
|
||||
h := hmac.New(sha256.New, secret)
|
||||
h.Write([]byte(signString))
|
||||
expectedSig := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
if !hmac.Equal([]byte(sig), []byte(expectedSig)) {
|
||||
return "", ErrInvalidSignature
|
||||
}
|
||||
|
||||
return cid, nil
|
||||
}
|
||||
271
pkg/signedurl/signedurl_test.go
Normal file
271
pkg/signedurl/signedurl_test.go
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
package signedurl
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDerivePrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
// Scene stream variants
|
||||
{"/scene/1/stream", "/scene/1/stream"},
|
||||
{"/scene/1/stream.mp4", "/scene/1/stream"},
|
||||
{"/scene/1/stream.webm", "/scene/1/stream"},
|
||||
{"/scene/1/stream.mkv", "/scene/1/stream"},
|
||||
{"/scene/1/stream.m3u8", "/scene/1/stream"},
|
||||
{"/scene/1/stream.mpd", "/scene/1/stream"},
|
||||
|
||||
// HLS segments
|
||||
{"/scene/1/stream.m3u8/0.ts", "/scene/1/stream"},
|
||||
{"/scene/1/stream.m3u8/99.ts", "/scene/1/stream"},
|
||||
|
||||
// DASH segments
|
||||
{"/scene/1/stream.mpd/5_v.webm", "/scene/1/stream"},
|
||||
{"/scene/1/stream.mpd/5_a.webm", "/scene/1/stream"},
|
||||
{"/scene/1/stream.mpd/init_v.webm", "/scene/1/stream"},
|
||||
{"/scene/1/stream.mpd/init_a.webm", "/scene/1/stream"},
|
||||
|
||||
// Caption
|
||||
{"/scene/1/caption", "/scene/1/caption"},
|
||||
|
||||
// Image paths
|
||||
{"/image/5/thumbnail", "/image/5/thumbnail"},
|
||||
{"/image/5/image", "/image/5/image"},
|
||||
|
||||
// Gallery paths
|
||||
{"/gallery/3/cover", "/gallery/3/cover"},
|
||||
{"/gallery/3/preview", "/gallery/3/preview"},
|
||||
|
||||
// Short paths
|
||||
{"/scene/1", "/scene/1"},
|
||||
{"/scene", "/scene"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
got := DerivePrefix(tt.path)
|
||||
if got != tt.expected {
|
||||
t.Errorf("DerivePrefix(%q) = %q, want %q", tt.path, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCredentialID(t *testing.T) {
|
||||
secret := []byte("test-secret-key")
|
||||
|
||||
t.Run("deterministic", func(t *testing.T) {
|
||||
cid1 := GenerateCredentialID(secret, "alice")
|
||||
cid2 := GenerateCredentialID(secret, "alice")
|
||||
if cid1 != cid2 {
|
||||
t.Errorf("expected deterministic output, got %q and %q", cid1, cid2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different usernames produce different cids", func(t *testing.T) {
|
||||
cid1 := GenerateCredentialID(secret, "alice")
|
||||
cid2 := GenerateCredentialID(secret, "bob")
|
||||
if cid1 == cid2 {
|
||||
t.Errorf("expected different cids for different usernames, both got %q", cid1)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different secrets produce different cids", func(t *testing.T) {
|
||||
cid1 := GenerateCredentialID([]byte("secret-1"), "alice")
|
||||
cid2 := GenerateCredentialID([]byte("secret-2"), "alice")
|
||||
if cid1 == cid2 {
|
||||
t.Errorf("expected different cids for different secrets, both got %q", cid1)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("length is 16 hex chars", func(t *testing.T) {
|
||||
cid := GenerateCredentialID(secret, "alice")
|
||||
if len(cid) != 16 {
|
||||
t.Errorf("expected length 16, got %d (%q)", len(cid), cid)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSignAndVerifyRoundtrip(t *testing.T) {
|
||||
secret := []byte("test-secret-key")
|
||||
cid := GenerateCredentialID(secret, "alice")
|
||||
expires := time.Now().Add(1 * time.Hour)
|
||||
|
||||
params := SignPrefix("/scene/1/stream", secret, cid, expires)
|
||||
|
||||
gotCID, err := VerifyURL("/scene/1/stream", params, secret)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if gotCID != cid {
|
||||
t.Errorf("expected cid %q, got %q", cid, gotCID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifferentPathSamePrefixVerifies(t *testing.T) {
|
||||
secret := []byte("test-secret-key")
|
||||
cid := GenerateCredentialID(secret, "alice")
|
||||
expires := time.Now().Add(1 * time.Hour)
|
||||
|
||||
// Sign for the stream prefix
|
||||
params := SignPrefix("/scene/1/stream", secret, cid, expires)
|
||||
|
||||
// Verify with different paths that share the same prefix
|
||||
paths := []string{
|
||||
"/scene/1/stream",
|
||||
"/scene/1/stream.mp4",
|
||||
"/scene/1/stream.m3u8",
|
||||
"/scene/1/stream.m3u8/0.ts",
|
||||
"/scene/1/stream.mpd",
|
||||
"/scene/1/stream.mpd/5_v.webm",
|
||||
"/scene/1/stream.mpd/init_a.webm",
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
gotCID, err := VerifyURL(path, params, secret)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for path %q: %v", path, err)
|
||||
}
|
||||
if gotCID != cid {
|
||||
t.Errorf("expected cid %q, got %q", cid, gotCID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifferentPrefixFails(t *testing.T) {
|
||||
secret := []byte("test-secret-key")
|
||||
cid := GenerateCredentialID(secret, "alice")
|
||||
expires := time.Now().Add(1 * time.Hour)
|
||||
|
||||
params := SignPrefix("/scene/1/stream", secret, cid, expires)
|
||||
|
||||
// Different scene ID
|
||||
_, err := VerifyURL("/scene/2/stream", params, secret)
|
||||
if !errors.Is(err, ErrInvalidSignature) {
|
||||
t.Errorf("expected ErrInvalidSignature, got %v", err)
|
||||
}
|
||||
|
||||
// Different resource type
|
||||
_, err = VerifyURL("/scene/1/caption", params, secret)
|
||||
if !errors.Is(err, ErrInvalidSignature) {
|
||||
t.Errorf("expected ErrInvalidSignature, got %v", err)
|
||||
}
|
||||
|
||||
// Different entity type
|
||||
_, err = VerifyURL("/image/1/stream", params, secret)
|
||||
if !errors.Is(err, ErrInvalidSignature) {
|
||||
t.Errorf("expected ErrInvalidSignature, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpiredURLFails(t *testing.T) {
|
||||
secret := []byte("test-secret-key")
|
||||
cid := GenerateCredentialID(secret, "alice")
|
||||
expires := time.Now().Add(-1 * time.Hour) // expired 1 hour ago
|
||||
|
||||
params := SignPrefix("/scene/1/stream", secret, cid, expires)
|
||||
|
||||
_, err := VerifyURL("/scene/1/stream", params, secret)
|
||||
if !errors.Is(err, ErrExpiredURL) {
|
||||
t.Errorf("expected ErrExpiredURL, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTamperedSignatureFails(t *testing.T) {
|
||||
secret := []byte("test-secret-key")
|
||||
cid := GenerateCredentialID(secret, "alice")
|
||||
expires := time.Now().Add(1 * time.Hour)
|
||||
|
||||
params := SignPrefix("/scene/1/stream", secret, cid, expires)
|
||||
params.Set(SigParam, "tampered0000000000000000000000000000000000000000000000000000000")
|
||||
|
||||
_, err := VerifyURL("/scene/1/stream", params, secret)
|
||||
if !errors.Is(err, ErrInvalidSignature) {
|
||||
t.Errorf("expected ErrInvalidSignature, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTamperedCIDFails(t *testing.T) {
|
||||
secret := []byte("test-secret-key")
|
||||
cid := GenerateCredentialID(secret, "alice")
|
||||
expires := time.Now().Add(1 * time.Hour)
|
||||
|
||||
params := SignPrefix("/scene/1/stream", secret, cid, expires)
|
||||
params.Set(CIDParam, "tamperedcid12345")
|
||||
|
||||
_, err := VerifyURL("/scene/1/stream", params, secret)
|
||||
if !errors.Is(err, ErrInvalidSignature) {
|
||||
t.Errorf("expected ErrInvalidSignature, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingParamsFails(t *testing.T) {
|
||||
secret := []byte("test-secret-key")
|
||||
cid := GenerateCredentialID(secret, "alice")
|
||||
expires := time.Now().Add(1 * time.Hour)
|
||||
|
||||
full := SignPrefix("/scene/1/stream", secret, cid, expires)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
missing string
|
||||
}{
|
||||
{"missing cid", CIDParam},
|
||||
{"missing expires", ExpiresParam},
|
||||
{"missing signature", SigParam},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
params := make(url.Values)
|
||||
for k, v := range full {
|
||||
params[k] = v
|
||||
}
|
||||
params.Del(tt.missing)
|
||||
|
||||
_, err := VerifyURL("/scene/1/stream", params, secret)
|
||||
if !errors.Is(err, ErrMissingParams) {
|
||||
t.Errorf("expected ErrMissingParams, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTamperedExpiresFails(t *testing.T) {
|
||||
secret := []byte("test-secret-key")
|
||||
cid := GenerateCredentialID(secret, "alice")
|
||||
expires := time.Now().Add(1 * time.Hour)
|
||||
|
||||
params := SignPrefix("/scene/1/stream", secret, cid, expires)
|
||||
|
||||
// Attacker extends the expiry by 24 hours
|
||||
tampered := time.Now().Add(25 * time.Hour)
|
||||
params.Set(ExpiresParam, strconv.FormatInt(tampered.Unix(), 10))
|
||||
|
||||
_, err := VerifyURL("/scene/1/stream", params, secret)
|
||||
if !errors.Is(err, ErrInvalidSignature) {
|
||||
t.Errorf("expected ErrInvalidSignature, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrongSecretFails(t *testing.T) {
|
||||
secret := []byte("test-secret-key")
|
||||
wrongSecret := []byte("wrong-secret-key")
|
||||
cid := GenerateCredentialID(secret, "alice")
|
||||
expires := time.Now().Add(1 * time.Hour)
|
||||
|
||||
params := SignPrefix("/scene/1/stream", secret, cid, expires)
|
||||
|
||||
_, err := VerifyURL("/scene/1/stream", params, wrongSecret)
|
||||
if !errors.Is(err, ErrInvalidSignature) {
|
||||
t.Errorf("expected ErrInvalidSignature, got %v", err)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue