mirror of
https://github.com/stashapp/stash.git
synced 2026-04-19 13:31:15 +02:00
Fix: Custom Field Filtering (#6614)
* add tests * Refactor queryBuilder: split args into per-clause fields
This commit is contained in:
parent
c7e1c3da69
commit
c874bd560e
7 changed files with 89 additions and 18 deletions
|
|
@ -1129,7 +1129,7 @@ func (h *relatedFilterHandler) handle(ctx context.Context, f *filterBuilder) {
|
|||
return
|
||||
}
|
||||
|
||||
f.addWhere(fmt.Sprintf("%s IN ("+subQuery.toSQL(false)+")", h.relatedIDCol), subQuery.args...)
|
||||
f.addWhere(fmt.Sprintf("%s IN ("+subQuery.toSQL(false)+")", h.relatedIDCol), subQuery.allArgs()...)
|
||||
}
|
||||
|
||||
type phashDistanceCriterionHandler struct {
|
||||
|
|
|
|||
|
|
@ -975,7 +975,7 @@ func (qb *FileStore) queryGroupedFields(ctx context.Context, options models.File
|
|||
Megapixels float64
|
||||
Size int64
|
||||
}{}
|
||||
if err := qb.repository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil {
|
||||
if err := qb.repository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.allArgs(), &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -600,7 +600,7 @@ func (qb *FolderStore) queryGroupedFields(ctx context.Context, options models.Fo
|
|||
Megapixels float64
|
||||
Size int64
|
||||
}{}
|
||||
if err := qb.repository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil {
|
||||
if err := qb.repository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.allArgs(), &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -926,7 +926,7 @@ func (qb *ImageStore) queryGroupedFields(ctx context.Context, options models.Ima
|
|||
Megapixels null.Float
|
||||
Size null.Float
|
||||
}{}
|
||||
if err := imageRepository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil {
|
||||
if err := imageRepository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.allArgs(), &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,13 +17,26 @@ type queryBuilder struct {
|
|||
joins joins
|
||||
whereClauses []string
|
||||
havingClauses []string
|
||||
args []interface{}
|
||||
withClauses []string
|
||||
recursiveWith bool
|
||||
|
||||
withArgs []interface{}
|
||||
joinArgs []interface{}
|
||||
whereArgs []interface{}
|
||||
havingArgs []interface{}
|
||||
|
||||
sortAndPagination string
|
||||
}
|
||||
|
||||
func (qb queryBuilder) allArgs() []interface{} {
|
||||
var args []interface{}
|
||||
args = append(args, qb.withArgs...)
|
||||
args = append(args, qb.joinArgs...)
|
||||
args = append(args, qb.whereArgs...)
|
||||
args = append(args, qb.havingArgs...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (qb queryBuilder) body(includeSortPagination bool) string {
|
||||
return fmt.Sprintf("SELECT %s FROM %s%s", strings.Join(qb.columns, ", "), qb.from, qb.joins.toSQL(includeSortPagination))
|
||||
}
|
||||
|
|
@ -55,13 +68,13 @@ func (qb queryBuilder) toSQL(includeSortPagination bool) string {
|
|||
func (qb queryBuilder) findIDs(ctx context.Context) ([]int, error) {
|
||||
const includeSortPagination = true
|
||||
sql := qb.toSQL(includeSortPagination)
|
||||
return qb.repository.runIdsQuery(ctx, sql, qb.args)
|
||||
return qb.repository.runIdsQuery(ctx, sql, qb.allArgs())
|
||||
}
|
||||
|
||||
func (qb queryBuilder) executeFind(ctx context.Context) ([]int, int, error) {
|
||||
const includeSortPagination = true
|
||||
body := qb.body(includeSortPagination)
|
||||
return qb.repository.executeFindQuery(ctx, body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses, qb.recursiveWith)
|
||||
return qb.repository.executeFindQuery(ctx, body, qb.allArgs(), qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses, qb.recursiveWith)
|
||||
}
|
||||
|
||||
func (qb queryBuilder) executeCount(ctx context.Context) (int, error) {
|
||||
|
|
@ -79,7 +92,7 @@ func (qb queryBuilder) executeCount(ctx context.Context) (int, error) {
|
|||
|
||||
body = qb.repository.buildQueryBody(body, qb.whereClauses, qb.havingClauses)
|
||||
countQuery := withClause + qb.repository.buildCountQuery(body)
|
||||
return qb.repository.runCountQuery(ctx, countQuery, qb.args)
|
||||
return qb.repository.runCountQuery(ctx, countQuery, qb.allArgs())
|
||||
}
|
||||
|
||||
func (qb *queryBuilder) addWhere(clauses ...string) {
|
||||
|
|
@ -109,7 +122,11 @@ func (qb *queryBuilder) addWith(recursive bool, clauses ...string) {
|
|||
}
|
||||
|
||||
func (qb *queryBuilder) addArg(args ...interface{}) {
|
||||
qb.args = append(qb.args, args...)
|
||||
qb.whereArgs = append(qb.whereArgs, args...)
|
||||
}
|
||||
|
||||
func (qb *queryBuilder) addHavingArg(args ...interface{}) {
|
||||
qb.havingArgs = append(qb.havingArgs, args...)
|
||||
}
|
||||
|
||||
func (qb *queryBuilder) hasJoin(alias string) bool {
|
||||
|
|
@ -148,7 +165,7 @@ func (qb *queryBuilder) joinSort(table, as, onClause string) {
|
|||
func (qb *queryBuilder) addJoins(joins ...join) {
|
||||
for _, j := range joins {
|
||||
if qb.joins.addUnique(j) {
|
||||
qb.args = append(qb.args, j.args...)
|
||||
qb.joinArgs = append(qb.joinArgs, j.args...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -163,20 +180,16 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) error {
|
|||
if len(clause) > 0 {
|
||||
qb.addWith(f.recursiveWith, clause)
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
// WITH clause always comes first and thus precedes alk args
|
||||
qb.args = append(args, qb.args...)
|
||||
qb.withArgs = append(qb.withArgs, args...)
|
||||
}
|
||||
|
||||
// add joins here to insert args
|
||||
qb.addJoins(f.getAllJoins()...)
|
||||
|
||||
clause, args = f.generateWhereClauses()
|
||||
if len(clause) > 0 {
|
||||
qb.addWhere(clause)
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
qb.addArg(args...)
|
||||
}
|
||||
|
|
@ -185,9 +198,8 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) error {
|
|||
if len(clause) > 0 {
|
||||
qb.addHaving(clause)
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
qb.addArg(args...)
|
||||
qb.addHavingArg(args...)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -1097,7 +1097,7 @@ func (qb *SceneStore) queryGroupedFields(ctx context.Context, options models.Sce
|
|||
Duration null.Float
|
||||
Size null.Float
|
||||
}{}
|
||||
if err := sceneRepository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil {
|
||||
if err := sceneRepository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.allArgs(), &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1889,6 +1889,65 @@ func TestTagQueryCustomFields(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test combining text search (findFilter.Q) with custom field filters.
|
||||
// This verifies that positional args are bound in the correct order
|
||||
// when JOINs (from custom fields) and WHERE (from text search) both
|
||||
// have parameterized placeholders.
|
||||
runWithRollbackTxn(t, "equals with text search", func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
tagName := getTagStringValue(tagIdxWithGallery, "Name")
|
||||
q := tagName
|
||||
findFilter := &models.FindFilterType{Q: &q}
|
||||
|
||||
tagFilter := &models.TagFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
Value: []any{getTagStringValue(tagIdxWithGallery, "custom")},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tags, _, err := db.Tag.Query(ctx, tagFilter, findFilter)
|
||||
if err != nil {
|
||||
t.Errorf("TagStore.Query() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ids := tagsToIDs(tags)
|
||||
assert.Contains(ids, tagIDs[tagIdxWithGallery])
|
||||
assert.Len(tags, 1)
|
||||
})
|
||||
|
||||
runWithRollbackTxn(t, "is_null with text search", func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
tagName := getTagStringValue(tagIdxWithGallery, "Name")
|
||||
q := tagName
|
||||
findFilter := &models.FindFilterType{Q: &q}
|
||||
|
||||
tagFilter := &models.TagFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "not existing",
|
||||
Modifier: models.CriterionModifierIsNull,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tags, _, err := db.Tag.Query(ctx, tagFilter, findFilter)
|
||||
if err != nil {
|
||||
t.Errorf("TagStore.Query() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ids := tagsToIDs(tags)
|
||||
assert.Contains(ids, tagIDs[tagIdxWithGallery])
|
||||
assert.Len(tags, 1)
|
||||
})
|
||||
}
|
||||
|
||||
// TODO Destroy
|
||||
|
|
|
|||
Loading…
Reference in a new issue