filestash/vendor/github.com/crewjam/httperr/middleware.go
2022-01-25 23:26:55 +11:00

62 lines
1.7 KiB
Go

package httperr
import (
"context"
"net/http"
)
type onErrorIndexType int
const onErrorIndex onErrorIndexType = iota
// Middleware wraps the provided handler with middleware that captures errors which
// are returned from HandlerFunc, or reported via ReportError, and invokes the provided
// callback to render them. If the handler returns a status code >= 400, the response is
// captured and passed to OnError as a Response.
//
type Middleware struct {
// OnError is a function that is called then a request fails with an error. If this function
// returns nil, then the error is assumed to be handled. If it returns a non-nil error, then
// that error is written to the client with Write()
OnError func(w http.ResponseWriter, r *http.Request, err error) error
// Handler is the next handler
Handler http.Handler
}
func (m Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var unwrappedWriter = w
var wrappedWriter *basicWriter
if m.OnError != nil {
wrappedWriter, w = wrapWriter(w)
}
var didCallOnError bool
r = r.WithContext(context.WithValue(r.Context(), onErrorIndex, func(err error) {
if m.OnError != nil {
didCallOnError = true
handlerErr := m.OnError(unwrappedWriter, r, err)
if handlerErr != nil {
Write(unwrappedWriter, r, handlerErr)
}
}
}))
m.Handler.ServeHTTP(w, r)
if wrappedWriter != nil && wrappedWriter.statusCode >= 400 && !didCallOnError {
err := Response(*wrappedWriter.copy)
handlerErr := m.OnError(unwrappedWriter, r, err)
if handlerErr != nil {
Write(unwrappedWriter, r, handlerErr)
}
}
}
// ReportError reports the error to the function given in
// OnError.
func ReportError(r *http.Request, err error) {
if v := r.Context().Value(onErrorIndex); v != nil {
v.(func(error))(err)
}
}