stash/vendor/github.com/asticode/go-astikit/http.go
2022-11-07 12:33:15 +11:00

632 lines
16 KiB
Go

package astikit
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
)
var ErrHTTPSenderUnmarshaledError = errors.New("astikit: unmarshaled error")
// ServeHTTPOptions represents serve options
type ServeHTTPOptions struct {
Addr string
Handler http.Handler
}
// ServeHTTP spawns an HTTP server
func ServeHTTP(w *Worker, o ServeHTTPOptions) {
// Create server
s := &http.Server{Addr: o.Addr, Handler: o.Handler}
// Execute in a task
w.NewTask().Do(func() {
// Log
w.Logger().Infof("astikit: serving on %s", o.Addr)
// Serve
var done = make(chan error)
go func() {
if err := s.ListenAndServe(); err != nil {
done <- err
}
}()
// Wait for context or done to be done
select {
case <-w.ctx.Done():
if w.ctx.Err() != context.Canceled {
w.Logger().Error(fmt.Errorf("astikit: context error: %w", w.ctx.Err()))
}
case err := <-done:
if err != nil {
w.Logger().Error(fmt.Errorf("astikit: serving failed: %w", err))
}
}
// Shutdown
w.Logger().Infof("astikit: shutting down server on %s", o.Addr)
if err := s.Shutdown(context.Background()); err != nil {
w.Logger().Error(fmt.Errorf("astikit: shutting down server on %s failed: %w", o.Addr, err))
}
})
}
// HTTPClient represents an HTTP client
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}
// HTTPSender represents an object capable of sending http requests
type HTTPSender struct {
client HTTPClient
l SeverityLogger
retryFunc HTTPSenderRetryFunc
retryMax int
retrySleep time.Duration
timeout time.Duration
}
// HTTPSenderRetryFunc is a function that decides whether to retry an HTTP request
type HTTPSenderRetryFunc func(resp *http.Response) error
// HTTPSenderOptions represents HTTPSender options
type HTTPSenderOptions struct {
Client HTTPClient
Logger StdLogger
RetryFunc HTTPSenderRetryFunc
RetryMax int
RetrySleep time.Duration
Timeout time.Duration
}
// NewHTTPSender creates a new HTTP sender
func NewHTTPSender(o HTTPSenderOptions) (s *HTTPSender) {
s = &HTTPSender{
client: o.Client,
l: AdaptStdLogger(o.Logger),
retryFunc: o.RetryFunc,
retryMax: o.RetryMax,
retrySleep: o.RetrySleep,
timeout: o.Timeout,
}
if s.client == nil {
s.client = &http.Client{}
}
if s.retryFunc == nil {
s.retryFunc = s.defaultHTTPRetryFunc
}
return
}
func (s *HTTPSender) defaultHTTPRetryFunc(resp *http.Response) error {
if resp.StatusCode >= http.StatusInternalServerError {
return fmt.Errorf("astikit: invalid status code %d", resp.StatusCode)
}
return nil
}
// Send sends a new *http.Request
func (s *HTTPSender) Send(req *http.Request) (*http.Response, error) {
return s.SendWithTimeout(req, s.timeout)
}
// SendWithTimeout sends a new *http.Request with a timeout
func (s *HTTPSender) SendWithTimeout(req *http.Request, timeout time.Duration) (resp *http.Response, err error) {
// Set name
name := req.Method + " request"
if req.URL != nil {
name += " to " + req.URL.String()
}
// Timeout
if timeout > 0 {
// Create context
ctx, cancel := context.WithTimeout(req.Context(), timeout)
defer cancel()
// Update request
req = req.WithContext(ctx)
// Update name
name += " with timeout " + timeout.String()
}
// Loop
// We start at retryMax + 1 so that it runs at least once even if retryMax == 0
tries := 0
for retriesLeft := s.retryMax + 1; retriesLeft > 0; retriesLeft-- {
// Get request name
nr := name + " (" + strconv.Itoa(s.retryMax-retriesLeft+2) + "/" + strconv.Itoa(s.retryMax+1) + ")"
tries++
// Send request
s.l.Debugf("astikit: sending %s", nr)
if resp, err = s.client.Do(req); err != nil {
// Retry if error is temporary, stop here otherwise
if netError, ok := err.(net.Error); !ok || !netError.Temporary() {
err = fmt.Errorf("astikit: sending %s failed: %w", nr, err)
return
}
} else if err = req.Context().Err(); err != nil {
err = fmt.Errorf("astikit: request context failed: %w", err)
return
} else {
err = s.retryFunc(resp)
}
// Retry
if err != nil {
if retriesLeft > 1 {
s.l.Errorf("astikit: sending %s failed, sleeping %s and retrying... (%d retries left): %w", nr, s.retrySleep, retriesLeft-1, err)
time.Sleep(s.retrySleep)
}
continue
}
// Return if conditions for retrying were not met
return
}
// Max retries limit reached
err = fmt.Errorf("astikit: sending %s failed after %d tries: %w", name, tries, err)
return
}
// HTTPSendJSONOptions represents SendJSON options
type HTTPSendJSONOptions struct {
BodyError interface{}
BodyIn interface{}
BodyOut interface{}
Headers map[string]string
Method string
URL string
}
// SendJSON sends a new JSON HTTP request
func (s *HTTPSender) SendJSON(o HTTPSendJSONOptions) (err error) {
// Marshal body in
var bi io.Reader
if o.BodyIn != nil {
bb := &bytes.Buffer{}
if err = json.NewEncoder(bb).Encode(o.BodyIn); err != nil {
err = fmt.Errorf("astikit: marshaling body in failed: %w", err)
return
}
bi = bb
}
// Create request
var req *http.Request
if req, err = http.NewRequest(o.Method, o.URL, bi); err != nil {
err = fmt.Errorf("astikit: creating request failed: %w", err)
return
}
// Add headers
for k, v := range o.Headers {
req.Header.Set(k, v)
}
// Send request
var resp *http.Response
if resp, err = s.Send(req); err != nil {
err = fmt.Errorf("astikit: sending request failed: %w", err)
return
}
defer resp.Body.Close()
// Process status code
if code := resp.StatusCode; code < 200 || code > 299 {
// Try unmarshaling error
if o.BodyError != nil {
if err2 := json.NewDecoder(resp.Body).Decode(o.BodyError); err2 == nil {
err = ErrHTTPSenderUnmarshaledError
return
}
}
// Default error
err = fmt.Errorf("astikit: invalid status code %d", code)
return
}
// Unmarshal body out
if o.BodyOut != nil {
if err = json.NewDecoder(resp.Body).Decode(o.BodyOut); err != nil {
err = fmt.Errorf("astikit: unmarshaling failed: %w", err)
return
}
}
return
}
// HTTPResponseFunc is a func that can process an $http.Response
type HTTPResponseFunc func(resp *http.Response) error
func defaultHTTPResponseFunc(resp *http.Response) (err error) {
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
err = fmt.Errorf("astikit: invalid status code %d", resp.StatusCode)
return
}
return
}
// HTTPDownloader represents an object capable of downloading several HTTP srcs simultaneously
// and doing stuff to the results
type HTTPDownloader struct {
bp *BufferPool
l *GoroutineLimiter
responseFunc HTTPResponseFunc
s *HTTPSender
}
// HTTPDownloaderOptions represents HTTPDownloader options
type HTTPDownloaderOptions struct {
Limiter GoroutineLimiterOptions
ResponseFunc HTTPResponseFunc
Sender HTTPSenderOptions
}
// NewHTTPDownloader creates a new HTTPDownloader
func NewHTTPDownloader(o HTTPDownloaderOptions) (d *HTTPDownloader) {
d = &HTTPDownloader{
bp: NewBufferPool(),
l: NewGoroutineLimiter(o.Limiter),
responseFunc: o.ResponseFunc,
s: NewHTTPSender(o.Sender),
}
if d.responseFunc == nil {
d.responseFunc = defaultHTTPResponseFunc
}
return
}
// Close closes the downloader properly
func (d *HTTPDownloader) Close() error {
return d.l.Close()
}
type HTTPDownloaderSrc struct {
Body io.Reader
Header http.Header
Method string
URL string
}
// It is the responsibility of the caller to call i.Close()
type httpDownloaderFunc func(ctx context.Context, idx int, i *BufferPoolItem) error
func (d *HTTPDownloader) do(ctx context.Context, fn httpDownloaderFunc, idx int, src HTTPDownloaderSrc) (err error) {
// Defaults
if src.Method == "" {
src.Method = http.MethodGet
}
// Create request
var r *http.Request
if r, err = http.NewRequestWithContext(ctx, src.Method, src.URL, src.Body); err != nil {
err = fmt.Errorf("astikit: creating request to %s failed: %w", src.URL, err)
return
}
// Copy header
for k := range src.Header {
r.Header.Set(k, src.Header.Get(k))
}
// Send request
var resp *http.Response
if resp, err = d.s.Send(r); err != nil {
err = fmt.Errorf("astikit: sending request to %s failed: %w", src.URL, err)
return
}
defer resp.Body.Close()
// Create buffer pool item
buf := d.bp.New()
// Process response
if err = d.responseFunc(resp); err != nil {
err = fmt.Errorf("astikit: response for request to %s is invalid: %w", src.URL, err)
return
}
// Copy body
if _, err = Copy(ctx, buf, resp.Body); err != nil {
err = fmt.Errorf("astikit: copying body of %s failed: %w", src.URL, err)
return
}
// Custom
if err = fn(ctx, idx, buf); err != nil {
err = fmt.Errorf("astikit: custom callback on %s failed: %w", src.URL, err)
return
}
return
}
func (d *HTTPDownloader) download(ctx context.Context, srcs []HTTPDownloaderSrc, fn httpDownloaderFunc) (err error) {
// Nothing to download
if len(srcs) == 0 {
return nil
}
// Loop through srcs
wg := &sync.WaitGroup{}
wg.Add(len(srcs))
for idx, src := range srcs {
func(idx int, src HTTPDownloaderSrc) {
// Update error with ctx
if ctx.Err() != nil {
err = ctx.Err()
}
// Do nothing if error
if err != nil {
wg.Done()
return
}
// Do
d.l.Do(func() {
// Task is done
defer wg.Done()
// Do
if errD := d.do(ctx, fn, idx, src); errD != nil && err == nil {
err = errD
return
}
})
}(idx, src)
}
// Wait
wg.Wait()
return
}
// DownloadInDirectory downloads in parallel a set of srcs and saves them in a dst directory
func (d *HTTPDownloader) DownloadInDirectory(ctx context.Context, dst string, srcs ...HTTPDownloaderSrc) error {
return d.download(ctx, srcs, func(ctx context.Context, idx int, buf *BufferPoolItem) (err error) {
// Make sure to close buffer
defer buf.Close()
// Make sure destination directory exists
if err = os.MkdirAll(dst, DefaultDirMode); err != nil {
err = fmt.Errorf("astikit: mkdirall %s failed: %w", dst, err)
return
}
// Create destination file
var f *os.File
dst := filepath.Join(dst, filepath.Base(srcs[idx].URL))
if f, err = os.Create(dst); err != nil {
err = fmt.Errorf("astikit: creating %s failed: %w", dst, err)
return
}
defer f.Close()
// Copy buffer
if _, err = Copy(ctx, f, buf); err != nil {
err = fmt.Errorf("astikit: copying content to %s failed: %w", dst, err)
return
}
return
})
}
// DownloadInWriter downloads in parallel a set of srcs and concatenates them in a writer while
// maintaining the initial order
func (d *HTTPDownloader) DownloadInWriter(ctx context.Context, dst io.Writer, srcs ...HTTPDownloaderSrc) error {
// Init
type chunk struct {
buf *BufferPoolItem
idx int
}
var cs []chunk
var m sync.Mutex // Locks cs
var requiredIdx int
// Make sure to close all buffers
defer func() {
for _, c := range cs {
c.buf.Close()
}
}()
// Download
return d.download(ctx, srcs, func(ctx context.Context, idx int, buf *BufferPoolItem) (err error) {
// Lock
m.Lock()
defer m.Unlock()
// Check where to insert chunk
var idxInsert = -1
for idxChunk := 0; idxChunk < len(cs); idxChunk++ {
if idx < cs[idxChunk].idx {
idxInsert = idxChunk
break
}
}
// Create chunk
c := chunk{
buf: buf,
idx: idx,
}
// Add chunk
if idxInsert > -1 {
cs = append(cs[:idxInsert], append([]chunk{c}, cs[idxInsert:]...)...)
} else {
cs = append(cs, c)
}
// Loop through chunks
for idxChunk := 0; idxChunk < len(cs); idxChunk++ {
// Get chunk
c := cs[idxChunk]
// The chunk should be copied
if c.idx == requiredIdx {
// Copy chunk content
// Do not check error right away since we still want to close the buffer
// and remove the chunk
_, err = Copy(ctx, dst, c.buf)
// Close buffer
c.buf.Close()
// Remove chunk
requiredIdx++
cs = append(cs[:idxChunk], cs[idxChunk+1:]...)
idxChunk--
// Check error
if err != nil {
err = fmt.Errorf("astikit: copying chunk #%d to dst failed: %w", c.idx, err)
return
}
}
}
return
})
}
// DownloadInFile downloads in parallel a set of srcs and concatenates them in a dst file while
// maintaining the initial order
func (d *HTTPDownloader) DownloadInFile(ctx context.Context, dst string, srcs ...HTTPDownloaderSrc) (err error) {
// Make sure destination directory exists
if err = os.MkdirAll(filepath.Dir(dst), DefaultDirMode); err != nil {
err = fmt.Errorf("astikit: mkdirall %s failed: %w", filepath.Dir(dst), err)
return
}
// Create destination file
var f *os.File
if f, err = os.Create(dst); err != nil {
err = fmt.Errorf("astikit: creating %s failed: %w", dst, err)
return
}
defer f.Close()
// Download in writer
return d.DownloadInWriter(ctx, f, srcs...)
}
// HTTPMiddleware represents an HTTP middleware
type HTTPMiddleware func(http.Handler) http.Handler
// ChainHTTPMiddlewares chains HTTP middlewares
func ChainHTTPMiddlewares(h http.Handler, ms ...HTTPMiddleware) http.Handler {
return ChainHTTPMiddlewaresWithPrefix(h, []string{}, ms...)
}
// ChainHTTPMiddlewaresWithPrefix chains HTTP middlewares if one of prefixes is present
func ChainHTTPMiddlewaresWithPrefix(h http.Handler, prefixes []string, ms ...HTTPMiddleware) http.Handler {
for _, m := range ms {
if m == nil {
continue
}
if len(prefixes) == 0 {
h = m(h)
} else {
t := h
h = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
for _, prefix := range prefixes {
if strings.HasPrefix(r.URL.EscapedPath(), prefix) {
m(t).ServeHTTP(rw, r)
return
}
}
t.ServeHTTP(rw, r)
})
}
}
return h
}
func handleHTTPBasicAuth(username, password string, rw http.ResponseWriter, r *http.Request) bool {
if u, p, ok := r.BasicAuth(); !ok || u != username || p != password {
rw.Header().Set("WWW-Authenticate", "Basic Realm=Please enter your credentials")
rw.WriteHeader(http.StatusUnauthorized)
return true
}
return false
}
// HTTPMiddlewareBasicAuth adds basic HTTP auth to an HTTP handler
func HTTPMiddlewareBasicAuth(username, password string) HTTPMiddleware {
if username == "" && password == "" {
return nil
}
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Handle basic auth
if handleHTTPBasicAuth(username, password, rw, r) {
return
}
// Next handler
h.ServeHTTP(rw, r)
})
}
}
func setHTTPContentType(contentType string, rw http.ResponseWriter) {
rw.Header().Set("Content-Type", contentType)
}
// HTTPMiddlewareContentType adds a content type to an HTTP handler
func HTTPMiddlewareContentType(contentType string) HTTPMiddleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Set content type
setHTTPContentType(contentType, rw)
// Next handler
h.ServeHTTP(rw, r)
})
}
}
func setHTTPHeaders(vs map[string]string, rw http.ResponseWriter) {
for k, v := range vs {
rw.Header().Set(k, v)
}
}
// HTTPMiddlewareHeaders adds headers to an HTTP handler
func HTTPMiddlewareHeaders(vs map[string]string) HTTPMiddleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Set headers
setHTTPHeaders(vs, rw)
// Next handler
h.ServeHTTP(rw, r)
})
}
}
// HTTPMiddlewareCORSHeaders adds CORS headers to an HTTP handler
func HTTPMiddlewareCORSHeaders() HTTPMiddleware {
return HTTPMiddlewareHeaders(map[string]string{
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Origin": "*",
})
}