Make all calls to the database pass context.

This means that long queries can be cancelled by navigating to another page. Previously the query would continue to run, impacting on future queries.
This commit is contained in:
WithoutPants 2025-12-05 18:18:32 +11:00
parent ff360ba5b1
commit 3e6de00508
2 changed files with 9 additions and 18 deletions

View file

@ -96,7 +96,7 @@ func (r *repository) runIdsQuery(ctx context.Context, query string, args []inter
} }
func (r *repository) queryFunc(ctx context.Context, query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error { func (r *repository) queryFunc(ctx context.Context, query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error {
rows, err := dbWrapper.Queryx(ctx, query, args...) rows, err := dbWrapper.QueryxContext(ctx, query, args...)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err return err

View file

@ -16,8 +16,8 @@ const (
type dbReader interface { type dbReader interface {
Get(dest interface{}, query string, args ...interface{}) error Get(dest interface{}, query string, args ...interface{}) error
Select(dest interface{}, query string, args ...interface{}) error GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
Queryx(query string, args ...interface{}) (*sqlx.Rows, error) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error)
} }
@ -54,7 +54,7 @@ func (*dbWrapperType) Get(ctx context.Context, dest interface{}, query string, a
} }
start := time.Now() start := time.Now()
err = tx.Get(dest, query, args...) err = tx.GetContext(ctx, dest, query, args...)
logSQL(start, query, args...) logSQL(start, query, args...)
return sqlError(err, query, args...) return sqlError(err, query, args...)
@ -67,7 +67,7 @@ func (*dbWrapperType) Select(ctx context.Context, dest interface{}, query string
} }
start := time.Now() start := time.Now()
err = tx.Select(dest, query, args...) err = tx.SelectContext(ctx, dest, query, args...)
logSQL(start, query, args...) logSQL(start, query, args...)
return sqlError(err, query, args...) return sqlError(err, query, args...)
@ -80,23 +80,14 @@ func (*dbWrapperType) Queryx(ctx context.Context, query string, args ...interfac
} }
start := time.Now() start := time.Now()
ret, err := tx.Queryx(query, args...) ret, err := tx.QueryxContext(ctx, query, args...)
logSQL(start, query, args...) logSQL(start, query, args...)
return ret, sqlError(err, query, args...) return ret, sqlError(err, query, args...)
} }
func (*dbWrapperType) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { func (*dbWrapperType) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
tx, err := getDBReader(ctx) return dbWrapper.Queryx(ctx, query, args...)
if err != nil {
return nil, sqlError(err, query, args...)
}
start := time.Now()
ret, err := tx.QueryxContext(ctx, query, args...)
logSQL(start, query, args...)
return ret, sqlError(err, query, args...)
} }
func (*dbWrapperType) NamedExec(ctx context.Context, query string, arg interface{}) (sql.Result, error) { func (*dbWrapperType) NamedExec(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
@ -106,7 +97,7 @@ func (*dbWrapperType) NamedExec(ctx context.Context, query string, arg interface
} }
start := time.Now() start := time.Now()
ret, err := tx.NamedExec(query, arg) ret, err := tx.NamedExecContext(ctx, query, arg)
logSQL(start, query, arg) logSQL(start, query, arg)
return ret, sqlError(err, query, arg) return ret, sqlError(err, query, arg)
@ -119,7 +110,7 @@ func (*dbWrapperType) Exec(ctx context.Context, query string, args ...interface{
} }
start := time.Now() start := time.Now()
ret, err := tx.Exec(query, args...) ret, err := tx.ExecContext(ctx, query, args...)
logSQL(start, query, args...) logSQL(start, query, args...)
return ret, sqlError(err, query, args...) return ret, sqlError(err, query, args...)