mirror of
https://github.com/mickael-kerjean/filestash
synced 2025-12-28 11:16:52 +01:00
fix (plg_handler_mcp): support for state url param
This commit is contained in:
parent
516a861974
commit
f8f26035fc
3 changed files with 81 additions and 33 deletions
|
|
@ -16,14 +16,13 @@ import (
|
|||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (this *Server) messageHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (this *Server) messageHandler(_ *App, w http.ResponseWriter, r *http.Request) {
|
||||
sessionID := r.URL.Query().Get("sessionId")
|
||||
if r.Method != http.MethodPost {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("Invalid Request"))
|
||||
return
|
||||
}
|
||||
|
||||
request := JSONRPCRequest{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
|
@ -34,7 +33,7 @@ func (this *Server) messageHandler(w http.ResponseWriter, r *http.Request) {
|
|||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (this *Server) sseHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (this *Server) sseHandler(_ *App, w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/mickael-kerjean/filestash/server/common"
|
||||
|
|
@ -11,10 +12,23 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
DEFAULT_TOKEN_EXPIRY = 3600
|
||||
DEFAULT_TOKEN_EXPIRY = 3600
|
||||
DEFAULT_SECRET_EXPIRY = 30 * 24 * 3600
|
||||
)
|
||||
|
||||
func (this Server) WellKnownInfoHandler(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
KEY_FOR_CLIENT_SECRET string
|
||||
KEY_FOR_CODE string
|
||||
)
|
||||
|
||||
func init() {
|
||||
Hooks.Register.Onload(func() {
|
||||
KEY_FOR_CLIENT_SECRET = Hash("MCP_SECRET_"+SECRET_KEY, len(SECRET_KEY))
|
||||
KEY_FOR_CODE = Hash("MCP_CODE_"+SECRET_KEY, len(SECRET_KEY))
|
||||
})
|
||||
}
|
||||
|
||||
func (this Server) WellKnownInfoHandler(_ *App, w http.ResponseWriter, r *http.Request) {
|
||||
WithCors(w)
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodOptions {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
|
|
@ -23,7 +37,7 @@ func (this Server) WellKnownInfoHandler(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
scheme := "https"
|
||||
host := r.Host
|
||||
if host == "localhost" || host == "127.0.0.1" {
|
||||
if strings.HasPrefix(host, "localhost") || strings.HasPrefix(host, "127.0.0.1") {
|
||||
scheme = "http"
|
||||
}
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
|
@ -44,12 +58,13 @@ func (this Server) WellKnownInfoHandler(w http.ResponseWriter, r *http.Request)
|
|||
})
|
||||
}
|
||||
|
||||
func (this Server) AuthorizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (this Server) AuthorizeHandler(_ *App, w http.ResponseWriter, r *http.Request) {
|
||||
WithCors(w)
|
||||
|
||||
responseType := r.URL.Query().Get("response_type")
|
||||
clientID := r.URL.Query().Get("client_id")
|
||||
redirectURI := r.URL.Query().Get("redirect_uri")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if responseType != "code" {
|
||||
http.Error(w, "response_type must be 'code'", http.StatusBadRequest)
|
||||
|
|
@ -61,11 +76,13 @@ func (this Server) AuthorizeHandler(w http.ResponseWriter, r *http.Request) {
|
|||
http.Error(w, "redirect_uri is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, fmt.Sprintf("/login?next=/api/mcp?redirect_uri=%s", redirectURI), http.StatusSeeOther)
|
||||
http.Redirect(w, r, fmt.Sprintf(
|
||||
"/login?next=/api/mcp?redirect_uri=%s%%26state=%s%%26client_id=%s",
|
||||
redirectURI, state, clientID,
|
||||
), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (this Server) TokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (this Server) TokenHandler(_ *App, w http.ResponseWriter, r *http.Request) {
|
||||
WithCors(w)
|
||||
if r.Method != http.MethodPost && r.Method != http.MethodOptions {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
|
|
@ -79,38 +96,71 @@ func (this Server) TokenHandler(w http.ResponseWriter, r *http.Request) {
|
|||
http.Error(w, "Invalid Grant Type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientID := r.FormValue("client_id")
|
||||
if r.FormValue("client_secret") != clientSecret(clientID) {
|
||||
http.Error(w, "Invalid client credentials", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
token, err := DecryptString(Hash(KEY_FOR_CODE+clientID, len(SECRET_KEY)), r.FormValue("code"))
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid authorization code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"access_token": r.FormValue("code"),
|
||||
"access_token": token,
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
func (this Server) RegisterHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (this Server) RegisterHandler(ctx *App, w http.ResponseWriter, r *http.Request) {
|
||||
WithCors(w)
|
||||
if r.Method != http.MethodPost && r.Method != http.MethodOptions {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
clientName := fmt.Sprintf("%s", ctx.Body["client_name"])
|
||||
clientID := clientName + "." + Hash(clientName+time.Now().String(), 8)
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"client_id": "anonymous",
|
||||
"client_secret": "anonymous",
|
||||
"client_id_issued_at": time.Now().Unix(),
|
||||
"client_secret_expires_at": 0,
|
||||
"client_name": "Untrusted",
|
||||
"redirect_uris": []string{},
|
||||
"grant_types": []string{"authorization_code"},
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
json.NewEncoder(w).Encode(struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at"`
|
||||
ClientName string `json:"client_name"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
|
||||
}{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret(clientID),
|
||||
ClientIDIssuedAt: time.Now().Unix(),
|
||||
ClientSecretExpiresAt: time.Now().Unix() + DEFAULT_SECRET_EXPIRY,
|
||||
ClientName: clientName,
|
||||
RedirectURIs: []string{},
|
||||
GrantTypes: []string{"authorization_code"},
|
||||
TokenEndpointAuthMethod: "client_secret_basic",
|
||||
})
|
||||
}
|
||||
|
||||
func clientSecret(clientID string) string {
|
||||
return Hash(clientID+KEY_FOR_CLIENT_SECRET, 32)
|
||||
}
|
||||
|
||||
func (this Server) CallbackHandler(ctx *App, res http.ResponseWriter, req *http.Request) {
|
||||
uri := req.URL.Query().Get("redirect_uri")
|
||||
state := req.URL.Query().Get("state")
|
||||
clientID := req.URL.Query().Get("client_id")
|
||||
if uri == "" {
|
||||
SendErrorResult(res, ErrNotValid)
|
||||
return
|
||||
}
|
||||
http.Redirect(res, req, fmt.Sprintf(uri+"?code=%s", ctx.Authorization), http.StatusSeeOther)
|
||||
code, err := EncryptString(Hash(KEY_FOR_CODE+clientID, len(SECRET_KEY)), ctx.Authorization)
|
||||
if err != nil {
|
||||
SendErrorResult(res, ErrNotValid)
|
||||
return
|
||||
}
|
||||
http.Redirect(res, req, fmt.Sprintf(uri+"?code=%s&state=%s", code, state), http.StatusSeeOther)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,17 +26,16 @@ func init() {
|
|||
return nil
|
||||
}
|
||||
srv := Server{}
|
||||
r.HandleFunc("/sse", srv.sseHandler)
|
||||
r.HandleFunc("/messages", srv.messageHandler)
|
||||
r.HandleFunc("/.well-known/oauth-authorization-server", srv.WellKnownInfoHandler)
|
||||
r.HandleFunc("/mcp/authorize", srv.AuthorizeHandler)
|
||||
r.HandleFunc("/mcp/token", srv.TokenHandler)
|
||||
r.HandleFunc("/mcp/register", srv.RegisterHandler)
|
||||
r.HandleFunc("/api/mcp", NewMiddlewareChain(
|
||||
srv.CallbackHandler,
|
||||
[]Middleware{SessionStart, LoggedInOnly},
|
||||
*app,
|
||||
))
|
||||
m := []Middleware{}
|
||||
r.HandleFunc("/sse", NewMiddlewareChain(srv.sseHandler, m, *app))
|
||||
r.HandleFunc("/messages", NewMiddlewareChain(srv.messageHandler, m, *app))
|
||||
r.HandleFunc("/.well-known/oauth-authorization-server", NewMiddlewareChain(srv.WellKnownInfoHandler, m, *app))
|
||||
r.HandleFunc("/mcp/authorize", NewMiddlewareChain(srv.AuthorizeHandler, m, *app))
|
||||
r.HandleFunc("/mcp/token", NewMiddlewareChain(srv.TokenHandler, m, *app))
|
||||
m = []Middleware{BodyParser}
|
||||
r.HandleFunc("/mcp/register", NewMiddlewareChain(srv.RegisterHandler, m, *app))
|
||||
m = []Middleware{SessionStart, LoggedInOnly}
|
||||
r.HandleFunc("/api/mcp", NewMiddlewareChain(srv.CallbackHandler, m, *app))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue