From 4aef4fea4e460e735fecfc5973fb111d79e2a4f0 Mon Sep 17 00:00:00 2001 From: Mickael Date: Fri, 5 Sep 2025 13:41:17 +1000 Subject: [PATCH] fix (plg_handler_mcp): maintain mcp - #875 --- server/plugin/plg_handler_mcp/handler.go | 5 +- server/plugin/plg_handler_mcp/handler_auth.go | 93 +++++++++++++++---- server/plugin/plg_handler_mcp/index.go | 21 ++--- 3 files changed, 86 insertions(+), 33 deletions(-) diff --git a/server/plugin/plg_handler_mcp/handler.go b/server/plugin/plg_handler_mcp/handler.go index 6e54cbbb..c1fc06b6 100644 --- a/server/plugin/plg_handler_mcp/handler.go +++ b/server/plugin/plg_handler_mcp/handler.go @@ -16,14 +16,13 @@ import ( "github.com/google/uuid" ) -func (this *Server) messageHandler(w http.ResponseWriter, r *http.Request) { +func (this *Server) messageHandler(_ *App, w http.ResponseWriter, r *http.Request) { sessionID := r.URL.Query().Get("sessionId") if r.Method != http.MethodPost { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("Invalid Request")) return } - request := JSONRPCRequest{} if err := json.NewDecoder(r.Body).Decode(&request); err != nil { w.WriteHeader(http.StatusBadRequest) @@ -34,7 +33,7 @@ func (this *Server) messageHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } -func (this *Server) sseHandler(w http.ResponseWriter, r *http.Request) { +func (this *Server) sseHandler(_ *App, w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { w.WriteHeader(http.StatusBadRequest) return diff --git a/server/plugin/plg_handler_mcp/handler_auth.go b/server/plugin/plg_handler_mcp/handler_auth.go index f7970abf..3e187e6b 100644 --- a/server/plugin/plg_handler_mcp/handler_auth.go +++ b/server/plugin/plg_handler_mcp/handler_auth.go @@ -4,6 +4,8 @@ import ( "encoding/json" "fmt" "net/http" + "regexp" + "strings" "time" . "github.com/mickael-kerjean/filestash/server/common" @@ -11,10 +13,23 @@ import ( ) const ( - DEFAULT_TOKEN_EXPIRY = 3600 + DEFAULT_TOKEN_EXPIRY = 3600 + DEFAULT_SECRET_EXPIRY = 30 * 24 * 3600 ) -func (this Server) WellKnownInfoHandler(w http.ResponseWriter, r *http.Request) { +var ( + KEY_FOR_CLIENT_SECRET string + KEY_FOR_CODE string +) + +func init() { + Hooks.Register.Onload(func() { + KEY_FOR_CLIENT_SECRET = Hash("MCP_SECRET_"+SECRET_KEY, len(SECRET_KEY)) + KEY_FOR_CODE = Hash("MCP_CODE_"+SECRET_KEY, len(SECRET_KEY)) + }) +} + +func (this Server) WellKnownInfoHandler(_ *App, w http.ResponseWriter, r *http.Request) { WithCors(w) if r.Method != http.MethodGet && r.Method != http.MethodOptions { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -23,7 +38,7 @@ func (this Server) WellKnownInfoHandler(w http.ResponseWriter, r *http.Request) scheme := "https" host := r.Host - if host == "localhost" || host == "127.0.0.1" { + if strings.HasPrefix(host, "localhost") || strings.HasPrefix(host, "127.0.0.1") { scheme = "http" } baseURL := fmt.Sprintf("%s://%s", scheme, host) @@ -44,7 +59,7 @@ func (this Server) WellKnownInfoHandler(w http.ResponseWriter, r *http.Request) }) } -func (this Server) AuthorizeHandler(w http.ResponseWriter, r *http.Request) { +func (this Server) AuthorizeHandler(_ *App, w http.ResponseWriter, r *http.Request) { WithCors(w) responseType := r.URL.Query().Get("response_type") @@ -62,11 +77,13 @@ func (this Server) AuthorizeHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "redirect_uri is required", http.StatusBadRequest) return } - - http.Redirect(w, r, fmt.Sprintf("/login?next=/api/mcp?redirect_uri=%s%%26state=%s", redirectURI, state), http.StatusSeeOther) + http.Redirect(w, r, fmt.Sprintf( + "/login?next=/api/mcp?redirect_uri=%s%%26state=%s%%26client_id=%s", + redirectURI, state, clientID, + ), http.StatusSeeOther) } -func (this Server) TokenHandler(w http.ResponseWriter, r *http.Request) { +func (this Server) TokenHandler(_ *App, w http.ResponseWriter, r *http.Request) { WithCors(w) if r.Method != http.MethodPost && r.Method != http.MethodOptions { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -80,39 +97,77 @@ func (this Server) TokenHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Invalid Grant Type", http.StatusBadRequest) return } + clientID := r.FormValue("client_id") + if r.FormValue("client_secret") != clientSecret(clientID) { + http.Error(w, "Invalid Client Credentials", http.StatusUnauthorized) + return + } + token, err := DecryptString(Hash(KEY_FOR_CODE+clientID, len(SECRET_KEY)), r.FormValue("code")) + if err != nil { + http.Error(w, "Invalid authorization code", http.StatusBadRequest) + return + } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ - "access_token": r.FormValue("code"), + "access_token": token, "token_type": "Bearer", }) } -func (this Server) RegisterHandler(w http.ResponseWriter, r *http.Request) { +func (this Server) RegisterHandler(ctx *App, w http.ResponseWriter, r *http.Request) { WithCors(w) if r.Method != http.MethodPost && r.Method != http.MethodOptions { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } + clientName := regexp.MustCompile("[^a-zA-Z0-9\\-]+").ReplaceAllString( + fmt.Sprintf("%s", ctx.Body["client_name"]), + "", + ) + clientID := clientName + "." + Hash(clientName+time.Now().String(), 8) w.WriteHeader(http.StatusCreated) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ - "client_id": "anonymous", - "client_secret": "anonymous", - "client_id_issued_at": time.Now().Unix(), - "client_secret_expires_at": 0, - "client_name": "Untrusted", - "redirect_uris": []string{}, - "grant_types": []string{"authorization_code"}, - "token_endpoint_auth_method": "client_secret_basic", + json.NewEncoder(w).Encode(struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + ClientIDIssuedAt int64 `json:"client_id_issued_at"` + ClientSecretExpiresAt int64 `json:"client_secret_expires_at"` + ClientName string `json:"client_name"` + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + }{ + ClientID: clientID, + ClientSecret: clientSecret(clientID), + ClientIDIssuedAt: time.Now().Unix(), + ClientSecretExpiresAt: time.Now().Unix() + DEFAULT_SECRET_EXPIRY, + ClientName: clientName, + RedirectURIs: []string{}, + GrantTypes: []string{"authorization_code"}, + TokenEndpointAuthMethod: "client_secret_basic", }) } +func clientSecret(clientID string) string { + return Hash(clientID+KEY_FOR_CLIENT_SECRET, 32) +} + func (this Server) CallbackHandler(ctx *App, res http.ResponseWriter, req *http.Request) { uri := req.URL.Query().Get("redirect_uri") state := req.URL.Query().Get("state") + clientID := req.URL.Query().Get("client_id") if uri == "" { SendErrorResult(res, ErrNotValid) return } - http.Redirect(res, req, fmt.Sprintf(uri+"?code=%s&state=%s", ctx.Authorization, state), http.StatusSeeOther) + code, err := EncryptString(Hash(KEY_FOR_CODE+clientID, len(SECRET_KEY)), ctx.Authorization) + if err != nil { + SendErrorResult(res, ErrNotValid) + return + } + uri += "?code=" + code + if state != "" { + uri += "&state=" + state + } + http.Redirect(res, req, uri, http.StatusSeeOther) } diff --git a/server/plugin/plg_handler_mcp/index.go b/server/plugin/plg_handler_mcp/index.go index 73075223..818b4318 100644 --- a/server/plugin/plg_handler_mcp/index.go +++ b/server/plugin/plg_handler_mcp/index.go @@ -26,17 +26,16 @@ func init() { return nil } srv := Server{} - r.HandleFunc("/sse", srv.sseHandler) - r.HandleFunc("/messages", srv.messageHandler) - r.HandleFunc("/.well-known/oauth-authorization-server", srv.WellKnownInfoHandler) - r.HandleFunc("/mcp/authorize", srv.AuthorizeHandler) - r.HandleFunc("/mcp/token", srv.TokenHandler) - r.HandleFunc("/mcp/register", srv.RegisterHandler) - r.HandleFunc("/api/mcp", NewMiddlewareChain( - srv.CallbackHandler, - []Middleware{SessionStart, LoggedInOnly}, - *app, - )) + m := []Middleware{} + r.HandleFunc("/sse", NewMiddlewareChain(srv.sseHandler, m, *app)) + r.HandleFunc("/messages", NewMiddlewareChain(srv.messageHandler, m, *app)) + r.HandleFunc("/.well-known/oauth-authorization-server", NewMiddlewareChain(srv.WellKnownInfoHandler, m, *app)) + r.HandleFunc("/mcp/authorize", NewMiddlewareChain(srv.AuthorizeHandler, m, *app)) + r.HandleFunc("/mcp/token", NewMiddlewareChain(srv.TokenHandler, m, *app)) + m = []Middleware{BodyParser} + r.HandleFunc("/mcp/register", NewMiddlewareChain(srv.RegisterHandler, m, *app)) + m = []Middleware{SessionStart, LoggedInOnly} + r.HandleFunc("/api/mcp", NewMiddlewareChain(srv.CallbackHandler, m, *app)) return nil }) }