diff --git a/pkg/sqlite/repository.go b/pkg/sqlite/repository.go index 92ea10ee0..18d501e3a 100644 --- a/pkg/sqlite/repository.go +++ b/pkg/sqlite/repository.go @@ -96,7 +96,7 @@ func (r *repository) runIdsQuery(ctx context.Context, query string, args []inter } func (r *repository) queryFunc(ctx context.Context, query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error { - rows, err := dbWrapper.Queryx(ctx, query, args...) + rows, err := dbWrapper.QueryxContext(ctx, query, args...) if err != nil && !errors.Is(err, sql.ErrNoRows) { return err diff --git a/pkg/sqlite/tx.go b/pkg/sqlite/tx.go index a2e272aa9..b6701dc81 100644 --- a/pkg/sqlite/tx.go +++ b/pkg/sqlite/tx.go @@ -16,8 +16,8 @@ const ( type dbReader interface { Get(dest interface{}, query string, args ...interface{}) error - Select(dest interface{}, query string, args ...interface{}) error - Queryx(query string, args ...interface{}) (*sqlx.Rows, error) + GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error + SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) } @@ -54,7 +54,7 @@ func (*dbWrapperType) Get(ctx context.Context, dest interface{}, query string, a } start := time.Now() - err = tx.Get(dest, query, args...) + err = tx.GetContext(ctx, dest, query, args...) logSQL(start, query, args...) return sqlError(err, query, args...) @@ -67,7 +67,7 @@ func (*dbWrapperType) Select(ctx context.Context, dest interface{}, query string } start := time.Now() - err = tx.Select(dest, query, args...) + err = tx.SelectContext(ctx, dest, query, args...) logSQL(start, query, args...) return sqlError(err, query, args...) @@ -80,23 +80,14 @@ func (*dbWrapperType) Queryx(ctx context.Context, query string, args ...interfac } start := time.Now() - ret, err := tx.Queryx(query, args...) + ret, err := tx.QueryxContext(ctx, query, args...) logSQL(start, query, args...) return ret, sqlError(err, query, args...) } func (*dbWrapperType) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { - tx, err := getDBReader(ctx) - if err != nil { - return nil, sqlError(err, query, args...) - } - - start := time.Now() - ret, err := tx.QueryxContext(ctx, query, args...) - logSQL(start, query, args...) - - return ret, sqlError(err, query, args...) + return dbWrapper.Queryx(ctx, query, args...) } func (*dbWrapperType) NamedExec(ctx context.Context, query string, arg interface{}) (sql.Result, error) { @@ -106,7 +97,7 @@ func (*dbWrapperType) NamedExec(ctx context.Context, query string, arg interface } start := time.Now() - ret, err := tx.NamedExec(query, arg) + ret, err := tx.NamedExecContext(ctx, query, arg) logSQL(start, query, arg) return ret, sqlError(err, query, arg) @@ -119,7 +110,7 @@ func (*dbWrapperType) Exec(ctx context.Context, query string, args ...interface{ } start := time.Now() - ret, err := tx.Exec(query, args...) + ret, err := tx.ExecContext(ctx, query, args...) logSQL(start, query, args...) return ret, sqlError(err, query, args...)