mirror of
https://github.com/stashapp/stash.git
synced 2025-12-11 10:54:14 +01:00
632 lines
16 KiB
Go
632 lines
16 KiB
Go
package astikit
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
var ErrHTTPSenderUnmarshaledError = errors.New("astikit: unmarshaled error")
|
|
|
|
// ServeHTTPOptions represents serve options
|
|
type ServeHTTPOptions struct {
|
|
Addr string
|
|
Handler http.Handler
|
|
}
|
|
|
|
// ServeHTTP spawns an HTTP server
|
|
func ServeHTTP(w *Worker, o ServeHTTPOptions) {
|
|
// Create server
|
|
s := &http.Server{Addr: o.Addr, Handler: o.Handler}
|
|
|
|
// Execute in a task
|
|
w.NewTask().Do(func() {
|
|
// Log
|
|
w.Logger().Infof("astikit: serving on %s", o.Addr)
|
|
|
|
// Serve
|
|
var done = make(chan error)
|
|
go func() {
|
|
if err := s.ListenAndServe(); err != nil {
|
|
done <- err
|
|
}
|
|
}()
|
|
|
|
// Wait for context or done to be done
|
|
select {
|
|
case <-w.ctx.Done():
|
|
if w.ctx.Err() != context.Canceled {
|
|
w.Logger().Error(fmt.Errorf("astikit: context error: %w", w.ctx.Err()))
|
|
}
|
|
case err := <-done:
|
|
if err != nil {
|
|
w.Logger().Error(fmt.Errorf("astikit: serving failed: %w", err))
|
|
}
|
|
}
|
|
|
|
// Shutdown
|
|
w.Logger().Infof("astikit: shutting down server on %s", o.Addr)
|
|
if err := s.Shutdown(context.Background()); err != nil {
|
|
w.Logger().Error(fmt.Errorf("astikit: shutting down server on %s failed: %w", o.Addr, err))
|
|
}
|
|
})
|
|
}
|
|
|
|
// HTTPClient represents an HTTP client
|
|
type HTTPClient interface {
|
|
Do(req *http.Request) (*http.Response, error)
|
|
}
|
|
|
|
// HTTPSender represents an object capable of sending http requests
|
|
type HTTPSender struct {
|
|
client HTTPClient
|
|
l SeverityLogger
|
|
retryFunc HTTPSenderRetryFunc
|
|
retryMax int
|
|
retrySleep time.Duration
|
|
timeout time.Duration
|
|
}
|
|
|
|
// HTTPSenderRetryFunc is a function that decides whether to retry an HTTP request
|
|
type HTTPSenderRetryFunc func(resp *http.Response) error
|
|
|
|
// HTTPSenderOptions represents HTTPSender options
|
|
type HTTPSenderOptions struct {
|
|
Client HTTPClient
|
|
Logger StdLogger
|
|
RetryFunc HTTPSenderRetryFunc
|
|
RetryMax int
|
|
RetrySleep time.Duration
|
|
Timeout time.Duration
|
|
}
|
|
|
|
// NewHTTPSender creates a new HTTP sender
|
|
func NewHTTPSender(o HTTPSenderOptions) (s *HTTPSender) {
|
|
s = &HTTPSender{
|
|
client: o.Client,
|
|
l: AdaptStdLogger(o.Logger),
|
|
retryFunc: o.RetryFunc,
|
|
retryMax: o.RetryMax,
|
|
retrySleep: o.RetrySleep,
|
|
timeout: o.Timeout,
|
|
}
|
|
if s.client == nil {
|
|
s.client = &http.Client{}
|
|
}
|
|
if s.retryFunc == nil {
|
|
s.retryFunc = s.defaultHTTPRetryFunc
|
|
}
|
|
return
|
|
}
|
|
|
|
func (s *HTTPSender) defaultHTTPRetryFunc(resp *http.Response) error {
|
|
if resp.StatusCode >= http.StatusInternalServerError {
|
|
return fmt.Errorf("astikit: invalid status code %d", resp.StatusCode)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Send sends a new *http.Request
|
|
func (s *HTTPSender) Send(req *http.Request) (*http.Response, error) {
|
|
return s.SendWithTimeout(req, s.timeout)
|
|
}
|
|
|
|
// SendWithTimeout sends a new *http.Request with a timeout
|
|
func (s *HTTPSender) SendWithTimeout(req *http.Request, timeout time.Duration) (resp *http.Response, err error) {
|
|
// Set name
|
|
name := req.Method + " request"
|
|
if req.URL != nil {
|
|
name += " to " + req.URL.String()
|
|
}
|
|
|
|
// Timeout
|
|
if timeout > 0 {
|
|
// Create context
|
|
ctx, cancel := context.WithTimeout(req.Context(), timeout)
|
|
defer cancel()
|
|
|
|
// Update request
|
|
req = req.WithContext(ctx)
|
|
|
|
// Update name
|
|
name += " with timeout " + timeout.String()
|
|
}
|
|
|
|
// Loop
|
|
// We start at retryMax + 1 so that it runs at least once even if retryMax == 0
|
|
tries := 0
|
|
for retriesLeft := s.retryMax + 1; retriesLeft > 0; retriesLeft-- {
|
|
// Get request name
|
|
nr := name + " (" + strconv.Itoa(s.retryMax-retriesLeft+2) + "/" + strconv.Itoa(s.retryMax+1) + ")"
|
|
tries++
|
|
|
|
// Send request
|
|
s.l.Debugf("astikit: sending %s", nr)
|
|
if resp, err = s.client.Do(req); err != nil {
|
|
// Retry if error is temporary, stop here otherwise
|
|
if netError, ok := err.(net.Error); !ok || !netError.Temporary() {
|
|
err = fmt.Errorf("astikit: sending %s failed: %w", nr, err)
|
|
return
|
|
}
|
|
} else if err = req.Context().Err(); err != nil {
|
|
err = fmt.Errorf("astikit: request context failed: %w", err)
|
|
return
|
|
} else {
|
|
err = s.retryFunc(resp)
|
|
}
|
|
|
|
// Retry
|
|
if err != nil {
|
|
if retriesLeft > 1 {
|
|
s.l.Errorf("astikit: sending %s failed, sleeping %s and retrying... (%d retries left): %w", nr, s.retrySleep, retriesLeft-1, err)
|
|
time.Sleep(s.retrySleep)
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Return if conditions for retrying were not met
|
|
return
|
|
}
|
|
|
|
// Max retries limit reached
|
|
err = fmt.Errorf("astikit: sending %s failed after %d tries: %w", name, tries, err)
|
|
return
|
|
}
|
|
|
|
// HTTPSendJSONOptions represents SendJSON options
|
|
type HTTPSendJSONOptions struct {
|
|
BodyError interface{}
|
|
BodyIn interface{}
|
|
BodyOut interface{}
|
|
Headers map[string]string
|
|
Method string
|
|
URL string
|
|
}
|
|
|
|
// SendJSON sends a new JSON HTTP request
|
|
func (s *HTTPSender) SendJSON(o HTTPSendJSONOptions) (err error) {
|
|
// Marshal body in
|
|
var bi io.Reader
|
|
if o.BodyIn != nil {
|
|
bb := &bytes.Buffer{}
|
|
if err = json.NewEncoder(bb).Encode(o.BodyIn); err != nil {
|
|
err = fmt.Errorf("astikit: marshaling body in failed: %w", err)
|
|
return
|
|
}
|
|
bi = bb
|
|
}
|
|
|
|
// Create request
|
|
var req *http.Request
|
|
if req, err = http.NewRequest(o.Method, o.URL, bi); err != nil {
|
|
err = fmt.Errorf("astikit: creating request failed: %w", err)
|
|
return
|
|
}
|
|
|
|
// Add headers
|
|
for k, v := range o.Headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
|
|
// Send request
|
|
var resp *http.Response
|
|
if resp, err = s.Send(req); err != nil {
|
|
err = fmt.Errorf("astikit: sending request failed: %w", err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Process status code
|
|
if code := resp.StatusCode; code < 200 || code > 299 {
|
|
// Try unmarshaling error
|
|
if o.BodyError != nil {
|
|
if err2 := json.NewDecoder(resp.Body).Decode(o.BodyError); err2 == nil {
|
|
err = ErrHTTPSenderUnmarshaledError
|
|
return
|
|
}
|
|
}
|
|
|
|
// Default error
|
|
err = fmt.Errorf("astikit: invalid status code %d", code)
|
|
return
|
|
}
|
|
|
|
// Unmarshal body out
|
|
if o.BodyOut != nil {
|
|
if err = json.NewDecoder(resp.Body).Decode(o.BodyOut); err != nil {
|
|
err = fmt.Errorf("astikit: unmarshaling failed: %w", err)
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// HTTPResponseFunc is a func that can process an $http.Response
|
|
type HTTPResponseFunc func(resp *http.Response) error
|
|
|
|
func defaultHTTPResponseFunc(resp *http.Response) (err error) {
|
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
err = fmt.Errorf("astikit: invalid status code %d", resp.StatusCode)
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
// HTTPDownloader represents an object capable of downloading several HTTP srcs simultaneously
|
|
// and doing stuff to the results
|
|
type HTTPDownloader struct {
|
|
bp *BufferPool
|
|
l *GoroutineLimiter
|
|
responseFunc HTTPResponseFunc
|
|
s *HTTPSender
|
|
}
|
|
|
|
// HTTPDownloaderOptions represents HTTPDownloader options
|
|
type HTTPDownloaderOptions struct {
|
|
Limiter GoroutineLimiterOptions
|
|
ResponseFunc HTTPResponseFunc
|
|
Sender HTTPSenderOptions
|
|
}
|
|
|
|
// NewHTTPDownloader creates a new HTTPDownloader
|
|
func NewHTTPDownloader(o HTTPDownloaderOptions) (d *HTTPDownloader) {
|
|
d = &HTTPDownloader{
|
|
bp: NewBufferPool(),
|
|
l: NewGoroutineLimiter(o.Limiter),
|
|
responseFunc: o.ResponseFunc,
|
|
s: NewHTTPSender(o.Sender),
|
|
}
|
|
if d.responseFunc == nil {
|
|
d.responseFunc = defaultHTTPResponseFunc
|
|
}
|
|
return
|
|
}
|
|
|
|
// Close closes the downloader properly
|
|
func (d *HTTPDownloader) Close() error {
|
|
return d.l.Close()
|
|
}
|
|
|
|
type HTTPDownloaderSrc struct {
|
|
Body io.Reader
|
|
Header http.Header
|
|
Method string
|
|
URL string
|
|
}
|
|
|
|
// It is the responsibility of the caller to call i.Close()
|
|
type httpDownloaderFunc func(ctx context.Context, idx int, i *BufferPoolItem) error
|
|
|
|
func (d *HTTPDownloader) do(ctx context.Context, fn httpDownloaderFunc, idx int, src HTTPDownloaderSrc) (err error) {
|
|
// Defaults
|
|
if src.Method == "" {
|
|
src.Method = http.MethodGet
|
|
}
|
|
|
|
// Create request
|
|
var r *http.Request
|
|
if r, err = http.NewRequestWithContext(ctx, src.Method, src.URL, src.Body); err != nil {
|
|
err = fmt.Errorf("astikit: creating request to %s failed: %w", src.URL, err)
|
|
return
|
|
}
|
|
|
|
// Copy header
|
|
for k := range src.Header {
|
|
r.Header.Set(k, src.Header.Get(k))
|
|
}
|
|
|
|
// Send request
|
|
var resp *http.Response
|
|
if resp, err = d.s.Send(r); err != nil {
|
|
err = fmt.Errorf("astikit: sending request to %s failed: %w", src.URL, err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Create buffer pool item
|
|
buf := d.bp.New()
|
|
|
|
// Process response
|
|
if err = d.responseFunc(resp); err != nil {
|
|
err = fmt.Errorf("astikit: response for request to %s is invalid: %w", src.URL, err)
|
|
return
|
|
}
|
|
|
|
// Copy body
|
|
if _, err = Copy(ctx, buf, resp.Body); err != nil {
|
|
err = fmt.Errorf("astikit: copying body of %s failed: %w", src.URL, err)
|
|
return
|
|
}
|
|
|
|
// Custom
|
|
if err = fn(ctx, idx, buf); err != nil {
|
|
err = fmt.Errorf("astikit: custom callback on %s failed: %w", src.URL, err)
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
func (d *HTTPDownloader) download(ctx context.Context, srcs []HTTPDownloaderSrc, fn httpDownloaderFunc) (err error) {
|
|
// Nothing to download
|
|
if len(srcs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Loop through srcs
|
|
wg := &sync.WaitGroup{}
|
|
wg.Add(len(srcs))
|
|
for idx, src := range srcs {
|
|
func(idx int, src HTTPDownloaderSrc) {
|
|
// Update error with ctx
|
|
if ctx.Err() != nil {
|
|
err = ctx.Err()
|
|
}
|
|
|
|
// Do nothing if error
|
|
if err != nil {
|
|
wg.Done()
|
|
return
|
|
}
|
|
|
|
// Do
|
|
d.l.Do(func() {
|
|
// Task is done
|
|
defer wg.Done()
|
|
|
|
// Do
|
|
if errD := d.do(ctx, fn, idx, src); errD != nil && err == nil {
|
|
err = errD
|
|
return
|
|
}
|
|
})
|
|
}(idx, src)
|
|
}
|
|
|
|
// Wait
|
|
wg.Wait()
|
|
return
|
|
}
|
|
|
|
// DownloadInDirectory downloads in parallel a set of srcs and saves them in a dst directory
|
|
func (d *HTTPDownloader) DownloadInDirectory(ctx context.Context, dst string, srcs ...HTTPDownloaderSrc) error {
|
|
return d.download(ctx, srcs, func(ctx context.Context, idx int, buf *BufferPoolItem) (err error) {
|
|
// Make sure to close buffer
|
|
defer buf.Close()
|
|
|
|
// Make sure destination directory exists
|
|
if err = os.MkdirAll(dst, DefaultDirMode); err != nil {
|
|
err = fmt.Errorf("astikit: mkdirall %s failed: %w", dst, err)
|
|
return
|
|
}
|
|
|
|
// Create destination file
|
|
var f *os.File
|
|
dst := filepath.Join(dst, filepath.Base(srcs[idx].URL))
|
|
if f, err = os.Create(dst); err != nil {
|
|
err = fmt.Errorf("astikit: creating %s failed: %w", dst, err)
|
|
return
|
|
}
|
|
defer f.Close()
|
|
|
|
// Copy buffer
|
|
if _, err = Copy(ctx, f, buf); err != nil {
|
|
err = fmt.Errorf("astikit: copying content to %s failed: %w", dst, err)
|
|
return
|
|
}
|
|
return
|
|
})
|
|
}
|
|
|
|
// DownloadInWriter downloads in parallel a set of srcs and concatenates them in a writer while
|
|
// maintaining the initial order
|
|
func (d *HTTPDownloader) DownloadInWriter(ctx context.Context, dst io.Writer, srcs ...HTTPDownloaderSrc) error {
|
|
// Init
|
|
type chunk struct {
|
|
buf *BufferPoolItem
|
|
idx int
|
|
}
|
|
var cs []chunk
|
|
var m sync.Mutex // Locks cs
|
|
var requiredIdx int
|
|
|
|
// Make sure to close all buffers
|
|
defer func() {
|
|
for _, c := range cs {
|
|
c.buf.Close()
|
|
}
|
|
}()
|
|
|
|
// Download
|
|
return d.download(ctx, srcs, func(ctx context.Context, idx int, buf *BufferPoolItem) (err error) {
|
|
// Lock
|
|
m.Lock()
|
|
defer m.Unlock()
|
|
|
|
// Check where to insert chunk
|
|
var idxInsert = -1
|
|
for idxChunk := 0; idxChunk < len(cs); idxChunk++ {
|
|
if idx < cs[idxChunk].idx {
|
|
idxInsert = idxChunk
|
|
break
|
|
}
|
|
}
|
|
|
|
// Create chunk
|
|
c := chunk{
|
|
buf: buf,
|
|
idx: idx,
|
|
}
|
|
|
|
// Add chunk
|
|
if idxInsert > -1 {
|
|
cs = append(cs[:idxInsert], append([]chunk{c}, cs[idxInsert:]...)...)
|
|
} else {
|
|
cs = append(cs, c)
|
|
}
|
|
|
|
// Loop through chunks
|
|
for idxChunk := 0; idxChunk < len(cs); idxChunk++ {
|
|
// Get chunk
|
|
c := cs[idxChunk]
|
|
|
|
// The chunk should be copied
|
|
if c.idx == requiredIdx {
|
|
// Copy chunk content
|
|
// Do not check error right away since we still want to close the buffer
|
|
// and remove the chunk
|
|
_, err = Copy(ctx, dst, c.buf)
|
|
|
|
// Close buffer
|
|
c.buf.Close()
|
|
|
|
// Remove chunk
|
|
requiredIdx++
|
|
cs = append(cs[:idxChunk], cs[idxChunk+1:]...)
|
|
idxChunk--
|
|
|
|
// Check error
|
|
if err != nil {
|
|
err = fmt.Errorf("astikit: copying chunk #%d to dst failed: %w", c.idx, err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
return
|
|
})
|
|
}
|
|
|
|
// DownloadInFile downloads in parallel a set of srcs and concatenates them in a dst file while
|
|
// maintaining the initial order
|
|
func (d *HTTPDownloader) DownloadInFile(ctx context.Context, dst string, srcs ...HTTPDownloaderSrc) (err error) {
|
|
// Make sure destination directory exists
|
|
if err = os.MkdirAll(filepath.Dir(dst), DefaultDirMode); err != nil {
|
|
err = fmt.Errorf("astikit: mkdirall %s failed: %w", filepath.Dir(dst), err)
|
|
return
|
|
}
|
|
|
|
// Create destination file
|
|
var f *os.File
|
|
if f, err = os.Create(dst); err != nil {
|
|
err = fmt.Errorf("astikit: creating %s failed: %w", dst, err)
|
|
return
|
|
}
|
|
defer f.Close()
|
|
|
|
// Download in writer
|
|
return d.DownloadInWriter(ctx, f, srcs...)
|
|
}
|
|
|
|
// HTTPMiddleware represents an HTTP middleware
|
|
type HTTPMiddleware func(http.Handler) http.Handler
|
|
|
|
// ChainHTTPMiddlewares chains HTTP middlewares
|
|
func ChainHTTPMiddlewares(h http.Handler, ms ...HTTPMiddleware) http.Handler {
|
|
return ChainHTTPMiddlewaresWithPrefix(h, []string{}, ms...)
|
|
}
|
|
|
|
// ChainHTTPMiddlewaresWithPrefix chains HTTP middlewares if one of prefixes is present
|
|
func ChainHTTPMiddlewaresWithPrefix(h http.Handler, prefixes []string, ms ...HTTPMiddleware) http.Handler {
|
|
for _, m := range ms {
|
|
if m == nil {
|
|
continue
|
|
}
|
|
if len(prefixes) == 0 {
|
|
h = m(h)
|
|
} else {
|
|
t := h
|
|
h = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
for _, prefix := range prefixes {
|
|
if strings.HasPrefix(r.URL.EscapedPath(), prefix) {
|
|
m(t).ServeHTTP(rw, r)
|
|
return
|
|
}
|
|
}
|
|
t.ServeHTTP(rw, r)
|
|
})
|
|
}
|
|
}
|
|
return h
|
|
}
|
|
|
|
func handleHTTPBasicAuth(username, password string, rw http.ResponseWriter, r *http.Request) bool {
|
|
if u, p, ok := r.BasicAuth(); !ok || u != username || p != password {
|
|
rw.Header().Set("WWW-Authenticate", "Basic Realm=Please enter your credentials")
|
|
rw.WriteHeader(http.StatusUnauthorized)
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// HTTPMiddlewareBasicAuth adds basic HTTP auth to an HTTP handler
|
|
func HTTPMiddlewareBasicAuth(username, password string) HTTPMiddleware {
|
|
if username == "" && password == "" {
|
|
return nil
|
|
}
|
|
return func(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
// Handle basic auth
|
|
if handleHTTPBasicAuth(username, password, rw, r) {
|
|
return
|
|
}
|
|
|
|
// Next handler
|
|
h.ServeHTTP(rw, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func setHTTPContentType(contentType string, rw http.ResponseWriter) {
|
|
rw.Header().Set("Content-Type", contentType)
|
|
}
|
|
|
|
// HTTPMiddlewareContentType adds a content type to an HTTP handler
|
|
func HTTPMiddlewareContentType(contentType string) HTTPMiddleware {
|
|
return func(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
// Set content type
|
|
setHTTPContentType(contentType, rw)
|
|
|
|
// Next handler
|
|
h.ServeHTTP(rw, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func setHTTPHeaders(vs map[string]string, rw http.ResponseWriter) {
|
|
for k, v := range vs {
|
|
rw.Header().Set(k, v)
|
|
}
|
|
}
|
|
|
|
// HTTPMiddlewareHeaders adds headers to an HTTP handler
|
|
func HTTPMiddlewareHeaders(vs map[string]string) HTTPMiddleware {
|
|
return func(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
// Set headers
|
|
setHTTPHeaders(vs, rw)
|
|
|
|
// Next handler
|
|
h.ServeHTTP(rw, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// HTTPMiddlewareCORSHeaders adds CORS headers to an HTTP handler
|
|
func HTTPMiddlewareCORSHeaders() HTTPMiddleware {
|
|
return HTTPMiddlewareHeaders(map[string]string{
|
|
"Access-Control-Allow-Headers": "*",
|
|
"Access-Control-Allow-Methods": "*",
|
|
"Access-Control-Allow-Origin": "*",
|
|
})
|
|
}
|