diff --git a/server/plugin/plg_handler_mcp/handler.go b/server/plugin/plg_handler_mcp/handler.go index 00f51a87..62387fa8 100644 --- a/server/plugin/plg_handler_mcp/handler.go +++ b/server/plugin/plg_handler_mcp/handler.go @@ -36,14 +36,11 @@ func (this *Server) messageHandler(_ *App, w http.ResponseWriter, r *http.Reques } func (this *Server) sseHandler(_ *App, w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - w.WriteHeader(http.StatusBadRequest) - return - } token := ExtractToken(r) if token == "" { Log.Debug("plg_handler_mcp::sse msg=invalid_token") w.Header().Add("Content-Type", "application/json") + w.Header().Add("WWW-Authenticate", "Bearer resource_metadata=\""+this.baseURL(r)+"/.well-known/oauth-protected-resource\"") w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(JSONRPCResponse{ JSONRPC: "2.0", diff --git a/server/plugin/plg_handler_mcp/handler_auth.go b/server/plugin/plg_handler_mcp/handler_auth.go index 1f9aba3d..f5ffc1d1 100644 --- a/server/plugin/plg_handler_mcp/handler_auth.go +++ b/server/plugin/plg_handler_mcp/handler_auth.go @@ -9,7 +9,6 @@ import ( "time" . "github.com/mickael-kerjean/filestash/server/common" - . "github.com/mickael-kerjean/filestash/server/plugin/plg_handler_mcp/utils" ) const ( @@ -25,19 +24,8 @@ func init() { }) } -func (this Server) WellKnownInfoHandler(_ *App, w http.ResponseWriter, r *http.Request) { - WithCors(w) - if r.Method != http.MethodGet && r.Method != http.MethodOptions { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - scheme := "https" - host := r.Host - if strings.HasPrefix(host, "localhost") || strings.HasPrefix(host, "127.0.0.1") { - scheme = "http" - } - baseURL := fmt.Sprintf("%s://%s", scheme, host) +func (this Server) WellKnownOAuthAuthorizationServerHandler(_ *App, w http.ResponseWriter, r *http.Request) { + baseURL := this.baseURL(r) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "issuer": baseURL, @@ -55,9 +43,27 @@ func (this Server) WellKnownInfoHandler(_ *App, w http.ResponseWriter, r *http.R }) } -func (this Server) AuthorizeHandler(_ *App, w http.ResponseWriter, r *http.Request) { - WithCors(w) +func (this Server) WellKnownOAuthProtectedResourceHandler(_ *App, w http.ResponseWriter, r *http.Request) { + baseURL := this.baseURL(r) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "resource": baseURL, + "authorization_servers": []string{baseURL}, + "bearer_methods_supported": []string{"header"}, + "scopes_supported": []string{"openid"}, + }) +} +func (this Server) baseURL(r *http.Request) string { + scheme := "https" + host := r.Host + if strings.HasPrefix(host, "localhost") || strings.HasPrefix(host, "127.0.0.1") { + scheme = "http" + } + return fmt.Sprintf("%s://%s", scheme, host) +} + +func (this Server) AuthorizeHandler(_ *App, w http.ResponseWriter, r *http.Request) { responseType := r.URL.Query().Get("response_type") clientID := r.URL.Query().Get("client_id") redirectURI := r.URL.Query().Get("redirect_uri") @@ -80,11 +86,6 @@ func (this Server) AuthorizeHandler(_ *App, w http.ResponseWriter, r *http.Reque } func (this Server) TokenHandler(_ *App, w http.ResponseWriter, r *http.Request) { - WithCors(w) - if r.Method != http.MethodPost && r.Method != http.MethodOptions { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } if err := r.ParseForm(); err != nil { http.Error(w, "Invalid request", http.StatusBadRequest) return @@ -106,11 +107,6 @@ func (this Server) TokenHandler(_ *App, w http.ResponseWriter, r *http.Request) } func (this Server) RegisterHandler(ctx *App, w http.ResponseWriter, r *http.Request) { - WithCors(w) - if r.Method != http.MethodPost && r.Method != http.MethodOptions { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } clientName := regexp.MustCompile("[^a-zA-Z0-9\\-]+").ReplaceAllString( fmt.Sprintf("%s", ctx.Body["client_name"]), "", diff --git a/server/plugin/plg_handler_mcp/index.go b/server/plugin/plg_handler_mcp/index.go index 08a00ab5..56956f7d 100644 --- a/server/plugin/plg_handler_mcp/index.go +++ b/server/plugin/plg_handler_mcp/index.go @@ -6,6 +6,7 @@ import ( . "github.com/mickael-kerjean/filestash/server/common" . "github.com/mickael-kerjean/filestash/server/middleware" . "github.com/mickael-kerjean/filestash/server/plugin/plg_handler_mcp/config" + . "github.com/mickael-kerjean/filestash/server/plugin/plg_handler_mcp/utils" "github.com/gorilla/mux" ) @@ -25,16 +26,20 @@ func init() { return nil } srv := Server{} - m := []Middleware{} - r.HandleFunc("/sse", NewMiddlewareChain(srv.sseHandler, m, *app)) - r.HandleFunc("/messages", NewMiddlewareChain(srv.messageHandler, m, *app)) - r.HandleFunc("/.well-known/oauth-authorization-server", NewMiddlewareChain(srv.WellKnownInfoHandler, m, *app)) - r.HandleFunc("/mcp/authorize", NewMiddlewareChain(srv.AuthorizeHandler, m, *app)) - r.HandleFunc("/mcp/token", NewMiddlewareChain(srv.TokenHandler, m, *app)) + m := []Middleware{WithCORS} + r.HandleFunc("/sse", NewMiddlewareChain(srv.sseHandler, m, *app)).Methods("GET", "OPTIONS") + r.HandleFunc("/messages", NewMiddlewareChain(srv.messageHandler, m, *app)).Methods("POST", "OPTIONS") + r.HandleFunc("/.well-known/oauth-authorization-server", NewMiddlewareChain(srv.WellKnownOAuthAuthorizationServerHandler, m, *app)).Methods("GET", "OPTIONS") + r.HandleFunc("/.well-known/oauth-protected-resource", NewMiddlewareChain(srv.WellKnownOAuthProtectedResourceHandler, m, *app)).Methods("GET", "OPTIONS") + r.HandleFunc("/.well-known/oauth-protected-resource/sse", NewMiddlewareChain(srv.WellKnownOAuthProtectedResourceHandler, m, *app)).Methods("GET", "OPTIONS") + + r.HandleFunc("/mcp/token", NewMiddlewareChain(srv.TokenHandler, m, *app)).Methods("POST") + m = []Middleware{} + r.HandleFunc("/mcp/authorize", NewMiddlewareChain(srv.AuthorizeHandler, m, *app)).Methods("GET") m = []Middleware{BodyParser} - r.HandleFunc("/mcp/register", NewMiddlewareChain(srv.RegisterHandler, m, *app)) + r.HandleFunc("/mcp/register", NewMiddlewareChain(srv.RegisterHandler, m, *app)).Methods("POST") m = []Middleware{SessionStart, LoggedInOnly} - r.HandleFunc("/api/mcp", NewMiddlewareChain(srv.CallbackHandler, m, *app)) + r.HandleFunc("/api/mcp", NewMiddlewareChain(srv.CallbackHandler, m, *app)).Methods("GET") return nil }) } diff --git a/server/plugin/plg_handler_mcp/utils/cors.go b/server/plugin/plg_handler_mcp/utils/cors.go index ee02f213..b48f5b7e 100644 --- a/server/plugin/plg_handler_mcp/utils/cors.go +++ b/server/plugin/plg_handler_mcp/utils/cors.go @@ -2,9 +2,18 @@ package utils import ( "net/http" + + . "github.com/mickael-kerjean/filestash/server/common" ) -func WithCors(w http.ResponseWriter) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Headers", "mcp-protocol-version, Content-Type") +func WithCORS(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Headers", "mcp-protocol-version, Content-Type, Authorization") + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + fn(ctx, w, r) + }) }