Fix: Custom Field Filtering (#6614)

* add tests
* Refactor queryBuilder: split args into per-clause fields
This commit is contained in:
Gykes 2026-02-27 16:05:13 -08:00 committed by GitHub
parent c7e1c3da69
commit c874bd560e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 89 additions and 18 deletions

View file

@ -1129,7 +1129,7 @@ func (h *relatedFilterHandler) handle(ctx context.Context, f *filterBuilder) {
return
}
f.addWhere(fmt.Sprintf("%s IN ("+subQuery.toSQL(false)+")", h.relatedIDCol), subQuery.args...)
f.addWhere(fmt.Sprintf("%s IN ("+subQuery.toSQL(false)+")", h.relatedIDCol), subQuery.allArgs()...)
}
type phashDistanceCriterionHandler struct {

View file

@ -975,7 +975,7 @@ func (qb *FileStore) queryGroupedFields(ctx context.Context, options models.File
Megapixels float64
Size int64
}{}
if err := qb.repository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil {
if err := qb.repository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.allArgs(), &out); err != nil {
return nil, err
}

View file

@ -600,7 +600,7 @@ func (qb *FolderStore) queryGroupedFields(ctx context.Context, options models.Fo
Megapixels float64
Size int64
}{}
if err := qb.repository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil {
if err := qb.repository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.allArgs(), &out); err != nil {
return nil, err
}

View file

@ -926,7 +926,7 @@ func (qb *ImageStore) queryGroupedFields(ctx context.Context, options models.Ima
Megapixels null.Float
Size null.Float
}{}
if err := imageRepository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil {
if err := imageRepository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.allArgs(), &out); err != nil {
return nil, err
}

View file

@ -17,13 +17,26 @@ type queryBuilder struct {
joins joins
whereClauses []string
havingClauses []string
args []interface{}
withClauses []string
recursiveWith bool
withArgs []interface{}
joinArgs []interface{}
whereArgs []interface{}
havingArgs []interface{}
sortAndPagination string
}
func (qb queryBuilder) allArgs() []interface{} {
var args []interface{}
args = append(args, qb.withArgs...)
args = append(args, qb.joinArgs...)
args = append(args, qb.whereArgs...)
args = append(args, qb.havingArgs...)
return args
}
func (qb queryBuilder) body(includeSortPagination bool) string {
return fmt.Sprintf("SELECT %s FROM %s%s", strings.Join(qb.columns, ", "), qb.from, qb.joins.toSQL(includeSortPagination))
}
@ -55,13 +68,13 @@ func (qb queryBuilder) toSQL(includeSortPagination bool) string {
func (qb queryBuilder) findIDs(ctx context.Context) ([]int, error) {
const includeSortPagination = true
sql := qb.toSQL(includeSortPagination)
return qb.repository.runIdsQuery(ctx, sql, qb.args)
return qb.repository.runIdsQuery(ctx, sql, qb.allArgs())
}
func (qb queryBuilder) executeFind(ctx context.Context) ([]int, int, error) {
const includeSortPagination = true
body := qb.body(includeSortPagination)
return qb.repository.executeFindQuery(ctx, body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses, qb.recursiveWith)
return qb.repository.executeFindQuery(ctx, body, qb.allArgs(), qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses, qb.recursiveWith)
}
func (qb queryBuilder) executeCount(ctx context.Context) (int, error) {
@ -79,7 +92,7 @@ func (qb queryBuilder) executeCount(ctx context.Context) (int, error) {
body = qb.repository.buildQueryBody(body, qb.whereClauses, qb.havingClauses)
countQuery := withClause + qb.repository.buildCountQuery(body)
return qb.repository.runCountQuery(ctx, countQuery, qb.args)
return qb.repository.runCountQuery(ctx, countQuery, qb.allArgs())
}
func (qb *queryBuilder) addWhere(clauses ...string) {
@ -109,7 +122,11 @@ func (qb *queryBuilder) addWith(recursive bool, clauses ...string) {
}
func (qb *queryBuilder) addArg(args ...interface{}) {
qb.args = append(qb.args, args...)
qb.whereArgs = append(qb.whereArgs, args...)
}
func (qb *queryBuilder) addHavingArg(args ...interface{}) {
qb.havingArgs = append(qb.havingArgs, args...)
}
func (qb *queryBuilder) hasJoin(alias string) bool {
@ -148,7 +165,7 @@ func (qb *queryBuilder) joinSort(table, as, onClause string) {
func (qb *queryBuilder) addJoins(joins ...join) {
for _, j := range joins {
if qb.joins.addUnique(j) {
qb.args = append(qb.args, j.args...)
qb.joinArgs = append(qb.joinArgs, j.args...)
}
}
}
@ -163,20 +180,16 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) error {
if len(clause) > 0 {
qb.addWith(f.recursiveWith, clause)
}
if len(args) > 0 {
// WITH clause always comes first and thus precedes alk args
qb.args = append(args, qb.args...)
qb.withArgs = append(qb.withArgs, args...)
}
// add joins here to insert args
qb.addJoins(f.getAllJoins()...)
clause, args = f.generateWhereClauses()
if len(clause) > 0 {
qb.addWhere(clause)
}
if len(args) > 0 {
qb.addArg(args...)
}
@ -185,9 +198,8 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) error {
if len(clause) > 0 {
qb.addHaving(clause)
}
if len(args) > 0 {
qb.addArg(args...)
qb.addHavingArg(args...)
}
return nil

View file

@ -1097,7 +1097,7 @@ func (qb *SceneStore) queryGroupedFields(ctx context.Context, options models.Sce
Duration null.Float
Size null.Float
}{}
if err := sceneRepository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil {
if err := sceneRepository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.allArgs(), &out); err != nil {
return nil, err
}

View file

@ -1889,6 +1889,65 @@ func TestTagQueryCustomFields(t *testing.T) {
}
})
}
// Test combining text search (findFilter.Q) with custom field filters.
// This verifies that positional args are bound in the correct order
// when JOINs (from custom fields) and WHERE (from text search) both
// have parameterized placeholders.
runWithRollbackTxn(t, "equals with text search", func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
tagName := getTagStringValue(tagIdxWithGallery, "Name")
q := tagName
findFilter := &models.FindFilterType{Q: &q}
tagFilter := &models.TagFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierEquals,
Value: []any{getTagStringValue(tagIdxWithGallery, "custom")},
},
},
}
tags, _, err := db.Tag.Query(ctx, tagFilter, findFilter)
if err != nil {
t.Errorf("TagStore.Query() error = %v", err)
return
}
ids := tagsToIDs(tags)
assert.Contains(ids, tagIDs[tagIdxWithGallery])
assert.Len(tags, 1)
})
runWithRollbackTxn(t, "is_null with text search", func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
tagName := getTagStringValue(tagIdxWithGallery, "Name")
q := tagName
findFilter := &models.FindFilterType{Q: &q}
tagFilter := &models.TagFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "not existing",
Modifier: models.CriterionModifierIsNull,
},
},
}
tags, _, err := db.Tag.Query(ctx, tagFilter, findFilter)
if err != nil {
t.Errorf("TagStore.Query() error = %v", err)
return
}
ids := tagsToIDs(tags)
assert.Contains(ids, tagIDs[tagIdxWithGallery])
assert.Len(tags, 1)
})
}
// TODO Destroy