From b4879ef758f4cf4b14325e6d434b828e049b5840 Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Fri, 23 Jun 2023 11:04:54 +1000 Subject: [PATCH] Fix marker tag filtering (#3846) --- pkg/sqlite/scene_marker.go | 69 ++++++++++---- pkg/sqlite/scene_marker_test.go | 154 +++++++++++++++++++++++--------- pkg/sqlite/setup_test.go | 18 ++-- 3 files changed, 174 insertions(+), 67 deletions(-) diff --git a/pkg/sqlite/scene_marker.go b/pkg/sqlite/scene_marker.go index 5ead4867e..ab0be7117 100644 --- a/pkg/sqlite/scene_marker.go +++ b/pkg/sqlite/scene_marker.go @@ -317,9 +317,11 @@ func sceneMarkerTagIDCriterionHandler(qb *SceneMarkerStore, tagID *string) crite } } -func sceneMarkerTagsCriterionHandler(qb *SceneMarkerStore, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { +func sceneMarkerTagsCriterionHandler(qb *SceneMarkerStore, criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(ctx context.Context, f *filterBuilder) { - if tags != nil { + if criterion != nil { + tags := criterion.CombineExcludes() + if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull { var notClause string if tags.Modifier == models.CriterionModifierNotNull { @@ -332,26 +334,59 @@ func sceneMarkerTagsCriterionHandler(qb *SceneMarkerStore, tags *models.Hierarch return } - if len(tags.Value) == 0 { - return - } - valuesClause, err := getHierarchicalValues(ctx, qb.tx, tags.Value, tagTable, "tags_relations", "parent_id", "child_id", tags.Depth) - if err != nil { - f.setError(err) + if tags.Modifier == models.CriterionModifierEquals && tags.Depth != nil && *tags.Depth != 0 { + f.setError(fmt.Errorf("depth is not supported for equals modifier for marker tag filtering")) return } - f.addWith(`marker_tags AS ( -SELECT mt.scene_marker_id, t.column1 AS root_tag_id FROM scene_markers_tags mt -INNER JOIN (` + valuesClause + `) t ON t.column2 = mt.tag_id -UNION -SELECT m.id, t.column1 FROM scene_markers m -INNER JOIN (` + valuesClause + `) t ON t.column2 = m.primary_tag_id -)`) + if len(tags.Value) == 0 && len(tags.Excludes) == 0 { + return + } - f.addLeftJoin("marker_tags", "", "marker_tags.scene_marker_id = scene_markers.id") + if len(tags.Value) > 0 { + valuesClause, err := getHierarchicalValues(ctx, qb.tx, tags.Value, tagTable, "tags_relations", "parent_id", "child_id", tags.Depth) + if err != nil { + f.setError(err) + return + } - addHierarchicalConditionClauses(f, *tags, "marker_tags", "root_tag_id") + f.addWith(`marker_tags AS ( + SELECT mt.scene_marker_id, t.column1 AS root_tag_id FROM scene_markers_tags mt + INNER JOIN (` + valuesClause + `) t ON t.column2 = mt.tag_id + UNION + SELECT m.id, t.column1 FROM scene_markers m + INNER JOIN (` + valuesClause + `) t ON t.column2 = m.primary_tag_id + )`) + + f.addLeftJoin("marker_tags", "", "marker_tags.scene_marker_id = scene_markers.id") + + switch tags.Modifier { + case models.CriterionModifierEquals: + // includes only the provided ids + f.addWhere("marker_tags.root_tag_id IS NOT NULL") + tagsLen := len(tags.Value) + f.addHaving(fmt.Sprintf("count(distinct marker_tags.root_tag_id) IS %d", tagsLen)) + // decrement by one to account for primary tag id + f.addWhere("(SELECT COUNT(*) FROM scene_markers_tags s WHERE s.scene_marker_id = scene_markers.id) = ?", tagsLen-1) + case models.CriterionModifierNotEquals: + f.setError(fmt.Errorf("not equals modifier is not supported for scene marker tags")) + default: + addHierarchicalConditionClauses(f, tags, "marker_tags", "root_tag_id") + } + } + + if len(criterion.Excludes) > 0 { + valuesClause, err := getHierarchicalValues(ctx, dbWrapper{}, tags.Excludes, tagTable, "tags_relations", "parent_id", "child_id", tags.Depth) + if err != nil { + f.setError(err) + return + } + + clause := "scene_markers.id NOT IN (SELECT scene_markers_tags.scene_marker_id FROM scene_markers_tags WHERE scene_markers_tags.tag_id IN (SELECT column2 FROM (%s)))" + f.addWhere(fmt.Sprintf(clause, valuesClause)) + + f.addWhere(fmt.Sprintf("scene_markers.primary_tag_id NOT IN (SELECT column2 FROM (%s))", valuesClause)) + } } } } diff --git a/pkg/sqlite/scene_marker_test.go b/pkg/sqlite/scene_marker_test.go index 723f26f0e..0dd8e249f 100644 --- a/pkg/sqlite/scene_marker_test.go +++ b/pkg/sqlite/scene_marker_test.go @@ -52,7 +52,7 @@ func TestMarkerCountByTagID(t *testing.T) { t.Errorf("error calling CountByTagID: %s", err.Error()) } - assert.Equal(t, 4, markerCount) + assert.Equal(t, 6, markerCount) markerCount, err = mqb.CountByTagID(ctx, tagIDs[tagIdxWithMarkers]) @@ -60,7 +60,7 @@ func TestMarkerCountByTagID(t *testing.T) { t.Errorf("error calling CountByTagID: %s", err.Error()) } - assert.Equal(t, 1, markerCount) + assert.Equal(t, 2, markerCount) markerCount, err = mqb.CountByTagID(ctx, 0) @@ -89,6 +89,40 @@ func TestMarkerQuerySortBySceneUpdated(t *testing.T) { }) } +func verifyIDs(t *testing.T, modifier models.CriterionModifier, values []int, results []int) { + t.Helper() + switch modifier { + case models.CriterionModifierIsNull: + assert.Len(t, results, 0) + case models.CriterionModifierNotNull: + assert.NotEqual(t, 0, len(results)) + case models.CriterionModifierIncludes: + for _, v := range values { + assert.Contains(t, results, v) + } + case models.CriterionModifierExcludes: + for _, v := range values { + assert.NotContains(t, results, v) + } + case models.CriterionModifierEquals: + for _, v := range values { + assert.Contains(t, results, v) + } + assert.Len(t, results, len(values)) + case models.CriterionModifierNotEquals: + foundAll := true + for _, v := range values { + if !intslice.IntInclude(results, v) { + foundAll = false + break + } + } + if foundAll && len(results) == len(values) { + t.Errorf("expected ids not equal to %v - found %v", values, results) + } + } +} + func TestMarkerQueryTags(t *testing.T) { type test struct { name string @@ -97,17 +131,19 @@ func TestMarkerQueryTags(t *testing.T) { } withTxn(func(ctx context.Context) error { - testTags := func(m *models.SceneMarker, markerFilter *models.SceneMarkerFilterType) { + testTags := func(t *testing.T, m *models.SceneMarker, markerFilter *models.SceneMarkerFilterType) { tagIDs, err := db.SceneMarker.GetTagIDs(ctx, m.ID) if err != nil { t.Errorf("error getting marker tag ids: %v", err) } - if markerFilter.Tags.Modifier == models.CriterionModifierIsNull && len(tagIDs) > 0 { - t.Errorf("expected marker %d to have no tags - found %d", m.ID, len(tagIDs)) - } - if markerFilter.Tags.Modifier == models.CriterionModifierNotNull && len(tagIDs) == 0 { - t.Errorf("expected marker %d to have tags - found 0", m.ID) + + // HACK - if modifier isn't null/not null, then add the primary tag id + if markerFilter.Tags.Modifier != models.CriterionModifierIsNull && markerFilter.Tags.Modifier != models.CriterionModifierNotNull { + tagIDs = append(tagIDs, m.PrimaryTagID) } + + values, _ := stringslice.StringSliceToIntSlice(markerFilter.Tags.Value) + verifyIDs(t, markerFilter.Tags.Modifier, values, tagIDs) } cases := []test{ @@ -129,6 +165,71 @@ func TestMarkerQueryTags(t *testing.T) { }, nil, }, + { + "includes", + &models.SceneMarkerFilterType{ + Tags: &models.HierarchicalMultiCriterionInput{ + Modifier: models.CriterionModifierIncludes, + Value: []string{ + strconv.Itoa(tagIDs[tagIdxWithMarkers]), + }, + }, + }, + nil, + }, + { + "includes all", + &models.SceneMarkerFilterType{ + Tags: &models.HierarchicalMultiCriterionInput{ + Modifier: models.CriterionModifierIncludesAll, + Value: []string{ + strconv.Itoa(tagIDs[tagIdxWithMarkers]), + strconv.Itoa(tagIDs[tagIdx2WithMarkers]), + }, + }, + }, + nil, + }, + { + "equals", + &models.SceneMarkerFilterType{ + Tags: &models.HierarchicalMultiCriterionInput{ + Modifier: models.CriterionModifierEquals, + Value: []string{ + strconv.Itoa(tagIDs[tagIdxWithPrimaryMarkers]), + strconv.Itoa(tagIDs[tagIdxWithMarkers]), + strconv.Itoa(tagIDs[tagIdx2WithMarkers]), + }, + }, + }, + nil, + }, + // not equals not supported + // { + // "not equals", + // &models.SceneMarkerFilterType{ + // Tags: &models.HierarchicalMultiCriterionInput{ + // Modifier: models.CriterionModifierNotEquals, + // Value: []string{ + // strconv.Itoa(tagIDs[tagIdx2WithScene]), + // strconv.Itoa(tagIDs[tagIdx3WithScene]), + // }, + // }, + // }, + // nil, + // }, + { + "excludes", + &models.SceneMarkerFilterType{ + Tags: &models.HierarchicalMultiCriterionInput{ + Modifier: models.CriterionModifierIncludes, + Value: []string{ + strconv.Itoa(tagIDs[tagIdx2WithMarkers]), + }, + }, + }, + nil, + }, } for _, tc := range cases { @@ -136,7 +237,7 @@ func TestMarkerQueryTags(t *testing.T) { markers := queryMarkers(ctx, t, db.SceneMarker, tc.markerFilter, tc.findFilter) assert.Greater(t, len(markers), 0) for _, m := range markers { - testTags(m, tc.markerFilter) + testTags(t, m, tc.markerFilter) } }) } @@ -167,40 +268,7 @@ func TestMarkerQuerySceneTags(t *testing.T) { tagIDs := s.TagIDs.List() values, _ := stringslice.StringSliceToIntSlice(markerFilter.SceneTags.Value) - switch markerFilter.SceneTags.Modifier { - case models.CriterionModifierIsNull: - if len(tagIDs) > 0 { - t.Errorf("expected marker %d to have no scene tags - found %d", m.ID, len(tagIDs)) - } - case models.CriterionModifierNotNull: - if len(tagIDs) == 0 { - t.Errorf("expected marker %d to have scene tags - found 0", m.ID) - } - case models.CriterionModifierIncludes: - for _, v := range values { - assert.Contains(t, tagIDs, v) - } - case models.CriterionModifierExcludes: - for _, v := range values { - assert.NotContains(t, tagIDs, v) - } - case models.CriterionModifierEquals: - for _, v := range values { - assert.Contains(t, tagIDs, v) - } - assert.Len(t, tagIDs, len(values)) - case models.CriterionModifierNotEquals: - foundAll := true - for _, v := range values { - if !intslice.IntInclude(tagIDs, v) { - foundAll = false - break - } - } - if foundAll && len(tagIDs) == len(values) { - t.Errorf("expected marker %d to have scene tags not equal to %v - found %v", m.ID, values, tagIDs) - } - } + verifyIDs(t, markerFilter.SceneTags.Modifier, values, tagIDs) } cases := []test{ diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index fa7ebfdca..e5b56efad 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -17,6 +17,7 @@ import ( "github.com/stashapp/stash/pkg/file" "github.com/stashapp/stash/pkg/hash/md5" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sliceutil/intslice" "github.com/stashapp/stash/pkg/sqlite" "github.com/stashapp/stash/pkg/txn" @@ -213,6 +214,7 @@ const ( tagIdxWithGrandChild tagIdxWithParentAndChild tagIdxWithGrandParent + tagIdx2WithMarkers // new indexes above // tags with dup names start from the end tagIdx1WithDupName @@ -400,6 +402,8 @@ var ( markerSpecs = []markerSpec{ {sceneIdxWithMarkers, tagIdxWithPrimaryMarkers, nil}, {sceneIdxWithMarkers, tagIdxWithPrimaryMarkers, []int{tagIdxWithMarkers}}, + {sceneIdxWithMarkers, tagIdxWithPrimaryMarkers, []int{tagIdx2WithMarkers}}, + {sceneIdxWithMarkers, tagIdxWithPrimaryMarkers, []int{tagIdxWithMarkers, tagIdx2WithMarkers}}, {sceneIdxWithMarkerAndTag, tagIdxWithPrimaryMarkers, nil}, {sceneIdxWithMarkerTwoTags, tagIdxWithPrimaryMarkers, nil}, } @@ -1477,15 +1481,15 @@ func getTagSceneCount(id int) int { } func getTagMarkerCount(id int) int { - if id == tagIDs[tagIdxWithPrimaryMarkers] { - return 3 + count := 0 + idx := indexFromID(tagIDs, id) + for _, s := range markerSpecs { + if s.primaryTagIdx == idx || intslice.IntInclude(s.tagIdxs, idx) { + count++ + } } - if id == tagIDs[tagIdxWithMarkers] { - return 1 - } - - return 0 + return count } func getTagImageCount(id int) int {