mirror of
https://github.com/mickael-kerjean/filestash
synced 2025-12-06 16:32:31 +01:00
62 lines
1.7 KiB
Go
62 lines
1.7 KiB
Go
package httperr
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
)
|
|
|
|
type onErrorIndexType int
|
|
|
|
const onErrorIndex onErrorIndexType = iota
|
|
|
|
// Middleware wraps the provided handler with middleware that captures errors which
|
|
// are returned from HandlerFunc, or reported via ReportError, and invokes the provided
|
|
// callback to render them. If the handler returns a status code >= 400, the response is
|
|
// captured and passed to OnError as a Response.
|
|
//
|
|
type Middleware struct {
|
|
// OnError is a function that is called then a request fails with an error. If this function
|
|
// returns nil, then the error is assumed to be handled. If it returns a non-nil error, then
|
|
// that error is written to the client with Write()
|
|
OnError func(w http.ResponseWriter, r *http.Request, err error) error
|
|
|
|
// Handler is the next handler
|
|
Handler http.Handler
|
|
}
|
|
|
|
func (m Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
var unwrappedWriter = w
|
|
var wrappedWriter *basicWriter
|
|
if m.OnError != nil {
|
|
wrappedWriter, w = wrapWriter(w)
|
|
}
|
|
|
|
var didCallOnError bool
|
|
r = r.WithContext(context.WithValue(r.Context(), onErrorIndex, func(err error) {
|
|
if m.OnError != nil {
|
|
didCallOnError = true
|
|
handlerErr := m.OnError(unwrappedWriter, r, err)
|
|
if handlerErr != nil {
|
|
Write(unwrappedWriter, r, handlerErr)
|
|
}
|
|
}
|
|
}))
|
|
|
|
m.Handler.ServeHTTP(w, r)
|
|
|
|
if wrappedWriter != nil && wrappedWriter.statusCode >= 400 && !didCallOnError {
|
|
err := Response(*wrappedWriter.copy)
|
|
handlerErr := m.OnError(unwrappedWriter, r, err)
|
|
if handlerErr != nil {
|
|
Write(unwrappedWriter, r, handlerErr)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ReportError reports the error to the function given in
|
|
// OnError.
|
|
func ReportError(r *http.Request, err error) {
|
|
if v := r.Context().Value(onErrorIndex); v != nil {
|
|
v.(func(error))(err)
|
|
}
|
|
}
|