From 3e6de00508693eb1c4fddb5353df8b91aaf3b5df Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Fri, 5 Dec 2025 18:18:32 +1100 Subject: [PATCH] Make all calls to the database pass context. This means that long queries can be cancelled by navigating to another page. Previously the query would continue to run, impacting on future queries. --- pkg/sqlite/repository.go | 2 +- pkg/sqlite/tx.go | 25 ++++++++----------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/pkg/sqlite/repository.go b/pkg/sqlite/repository.go index 92ea10ee0..18d501e3a 100644 --- a/pkg/sqlite/repository.go +++ b/pkg/sqlite/repository.go @@ -96,7 +96,7 @@ func (r *repository) runIdsQuery(ctx context.Context, query string, args []inter } func (r *repository) queryFunc(ctx context.Context, query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error { - rows, err := dbWrapper.Queryx(ctx, query, args...) + rows, err := dbWrapper.QueryxContext(ctx, query, args...) if err != nil && !errors.Is(err, sql.ErrNoRows) { return err diff --git a/pkg/sqlite/tx.go b/pkg/sqlite/tx.go index a2e272aa9..b6701dc81 100644 --- a/pkg/sqlite/tx.go +++ b/pkg/sqlite/tx.go @@ -16,8 +16,8 @@ const ( type dbReader interface { Get(dest interface{}, query string, args ...interface{}) error - Select(dest interface{}, query string, args ...interface{}) error - Queryx(query string, args ...interface{}) (*sqlx.Rows, error) + GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error + SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) } @@ -54,7 +54,7 @@ func (*dbWrapperType) Get(ctx context.Context, dest interface{}, query string, a } start := time.Now() - err = tx.Get(dest, query, args...) + err = tx.GetContext(ctx, dest, query, args...) logSQL(start, query, args...) return sqlError(err, query, args...) @@ -67,7 +67,7 @@ func (*dbWrapperType) Select(ctx context.Context, dest interface{}, query string } start := time.Now() - err = tx.Select(dest, query, args...) + err = tx.SelectContext(ctx, dest, query, args...) logSQL(start, query, args...) return sqlError(err, query, args...) @@ -80,23 +80,14 @@ func (*dbWrapperType) Queryx(ctx context.Context, query string, args ...interfac } start := time.Now() - ret, err := tx.Queryx(query, args...) + ret, err := tx.QueryxContext(ctx, query, args...) logSQL(start, query, args...) return ret, sqlError(err, query, args...) } func (*dbWrapperType) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { - tx, err := getDBReader(ctx) - if err != nil { - return nil, sqlError(err, query, args...) - } - - start := time.Now() - ret, err := tx.QueryxContext(ctx, query, args...) - logSQL(start, query, args...) - - return ret, sqlError(err, query, args...) + return dbWrapper.Queryx(ctx, query, args...) } func (*dbWrapperType) NamedExec(ctx context.Context, query string, arg interface{}) (sql.Result, error) { @@ -106,7 +97,7 @@ func (*dbWrapperType) NamedExec(ctx context.Context, query string, arg interface } start := time.Now() - ret, err := tx.NamedExec(query, arg) + ret, err := tx.NamedExecContext(ctx, query, arg) logSQL(start, query, arg) return ret, sqlError(err, query, arg) @@ -119,7 +110,7 @@ func (*dbWrapperType) Exec(ctx context.Context, query string, args ...interface{ } start := time.Now() - ret, err := tx.Exec(query, args...) + ret, err := tx.ExecContext(ctx, query, args...) logSQL(start, query, args...) return ret, sqlError(err, query, args...)