Unify query construction logic

Unify query creation logic from
- queryparse.py:construct_query_part,
- Model.field_query,
- DefaultTemplateFunctions._tmpl_unique

to a single implementation under `LibModel.field_query` class method.
This method should be used for query resolution for model fields.
This commit is contained in:
Šarūnas Nejus 2024-05-08 11:36:57 +01:00
parent f4097410eb
commit 69faa58bab
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
6 changed files with 52 additions and 62 deletions

View file

@ -35,8 +35,6 @@ import beets
from ..util import cached_classproperty, functemplate
from . import types
from .query import (
AndQuery,
FieldQuery,
FieldQueryType,
FieldSort,
MatchQuery,
@ -718,33 +716,6 @@ class Model(ABC, Generic[D]):
"""Set the object's key to a value represented by a string."""
self[key] = self._parse(key, string)
# Convenient queries.
@classmethod
def field_query(
cls,
field,
pattern,
query_cls: FieldQueryType = MatchQuery,
) -> FieldQuery:
"""Get a `FieldQuery` for this model."""
return query_cls(field, pattern, field in cls._fields)
@classmethod
def all_fields_query(
cls: type[Model],
pats: Mapping[str, str],
query_cls: FieldQueryType = MatchQuery,
):
"""Get a query that matches many fields with different patterns.
`pats` should be a mapping from field names to patterns. The
resulting query is a conjunction ("and") of per-field queries
for all of these field/pattern pairs.
"""
subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()]
return AndQuery(subqueries)
# Database controller and supporting interfaces.

View file

@ -97,6 +97,9 @@ class Query(ABC):
"""
...
def __and__(self, other: Query) -> AndQuery:
return AndQuery([self, other])
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"

View file

@ -20,15 +20,17 @@ import itertools
import re
from typing import TYPE_CHECKING
from . import Model, query
from . import query
if TYPE_CHECKING:
from collections.abc import Collection, Sequence
from ..library import LibModel
from .query import FieldQueryType, Sort
Prefixes = dict[str, FieldQueryType]
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
r"(-|\^)?" # Negation prefixes.
@ -112,7 +114,7 @@ def parse_query_part(
def construct_query_part(
model_cls: type[Model],
model_cls: type[LibModel],
prefixes: Prefixes,
query_part: str,
) -> query.Query:
@ -160,15 +162,7 @@ def construct_query_part(
# Field queries get constructed according to the name of the field
# they are querying.
else:
field = table = key.lower()
if field in model_cls.shared_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to query it in a join.
# Using an explicit table name resolves this.
table = f"{model_cls._table}.{field}"
field_in_db = field in model_cls.all_db_fields
out_query = query_class(table, pattern, field_in_db)
out_query = model_cls.field_query(key.lower(), pattern, query_class)
# Apply negation.
if negate:
@ -180,7 +174,7 @@ def construct_query_part(
# TYPING ERROR
def query_from_strings(
query_cls: type[query.CollectionQuery],
model_cls: type[Model],
model_cls: type[LibModel],
prefixes: Prefixes,
query_parts: Collection[str],
) -> query.Query:
@ -197,7 +191,7 @@ def query_from_strings(
def construct_sort_part(
model_cls: type[Model],
model_cls: type[LibModel],
part: str,
case_insensitive: bool = True,
) -> Sort:
@ -228,7 +222,7 @@ def construct_sort_part(
def sort_from_strings(
model_cls: type[Model],
model_cls: type[LibModel],
sort_parts: Sequence[str],
case_insensitive: bool = True,
) -> Sort:
@ -247,7 +241,7 @@ def sort_from_strings(
def parse_sorted_query(
model_cls: type[Model],
model_cls: type[LibModel],
parts: list[str],
prefixes: Prefixes = {},
case_insensitive: bool = True,

View file

@ -707,9 +707,7 @@ class ImportTask(BaseImportTask):
# use a temporary Album object to generate any computed fields.
tmp_album = library.Album(lib, **info)
keys = config["import"]["duplicate_keys"]["album"].as_str_seq()
dup_query = library.Album.all_fields_query(
{key: tmp_album.get(key) for key in keys}
)
dup_query = tmp_album.duplicates_query(keys)
# Don't count albums with the same files as duplicates.
task_paths = {i.path for i in self.items if i}
@ -1025,9 +1023,7 @@ class SingletonImportTask(ImportTask):
# temporary `Item` object to generate any computed fields.
tmp_item = library.Item(lib, **info)
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
dup_query = library.Album.all_fields_query(
{key: tmp_item.get(key) for key in keys}
)
dup_query = tmp_item.duplicates_query(keys)
found_items = []
for other_item in lib.items(dup_query):

View file

@ -25,6 +25,7 @@ import time
import unicodedata
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING
import platformdirs
from mediafile import MediaFile, UnreadableFileError
@ -42,6 +43,9 @@ from beets.util import (
)
from beets.util.functemplate import Template, template
if TYPE_CHECKING:
from .dbcore.query import FieldQuery, FieldQueryType
# To use the SQLite "blob" type, it doesn't suffice to provide a byte
# string; SQLite treats that as encoded text. Wrapping it in a
# `memoryview` tells it that we actually mean non-text data.
@ -375,6 +379,31 @@ class LibModel(dbcore.Model["Library"]):
def __bytes__(self):
return self.__str__().encode("utf-8")
# Convenient queries.
@classmethod
def field_query(
cls, field: str, pattern: str, query_cls: FieldQueryType
) -> FieldQuery:
"""Get a `FieldQuery` for the given field on this model."""
fast = field in cls.all_db_fields
if field in cls.shared_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to use it in a query.
# Using an explicit table name resolves this.
field = f"{cls._table}.{field}"
return query_cls(field, pattern, fast)
def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery:
"""Return a query for entities with same values in the given fields."""
return dbcore.AndQuery(
[
self.field_query(f, self.get(f), dbcore.MatchQuery)
for f in fields
]
)
class FormattedItemMapping(dbcore.db.FormattedMapping):
"""Add lookup for album-level fields.
@ -648,6 +677,12 @@ class Item(LibModel):
getters["filesize"] = Item.try_filesize # In bytes.
return getters
def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery:
"""Return a query for entities with same values in the given fields."""
return super().duplicates_query(fields) & dbcore.query.NoneQuery(
"album_id"
)
@classmethod
def from_path(cls, path):
"""Create a new item from the media file at the specified path."""
@ -1866,7 +1901,6 @@ class DefaultTemplateFunctions:
Item.all_keys(),
# Do nothing for non singletons.
lambda i: i.album_id is not None,
initial_subqueries=[dbcore.query.NoneQuery("album_id", True)],
)
def _tmpl_unique_memokey(self, name, keys, disam, item_id):
@ -1885,7 +1919,6 @@ class DefaultTemplateFunctions:
db_item,
item_keys,
skip_item,
initial_subqueries=None,
):
"""Generate a string that is guaranteed to be unique among all items of
the same type as "db_item" who share the same set of keys.
@ -1932,15 +1965,7 @@ class DefaultTemplateFunctions:
bracket_r = ""
# Find matching items to disambiguate with.
subqueries = []
if initial_subqueries is not None:
subqueries.extend(initial_subqueries)
for key in keys:
value = db_item.get(key, "")
# Use slow queries for flexible attributes.
fast = key in item_keys
subqueries.append(dbcore.MatchQuery(key, value, fast))
query = dbcore.AndQuery(subqueries)
query = db_item.duplicates_query(keys)
ambigous_items = (
self.lib.items(query)
if isinstance(db_item, Item)

View file

@ -23,6 +23,7 @@ from tempfile import mkstemp
import pytest
from beets import dbcore
from beets.library import LibModel
from beets.test import _common
# Fixture: concrete database and model classes. For migration tests, we
@ -44,7 +45,7 @@ class QueryFixture(dbcore.query.FieldQuery):
return True
class ModelFixture1(dbcore.Model):
class ModelFixture1(LibModel):
_table = "test"
_flex_table = "testflex"
_fields = {