mirror of
https://github.com/stashapp/stash.git
synced 2025-12-06 08:26:00 +01:00
This means that long queries can be cancelled by navigating to another page. Previously the query would continue to run, impacting on future queries.
521 lines
13 KiB
Go
521 lines
13 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
|
|
"github.com/stashapp/stash/pkg/models"
|
|
)
|
|
|
|
const idColumn = "id"
|
|
|
|
type repository struct {
|
|
tableName string
|
|
idColumn string
|
|
}
|
|
|
|
func (r *repository) getAll(ctx context.Context, id int, f func(rows *sqlx.Rows) error) error {
|
|
stmt := fmt.Sprintf("SELECT * FROM %s WHERE %s = ?", r.tableName, r.idColumn)
|
|
return r.queryFunc(ctx, stmt, []interface{}{id}, false, f)
|
|
}
|
|
|
|
func (r *repository) destroyExisting(ctx context.Context, ids []int) error {
|
|
for _, id := range ids {
|
|
exists, err := r.exists(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !exists {
|
|
return fmt.Errorf("%s %d does not exist in %s", r.idColumn, id, r.tableName)
|
|
}
|
|
}
|
|
|
|
return r.destroy(ctx, ids)
|
|
}
|
|
|
|
func (r *repository) destroy(ctx context.Context, ids []int) error {
|
|
for _, id := range ids {
|
|
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", r.tableName, r.idColumn)
|
|
if _, err := dbWrapper.Exec(ctx, stmt, id); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *repository) exists(ctx context.Context, id int) (bool, error) {
|
|
stmt := fmt.Sprintf("SELECT %s FROM %s WHERE %s = ? LIMIT 1", r.idColumn, r.tableName, r.idColumn)
|
|
stmt = r.buildCountQuery(stmt)
|
|
|
|
c, err := r.runCountQuery(ctx, stmt, []interface{}{id})
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return c == 1, nil
|
|
}
|
|
|
|
func (r *repository) buildCountQuery(query string) string {
|
|
return "SELECT COUNT(*) as count FROM (" + query + ") as temp"
|
|
}
|
|
|
|
func (r *repository) runCountQuery(ctx context.Context, query string, args []interface{}) (int, error) {
|
|
result := struct {
|
|
Int int `db:"count"`
|
|
}{0}
|
|
|
|
// Perform query and fetch result
|
|
if err := dbWrapper.Get(ctx, &result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return 0, err
|
|
}
|
|
|
|
return result.Int, nil
|
|
}
|
|
|
|
func (r *repository) runIdsQuery(ctx context.Context, query string, args []interface{}) ([]int, error) {
|
|
var result []struct {
|
|
Int int `db:"id"`
|
|
}
|
|
|
|
if err := dbWrapper.Select(ctx, &result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return []int{}, fmt.Errorf("running query: %s [%v]: %w", query, args, err)
|
|
}
|
|
|
|
vsm := make([]int, len(result))
|
|
for i, v := range result {
|
|
vsm[i] = v.Int
|
|
}
|
|
return vsm, nil
|
|
}
|
|
|
|
func (r *repository) queryFunc(ctx context.Context, query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error {
|
|
rows, err := dbWrapper.QueryxContext(ctx, query, args...)
|
|
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
if err := f(rows); err != nil {
|
|
return err
|
|
}
|
|
if single {
|
|
break
|
|
}
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// queryStruct executes a query and scans the result into the provided struct.
|
|
// Unlike the other query methods, this will return an error if no rows are found.
|
|
func (r *repository) queryStruct(ctx context.Context, query string, args []interface{}, out interface{}) error {
|
|
// changed from queryFunc, since it was not logging the performance correctly,
|
|
// since the query doesn't actually execute until Scan is called
|
|
if err := dbWrapper.Get(ctx, out, query, args...); err != nil {
|
|
return fmt.Errorf("executing query: %s [%v]: %w", query, args, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *repository) querySimple(ctx context.Context, query string, args []interface{}, out interface{}) error {
|
|
rows, err := dbWrapper.Queryx(ctx, query, args...)
|
|
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
if rows.Next() {
|
|
if err := rows.Scan(out); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *repository) buildQueryBody(body string, whereClauses []string, havingClauses []string) string {
|
|
if len(whereClauses) > 0 {
|
|
body = body + " WHERE " + strings.Join(whereClauses, " AND ") // TODO handle AND or OR
|
|
}
|
|
if len(havingClauses) > 0 {
|
|
body = body + " GROUP BY " + r.tableName + ".id "
|
|
body = body + " HAVING " + strings.Join(havingClauses, " AND ") // TODO handle AND or OR
|
|
}
|
|
|
|
return body
|
|
}
|
|
|
|
func (r *repository) executeFindQuery(ctx context.Context, body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string, recursiveWith bool) ([]int, int, error) {
|
|
body = r.buildQueryBody(body, whereClauses, havingClauses)
|
|
|
|
withClause := ""
|
|
if len(withClauses) > 0 {
|
|
var recursive string
|
|
if recursiveWith {
|
|
recursive = " RECURSIVE "
|
|
}
|
|
withClause = "WITH " + recursive + strings.Join(withClauses, ", ") + " "
|
|
}
|
|
|
|
countQuery := withClause + r.buildCountQuery(body)
|
|
idsQuery := withClause + body + sortAndPagination
|
|
|
|
// Perform query and fetch result
|
|
var countResult int
|
|
var countErr error
|
|
var idsResult []int
|
|
var idsErr error
|
|
|
|
countResult, countErr = r.runCountQuery(ctx, countQuery, args)
|
|
idsResult, idsErr = r.runIdsQuery(ctx, idsQuery, args)
|
|
|
|
if countErr != nil {
|
|
return nil, 0, fmt.Errorf("error executing count query with SQL: %s, args: %v, error: %s", countQuery, args, countErr.Error())
|
|
}
|
|
if idsErr != nil {
|
|
return nil, 0, fmt.Errorf("error executing find query with SQL: %s, args: %v, error: %s", idsQuery, args, idsErr.Error())
|
|
}
|
|
|
|
return idsResult, countResult, nil
|
|
}
|
|
|
|
func (r *repository) newQuery() queryBuilder {
|
|
return queryBuilder{
|
|
repository: r,
|
|
}
|
|
}
|
|
|
|
func (r *repository) join(j joiner, as string, parentIDCol string) {
|
|
t := r.tableName
|
|
if as != "" {
|
|
t = as
|
|
}
|
|
j.addLeftJoin(r.tableName, as, fmt.Sprintf("%s.%s = %s", t, r.idColumn, parentIDCol))
|
|
}
|
|
|
|
func (r *repository) innerJoin(j joiner, as string, parentIDCol string) {
|
|
t := r.tableName
|
|
if as != "" {
|
|
t = as
|
|
}
|
|
j.addInnerJoin(r.tableName, as, fmt.Sprintf("%s.%s = %s", t, r.idColumn, parentIDCol))
|
|
}
|
|
|
|
type joiner interface {
|
|
addLeftJoin(table, as, onClause string, args ...interface{})
|
|
addInnerJoin(table, as, onClause string, args ...interface{})
|
|
}
|
|
|
|
type joinRepository struct {
|
|
repository
|
|
fkColumn string
|
|
|
|
// fields for ordering
|
|
foreignTable string
|
|
orderBy string
|
|
}
|
|
|
|
func (r *joinRepository) getIDs(ctx context.Context, id int) ([]int, error) {
|
|
var joinStr string
|
|
if r.foreignTable != "" {
|
|
joinStr = fmt.Sprintf(" INNER JOIN %s ON %[1]s.id = %s.%s", r.foreignTable, r.tableName, r.fkColumn)
|
|
}
|
|
|
|
query := fmt.Sprintf(`SELECT %[2]s.%[1]s as id from %s%s WHERE %s = ?`, r.fkColumn, r.tableName, joinStr, r.idColumn)
|
|
|
|
if r.orderBy != "" {
|
|
query += " ORDER BY " + r.orderBy
|
|
}
|
|
|
|
return r.runIdsQuery(ctx, query, []interface{}{id})
|
|
}
|
|
|
|
func (r *joinRepository) insert(ctx context.Context, id int, foreignIDs ...int) error {
|
|
stmt, err := dbWrapper.Prepare(ctx, fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.fkColumn))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
defer stmt.Close()
|
|
|
|
for _, fk := range foreignIDs {
|
|
if _, err := dbWrapper.ExecStmt(ctx, stmt, id, fk); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// insertOrIgnore inserts a join into the table, silently failing in the event that a conflict occurs (ie when the join already exists)
|
|
func (r *joinRepository) insertOrIgnore(ctx context.Context, id int, foreignIDs ...int) error {
|
|
stmt, err := dbWrapper.Prepare(ctx, fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?) ON CONFLICT (%[2]s, %s) DO NOTHING", r.tableName, r.idColumn, r.fkColumn))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
defer stmt.Close()
|
|
|
|
for _, fk := range foreignIDs {
|
|
if _, err := dbWrapper.ExecStmt(ctx, stmt, id, fk); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *joinRepository) destroyJoins(ctx context.Context, id int, foreignIDs ...int) error {
|
|
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ? AND %s IN %s", r.tableName, r.idColumn, r.fkColumn, getInBinding(len(foreignIDs)))
|
|
|
|
args := make([]interface{}, len(foreignIDs)+1)
|
|
args[0] = id
|
|
for i, v := range foreignIDs {
|
|
args[i+1] = v
|
|
}
|
|
|
|
if _, err := dbWrapper.Exec(ctx, stmt, args...); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *joinRepository) replace(ctx context.Context, id int, foreignIDs []int) error {
|
|
if err := r.destroy(ctx, []int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, fk := range foreignIDs {
|
|
if err := r.insert(ctx, id, fk); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type captionRepository struct {
|
|
repository
|
|
}
|
|
|
|
func (r *captionRepository) get(ctx context.Context, id models.FileID) ([]*models.VideoCaption, error) {
|
|
query := fmt.Sprintf("SELECT %s, %s, %s from %s WHERE %s = ?", captionCodeColumn, captionFilenameColumn, captionTypeColumn, r.tableName, r.idColumn)
|
|
var ret []*models.VideoCaption
|
|
err := r.queryFunc(ctx, query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
|
|
var captionCode string
|
|
var captionFilename string
|
|
var captionType string
|
|
|
|
if err := rows.Scan(&captionCode, &captionFilename, &captionType); err != nil {
|
|
return err
|
|
}
|
|
|
|
caption := &models.VideoCaption{
|
|
LanguageCode: captionCode,
|
|
Filename: captionFilename,
|
|
CaptionType: captionType,
|
|
}
|
|
ret = append(ret, caption)
|
|
return nil
|
|
})
|
|
return ret, err
|
|
}
|
|
|
|
func (r *captionRepository) insert(ctx context.Context, id models.FileID, caption *models.VideoCaption) (sql.Result, error) {
|
|
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s, %s, %s) VALUES (?, ?, ?, ?)", r.tableName, r.idColumn, captionCodeColumn, captionFilenameColumn, captionTypeColumn)
|
|
return dbWrapper.Exec(ctx, stmt, id, caption.LanguageCode, caption.Filename, caption.CaptionType)
|
|
}
|
|
|
|
func (r *captionRepository) replace(ctx context.Context, id models.FileID, captions []*models.VideoCaption) error {
|
|
if err := r.destroy(ctx, []int{int(id)}); err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, caption := range captions {
|
|
if _, err := r.insert(ctx, id, caption); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type stringRepository struct {
|
|
repository
|
|
stringColumn string
|
|
}
|
|
|
|
func (r *stringRepository) get(ctx context.Context, id int) ([]string, error) {
|
|
query := fmt.Sprintf("SELECT %s from %s WHERE %s = ?", r.stringColumn, r.tableName, r.idColumn)
|
|
var ret []string
|
|
err := r.queryFunc(ctx, query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
|
|
var out string
|
|
if err := rows.Scan(&out); err != nil {
|
|
return err
|
|
}
|
|
|
|
ret = append(ret, out)
|
|
return nil
|
|
})
|
|
return ret, err
|
|
}
|
|
|
|
func (r *stringRepository) insert(ctx context.Context, id int, s string) (sql.Result, error) {
|
|
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.stringColumn)
|
|
return dbWrapper.Exec(ctx, stmt, id, s)
|
|
}
|
|
|
|
func (r *stringRepository) replace(ctx context.Context, id int, newStrings []string) error {
|
|
if err := r.destroy(ctx, []int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, s := range newStrings {
|
|
if _, err := r.insert(ctx, id, s); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type stashIDRepository struct {
|
|
repository
|
|
}
|
|
|
|
type stashIDs []models.StashID
|
|
|
|
func (s *stashIDs) Append(o interface{}) {
|
|
*s = append(*s, o.(models.StashID))
|
|
}
|
|
|
|
func (s *stashIDs) New() interface{} {
|
|
return &models.StashID{}
|
|
}
|
|
|
|
func (r *stashIDRepository) get(ctx context.Context, id int) ([]models.StashID, error) {
|
|
query := fmt.Sprintf("SELECT stash_id, endpoint, updated_at from %s WHERE %s = ?", r.tableName, r.idColumn)
|
|
var ret stashIDs
|
|
err := r.queryFunc(ctx, query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
|
|
var v stashIDRow
|
|
if err := rows.StructScan(&v); err != nil {
|
|
return err
|
|
}
|
|
ret.Append(v.resolve())
|
|
return nil
|
|
})
|
|
return ret, err
|
|
}
|
|
|
|
type filesRepository struct {
|
|
repository
|
|
}
|
|
|
|
type relatedFileRow struct {
|
|
ID int `db:"id"`
|
|
FileID models.FileID `db:"file_id"`
|
|
Primary bool `db:"primary"`
|
|
}
|
|
|
|
func idToIndexMap(ids []int) map[int]int {
|
|
ret := make(map[int]int)
|
|
for i, id := range ids {
|
|
ret[id] = i
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func (r *filesRepository) getMany(ctx context.Context, ids []int, primaryOnly bool) ([][]models.FileID, error) {
|
|
var primaryClause string
|
|
if primaryOnly {
|
|
primaryClause = " AND `primary` = 1"
|
|
}
|
|
|
|
query := fmt.Sprintf("SELECT %s as id, file_id, `primary` from %s WHERE %[1]s IN %[3]s%s", r.idColumn, r.tableName, getInBinding(len(ids)), primaryClause)
|
|
|
|
idi := make([]interface{}, len(ids))
|
|
for i, id := range ids {
|
|
idi[i] = id
|
|
}
|
|
|
|
var fileRows []relatedFileRow
|
|
if err := r.queryFunc(ctx, query, idi, false, func(rows *sqlx.Rows) error {
|
|
var f relatedFileRow
|
|
|
|
if err := rows.StructScan(&f); err != nil {
|
|
return err
|
|
}
|
|
|
|
fileRows = append(fileRows, f)
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ret := make([][]models.FileID, len(ids))
|
|
idToIndex := idToIndexMap(ids)
|
|
|
|
for _, row := range fileRows {
|
|
id := row.ID
|
|
fileID := row.FileID
|
|
|
|
if row.Primary {
|
|
// prepend to list
|
|
ret[idToIndex[id]] = append([]models.FileID{fileID}, ret[idToIndex[id]]...)
|
|
} else {
|
|
ret[idToIndex[id]] = append(ret[idToIndex[id]], row.FileID)
|
|
}
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (r *filesRepository) get(ctx context.Context, id int) ([]models.FileID, error) {
|
|
query := fmt.Sprintf("SELECT file_id, `primary` from %s WHERE %s = ?", r.tableName, r.idColumn)
|
|
|
|
type relatedFile struct {
|
|
FileID models.FileID `db:"file_id"`
|
|
Primary bool `db:"primary"`
|
|
}
|
|
|
|
var ret []models.FileID
|
|
if err := r.queryFunc(ctx, query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
|
|
var f relatedFile
|
|
|
|
if err := rows.StructScan(&f); err != nil {
|
|
return err
|
|
}
|
|
|
|
if f.Primary {
|
|
// prepend to list
|
|
ret = append([]models.FileID{f.FileID}, ret...)
|
|
} else {
|
|
ret = append(ret, f.FileID)
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ret, nil
|
|
}
|