diff --git a/internal/api/authentication.go b/internal/api/authentication.go index be399d222..aec069529 100644 --- a/internal/api/authentication.go +++ b/internal/api/authentication.go @@ -12,6 +12,7 @@ import ( "github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/session" + "github.com/stashapp/stash/pkg/signedurl" ) const ( @@ -29,6 +30,46 @@ func allowUnauthenticated(r *http.Request) bool { return strings.HasPrefix(r.URL.Path, loginEndpoint) || r.URL.Path == logoutEndpoint || r.URL.Path == "/css" || strings.HasPrefix(r.URL.Path, "/assets") } +// authenticateSignedRequest checks if the request is a valid signed media request. +// Returns the matched username and true if valid, or empty string and false otherwise. +func authenticateSignedRequest(r *http.Request) (string, bool) { + // Only apply to scene stream paths (used by AirPlay/Chromecast devices that can't pass cookies) + if !strings.HasPrefix(r.URL.Path, "/scene/") { + return "", false + } + + c := config.GetInstance() + + // Signed URLs are only relevant when credentials are configured + if !c.HasCredentials() { + return "", false + } + + // Check for signed URL parameters + q := r.URL.Query() + if q.Get(signedurl.CIDParam) == "" || q.Get(signedurl.ExpiresParam) == "" || q.Get(signedurl.SigParam) == "" { + return "", false + } + + // Extract the credential ID and look up the user's signing key. + // We need the key before we can verify the signature, since in a + // multi-user setup each user could have their own signing key. + cid := q.Get(signedurl.CIDParam) + username, secret, found := resolveCredentialID(c, cid) + if !found { + logger.Warnf("signed URL credential ID mismatch") + return "", false + } + + // Verify the signature using the user's signing key + if _, err := signedurl.VerifyURL(r.URL.Path, q, secret); err != nil { + logger.Warnf("signed URL verification failed: %v", err) + return "", false + } + + return username, true +} + func authenticateHandler() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -42,6 +83,15 @@ func authenticateHandler() func(http.Handler) http.Handler { r = session.SetLocalRequest(r) + // Check for signed media requests + if username, ok := authenticateSignedRequest(r); ok { + ctx := r.Context() + ctx = session.SetCurrentUserID(ctx, username) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + return + } + userID, err := manager.GetInstance().SessionStore.Authenticate(w, r) if err != nil { if !errors.Is(err, session.ErrUnauthorized) { diff --git a/internal/api/resolver_model_scene.go b/internal/api/resolver_model_scene.go index 81113d858..fd001108d 100644 --- a/internal/api/resolver_model_scene.go +++ b/internal/api/resolver_model_scene.go @@ -9,6 +9,8 @@ import ( "github.com/stashapp/stash/internal/api/urlbuilders" "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/session" + "github.com/stashapp/stash/pkg/signedurl" ) func convertVideoFile(f models.File) (*models.VideoFile, error) { @@ -107,15 +109,38 @@ func (r *sceneResolver) Paths(ctx context.Context, obj *models.Scene) (*ScenePat baseURL, _ := ctx.Value(BaseURLCtxKey).(string) config := manager.GetInstance().Config builder := urlbuilders.NewSceneURLBuilder(baseURL, obj) + + var streamPath string + var captionBasePath string + if config.HasCredentials() { + userID := session.GetCurrentUserID(ctx) + if userID == nil { + return nil, fmt.Errorf("user ID not found") + } + + // Sign the stream prefix + streamURL := builder.GetStreamURL("") + streamURL.RawQuery = signedParams(config, *userID, signedurl.DerivePrefix(streamURL.Path)).Encode() + streamPath = streamURL.String() + + // Sign the caption prefix + captionBase := builder.GetCaptionURL() + captionBasePath = captionBase + "?" + signedParams(config, *userID, "/scene/"+builder.SceneID+"/caption").Encode() + } else { + apiKey := config.GetAPIKey() + streamURL := builder.GetStreamURL(apiKey) + streamPath = streamURL.String() + captionBasePath = builder.GetCaptionURL() + } + + // Web-only formats: use unsigned URLs (rely on cookie authentication) screenshotPath := builder.GetScreenshotURL() previewPath := builder.GetStreamPreviewURL() - streamPath := builder.GetStreamURL(config.GetAPIKey()).String() webpPath := builder.GetStreamPreviewImageURL() objHash := obj.GetHash(config.GetVideoFileNamingAlgorithm()) vttPath := builder.GetSpriteVTTURL(objHash) spritePath := builder.GetSpriteURL(objHash) funscriptPath := builder.GetFunscriptURL() - captionBasePath := builder.GetCaptionURL() interactiveHeatmap := builder.GetInteractiveHeatmapURL() return &ScenePathsType{ @@ -294,9 +319,25 @@ func (r *sceneResolver) SceneStreams(ctx context.Context, obj *models.Scene) ([] baseURL, _ := ctx.Value(BaseURLCtxKey).(string) builder := urlbuilders.NewSceneURLBuilder(baseURL, obj) - apiKey := config.GetAPIKey() - return manager.GetSceneStreamPaths(obj, builder.GetStreamURL(apiKey), config.GetMaxStreamingTranscodeSize()) + // Build the base stream URL with signing params or apikey + streamURL := builder.GetStreamURL("") + if config.HasCredentials() { + userID := session.GetCurrentUserID(ctx) + if userID == nil { + return nil, fmt.Errorf("user ID not found") + } + streamURL.RawQuery = signedParams(config, *userID, signedurl.DerivePrefix(streamURL.Path)).Encode() + } else { + apiKey := config.GetAPIKey() + if apiKey != "" { + v := streamURL.Query() + v.Set("apikey", apiKey) + streamURL.RawQuery = v.Encode() + } + } + + return manager.GetSceneStreamPaths(obj, streamURL, config.GetMaxStreamingTranscodeSize()) } func (r *sceneResolver) Interactive(ctx context.Context, obj *models.Scene) (bool, error) { diff --git a/internal/api/resolver_query_scene.go b/internal/api/resolver_query_scene.go index 1bb8f0f96..b651cf694 100644 --- a/internal/api/resolver_query_scene.go +++ b/internal/api/resolver_query_scene.go @@ -8,6 +8,8 @@ import ( "github.com/stashapp/stash/internal/api/urlbuilders" "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/session" + "github.com/stashapp/stash/pkg/signedurl" ) func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*manager.SceneStreamEndpoint, error) { @@ -39,7 +41,22 @@ func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*manage baseURL, _ := ctx.Value(BaseURLCtxKey).(string) builder := urlbuilders.NewSceneURLBuilder(baseURL, scene) - apiKey := config.GetAPIKey() - return manager.GetSceneStreamPaths(scene, builder.GetStreamURL(apiKey), config.GetMaxStreamingTranscodeSize()) + streamURL := builder.GetStreamURL("") + if config.HasCredentials() { + userID := session.GetCurrentUserID(ctx) + if userID == nil { + return nil, fmt.Errorf("user ID not found") + } + streamURL.RawQuery = signedParams(config, *userID, signedurl.DerivePrefix(streamURL.Path)).Encode() + } else { + apiKey := config.GetAPIKey() + if apiKey != "" { + v := streamURL.Query() + v.Set("apikey", apiKey) + streamURL.RawQuery = v.Encode() + } + } + + return manager.GetSceneStreamPaths(scene, streamURL, config.GetMaxStreamingTranscodeSize()) } diff --git a/internal/api/signed_url.go b/internal/api/signed_url.go new file mode 100644 index 000000000..38fff90a2 --- /dev/null +++ b/internal/api/signed_url.go @@ -0,0 +1,32 @@ +package api + +import ( + "net/url" + "time" + + "github.com/stashapp/stash/internal/manager/config" + "github.com/stashapp/stash/pkg/signedurl" +) + +// userSigningKey returns the HMAC signing key for a given user. +func userSigningKey(c *config.Config, _ string) []byte { + return c.GetJWTSignKey() +} + +// signedParams generates signed URL query parameters for the given path prefix and user. +func signedParams(c *config.Config, userID string, prefix string) url.Values { + secret := userSigningKey(c, userID) + cid := signedurl.GenerateCredentialID(secret, userID) + expires := time.Now().Add(time.Duration(c.GetSignedURLExpiry()) * time.Second) + return signedurl.SignPrefix(prefix, secret, cid, expires) +} + +// resolveCredentialID maps a credential ID back to a username and their signing key. +func resolveCredentialID(c *config.Config, cid string) (string, []byte, bool) { + username := c.GetUsername() + secret := userSigningKey(c, username) + if signedurl.GenerateCredentialID(secret, username) == cid { + return username, secret, true + } + return "", nil, false +} diff --git a/internal/manager/config/config.go b/internal/manager/config/config.go index 19e263810..28ab531de 100644 --- a/internal/manager/config/config.go +++ b/internal/manager/config/config.go @@ -43,6 +43,9 @@ const ( Password = "password" MaxSessionAge = "max_session_age" + SignedURLExpiry = "signed_url_expiry" + signedURLExpiryDefault = 60 * 60 * 24 // 24 hours in seconds + // SFWContentMode mode config key SFWContentMode = "sfw_content_mode" @@ -1229,6 +1232,21 @@ func (i *Config) GetMaxSessionAge() int { return ret } +// GetSignedURLExpiry gets the expiry time for signed URLs, in seconds. +// Defaults to 24 hours to accommodate long video playback sessions. +func (i *Config) GetSignedURLExpiry() int { + i.RLock() + defer i.RUnlock() + + ret := signedURLExpiryDefault + v := i.forKey(SignedURLExpiry) + if v.Exists(SignedURLExpiry) { + ret = v.Int(SignedURLExpiry) + } + + return ret +} + // GetCustomServedFolders gets the map of custom paths to their applicable // filesystem locations func (i *Config) GetCustomServedFolders() utils.URLMap { diff --git a/pkg/ffmpeg/stream_segmented.go b/pkg/ffmpeg/stream_segmented.go index f35b960ab..b1d545d20 100644 --- a/pkg/ffmpeg/stream_segmented.go +++ b/pkg/ffmpeg/stream_segmented.go @@ -20,6 +20,7 @@ import ( "github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/signedurl" "github.com/stashapp/stash/pkg/utils" "github.com/zencoder/go-dash/v3/mpd" @@ -433,26 +434,36 @@ func serveHLSManifest(sm *StreamManager, w http.ResponseWriter, r *http.Request, baseURL := prefix + baseUrl.String() urlQuery := url.Values{} + + // Forward auth params to segment URLs. API key takes precedence + // over signed params since it is explicitly configured by the user. + // TODO - this needs to be handled outside of this package apikey := r.URL.Query().Get(apiKeyParamKey) + if apikey != "" { + urlQuery.Set(apiKeyParamKey, apikey) + } else { + cid := r.URL.Query().Get(signedurl.CIDParam) + expires := r.URL.Query().Get(signedurl.ExpiresParam) + sig := r.URL.Query().Get(signedurl.SigParam) + if cid != "" && expires != "" && sig != "" { + urlQuery.Set(signedurl.CIDParam, cid) + urlQuery.Set(signedurl.ExpiresParam, expires) + urlQuery.Set(signedurl.SigParam, sig) + } + } if resolution != "" { urlQuery.Set(resolutionParamKey, resolution) } - // TODO - this needs to be handled outside of this package - if apikey != "" { - urlQuery.Set(apiKeyParamKey, apikey) - } - - urlQueryString := "" + segQuery := "" if len(urlQuery) > 0 { - urlQueryString = "?" + urlQuery.Encode() + segQuery = "?" + urlQuery.Encode() } var buf bytes.Buffer fmt.Fprint(&buf, "#EXTM3U\n") - fmt.Fprint(&buf, "#EXT-X-VERSION:3\n") fmt.Fprint(&buf, "#EXT-X-MEDIA-SEQUENCE:0\n") fmt.Fprintf(&buf, "#EXT-X-TARGETDURATION:%d\n", segmentLength) @@ -468,7 +479,7 @@ func serveHLSManifest(sm *StreamManager, w http.ResponseWriter, r *http.Request, } fmt.Fprintf(&buf, "#EXTINF:%f,\n", thisLength) - fmt.Fprintf(&buf, "%s/%d.ts%s\n", baseURL, segment, urlQueryString) + fmt.Fprintf(&buf, "%s/%d.ts%s\n", baseURL, segment, segQuery) leftover -= thisLength segment++ @@ -529,10 +540,21 @@ func serveDASHManifest(sm *StreamManager, w http.ResponseWriter, r *http.Request urlQuery := url.Values{} + // Forward auth params to segment URLs. API key takes precedence + // over signed params since it is explicitly configured by the user. // TODO - this needs to be handled outside of this package apikey := r.URL.Query().Get(apiKeyParamKey) if apikey != "" { urlQuery.Set(apiKeyParamKey, apikey) + } else { + cid := r.URL.Query().Get(signedurl.CIDParam) + expires := r.URL.Query().Get(signedurl.ExpiresParam) + sig := r.URL.Query().Get(signedurl.SigParam) + if cid != "" && expires != "" && sig != "" { + urlQuery.Set(signedurl.CIDParam, cid) + urlQuery.Set(signedurl.ExpiresParam, expires) + urlQuery.Set(signedurl.SigParam, sig) + } } maxTranscodeSize := sm.config.GetMaxStreamingTranscodeSize().GetMaxResolution() diff --git a/pkg/signedurl/signedurl.go b/pkg/signedurl/signedurl.go new file mode 100644 index 000000000..5e8e1ae17 --- /dev/null +++ b/pkg/signedurl/signedurl.go @@ -0,0 +1,105 @@ +// Package signedurl provides HMAC-signed URLs for media requests from devices that cannot pass cookies (AirPlay, Chromecast). +package signedurl + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "errors" + "net/url" + "strconv" + "strings" + "time" +) + +const ( + CIDParam = "cid" + ExpiresParam = "expires" + SigParam = "signature" +) + +var ( + ErrMissingParams = errors.New("missing required signed URL parameters") + ErrExpiredURL = errors.New("signed URL has expired") + ErrInvalidSignature = errors.New("invalid signature") + ErrInvalidURL = errors.New("invalid URL") +) + +// GenerateCredentialID produces an opaque, deterministic identifier for a user. +// It is an HMAC-SHA256 of the username, truncated to 16 hex characters. +func GenerateCredentialID(secret []byte, username string) string { + h := hmac.New(sha256.New, secret) + h.Write([]byte(username)) + return hex.EncodeToString(h.Sum(nil))[:16] +} + +// DerivePrefix extracts the signing prefix from a request path by taking the +// first 3 segments and stripping the file extension from the 3rd. +func DerivePrefix(path string) string { + parts := strings.Split(strings.Trim(path, "/"), "/") + if len(parts) < 3 { + return "/" + strings.Join(parts, "/") + } + action := parts[2] + if dotIdx := strings.IndexByte(action, '.'); dotIdx >= 0 { + action = action[:dotIdx] + } + return "/" + parts[0] + "/" + parts[1] + "/" + action +} + +// makeSignString constructs the canonical string to sign. +func makeSignString(prefix string, cid string, expires time.Time) string { + return prefix + "?" + CIDParam + "=" + cid + "&" + ExpiresParam + "=" + strconv.FormatInt(expires.Unix(), 10) +} + +// SignPrefix signs a path prefix and returns url.Values containing +// the cid, expires, and signature parameters. The caller appends +// these to any URL whose path falls under the signed prefix. +func SignPrefix(prefix string, secret []byte, cid string, expires time.Time) url.Values { + signString := makeSignString(prefix, cid, expires) + + h := hmac.New(sha256.New, secret) + h.Write([]byte(signString)) + signature := hex.EncodeToString(h.Sum(nil)) + + params := make(url.Values) + params.Set(CIDParam, cid) + params.Set(ExpiresParam, strconv.FormatInt(expires.Unix(), 10)) + params.Set(SigParam, signature) + return params +} + +// VerifyURL verifies a signed URL request. It derives the signing prefix +// from the request path, checks expiry, and validates the HMAC signature. +// Returns the credential ID on success. +func VerifyURL(requestPath string, queryParams url.Values, secret []byte) (string, error) { + cid := queryParams.Get(CIDParam) + expiresStr := queryParams.Get(ExpiresParam) + sig := queryParams.Get(SigParam) + + if cid == "" || expiresStr == "" || sig == "" { + return "", ErrMissingParams + } + + expires, err := strconv.ParseInt(expiresStr, 10, 64) + if err != nil { + return "", ErrInvalidURL + } + + if time.Now().Unix() > expires { + return "", ErrExpiredURL + } + + prefix := DerivePrefix(requestPath) + signString := makeSignString(prefix, cid, time.Unix(expires, 0)) + + h := hmac.New(sha256.New, secret) + h.Write([]byte(signString)) + expectedSig := hex.EncodeToString(h.Sum(nil)) + + if !hmac.Equal([]byte(sig), []byte(expectedSig)) { + return "", ErrInvalidSignature + } + + return cid, nil +} diff --git a/pkg/signedurl/signedurl_test.go b/pkg/signedurl/signedurl_test.go new file mode 100644 index 000000000..f21125067 --- /dev/null +++ b/pkg/signedurl/signedurl_test.go @@ -0,0 +1,271 @@ +package signedurl + +import ( + "errors" + "net/url" + "strconv" + "testing" + "time" +) + +func TestDerivePrefix(t *testing.T) { + tests := []struct { + path string + expected string + }{ + // Scene stream variants + {"/scene/1/stream", "/scene/1/stream"}, + {"/scene/1/stream.mp4", "/scene/1/stream"}, + {"/scene/1/stream.webm", "/scene/1/stream"}, + {"/scene/1/stream.mkv", "/scene/1/stream"}, + {"/scene/1/stream.m3u8", "/scene/1/stream"}, + {"/scene/1/stream.mpd", "/scene/1/stream"}, + + // HLS segments + {"/scene/1/stream.m3u8/0.ts", "/scene/1/stream"}, + {"/scene/1/stream.m3u8/99.ts", "/scene/1/stream"}, + + // DASH segments + {"/scene/1/stream.mpd/5_v.webm", "/scene/1/stream"}, + {"/scene/1/stream.mpd/5_a.webm", "/scene/1/stream"}, + {"/scene/1/stream.mpd/init_v.webm", "/scene/1/stream"}, + {"/scene/1/stream.mpd/init_a.webm", "/scene/1/stream"}, + + // Caption + {"/scene/1/caption", "/scene/1/caption"}, + + // Image paths + {"/image/5/thumbnail", "/image/5/thumbnail"}, + {"/image/5/image", "/image/5/image"}, + + // Gallery paths + {"/gallery/3/cover", "/gallery/3/cover"}, + {"/gallery/3/preview", "/gallery/3/preview"}, + + // Short paths + {"/scene/1", "/scene/1"}, + {"/scene", "/scene"}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + got := DerivePrefix(tt.path) + if got != tt.expected { + t.Errorf("DerivePrefix(%q) = %q, want %q", tt.path, got, tt.expected) + } + }) + } +} + +func TestGenerateCredentialID(t *testing.T) { + secret := []byte("test-secret-key") + + t.Run("deterministic", func(t *testing.T) { + cid1 := GenerateCredentialID(secret, "alice") + cid2 := GenerateCredentialID(secret, "alice") + if cid1 != cid2 { + t.Errorf("expected deterministic output, got %q and %q", cid1, cid2) + } + }) + + t.Run("different usernames produce different cids", func(t *testing.T) { + cid1 := GenerateCredentialID(secret, "alice") + cid2 := GenerateCredentialID(secret, "bob") + if cid1 == cid2 { + t.Errorf("expected different cids for different usernames, both got %q", cid1) + } + }) + + t.Run("different secrets produce different cids", func(t *testing.T) { + cid1 := GenerateCredentialID([]byte("secret-1"), "alice") + cid2 := GenerateCredentialID([]byte("secret-2"), "alice") + if cid1 == cid2 { + t.Errorf("expected different cids for different secrets, both got %q", cid1) + } + }) + + t.Run("length is 16 hex chars", func(t *testing.T) { + cid := GenerateCredentialID(secret, "alice") + if len(cid) != 16 { + t.Errorf("expected length 16, got %d (%q)", len(cid), cid) + } + }) +} + +func TestSignAndVerifyRoundtrip(t *testing.T) { + secret := []byte("test-secret-key") + cid := GenerateCredentialID(secret, "alice") + expires := time.Now().Add(1 * time.Hour) + + params := SignPrefix("/scene/1/stream", secret, cid, expires) + + gotCID, err := VerifyURL("/scene/1/stream", params, secret) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotCID != cid { + t.Errorf("expected cid %q, got %q", cid, gotCID) + } +} + +func TestDifferentPathSamePrefixVerifies(t *testing.T) { + secret := []byte("test-secret-key") + cid := GenerateCredentialID(secret, "alice") + expires := time.Now().Add(1 * time.Hour) + + // Sign for the stream prefix + params := SignPrefix("/scene/1/stream", secret, cid, expires) + + // Verify with different paths that share the same prefix + paths := []string{ + "/scene/1/stream", + "/scene/1/stream.mp4", + "/scene/1/stream.m3u8", + "/scene/1/stream.m3u8/0.ts", + "/scene/1/stream.mpd", + "/scene/1/stream.mpd/5_v.webm", + "/scene/1/stream.mpd/init_a.webm", + } + + for _, path := range paths { + t.Run(path, func(t *testing.T) { + gotCID, err := VerifyURL(path, params, secret) + if err != nil { + t.Fatalf("unexpected error for path %q: %v", path, err) + } + if gotCID != cid { + t.Errorf("expected cid %q, got %q", cid, gotCID) + } + }) + } +} + +func TestDifferentPrefixFails(t *testing.T) { + secret := []byte("test-secret-key") + cid := GenerateCredentialID(secret, "alice") + expires := time.Now().Add(1 * time.Hour) + + params := SignPrefix("/scene/1/stream", secret, cid, expires) + + // Different scene ID + _, err := VerifyURL("/scene/2/stream", params, secret) + if !errors.Is(err, ErrInvalidSignature) { + t.Errorf("expected ErrInvalidSignature, got %v", err) + } + + // Different resource type + _, err = VerifyURL("/scene/1/caption", params, secret) + if !errors.Is(err, ErrInvalidSignature) { + t.Errorf("expected ErrInvalidSignature, got %v", err) + } + + // Different entity type + _, err = VerifyURL("/image/1/stream", params, secret) + if !errors.Is(err, ErrInvalidSignature) { + t.Errorf("expected ErrInvalidSignature, got %v", err) + } +} + +func TestExpiredURLFails(t *testing.T) { + secret := []byte("test-secret-key") + cid := GenerateCredentialID(secret, "alice") + expires := time.Now().Add(-1 * time.Hour) // expired 1 hour ago + + params := SignPrefix("/scene/1/stream", secret, cid, expires) + + _, err := VerifyURL("/scene/1/stream", params, secret) + if !errors.Is(err, ErrExpiredURL) { + t.Errorf("expected ErrExpiredURL, got %v", err) + } +} + +func TestTamperedSignatureFails(t *testing.T) { + secret := []byte("test-secret-key") + cid := GenerateCredentialID(secret, "alice") + expires := time.Now().Add(1 * time.Hour) + + params := SignPrefix("/scene/1/stream", secret, cid, expires) + params.Set(SigParam, "tampered0000000000000000000000000000000000000000000000000000000") + + _, err := VerifyURL("/scene/1/stream", params, secret) + if !errors.Is(err, ErrInvalidSignature) { + t.Errorf("expected ErrInvalidSignature, got %v", err) + } +} + +func TestTamperedCIDFails(t *testing.T) { + secret := []byte("test-secret-key") + cid := GenerateCredentialID(secret, "alice") + expires := time.Now().Add(1 * time.Hour) + + params := SignPrefix("/scene/1/stream", secret, cid, expires) + params.Set(CIDParam, "tamperedcid12345") + + _, err := VerifyURL("/scene/1/stream", params, secret) + if !errors.Is(err, ErrInvalidSignature) { + t.Errorf("expected ErrInvalidSignature, got %v", err) + } +} + +func TestMissingParamsFails(t *testing.T) { + secret := []byte("test-secret-key") + cid := GenerateCredentialID(secret, "alice") + expires := time.Now().Add(1 * time.Hour) + + full := SignPrefix("/scene/1/stream", secret, cid, expires) + + tests := []struct { + name string + missing string + }{ + {"missing cid", CIDParam}, + {"missing expires", ExpiresParam}, + {"missing signature", SigParam}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params := make(url.Values) + for k, v := range full { + params[k] = v + } + params.Del(tt.missing) + + _, err := VerifyURL("/scene/1/stream", params, secret) + if !errors.Is(err, ErrMissingParams) { + t.Errorf("expected ErrMissingParams, got %v", err) + } + }) + } +} + +func TestTamperedExpiresFails(t *testing.T) { + secret := []byte("test-secret-key") + cid := GenerateCredentialID(secret, "alice") + expires := time.Now().Add(1 * time.Hour) + + params := SignPrefix("/scene/1/stream", secret, cid, expires) + + // Attacker extends the expiry by 24 hours + tampered := time.Now().Add(25 * time.Hour) + params.Set(ExpiresParam, strconv.FormatInt(tampered.Unix(), 10)) + + _, err := VerifyURL("/scene/1/stream", params, secret) + if !errors.Is(err, ErrInvalidSignature) { + t.Errorf("expected ErrInvalidSignature, got %v", err) + } +} + +func TestWrongSecretFails(t *testing.T) { + secret := []byte("test-secret-key") + wrongSecret := []byte("wrong-secret-key") + cid := GenerateCredentialID(secret, "alice") + expires := time.Now().Add(1 * time.Hour) + + params := SignPrefix("/scene/1/stream", secret, cid, expires) + + _, err := VerifyURL("/scene/1/stream", params, wrongSecret) + if !errors.Is(err, ErrInvalidSignature) { + t.Errorf("expected ErrInvalidSignature, got %v", err) + } +}