Toward better context handling (#1835)

* Use the request context

The code uses context.Background() in a flow where there is a
http.Request. Use the requests context instead.

* Use a true context in the plugin example

Let AddTag/RemoveTag take a context and use that context throughout
the example.

* Avoid the use of context.Background

Prefer context.TODO over context.Background deep in the call chain.

This marks the site as something which we need to context-handle
later, and also makes it clear to the reader that the context is
sort-of temporary in the code base.

While here, be consistent in handling the `act` variable in each
branch of the if .. { .. } .. check.

* Prefer context.TODO over context.Background

For the different scraping operations here, there is a context
higher up the call chain, which we ought to use. Mark the call-sites
as TODO for now, so we can come back later on a sweep of which parts
can be context-lifted.

* Thread context upwards

Initialization requires context for transactions. Thread the context
upward the call chain.

At the intialization call, add a context.TODO since we can't break this
yet. The singleton assumption prevents us from pulling it up into main for
now.

* make tasks context-aware

Change the task interface to understand contexts.

Pass the context down in some of the branches where it is needed.

* Make QueryStashBoxScene context-aware

This call naturally sits inside the request-context. Use it.

* Introduce a context in the JS plugin code

This allows us to use a context for HTTP calls inside the system.

Mark the context with a TODO at top level for now.

* Nitpick error formatting

Use %v rather than %s for error interfaces.
Do not begin an error strong with a capital letter.

* Avoid the use of http.Get in FFMPEG download chain

Since http.Get has no context, it isn't possible to break out or have
policy induced. The call will block until the GET completes. Rewrite
to use a http Request and provide a context.

Thread the context through the call chain for now. provide
context.TODO() at the top level of the initialization chain.

* Make getRemoteCDPWSAddress aware of contexts

Eliminate a call to http.Get and replace it with a context-aware
variant.

Push the context upwards in the call chain, but plug it before the
scraper interface so we don't have to rewrite said interface yet.

Plugged with context.TODO()

* Scraper: make the getImage function context-aware

Use a context, and pass it upwards. Plug it with context.TODO()
up the chain before the rewrite gets too much out of hand for now.

Minor tweaks along the way, remove a call to context.Background()
deep in the call chain.

* Make NOTIFY request context-aware

The call sits inside a Request-handler. So it's natural to use the
requests context as the context for the outgoing HTTP request.

* Use a context in the url scraper code

We are sitting in code which has a context, so utilize it for the
request as well.

* Use a context when checking versions

When we check the version of stash on Github, use a context. Thread
the context up to the initialization routine of the HTTP/GraphQL
server and plug it with a context.TODO() for now.

This paves the way for providing a context to the HTTP server code in a
future patch.

* Make utils func ReadImage context-aware

In almost all of the cases, there is a context in the call chain which
is a natural use. This is true for all the GraphQL mutations.

The exception is in task_stash_box_tag, so plug that task with
context.TODO() for now.

* Make stash-box get context-aware

Thread a context through the call chain until we hit the Client API.
Plug it with context.TODO() there for now.

* Enable the noctx linter

The code is now free of any uncontexted HTTP request. This means we
pass the noctx linter, and we can enable it in the code base.
This commit is contained in:
SmallCoccinelle 2021-10-14 06:32:41 +02:00 committed by GitHub
parent 41a1fb8aec
commit 655d3ae969
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 211 additions and 180 deletions

View file

@ -31,7 +31,7 @@ linters:
# - ifshort # - ifshort
- misspell - misspell
# - nakedret # - nakedret
# - noctx - noctx
# - paralleltest # - paralleltest
- revive - revive
- rowserrcheck - rowserrcheck

View file

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -107,12 +108,12 @@ type githubTagResponse struct {
Node_id string Node_id string
} }
func makeGithubRequest(url string, output interface{}) error { func makeGithubRequest(ctx context.Context, url string, output interface{}) error {
client := &http.Client{ client := &http.Client{
Timeout: 3 * time.Second, Timeout: 3 * time.Second,
} }
req, _ := http.NewRequest("GET", url, nil) req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
req.Header.Add("Accept", apiAcceptHeader) // gh api recommendation , send header with api version req.Header.Add("Accept", apiAcceptHeader) // gh api recommendation , send header with api version
response, err := client.Do(req) response, err := client.Do(req)
@ -147,7 +148,7 @@ func makeGithubRequest(url string, output interface{}) error {
// If running a build from the "master" branch, then the latest full release // If running a build from the "master" branch, then the latest full release
// is used, otherwise it uses the release that is tagged with "latest_develop" // is used, otherwise it uses the release that is tagged with "latest_develop"
// which is the latest pre-release build. // which is the latest pre-release build.
func GetLatestVersion(shortHash bool) (latestVersion string, latestRelease string, err error) { func GetLatestVersion(ctx context.Context, shortHash bool) (latestVersion string, latestRelease string, err error) {
arch := runtime.GOARCH // https://en.wikipedia.org/wiki/Comparison_of_ARM_cores arch := runtime.GOARCH // https://en.wikipedia.org/wiki/Comparison_of_ARM_cores
isARMv7 := cpu.ARM.HasNEON || cpu.ARM.HasVFPv3 || cpu.ARM.HasVFPv3D16 || cpu.ARM.HasVFPv4 // armv6 doesn't support any of these features isARMv7 := cpu.ARM.HasNEON || cpu.ARM.HasVFPv3 || cpu.ARM.HasVFPv3D16 || cpu.ARM.HasVFPv4 // armv6 doesn't support any of these features
@ -180,14 +181,14 @@ func GetLatestVersion(shortHash bool) (latestVersion string, latestRelease strin
} }
release := githubReleasesResponse{} release := githubReleasesResponse{}
err = makeGithubRequest(url, &release) err = makeGithubRequest(ctx, url, &release)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
if release.Prerelease == usePreRelease { if release.Prerelease == usePreRelease {
latestVersion = getReleaseHash(release, shortHash, usePreRelease) latestVersion = getReleaseHash(ctx, release, shortHash, usePreRelease)
if wantedRelease != "" { if wantedRelease != "" {
for _, asset := range release.Assets { for _, asset := range release.Assets {
@ -205,12 +206,12 @@ func GetLatestVersion(shortHash bool) (latestVersion string, latestRelease strin
return latestVersion, latestRelease, nil return latestVersion, latestRelease, nil
} }
func getReleaseHash(release githubReleasesResponse, shortHash bool, usePreRelease bool) string { func getReleaseHash(ctx context.Context, release githubReleasesResponse, shortHash bool, usePreRelease bool) string {
shaLength := len(release.Target_commitish) shaLength := len(release.Target_commitish)
// the /latest API call doesn't return the hash in target_commitish // the /latest API call doesn't return the hash in target_commitish
// also add sanity check in case Target_commitish is not 40 characters // also add sanity check in case Target_commitish is not 40 characters
if !usePreRelease || shaLength != 40 { if !usePreRelease || shaLength != 40 {
return getShaFromTags(shortHash, release.Tag_name) return getShaFromTags(ctx, shortHash, release.Tag_name)
} }
if shortHash { if shortHash {
@ -225,9 +226,9 @@ func getReleaseHash(release githubReleasesResponse, shortHash bool, usePreReleas
return release.Target_commitish return release.Target_commitish
} }
func printLatestVersion() { func printLatestVersion(ctx context.Context) {
_, githash, _ = GetVersion() _, githash, _ = GetVersion()
latest, _, err := GetLatestVersion(true) latest, _, err := GetLatestVersion(ctx, true)
if err != nil { if err != nil {
logger.Errorf("Couldn't find latest version: %s", err) logger.Errorf("Couldn't find latest version: %s", err)
} else { } else {
@ -241,10 +242,10 @@ func printLatestVersion() {
// get sha from the github api tags endpoint // get sha from the github api tags endpoint
// returns the sha1 hash/shorthash or "" if something's wrong // returns the sha1 hash/shorthash or "" if something's wrong
func getShaFromTags(shortHash bool, name string) string { func getShaFromTags(ctx context.Context, shortHash bool, name string) string {
url := apiTags url := apiTags
tags := []githubTagResponse{} tags := []githubTagResponse{}
err := makeGithubRequest(url, &tags) err := makeGithubRequest(ctx, url, &tags)
if err != nil { if err != nil {
logger.Errorf("Github Tags Api %v", err) logger.Errorf("Github Tags Api %v", err)

View file

@ -160,7 +160,7 @@ func (r *queryResolver) Version(ctx context.Context) (*models.Version, error) {
//Gets latest version (git shorthash commit for now) //Gets latest version (git shorthash commit for now)
func (r *queryResolver) Latestversion(ctx context.Context) (*models.ShortVersion, error) { func (r *queryResolver) Latestversion(ctx context.Context) (*models.ShortVersion, error) {
ver, url, err := GetLatestVersion(true) ver, url, err := GetLatestVersion(ctx, true)
if err == nil { if err == nil {
logger.Infof("Retrieved latest hash: %s", ver) logger.Infof("Retrieved latest hash: %s", ver)
} else { } else {

View file

@ -14,12 +14,12 @@ import (
) )
func (r *mutationResolver) Setup(ctx context.Context, input models.SetupInput) (bool, error) { func (r *mutationResolver) Setup(ctx context.Context, input models.SetupInput) (bool, error) {
err := manager.GetInstance().Setup(input) err := manager.GetInstance().Setup(ctx, input)
return err == nil, err return err == nil, err
} }
func (r *mutationResolver) Migrate(ctx context.Context, input models.MigrateInput) (bool, error) { func (r *mutationResolver) Migrate(ctx context.Context, input models.MigrateInput) (bool, error) {
err := manager.GetInstance().Migrate(input) err := manager.GetInstance().Migrate(ctx, input)
return err == nil, err return err == nil, err
} }

View file

@ -38,7 +38,7 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCr
// Process the base 64 encoded image string // Process the base 64 encoded image string
if input.FrontImage != nil { if input.FrontImage != nil {
frontimageData, err = utils.ProcessImageInput(*input.FrontImage) frontimageData, err = utils.ProcessImageInput(ctx, *input.FrontImage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -46,7 +46,7 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCr
// Process the base 64 encoded image string // Process the base 64 encoded image string
if input.BackImage != nil { if input.BackImage != nil {
backimageData, err = utils.ProcessImageInput(*input.BackImage) backimageData, err = utils.ProcessImageInput(ctx, *input.BackImage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -139,7 +139,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp
var frontimageData []byte var frontimageData []byte
frontImageIncluded := translator.hasField("front_image") frontImageIncluded := translator.hasField("front_image")
if input.FrontImage != nil { if input.FrontImage != nil {
frontimageData, err = utils.ProcessImageInput(*input.FrontImage) frontimageData, err = utils.ProcessImageInput(ctx, *input.FrontImage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -147,7 +147,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp
backImageIncluded := translator.hasField("back_image") backImageIncluded := translator.hasField("back_image")
var backimageData []byte var backimageData []byte
if input.BackImage != nil { if input.BackImage != nil {
backimageData, err = utils.ProcessImageInput(*input.BackImage) backimageData, err = utils.ProcessImageInput(ctx, *input.BackImage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -202,7 +202,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp
// HACK - if front image is null and back image is not null, then set the front image // HACK - if front image is null and back image is not null, then set the front image
// to the default image since we can't have a null front image and a non-null back image // to the default image since we can't have a null front image and a non-null back image
if frontimageData == nil && backimageData != nil { if frontimageData == nil && backimageData != nil {
frontimageData, _ = utils.ProcessImageInput(models.DefaultMovieImage) frontimageData, _ = utils.ProcessImageInput(ctx, models.DefaultMovieImage)
} }
if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil { if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil {

View file

@ -32,7 +32,7 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per
var err error var err error
if input.Image != nil { if input.Image != nil {
imageData, err = utils.ProcessImageInput(*input.Image) imageData, err = utils.ProcessImageInput(ctx, *input.Image)
} }
if err != nil { if err != nil {
@ -178,7 +178,7 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per
var err error var err error
imageIncluded := translator.hasField("image") imageIncluded := translator.hasField("image")
if input.Image != nil { if input.Image != nil {
imageData, err = utils.ProcessImageInput(*input.Image) imageData, err = utils.ProcessImageInput(ctx, *input.Image)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -32,7 +32,7 @@ func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUp
// Start the transaction and save the scene // Start the transaction and save the scene
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(repo models.Repository) error {
ret, err = r.sceneUpdate(input, translator, repo) ret, err = r.sceneUpdate(ctx, input, translator, repo)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@ -52,7 +52,7 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce
inputMap: inputMaps[i], inputMap: inputMaps[i],
} }
thisScene, err := r.sceneUpdate(*scene, translator, repo) thisScene, err := r.sceneUpdate(ctx, *scene, translator, repo)
ret = append(ret, thisScene) ret = append(ret, thisScene)
if err != nil { if err != nil {
@ -85,7 +85,7 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce
return newRet, nil return newRet, nil
} }
func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) { func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) {
// Populate scene from the input // Populate scene from the input
sceneID, err := strconv.Atoi(input.ID) sceneID, err := strconv.Atoi(input.ID)
if err != nil { if err != nil {
@ -110,7 +110,7 @@ func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator
if input.CoverImage != nil && *input.CoverImage != "" { if input.CoverImage != nil && *input.CoverImage != "" {
var err error var err error
coverImageData, err = utils.ProcessImageInput(*input.CoverImage) coverImageData, err = utils.ProcessImageInput(ctx, *input.CoverImage)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -3,10 +3,11 @@ package api
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/stashapp/stash/pkg/studio"
"strconv" "strconv"
"time" "time"
"github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/plugin"
@ -33,7 +34,7 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
// Process the base 64 encoded image string // Process the base 64 encoded image string
if input.Image != nil { if input.Image != nil {
imageData, err = utils.ProcessImageInput(*input.Image) imageData, err = utils.ProcessImageInput(ctx, *input.Image)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -129,7 +130,7 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio
imageIncluded := translator.hasField("image") imageIncluded := translator.hasField("image")
if input.Image != nil { if input.Image != nil {
var err error var err error
imageData, err = utils.ProcessImageInput(*input.Image) imageData, err = utils.ProcessImageInput(ctx, *input.Image)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -36,7 +36,7 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate
var err error var err error
if input.Image != nil { if input.Image != nil {
imageData, err = utils.ProcessImageInput(*input.Image) imageData, err = utils.ProcessImageInput(ctx, *input.Image)
if err != nil { if err != nil {
return nil, err return nil, err
@ -121,7 +121,7 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate
imageIncluded := translator.hasField("image") imageIncluded := translator.hasField("image")
if input.Image != nil { if input.Image != nil {
imageData, err = utils.ProcessImageInput(*input.Image) imageData, err = utils.ProcessImageInput(ctx, *input.Image)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -123,7 +123,7 @@ func (r *queryResolver) QueryStashBoxScene(ctx context.Context, input models.Sta
} }
if input.Q != nil { if input.Q != nil {
return client.QueryStashBoxScene(*input.Q) return client.QueryStashBoxScene(ctx, *input.Q)
} }
return nil, nil return nil, nil
@ -197,7 +197,7 @@ func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source models.Scr
if input.SceneID != nil { if input.SceneID != nil {
return client.FindStashBoxScenesByFingerprintsFlat([]string{*input.SceneID}) return client.FindStashBoxScenesByFingerprintsFlat([]string{*input.SceneID})
} else if input.Query != nil { } else if input.Query != nil {
return client.QueryStashBoxScene(*input.Query) return client.QueryStashBoxScene(ctx, *input.Query)
} }
return nil, errors.New("scene_id or query must be set") return nil, errors.New("scene_id or query must be set")

View file

@ -240,7 +240,7 @@ func Start(uiBox embed.FS, loginUIBox embed.FS) {
go func() { go func() {
printVersion() printVersion()
printLatestVersion() printLatestVersion(context.TODO())
logger.Infof("stash is listening on " + address) logger.Infof("stash is listening on " + address)
if tlsConfig != nil { if tlsConfig != nil {

View file

@ -165,7 +165,7 @@ func Backup(db *sqlx.DB, backupPath string) error {
var err error var err error
db, err = sqlx.Connect(sqlite3Driver, "file:"+dbPath+"?_fk=true") db, err = sqlx.Connect(sqlite3Driver, "file:"+dbPath+"?_fk=true")
if err != nil { if err != nil {
return fmt.Errorf("Open database %s failed:%s", dbPath, err) return fmt.Errorf("open database %s failed: %v", dbPath, err)
} }
defer db.Close() defer db.Close()
} }
@ -173,7 +173,7 @@ func Backup(db *sqlx.DB, backupPath string) error {
logger.Infof("Backing up database into: %s", backupPath) logger.Infof("Backing up database into: %s", backupPath)
_, err := db.Exec(`VACUUM INTO "` + backupPath + `"`) _, err := db.Exec(`VACUUM INTO "` + backupPath + `"`)
if err != nil { if err != nil {
return fmt.Errorf("vacuum failed: %s", err) return fmt.Errorf("vacuum failed: %v", err)
} }
return nil return nil

View file

@ -415,7 +415,7 @@ func (me *Server) serveIcon(w http.ResponseWriter, r *http.Request) {
} }
var scene *models.Scene var scene *models.Scene
err := me.txnManager.WithReadTxn(context.Background(), func(r models.ReaderRepository) error { err := me.txnManager.WithReadTxn(r.Context(), func(r models.ReaderRepository) error {
idInt, err := strconv.Atoi(sceneId) idInt, err := strconv.Atoi(sceneId)
if err != nil { if err != nil {
return nil return nil
@ -434,7 +434,7 @@ func (me *Server) serveIcon(w http.ResponseWriter, r *http.Request) {
me.sceneServer.ServeScreenshot(scene, w, r) me.sceneServer.ServeScreenshot(scene, w, r)
} }
func (me *Server) contentDirectoryInitialEvent(urls []*url.URL, sid string) { func (me *Server) contentDirectoryInitialEvent(ctx context.Context, urls []*url.URL, sid string) {
body := xmlMarshalOrPanic(upnp.PropertySet{ body := xmlMarshalOrPanic(upnp.PropertySet{
Properties: []upnp.Property{ Properties: []upnp.Property{
{ {
@ -465,7 +465,7 @@ func (me *Server) contentDirectoryInitialEvent(urls []*url.URL, sid string) {
body = append([]byte(`<?xml version="1.0"?>`+"\n"), body...) body = append([]byte(`<?xml version="1.0"?>`+"\n"), body...)
for _, _url := range urls { for _, _url := range urls {
bodyReader := bytes.NewReader(body) bodyReader := bytes.NewReader(body)
req, err := http.NewRequest("NOTIFY", _url.String(), bodyReader) req, err := http.NewRequestWithContext(ctx, "NOTIFY", _url.String(), bodyReader)
if err != nil { if err != nil {
logger.Errorf("Could not create a request to notify %s: %s", _url.String(), err) logger.Errorf("Could not create a request to notify %s: %s", _url.String(), err)
continue continue
@ -526,7 +526,7 @@ func (me *Server) contentDirectoryEventSubHandler(w http.ResponseWriter, r *http
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
go func() { go func() {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
me.contentDirectoryInitialEvent(urls, sid) me.contentDirectoryInitialEvent(r.Context(), urls, sid)
}() }()
} else if r.Method == "SUBSCRIBE" { } else if r.Method == "SUBSCRIBE" {
http.Error(w, "meh", http.StatusPreconditionFailed) http.Error(w, "meh", http.StatusPreconditionFailed)
@ -554,7 +554,7 @@ func (me *Server) initMux(mux *http.ServeMux) {
mux.HandleFunc(resPath, func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc(resPath, func(w http.ResponseWriter, r *http.Request) {
sceneId := r.URL.Query().Get("scene") sceneId := r.URL.Query().Get("scene")
var scene *models.Scene var scene *models.Scene
err := me.txnManager.WithReadTxn(context.Background(), func(r models.ReaderRepository) error { err := me.txnManager.WithReadTxn(r.Context(), func(r models.ReaderRepository) error {
sceneIdInt, err := strconv.Atoi(sceneId) sceneIdInt, err := strconv.Atoi(sceneId)
if err != nil { if err != nil {
return nil return nil

View file

@ -2,6 +2,7 @@ package ffmpeg
import ( import (
"archive/zip" "archive/zip"
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -36,9 +37,9 @@ func GetPaths(paths []string) (string, string) {
return ffmpegPath, ffprobePath return ffmpegPath, ffprobePath
} }
func Download(configDirectory string) error { func Download(ctx context.Context, configDirectory string) error {
for _, url := range getFFMPEGURL() { for _, url := range getFFMPEGURL() {
err := DownloadSingle(configDirectory, url) err := DownloadSingle(ctx, configDirectory, url)
if err != nil { if err != nil {
return err return err
} }
@ -69,7 +70,7 @@ func (r *progressReader) Read(p []byte) (int, error) {
return read, err return read, err
} }
func DownloadSingle(configDirectory, url string) error { func DownloadSingle(ctx context.Context, configDirectory, url string) error {
if url == "" { if url == "" {
return fmt.Errorf("no ffmpeg url for this platform") return fmt.Errorf("no ffmpeg url for this platform")
} }
@ -88,7 +89,12 @@ func DownloadSingle(configDirectory, url string) error {
logger.Infof("Downloading %s...", url) logger.Infof("Downloading %s...", url)
// Make the HTTP request // Make the HTTP request
resp, err := http.Get(url) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return err return err
} }

View file

@ -9,12 +9,12 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func setInitialMD5Config(txnManager models.TransactionManager) { func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManager) {
// if there are no scene files in the database, then default the // if there are no scene files in the database, then default the
// VideoFileNamingAlgorithm config setting to oshash and calculateMD5 to // VideoFileNamingAlgorithm config setting to oshash and calculateMD5 to
// false, otherwise set them to true for backwards compatibility purposes // false, otherwise set them to true for backwards compatibility purposes
var count int var count int
if err := txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
var err error var err error
count, err = r.Scene().Count() count, err = r.Scene().Count()
return err return err

View file

@ -1,6 +1,7 @@
package manager package manager
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"os" "os"
@ -58,6 +59,7 @@ func GetInstance() *singleton {
func Initialize() *singleton { func Initialize() *singleton {
once.Do(func() { once.Do(func() {
ctx := context.TODO()
cfg, err := config.Initialize() cfg, err := config.Initialize()
if err != nil { if err != nil {
@ -93,7 +95,7 @@ func Initialize() *singleton {
if err != nil { if err != nil {
panic(fmt.Sprintf("error initializing configuration: %s", err.Error())) panic(fmt.Sprintf("error initializing configuration: %s", err.Error()))
} else { } else {
if err := instance.PostInit(); err != nil { if err := instance.PostInit(ctx); err != nil {
panic(err) panic(err)
} }
} }
@ -152,6 +154,8 @@ func initProfiling(cpuProfilePath string) {
} }
func initFFMPEG() error { func initFFMPEG() error {
ctx := context.TODO()
// only do this if we have a config file set // only do this if we have a config file set
if instance.Config.GetConfigFile() != "" { if instance.Config.GetConfigFile() != "" {
// use same directory as config path // use same directory as config path
@ -164,7 +168,7 @@ func initFFMPEG() error {
if ffmpegPath == "" || ffprobePath == "" { if ffmpegPath == "" || ffprobePath == "" {
logger.Infof("couldn't find FFMPEG, attempting to download it") logger.Infof("couldn't find FFMPEG, attempting to download it")
if err := ffmpeg.Download(configDirectory); err != nil { if err := ffmpeg.Download(ctx, configDirectory); err != nil {
msg := `Unable to locate / automatically download FFMPEG msg := `Unable to locate / automatically download FFMPEG
Check the readme for download links. Check the readme for download links.
@ -195,7 +199,7 @@ func initLog() {
// PostInit initialises the paths, caches and txnManager after the initial // PostInit initialises the paths, caches and txnManager after the initial
// configuration has been set. Should only be called if the configuration // configuration has been set. Should only be called if the configuration
// is valid. // is valid.
func (s *singleton) PostInit() error { func (s *singleton) PostInit(ctx context.Context) error {
if err := s.Config.SetInitialConfig(); err != nil { if err := s.Config.SetInitialConfig(); err != nil {
logger.Warnf("could not set initial configuration: %v", err) logger.Warnf("could not set initial configuration: %v", err)
} }
@ -235,7 +239,7 @@ func (s *singleton) PostInit() error {
} }
if database.Ready() == nil { if database.Ready() == nil {
s.PostMigrate() s.PostMigrate(ctx)
} }
return nil return nil
@ -295,7 +299,7 @@ func setSetupDefaults(input *models.SetupInput) {
} }
} }
func (s *singleton) Setup(input models.SetupInput) error { func (s *singleton) Setup(ctx context.Context, input models.SetupInput) error {
setSetupDefaults(&input) setSetupDefaults(&input)
// create the config directory if it does not exist // create the config directory if it does not exist
@ -328,7 +332,7 @@ func (s *singleton) Setup(input models.SetupInput) error {
} }
// initialise the database // initialise the database
if err := s.PostInit(); err != nil { if err := s.PostInit(ctx); err != nil {
return fmt.Errorf("error initializing the database: %v", err) return fmt.Errorf("error initializing the database: %v", err)
} }
@ -349,7 +353,7 @@ func (s *singleton) validateFFMPEG() error {
return nil return nil
} }
func (s *singleton) Migrate(input models.MigrateInput) error { func (s *singleton) Migrate(ctx context.Context, input models.MigrateInput) error {
// always backup so that we can roll back to the previous version if // always backup so that we can roll back to the previous version if
// migration fails // migration fails
backupPath := input.BackupPath backupPath := input.BackupPath
@ -377,7 +381,7 @@ func (s *singleton) Migrate(input models.MigrateInput) error {
} }
// perform post-migration operations // perform post-migration operations
s.PostMigrate() s.PostMigrate(ctx)
// if no backup path was provided, then delete the created backup // if no backup path was provided, then delete the created backup
if input.BackupPath == "" { if input.BackupPath == "" {

View file

@ -90,7 +90,7 @@ func (s *singleton) Import(ctx context.Context) (int, error) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(), fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(),
} }
task.Start() task.Start(ctx)
}) })
return s.JobManager.Add(ctx, "Importing...", j), nil return s.JobManager.Add(ctx, "Importing...", j), nil
@ -122,7 +122,7 @@ func (s *singleton) RunSingleTask(ctx context.Context, t Task) int {
wg.Add(1) wg.Add(1)
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
t.Start() t.Start(ctx)
wg.Done() wg.Done()
}) })

View file

@ -1,6 +1,8 @@
package manager package manager
import "context"
// PostMigrate is executed after migrations have been executed. // PostMigrate is executed after migrations have been executed.
func (s *singleton) PostMigrate() { func (s *singleton) PostMigrate(ctx context.Context) {
setInitialMD5Config(s.TxnManager) setInitialMD5Config(ctx, s.TxnManager)
} }

View file

@ -1,6 +1,8 @@
package manager package manager
import "context"
type Task interface { type Task interface {
Start() Start(context.Context)
GetDescription() string GetDescription() string
} }

View file

@ -79,7 +79,7 @@ func (t *ImportTask) GetDescription() string {
return "Importing..." return "Importing..."
} }
func (t *ImportTask) Start() { func (t *ImportTask) Start(ctx context.Context) {
if t.TmpZip != "" { if t.TmpZip != "" {
defer func() { defer func() {
err := utils.RemoveDir(t.BaseDir) err := utils.RemoveDir(t.BaseDir)
@ -126,8 +126,6 @@ func (t *ImportTask) Start() {
} }
} }
ctx := context.TODO()
t.ImportTags(ctx) t.ImportTags(ctx)
t.ImportPerformers(ctx) t.ImportPerformers(ctx)
t.ImportStudios(ctx) t.ImportStudios(ctx)

View file

@ -120,7 +120,7 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) {
} }
go func() { go func() {
task.Start() task.Start(ctx)
wg.Done() wg.Done()
progress.Increment() progress.Increment()
}() }()
@ -238,12 +238,12 @@ type ScanTask struct {
CaseSensitiveFs bool CaseSensitiveFs bool
} }
func (t *ScanTask) Start() { func (t *ScanTask) Start(ctx context.Context) {
var s *models.Scene var s *models.Scene
t.progress.ExecuteTask("Scanning "+t.FilePath, func() { t.progress.ExecuteTask("Scanning "+t.FilePath, func() {
if isGallery(t.FilePath) { if isGallery(t.FilePath) {
t.scanGallery() t.scanGallery(ctx)
} else if isVideo(t.FilePath) { } else if isVideo(t.FilePath) {
s = t.scanScene() s = t.scanScene()
} else if isImage(t.FilePath) { } else if isImage(t.FilePath) {
@ -318,12 +318,12 @@ func (t *ScanTask) Start() {
} }
} }
func (t *ScanTask) scanGallery() { func (t *ScanTask) scanGallery(ctx context.Context) {
var g *models.Gallery var g *models.Gallery
images := 0 images := 0
scanImages := false scanImages := false
if err := t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
var err error var err error
g, err = r.Gallery().FindByPath(t.FilePath) g, err = r.Gallery().FindByPath(t.FilePath)
@ -976,7 +976,7 @@ func (t *ScanTask) scanZipImages(zipGallery *models.Gallery) {
subTask.zipGallery = zipGallery subTask.zipGallery = zipGallery
// run the subtask and wait for it to complete // run the subtask and wait for it to complete
subTask.Start() subTask.Start(context.TODO())
return nil return nil
}) })
if err != nil { if err != nil {

View file

@ -22,7 +22,7 @@ type StashBoxPerformerTagTask struct {
} }
func (t *StashBoxPerformerTagTask) Start() { func (t *StashBoxPerformerTagTask) Start() {
t.stashBoxPerformerTag() t.stashBoxPerformerTag(context.TODO())
} }
func (t *StashBoxPerformerTagTask) Description() string { func (t *StashBoxPerformerTagTask) Description() string {
@ -36,7 +36,7 @@ func (t *StashBoxPerformerTagTask) Description() string {
return fmt.Sprintf("Tagging performer %s from stash-box", name) return fmt.Sprintf("Tagging performer %s from stash-box", name)
} }
func (t *StashBoxPerformerTagTask) stashBoxPerformerTag() { func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
var performer *models.ScrapedPerformer var performer *models.ScrapedPerformer
var err error var err error
@ -169,7 +169,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag() {
} }
if len(performer.Images) > 0 && !excluded["image"] { if len(performer.Images) > 0 && !excluded["image"] {
image, err := utils.ReadImageFromURL(performer.Images[0]) image, err := utils.ReadImageFromURL(ctx, performer.Images[0])
if err != nil { if err != nil {
return err return err
} }
@ -232,7 +232,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag() {
} }
if len(performer.Images) > 0 { if len(performer.Images) > 0 {
image, imageErr := utils.ReadImageFromURL(performer.Images[0]) image, imageErr := utils.ReadImageFromURL(ctx, performer.Images[0])
if imageErr != nil { if imageErr != nil {
return imageErr return imageErr
} }

View file

@ -66,7 +66,7 @@ type SceneUpdateInput struct {
TagIds []graphql.ID `graphql:"tag_ids" json:"tag_ids"` TagIds []graphql.ID `graphql:"tag_ids" json:"tag_ids"`
} }
func getTagID(client *graphql.Client, create bool) (*graphql.ID, error) { func getTagID(ctx context.Context, client *graphql.Client, create bool) (*graphql.ID, error) {
log.Info("Checking if tag exists already") log.Info("Checking if tag exists already")
// see if tag exists already // see if tag exists already
@ -74,7 +74,7 @@ func getTagID(client *graphql.Client, create bool) (*graphql.ID, error) {
AllTags []Tag `graphql:"allTags"` AllTags []Tag `graphql:"allTags"`
} }
err := client.Query(context.Background(), &q, nil) err := client.Query(ctx, &q, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error getting tags: %s\n", err.Error()) return nil, fmt.Errorf("Error getting tags: %s\n", err.Error())
} }
@ -106,7 +106,7 @@ func getTagID(client *graphql.Client, create bool) (*graphql.ID, error) {
log.Info("Creating new tag") log.Info("Creating new tag")
err = client.Mutate(context.Background(), &m, vars) err = client.Mutate(ctx, &m, vars)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error mutating scene: %s\n", err.Error()) return nil, fmt.Errorf("Error mutating scene: %s\n", err.Error())
} }
@ -114,7 +114,7 @@ func getTagID(client *graphql.Client, create bool) (*graphql.ID, error) {
return &m.TagCreate.ID, nil return &m.TagCreate.ID, nil
} }
func findRandomScene(client *graphql.Client) (*Scene, error) { func findRandomScene(ctx context.Context, client *graphql.Client) (*Scene, error) {
// get a random scene // get a random scene
var q struct { var q struct {
FindScenes FindScenesResultType `graphql:"findScenes(filter: $c)"` FindScenes FindScenesResultType `graphql:"findScenes(filter: $c)"`
@ -132,7 +132,7 @@ func findRandomScene(client *graphql.Client) (*Scene, error) {
} }
log.Info("Finding a random scene") log.Info("Finding a random scene")
err := client.Query(context.Background(), &q, vars) err := client.Query(ctx, &q, vars)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error getting random scene: %s\n", err.Error()) return nil, fmt.Errorf("Error getting random scene: %s\n", err.Error())
} }
@ -155,14 +155,14 @@ func addTagId(tagIds []graphql.ID, tagId graphql.ID) []graphql.ID {
return tagIds return tagIds
} }
func AddTag(client *graphql.Client) error { func AddTag(ctx context.Context, client *graphql.Client) error {
tagID, err := getTagID(client, true) tagID, err := getTagID(ctx, client, true)
if err != nil { if err != nil {
return err return err
} }
scene, err := findRandomScene(client) scene, err := findRandomScene(ctx, client)
if err != nil { if err != nil {
return err return err
@ -188,7 +188,7 @@ func AddTag(client *graphql.Client) error {
} }
log.Infof("Adding tag to scene %v", scene.ID) log.Infof("Adding tag to scene %v", scene.ID)
err = client.Mutate(context.Background(), &m, vars) err = client.Mutate(ctx, &m, vars)
if err != nil { if err != nil {
return fmt.Errorf("Error mutating scene: %v", err) return fmt.Errorf("Error mutating scene: %v", err)
} }
@ -196,8 +196,8 @@ func AddTag(client *graphql.Client) error {
return nil return nil
} }
func RemoveTag(client *graphql.Client) error { func RemoveTag(ctx context.Context, client *graphql.Client) error {
tagID, err := getTagID(client, false) tagID, err := getTagID(ctx, client, false)
if err != nil { if err != nil {
return err return err
@ -223,7 +223,7 @@ func RemoveTag(client *graphql.Client) error {
log.Info("Destroying tag") log.Info("Destroying tag")
err = client.Mutate(context.Background(), &m, vars) err = client.Mutate(ctx, &m, vars)
if err != nil { if err != nil {
return fmt.Errorf("Error destroying tag: %v", err) return fmt.Errorf("Error destroying tag: %v", err)
} }

View file

@ -1,6 +1,7 @@
package plugin package plugin
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"path/filepath" "path/filepath"
@ -84,7 +85,7 @@ func (t *jsPluginTask) Start() error {
return fmt.Errorf("error adding util API: %w", err) return fmt.Errorf("error adding util API: %w", err)
} }
if err := js.AddGQLAPI(t.vm, t.input.ServerConnection.SessionCookie, t.gqlHandler); err != nil { if err := js.AddGQLAPI(context.TODO(), t.vm, t.input.ServerConnection.SessionCookie, t.gqlHandler); err != nil {
return fmt.Errorf("error adding GraphQL API: %w", err) return fmt.Errorf("error adding GraphQL API: %w", err)
} }

View file

@ -2,6 +2,7 @@ package js
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -33,7 +34,7 @@ func throw(vm *otto.Otto, str string) {
panic(value) panic(value)
} }
func gqlRequestFunc(vm *otto.Otto, cookie *http.Cookie, gqlHandler http.Handler) func(call otto.FunctionCall) otto.Value { func gqlRequestFunc(ctx context.Context, vm *otto.Otto, cookie *http.Cookie, gqlHandler http.Handler) func(call otto.FunctionCall) otto.Value {
return func(call otto.FunctionCall) otto.Value { return func(call otto.FunctionCall) otto.Value {
if len(call.ArgumentList) == 0 { if len(call.ArgumentList) == 0 {
throw(vm, "missing argument") throw(vm, "missing argument")
@ -61,7 +62,7 @@ func gqlRequestFunc(vm *otto.Otto, cookie *http.Cookie, gqlHandler http.Handler)
throw(vm, err.Error()) throw(vm, err.Error())
} }
r, err := http.NewRequest("POST", "/graphql", &body) r, err := http.NewRequestWithContext(ctx, "POST", "/graphql", &body)
if err != nil { if err != nil {
throw(vm, "could not make request") throw(vm, "could not make request")
} }
@ -103,9 +104,9 @@ func gqlRequestFunc(vm *otto.Otto, cookie *http.Cookie, gqlHandler http.Handler)
} }
} }
func AddGQLAPI(vm *otto.Otto, cookie *http.Cookie, gqlHandler http.Handler) error { func AddGQLAPI(ctx context.Context, vm *otto.Otto, cookie *http.Cookie, gqlHandler http.Handler) error {
gql, _ := vm.Object("({})") gql, _ := vm.Object("({})")
if err := gql.Set("Do", gqlRequestFunc(vm, cookie, gqlHandler)); err != nil { if err := gql.Set("Do", gqlRequestFunc(ctx, vm, cookie, gqlHandler)); err != nil {
return fmt.Errorf("unable to set GraphQL Do function: %w", err) return fmt.Errorf("unable to set GraphQL Do function: %w", err)
} }

View file

@ -1,6 +1,8 @@
package scraper package scraper
import "github.com/stashapp/stash/pkg/models" import (
"github.com/stashapp/stash/pkg/models"
)
type scraperAction string type scraperAction string

View file

@ -1,6 +1,7 @@
package scraper package scraper
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
@ -16,13 +17,13 @@ import (
// configurable at some point. // configurable at some point.
const imageGetTimeout = time.Second * 30 const imageGetTimeout = time.Second * 30
func setPerformerImage(p *models.ScrapedPerformer, globalConfig GlobalConfig) error { func setPerformerImage(ctx context.Context, p *models.ScrapedPerformer, globalConfig GlobalConfig) error {
if p == nil || p.Image == nil || !strings.HasPrefix(*p.Image, "http") { if p == nil || p.Image == nil || !strings.HasPrefix(*p.Image, "http") {
// nothing to do // nothing to do
return nil return nil
} }
img, err := getImage(*p.Image, globalConfig) img, err := getImage(ctx, *p.Image, globalConfig)
if err != nil { if err != nil {
return err return err
} }
@ -34,14 +35,14 @@ func setPerformerImage(p *models.ScrapedPerformer, globalConfig GlobalConfig) er
return nil return nil
} }
func setSceneImage(s *models.ScrapedScene, globalConfig GlobalConfig) error { func setSceneImage(ctx context.Context, s *models.ScrapedScene, globalConfig GlobalConfig) error {
// don't try to get the image if it doesn't appear to be a URL // don't try to get the image if it doesn't appear to be a URL
if s == nil || s.Image == nil || !strings.HasPrefix(*s.Image, "http") { if s == nil || s.Image == nil || !strings.HasPrefix(*s.Image, "http") {
// nothing to do // nothing to do
return nil return nil
} }
img, err := getImage(*s.Image, globalConfig) img, err := getImage(ctx, *s.Image, globalConfig)
if err != nil { if err != nil {
return err return err
} }
@ -51,14 +52,14 @@ func setSceneImage(s *models.ScrapedScene, globalConfig GlobalConfig) error {
return nil return nil
} }
func setMovieFrontImage(m *models.ScrapedMovie, globalConfig GlobalConfig) error { func setMovieFrontImage(ctx context.Context, m *models.ScrapedMovie, globalConfig GlobalConfig) error {
// don't try to get the image if it doesn't appear to be a URL // don't try to get the image if it doesn't appear to be a URL
if m == nil || m.FrontImage == nil || !strings.HasPrefix(*m.FrontImage, "http") { if m == nil || m.FrontImage == nil || !strings.HasPrefix(*m.FrontImage, "http") {
// nothing to do // nothing to do
return nil return nil
} }
img, err := getImage(*m.FrontImage, globalConfig) img, err := getImage(ctx, *m.FrontImage, globalConfig)
if err != nil { if err != nil {
return err return err
} }
@ -68,14 +69,14 @@ func setMovieFrontImage(m *models.ScrapedMovie, globalConfig GlobalConfig) error
return nil return nil
} }
func setMovieBackImage(m *models.ScrapedMovie, globalConfig GlobalConfig) error { func setMovieBackImage(ctx context.Context, m *models.ScrapedMovie, globalConfig GlobalConfig) error {
// don't try to get the image if it doesn't appear to be a URL // don't try to get the image if it doesn't appear to be a URL
if m == nil || m.BackImage == nil || !strings.HasPrefix(*m.BackImage, "http") { if m == nil || m.BackImage == nil || !strings.HasPrefix(*m.BackImage, "http") {
// nothing to do // nothing to do
return nil return nil
} }
img, err := getImage(*m.BackImage, globalConfig) img, err := getImage(ctx, *m.BackImage, globalConfig)
if err != nil { if err != nil {
return err return err
} }
@ -85,14 +86,14 @@ func setMovieBackImage(m *models.ScrapedMovie, globalConfig GlobalConfig) error
return nil return nil
} }
func getImage(url string, globalConfig GlobalConfig) (*string, error) { func getImage(ctx context.Context, url string, globalConfig GlobalConfig) (*string, error) {
client := &http.Client{ client := &http.Client{
Transport: &http.Transport{ // ignore insecure certificates Transport: &http.Transport{ // ignore insecure certificates
TLSClientConfig: &tls.Config{InsecureSkipVerify: !globalConfig.GetScraperCertCheck()}}, TLSClientConfig: &tls.Config{InsecureSkipVerify: !globalConfig.GetScraperCertCheck()}},
Timeout: imageGetTimeout, Timeout: imageGetTimeout,
} }
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -136,10 +137,10 @@ func getImage(url string, globalConfig GlobalConfig) (*string, error) {
return &img, nil return &img, nil
} }
func getStashPerformerImage(stashURL string, performerID string, globalConfig GlobalConfig) (*string, error) { func getStashPerformerImage(ctx context.Context, stashURL string, performerID string, globalConfig GlobalConfig) (*string, error) {
return getImage(stashURL+"/performer/"+performerID+"/image", globalConfig) return getImage(ctx, stashURL+"/performer/"+performerID+"/image", globalConfig)
} }
func getStashSceneImage(stashURL string, sceneID string, globalConfig GlobalConfig) (*string, error) { func getStashSceneImage(ctx context.Context, stashURL string, sceneID string, globalConfig GlobalConfig) (*string, error) {
return getImage(stashURL+"/scene/"+sceneID+"/screenshot", globalConfig) return getImage(ctx, stashURL+"/scene/"+sceneID+"/screenshot", globalConfig)
} }

View file

@ -1,6 +1,7 @@
package scraper package scraper
import ( import (
"context"
"errors" "errors"
"io" "io"
"net/url" "net/url"
@ -31,14 +32,14 @@ func (s *jsonScraper) getJsonScraper() *mappedScraper {
return s.config.JsonScrapers[s.scraper.Scraper] return s.config.JsonScrapers[s.scraper.Scraper]
} }
func (s *jsonScraper) scrapeURL(url string) (string, *mappedScraper, error) { func (s *jsonScraper) scrapeURL(ctx context.Context, url string) (string, *mappedScraper, error) {
scraper := s.getJsonScraper() scraper := s.getJsonScraper()
if scraper == nil { if scraper == nil {
return "", nil, errors.New("json scraper with name " + s.scraper.Scraper + " not found in config") return "", nil, errors.New("json scraper with name " + s.scraper.Scraper + " not found in config")
} }
doc, err := s.loadURL(url) doc, err := s.loadURL(ctx, url)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@ -47,8 +48,8 @@ func (s *jsonScraper) scrapeURL(url string) (string, *mappedScraper, error) {
return doc, scraper, nil return doc, scraper, nil
} }
func (s *jsonScraper) loadURL(url string) (string, error) { func (s *jsonScraper) loadURL(ctx context.Context, url string) (string, error) {
r, err := loadURL(url, s.config, s.globalConfig) r, err := loadURL(ctx, url, s.config, s.globalConfig)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -72,7 +73,7 @@ func (s *jsonScraper) loadURL(url string) (string, error) {
func (s *jsonScraper) scrapePerformerByURL(url string) (*models.ScrapedPerformer, error) { func (s *jsonScraper) scrapePerformerByURL(url string) (*models.ScrapedPerformer, error) {
u := replaceURL(url, s.scraper) // allow a URL Replace for performer by URL queries u := replaceURL(url, s.scraper) // allow a URL Replace for performer by URL queries
doc, scraper, err := s.scrapeURL(u) doc, scraper, err := s.scrapeURL(context.TODO(), u)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -83,7 +84,7 @@ func (s *jsonScraper) scrapePerformerByURL(url string) (*models.ScrapedPerformer
func (s *jsonScraper) scrapeSceneByURL(url string) (*models.ScrapedScene, error) { func (s *jsonScraper) scrapeSceneByURL(url string) (*models.ScrapedScene, error) {
u := replaceURL(url, s.scraper) // allow a URL Replace for scene by URL queries u := replaceURL(url, s.scraper) // allow a URL Replace for scene by URL queries
doc, scraper, err := s.scrapeURL(u) doc, scraper, err := s.scrapeURL(context.TODO(), u)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -94,7 +95,7 @@ func (s *jsonScraper) scrapeSceneByURL(url string) (*models.ScrapedScene, error)
func (s *jsonScraper) scrapeGalleryByURL(url string) (*models.ScrapedGallery, error) { func (s *jsonScraper) scrapeGalleryByURL(url string) (*models.ScrapedGallery, error) {
u := replaceURL(url, s.scraper) // allow a URL Replace for gallery by URL queries u := replaceURL(url, s.scraper) // allow a URL Replace for gallery by URL queries
doc, scraper, err := s.scrapeURL(u) doc, scraper, err := s.scrapeURL(context.TODO(), u)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -105,7 +106,7 @@ func (s *jsonScraper) scrapeGalleryByURL(url string) (*models.ScrapedGallery, er
func (s *jsonScraper) scrapeMovieByURL(url string) (*models.ScrapedMovie, error) { func (s *jsonScraper) scrapeMovieByURL(url string) (*models.ScrapedMovie, error) {
u := replaceURL(url, s.scraper) // allow a URL Replace for movie by URL queries u := replaceURL(url, s.scraper) // allow a URL Replace for movie by URL queries
doc, scraper, err := s.scrapeURL(u) doc, scraper, err := s.scrapeURL(context.TODO(), u)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -129,7 +130,7 @@ func (s *jsonScraper) scrapePerformersByName(name string) ([]*models.ScrapedPerf
url := s.scraper.QueryURL url := s.scraper.QueryURL
url = strings.Replace(url, placeholder, escapedName, -1) url = strings.Replace(url, placeholder, escapedName, -1)
doc, err := s.loadURL(url) doc, err := s.loadURL(context.TODO(), url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -158,7 +159,7 @@ func (s *jsonScraper) scrapeScenesByName(name string) ([]*models.ScrapedScene, e
url := s.scraper.QueryURL url := s.scraper.QueryURL
url = strings.Replace(url, placeholder, escapedName, -1) url = strings.Replace(url, placeholder, escapedName, -1)
doc, err := s.loadURL(url) doc, err := s.loadURL(context.TODO(), url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -182,7 +183,7 @@ func (s *jsonScraper) scrapeSceneByScene(scene *models.Scene) (*models.ScrapedSc
return nil, errors.New("json scraper with name " + s.scraper.Scraper + " not found in config") return nil, errors.New("json scraper with name " + s.scraper.Scraper + " not found in config")
} }
doc, err := s.loadURL(url) doc, err := s.loadURL(context.TODO(), url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -206,7 +207,7 @@ func (s *jsonScraper) scrapeSceneByFragment(scene models.ScrapedSceneInput) (*mo
return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config") return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config")
} }
doc, err := s.loadURL(url) doc, err := s.loadURL(context.TODO(), url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -230,7 +231,7 @@ func (s *jsonScraper) scrapeGalleryByGallery(gallery *models.Gallery) (*models.S
return nil, errors.New("json scraper with name " + s.scraper.Scraper + " not found in config") return nil, errors.New("json scraper with name " + s.scraper.Scraper + " not found in config")
} }
doc, err := s.loadURL(url) doc, err := s.loadURL(context.TODO(), url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -278,7 +279,7 @@ func (q *jsonQuery) runQuery(selector string) []string {
} }
func (q *jsonQuery) subScrape(value string) mappedQuery { func (q *jsonQuery) subScrape(value string) mappedQuery {
doc, err := q.scraper.loadURL(value) doc, err := q.scraper.loadURL(context.TODO(), value)
if err != nil { if err != nil {
logger.Warnf("Error getting URL '%s' for sub-scraper: %s", value, err.Error()) logger.Warnf("Error getting URL '%s' for sub-scraper: %s", value, err.Error())

View file

@ -202,7 +202,7 @@ func (c Cache) ScrapePerformer(scraperID string, scrapedPerformer models.Scraped
} }
if ret != nil { if ret != nil {
err = c.postScrapePerformer(ret) err = c.postScrapePerformer(context.TODO(), ret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -226,7 +226,7 @@ func (c Cache) ScrapePerformerURL(url string) (*models.ScrapedPerformer, error)
} }
if ret != nil { if ret != nil {
err = c.postScrapePerformer(ret) err = c.postScrapePerformer(context.TODO(), ret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -239,8 +239,8 @@ func (c Cache) ScrapePerformerURL(url string) (*models.ScrapedPerformer, error)
return nil, nil return nil, nil
} }
func (c Cache) postScrapePerformer(ret *models.ScrapedPerformer) error { func (c Cache) postScrapePerformer(ctx context.Context, ret *models.ScrapedPerformer) error {
if err := c.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
tqb := r.Tag() tqb := r.Tag()
tags, err := postProcessTags(tqb, ret.Tags) tags, err := postProcessTags(tqb, ret.Tags)
@ -255,7 +255,7 @@ func (c Cache) postScrapePerformer(ret *models.ScrapedPerformer) error {
} }
// post-process - set the image if applicable // post-process - set the image if applicable
if err := setPerformerImage(ret, c.globalConfig); err != nil { if err := setPerformerImage(ctx, ret, c.globalConfig); err != nil {
logger.Warnf("Could not set image using URL %s: %s", *ret.Image, err.Error()) logger.Warnf("Could not set image using URL %s: %s", *ret.Image, err.Error())
} }
@ -280,8 +280,8 @@ func (c Cache) postScrapeScenePerformer(ret *models.ScrapedPerformer) error {
return nil return nil
} }
func (c Cache) postScrapeScene(ret *models.ScrapedScene) error { func (c Cache) postScrapeScene(ctx context.Context, ret *models.ScrapedScene) error {
if err := c.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
pqb := r.Performer() pqb := r.Performer()
mqb := r.Movie() mqb := r.Movie()
tqb := r.Tag() tqb := r.Tag()
@ -323,8 +323,8 @@ func (c Cache) postScrapeScene(ret *models.ScrapedScene) error {
} }
// post-process - set the image if applicable // post-process - set the image if applicable
if err := setSceneImage(ret, c.globalConfig); err != nil { if err := setSceneImage(ctx, ret, c.globalConfig); err != nil {
logger.Warnf("Could not set image using URL %s: %s", *ret.Image, err.Error()) logger.Warnf("Could not set image using URL %s: %v", *ret.Image, err)
} }
return nil return nil
@ -382,7 +382,7 @@ func (c Cache) ScrapeScene(scraperID string, sceneID int) (*models.ScrapedScene,
} }
if ret != nil { if ret != nil {
err = c.postScrapeScene(ret) err = c.postScrapeScene(context.TODO(), ret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -419,7 +419,7 @@ func (c Cache) ScrapeSceneFragment(scraperID string, scene models.ScrapedSceneIn
} }
if ret != nil { if ret != nil {
err = c.postScrapeScene(ret) err = c.postScrapeScene(context.TODO(), ret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -443,7 +443,7 @@ func (c Cache) ScrapeSceneURL(url string) (*models.ScrapedScene, error) {
return nil, err return nil, err
} }
err = c.postScrapeScene(ret) err = c.postScrapeScene(context.TODO(), ret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -551,10 +551,10 @@ func (c Cache) ScrapeMovieURL(url string) (*models.ScrapedMovie, error) {
} }
// post-process - set the image if applicable // post-process - set the image if applicable
if err := setMovieFrontImage(ret, c.globalConfig); err != nil { if err := setMovieFrontImage(context.TODO(), ret, c.globalConfig); err != nil {
logger.Warnf("Could not set front image using URL %s: %s", *ret.FrontImage, err.Error()) logger.Warnf("Could not set front image using URL %s: %s", *ret.FrontImage, err.Error())
} }
if err := setMovieBackImage(ret, c.globalConfig); err != nil { if err := setMovieBackImage(context.TODO(), ret, c.globalConfig); err != nil {
logger.Warnf("Could not set back image using URL %s: %s", *ret.BackImage, err.Error()) logger.Warnf("Could not set back image using URL %s: %s", *ret.BackImage, err.Error())
} }

View file

@ -69,7 +69,7 @@ func (s *stashScraper) scrapePerformersByName(name string) ([]*models.ScrapedPer
}, },
} }
err := client.Query(context.Background(), &q, vars) err := client.Query(context.TODO(), &q, vars)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -125,7 +125,7 @@ func (s *stashScraper) scrapePerformerByFragment(scrapedPerformer models.Scraped
"f": performerID, "f": performerID,
} }
err := client.Query(context.Background(), &q, vars) err := client.Query(context.TODO(), &q, vars)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -138,7 +138,7 @@ func (s *stashScraper) scrapePerformerByFragment(scrapedPerformer models.Scraped
} }
// get the performer image directly // get the performer image directly
ret.Image, err = getStashPerformerImage(s.config.StashServer.URL, performerID, s.globalConfig) ret.Image, err = getStashPerformerImage(context.TODO(), s.config.StashServer.URL, performerID, s.globalConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -164,7 +164,7 @@ func (s *stashScraper) scrapedStashSceneToScrapedScene(scene *scrapedSceneStash)
} }
// get the performer image directly // get the performer image directly
ret.Image, err = getStashSceneImage(s.config.StashServer.URL, scene.ID, s.globalConfig) ret.Image, err = getStashSceneImage(context.TODO(), s.config.StashServer.URL, scene.ID, s.globalConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -190,7 +190,7 @@ func (s *stashScraper) scrapeScenesByName(name string) ([]*models.ScrapedScene,
}, },
} }
err := client.Query(context.Background(), &q, vars) err := client.Query(context.TODO(), &q, vars)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -240,7 +240,7 @@ func (s *stashScraper) scrapeSceneByScene(scene *models.Scene) (*models.ScrapedS
} }
client := s.getStashClient() client := s.getStashClient()
if err := client.Query(context.Background(), &q, vars); err != nil { if err := client.Query(context.TODO(), &q, vars); err != nil {
return nil, err return nil, err
} }
@ -251,7 +251,7 @@ func (s *stashScraper) scrapeSceneByScene(scene *models.Scene) (*models.ScrapedS
} }
// get the performer image directly // get the performer image directly
ret.Image, err = getStashSceneImage(s.config.StashServer.URL, q.FindScene.ID, s.globalConfig) ret.Image, err = getStashSceneImage(context.TODO(), s.config.StashServer.URL, q.FindScene.ID, s.globalConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -293,7 +293,7 @@ func (s *stashScraper) scrapeGalleryByGallery(gallery *models.Gallery) (*models.
} }
client := s.getStashClient() client := s.getStashClient()
if err := client.Query(context.Background(), &q, vars); err != nil { if err := client.Query(context.TODO(), &q, vars); err != nil {
return nil, err return nil, err
} }

View file

@ -45,8 +45,8 @@ func NewClient(box models.StashBox, txnManager models.TransactionManager) *Clien
} }
// QueryStashBoxScene queries stash-box for scenes using a query string. // QueryStashBoxScene queries stash-box for scenes using a query string.
func (c Client) QueryStashBoxScene(queryStr string) ([]*models.ScrapedScene, error) { func (c Client) QueryStashBoxScene(ctx context.Context, queryStr string) ([]*models.ScrapedScene, error) {
scenes, err := c.client.SearchScene(context.TODO(), queryStr) scenes, err := c.client.SearchScene(ctx, queryStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -55,7 +55,7 @@ func (c Client) QueryStashBoxScene(queryStr string) ([]*models.ScrapedScene, err
var ret []*models.ScrapedScene var ret []*models.ScrapedScene
for _, s := range sceneFragments { for _, s := range sceneFragments {
ss, err := sceneFragmentToScrapedScene(c.txnManager, s) ss, err := sceneFragmentToScrapedScene(context.TODO(), c.txnManager, s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -69,6 +69,8 @@ func (c Client) QueryStashBoxScene(queryStr string) ([]*models.ScrapedScene, err
// scene's MD5/OSHASH checksum, or PHash, and returns results in the same order // scene's MD5/OSHASH checksum, or PHash, and returns results in the same order
// as the input slice. // as the input slice.
func (c Client) FindStashBoxScenesByFingerprints(sceneIDs []string) ([][]*models.ScrapedScene, error) { func (c Client) FindStashBoxScenesByFingerprints(sceneIDs []string) ([][]*models.ScrapedScene, error) {
ctx := context.TODO()
ids, err := utils.StringSliceToIntSlice(sceneIDs) ids, err := utils.StringSliceToIntSlice(sceneIDs)
if err != nil { if err != nil {
return nil, err return nil, err
@ -78,7 +80,7 @@ func (c Client) FindStashBoxScenesByFingerprints(sceneIDs []string) ([][]*models
// map fingerprints to their scene index // map fingerprints to their scene index
fpToScene := make(map[string][]int) fpToScene := make(map[string][]int)
if err := c.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
qb := r.Scene() qb := r.Scene()
for index, sceneID := range ids { for index, sceneID := range ids {
@ -113,7 +115,7 @@ func (c Client) FindStashBoxScenesByFingerprints(sceneIDs []string) ([][]*models
return nil, err return nil, err
} }
allScenes, err := c.findStashBoxScenesByFingerprints(fingerprints) allScenes, err := c.findStashBoxScenesByFingerprints(ctx, fingerprints)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -139,6 +141,8 @@ func (c Client) FindStashBoxScenesByFingerprints(sceneIDs []string) ([][]*models
// FindStashBoxScenesByFingerprintsFlat queries stash-box for scenes using every // FindStashBoxScenesByFingerprintsFlat queries stash-box for scenes using every
// scene's MD5/OSHASH checksum, or PHash, and returns results a flat slice. // scene's MD5/OSHASH checksum, or PHash, and returns results a flat slice.
func (c Client) FindStashBoxScenesByFingerprintsFlat(sceneIDs []string) ([]*models.ScrapedScene, error) { func (c Client) FindStashBoxScenesByFingerprintsFlat(sceneIDs []string) ([]*models.ScrapedScene, error) {
ctx := context.TODO()
ids, err := utils.StringSliceToIntSlice(sceneIDs) ids, err := utils.StringSliceToIntSlice(sceneIDs)
if err != nil { if err != nil {
return nil, err return nil, err
@ -146,7 +150,7 @@ func (c Client) FindStashBoxScenesByFingerprintsFlat(sceneIDs []string) ([]*mode
var fingerprints []string var fingerprints []string
if err := c.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
qb := r.Scene() qb := r.Scene()
for _, sceneID := range ids { for _, sceneID := range ids {
@ -178,17 +182,17 @@ func (c Client) FindStashBoxScenesByFingerprintsFlat(sceneIDs []string) ([]*mode
return nil, err return nil, err
} }
return c.findStashBoxScenesByFingerprints(fingerprints) return c.findStashBoxScenesByFingerprints(ctx, fingerprints)
} }
func (c Client) findStashBoxScenesByFingerprints(fingerprints []string) ([]*models.ScrapedScene, error) { func (c Client) findStashBoxScenesByFingerprints(ctx context.Context, fingerprints []string) ([]*models.ScrapedScene, error) {
var ret []*models.ScrapedScene var ret []*models.ScrapedScene
for i := 0; i < len(fingerprints); i += 100 { for i := 0; i < len(fingerprints); i += 100 {
end := i + 100 end := i + 100
if end > len(fingerprints) { if end > len(fingerprints) {
end = len(fingerprints) end = len(fingerprints)
} }
scenes, err := c.client.FindScenesByFingerprints(context.TODO(), fingerprints[i:end]) scenes, err := c.client.FindScenesByFingerprints(ctx, fingerprints[i:end])
if err != nil { if err != nil {
return nil, err return nil, err
@ -197,7 +201,7 @@ func (c Client) findStashBoxScenesByFingerprints(fingerprints []string) ([]*mode
sceneFragments := scenes.FindScenesByFingerprints sceneFragments := scenes.FindScenesByFingerprints
for _, s := range sceneFragments { for _, s := range sceneFragments {
ss, err := sceneFragmentToScrapedScene(c.txnManager, s) ss, err := sceneFragmentToScrapedScene(ctx, c.txnManager, s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -504,12 +508,12 @@ func formatBodyModifications(m []*graphql.BodyModificationFragment) *string {
return &ret return &ret
} }
func fetchImage(url string) (*string, error) { func fetchImage(ctx context.Context, url string) (*string, error) {
client := &http.Client{ client := &http.Client{
Timeout: imageGetTimeout, Timeout: imageGetTimeout,
} }
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -590,8 +594,8 @@ func performerFragmentToScrapedScenePerformer(p graphql.PerformerFragment) *mode
return sp return sp
} }
func getFirstImage(images []*graphql.ImageFragment) *string { func getFirstImage(ctx context.Context, images []*graphql.ImageFragment) *string {
ret, err := fetchImage(images[0].URL) ret, err := fetchImage(ctx, images[0].URL)
if err != nil { if err != nil {
logger.Warnf("Error fetching image %s: %s", images[0].URL, err.Error()) logger.Warnf("Error fetching image %s: %s", images[0].URL, err.Error())
} }
@ -612,7 +616,7 @@ func getFingerprints(scene *graphql.SceneFragment) []*models.StashBoxFingerprint
return fingerprints return fingerprints
} }
func sceneFragmentToScrapedScene(txnManager models.TransactionManager, s *graphql.SceneFragment) (*models.ScrapedScene, error) { func sceneFragmentToScrapedScene(ctx context.Context, txnManager models.TransactionManager, s *graphql.SceneFragment) (*models.ScrapedScene, error) {
stashID := s.ID stashID := s.ID
ss := &models.ScrapedScene{ ss := &models.ScrapedScene{
Title: s.Title, Title: s.Title,
@ -629,10 +633,10 @@ func sceneFragmentToScrapedScene(txnManager models.TransactionManager, s *graphq
if len(s.Images) > 0 { if len(s.Images) > 0 {
// TODO - #454 code sorts images by aspect ratio according to a wanted // TODO - #454 code sorts images by aspect ratio according to a wanted
// orientation. I'm just grabbing the first for now // orientation. I'm just grabbing the first for now
ss.Image = getFirstImage(s.Images) ss.Image = getFirstImage(ctx, s.Images)
} }
if err := txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
pqb := r.Performer() pqb := r.Performer()
tqb := r.Tag() tqb := r.Tag()

View file

@ -28,11 +28,11 @@ import (
const scrapeGetTimeout = time.Second * 60 const scrapeGetTimeout = time.Second * 60
const scrapeDefaultSleep = time.Second * 2 const scrapeDefaultSleep = time.Second * 2
func loadURL(url string, scraperConfig config, globalConfig GlobalConfig) (io.Reader, error) { func loadURL(ctx context.Context, url string, scraperConfig config, globalConfig GlobalConfig) (io.Reader, error) {
driverOptions := scraperConfig.DriverOptions driverOptions := scraperConfig.DriverOptions
if driverOptions != nil && driverOptions.UseCDP { if driverOptions != nil && driverOptions.UseCDP {
// get the page using chrome dp // get the page using chrome dp
return urlFromCDP(url, *driverOptions, globalConfig) return urlFromCDP(ctx, url, *driverOptions, globalConfig)
} }
// get the page using http.Client // get the page using http.Client
@ -62,7 +62,7 @@ func loadURL(url string, scraperConfig config, globalConfig GlobalConfig) (io.Re
Jar: jar, Jar: jar,
} }
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -105,7 +105,7 @@ func loadURL(url string, scraperConfig config, globalConfig GlobalConfig) (io.Re
// func urlFromCDP uses chrome cdp and DOM to load and process the url // func urlFromCDP uses chrome cdp and DOM to load and process the url
// if remote is set as true in the scraperConfig it will try to use localhost:9222 // if remote is set as true in the scraperConfig it will try to use localhost:9222
// else it will look for google-chrome in path // else it will look for google-chrome in path
func urlFromCDP(url string, driverOptions scraperDriverOptions, globalConfig GlobalConfig) (io.Reader, error) { func urlFromCDP(ctx context.Context, url string, driverOptions scraperDriverOptions, globalConfig GlobalConfig) (io.Reader, error) {
if !driverOptions.UseCDP { if !driverOptions.UseCDP {
return nil, fmt.Errorf("url shouldn't be fetched through CDP") return nil, fmt.Errorf("url shouldn't be fetched through CDP")
@ -117,7 +117,7 @@ func urlFromCDP(url string, driverOptions scraperDriverOptions, globalConfig Glo
sleepDuration = time.Duration(driverOptions.Sleep) * time.Second sleepDuration = time.Duration(driverOptions.Sleep) * time.Second
} }
act := context.Background() act := context.TODO()
// if scraperCDPPath is a remote address, then allocate accordingly // if scraperCDPPath is a remote address, then allocate accordingly
cdpPath := globalConfig.GetScraperCDPPath() cdpPath := globalConfig.GetScraperCDPPath()
@ -130,13 +130,13 @@ func urlFromCDP(url string, driverOptions scraperDriverOptions, globalConfig Glo
// if CDPPath is http(s) then we need to get the websocket URL // if CDPPath is http(s) then we need to get the websocket URL
if isCDPPathHTTP(globalConfig) { if isCDPPathHTTP(globalConfig) {
var err error var err error
remote, err = getRemoteCDPWSAddress(remote) remote, err = getRemoteCDPWSAddress(ctx, remote)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
act, cancelAct = chromedp.NewRemoteAllocator(context.Background(), remote) act, cancelAct = chromedp.NewRemoteAllocator(act, remote)
} else { } else {
// use a temporary user directory for chrome // use a temporary user directory for chrome
dir, err := os.MkdirTemp("", "stash-chromedp") dir, err := os.MkdirTemp("", "stash-chromedp")
@ -218,8 +218,13 @@ func setCDPClicks(driverOptions scraperDriverOptions) chromedp.Tasks {
} }
// getRemoteCDPWSAddress returns the complete remote address that is required to access the cdp instance // getRemoteCDPWSAddress returns the complete remote address that is required to access the cdp instance
func getRemoteCDPWSAddress(address string) (string, error) { func getRemoteCDPWSAddress(ctx context.Context, url string) (string, error) {
resp, err := http.Get(address) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return "", err return "", err
} }

View file

@ -2,6 +2,7 @@ package scraper
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"net/url" "net/url"
"regexp" "regexp"
@ -42,7 +43,7 @@ func (s *xpathScraper) scrapeURL(url string) (*html.Node, *mappedScraper, error)
return nil, nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config") return nil, nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config")
} }
doc, err := s.loadURL(url) doc, err := s.loadURL(context.TODO(), url)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -110,7 +111,7 @@ func (s *xpathScraper) scrapePerformersByName(name string) ([]*models.ScrapedPer
url := s.scraper.QueryURL url := s.scraper.QueryURL
url = strings.Replace(url, placeholder, escapedName, -1) url = strings.Replace(url, placeholder, escapedName, -1)
doc, err := s.loadURL(url) doc, err := s.loadURL(context.TODO(), url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -139,7 +140,7 @@ func (s *xpathScraper) scrapeScenesByName(name string) ([]*models.ScrapedScene,
url := s.scraper.QueryURL url := s.scraper.QueryURL
url = strings.Replace(url, placeholder, escapedName, -1) url = strings.Replace(url, placeholder, escapedName, -1)
doc, err := s.loadURL(url) doc, err := s.loadURL(context.TODO(), url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -163,7 +164,7 @@ func (s *xpathScraper) scrapeSceneByScene(scene *models.Scene) (*models.ScrapedS
return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config") return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config")
} }
doc, err := s.loadURL(url) doc, err := s.loadURL(context.TODO(), url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -187,7 +188,7 @@ func (s *xpathScraper) scrapeSceneByFragment(scene models.ScrapedSceneInput) (*m
return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config") return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config")
} }
doc, err := s.loadURL(url) doc, err := s.loadURL(context.TODO(), url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -211,7 +212,7 @@ func (s *xpathScraper) scrapeGalleryByGallery(gallery *models.Gallery) (*models.
return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config") return nil, errors.New("xpath scraper with name " + s.scraper.Scraper + " not found in config")
} }
doc, err := s.loadURL(url) doc, err := s.loadURL(context.TODO(), url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -225,8 +226,8 @@ func (s *xpathScraper) scrapeGalleryByFragment(gallery models.ScrapedGalleryInpu
return nil, errors.New("scrapeGalleryByFragment not supported for xpath scraper") return nil, errors.New("scrapeGalleryByFragment not supported for xpath scraper")
} }
func (s *xpathScraper) loadURL(url string) (*html.Node, error) { func (s *xpathScraper) loadURL(ctx context.Context, url string) (*html.Node, error) {
r, err := loadURL(url, s.config, s.globalConfig) r, err := loadURL(ctx, url, s.config, s.globalConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -298,7 +299,7 @@ func (q *xpathQuery) nodeText(n *html.Node) string {
} }
func (q *xpathQuery) subScrape(value string) mappedQuery { func (q *xpathQuery) subScrape(value string) mappedQuery {
doc, err := q.scraper.loadURL(value) doc, err := q.scraper.loadURL(context.TODO(), value)
if err != nil { if err != nil {
logger.Warnf("Error getting URL '%s' for sub-scraper: %s", value, err.Error()) logger.Warnf("Error getting URL '%s' for sub-scraper: %s", value, err.Error())

View file

@ -1,6 +1,7 @@
package utils package utils
import ( import (
"context"
"crypto/md5" "crypto/md5"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
@ -20,7 +21,7 @@ const base64RE = `^data:.+\/(.+);base64,(.*)$`
// ProcessImageInput transforms an image string either from a base64 encoded // ProcessImageInput transforms an image string either from a base64 encoded
// string, or from a URL, and returns the image as a byte slice // string, or from a URL, and returns the image as a byte slice
func ProcessImageInput(imageInput string) ([]byte, error) { func ProcessImageInput(ctx context.Context, imageInput string) ([]byte, error) {
regex := regexp.MustCompile(base64RE) regex := regexp.MustCompile(base64RE)
if regex.MatchString(imageInput) { if regex.MatchString(imageInput) {
_, d, err := ProcessBase64Image(imageInput) _, d, err := ProcessBase64Image(imageInput)
@ -28,11 +29,11 @@ func ProcessImageInput(imageInput string) ([]byte, error) {
} }
// assume input is a URL. Read it. // assume input is a URL. Read it.
return ReadImageFromURL(imageInput) return ReadImageFromURL(ctx, imageInput)
} }
// ReadImageFromURL returns image data from a URL // ReadImageFromURL returns image data from a URL
func ReadImageFromURL(url string) ([]byte, error) { func ReadImageFromURL(ctx context.Context, url string) ([]byte, error) {
client := &http.Client{ client := &http.Client{
Transport: &http.Transport{ // ignore insecure certificates Transport: &http.Transport{ // ignore insecure certificates
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
@ -41,7 +42,7 @@ func ReadImageFromURL(url string) ([]byte, error) {
Timeout: imageGetTimeout, Timeout: imageGetTimeout,
} }
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }