Make all calls to the database pass context.

This means that long queries can be cancelled by navigating to another page. Previously the query would continue to run, impacting on future queries.
This commit is contained in:
WithoutPants 2025-12-05 18:18:32 +11:00
parent ff360ba5b1
commit 3e6de00508
2 changed files with 9 additions and 18 deletions

View file

@ -96,7 +96,7 @@ func (r *repository) runIdsQuery(ctx context.Context, query string, args []inter
}
func (r *repository) queryFunc(ctx context.Context, query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error {
rows, err := dbWrapper.Queryx(ctx, query, args...)
rows, err := dbWrapper.QueryxContext(ctx, query, args...)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err

View file

@ -16,8 +16,8 @@ const (
type dbReader interface {
Get(dest interface{}, query string, args ...interface{}) error
Select(dest interface{}, query string, args ...interface{}) error
Queryx(query string, args ...interface{}) (*sqlx.Rows, error)
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error)
}
@ -54,7 +54,7 @@ func (*dbWrapperType) Get(ctx context.Context, dest interface{}, query string, a
}
start := time.Now()
err = tx.Get(dest, query, args...)
err = tx.GetContext(ctx, dest, query, args...)
logSQL(start, query, args...)
return sqlError(err, query, args...)
@ -67,7 +67,7 @@ func (*dbWrapperType) Select(ctx context.Context, dest interface{}, query string
}
start := time.Now()
err = tx.Select(dest, query, args...)
err = tx.SelectContext(ctx, dest, query, args...)
logSQL(start, query, args...)
return sqlError(err, query, args...)
@ -80,23 +80,14 @@ func (*dbWrapperType) Queryx(ctx context.Context, query string, args ...interfac
}
start := time.Now()
ret, err := tx.Queryx(query, args...)
ret, err := tx.QueryxContext(ctx, query, args...)
logSQL(start, query, args...)
return ret, sqlError(err, query, args...)
}
func (*dbWrapperType) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
tx, err := getDBReader(ctx)
if err != nil {
return nil, sqlError(err, query, args...)
}
start := time.Now()
ret, err := tx.QueryxContext(ctx, query, args...)
logSQL(start, query, args...)
return ret, sqlError(err, query, args...)
return dbWrapper.Queryx(ctx, query, args...)
}
func (*dbWrapperType) NamedExec(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
@ -106,7 +97,7 @@ func (*dbWrapperType) NamedExec(ctx context.Context, query string, arg interface
}
start := time.Now()
ret, err := tx.NamedExec(query, arg)
ret, err := tx.NamedExecContext(ctx, query, arg)
logSQL(start, query, arg)
return ret, sqlError(err, query, arg)
@ -119,7 +110,7 @@ func (*dbWrapperType) Exec(ctx context.Context, query string, args ...interface{
}
start := time.Now()
ret, err := tx.Exec(query, args...)
ret, err := tx.ExecContext(ctx, query, args...)
logSQL(start, query, args...)
return ret, sqlError(err, query, args...)