diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 64e77f814..2aa0081d7 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -35,8 +35,6 @@ import beets from ..util import cached_classproperty, functemplate from . import types from .query import ( - AndQuery, - FieldQuery, FieldQueryType, FieldSort, MatchQuery, @@ -718,33 +716,6 @@ class Model(ABC, Generic[D]): """Set the object's key to a value represented by a string.""" self[key] = self._parse(key, string) - # Convenient queries. - - @classmethod - def field_query( - cls, - field, - pattern, - query_cls: FieldQueryType = MatchQuery, - ) -> FieldQuery: - """Get a `FieldQuery` for this model.""" - return query_cls(field, pattern, field in cls._fields) - - @classmethod - def all_fields_query( - cls: type[Model], - pats: Mapping[str, str], - query_cls: FieldQueryType = MatchQuery, - ): - """Get a query that matches many fields with different patterns. - - `pats` should be a mapping from field names to patterns. The - resulting query is a conjunction ("and") of per-field queries - for all of these field/pattern pairs. - """ - subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()] - return AndQuery(subqueries) - # Database controller and supporting interfaces. diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 866162c4a..1ff56f101 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -97,6 +97,9 @@ class Query(ABC): """ ... + def __and__(self, other: Query) -> AndQuery: + return AndQuery([self, other]) + def __repr__(self) -> str: return f"{self.__class__.__name__}()" diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index 289632668..c4ad9c4c9 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -20,15 +20,17 @@ import itertools import re from typing import TYPE_CHECKING -from . import Model, query +from . import query if TYPE_CHECKING: from collections.abc import Collection, Sequence + from ..library import LibModel from .query import FieldQueryType, Sort Prefixes = dict[str, FieldQueryType] + PARSE_QUERY_PART_REGEX = re.compile( # Non-capturing optional segment for the keyword. r"(-|\^)?" # Negation prefixes. @@ -112,7 +114,7 @@ def parse_query_part( def construct_query_part( - model_cls: type[Model], + model_cls: type[LibModel], prefixes: Prefixes, query_part: str, ) -> query.Query: @@ -160,15 +162,7 @@ def construct_query_part( # Field queries get constructed according to the name of the field # they are querying. else: - field = table = key.lower() - if field in model_cls.shared_db_fields: - # This field exists in both tables, so SQLite will encounter - # an OperationalError if we try to query it in a join. - # Using an explicit table name resolves this. - table = f"{model_cls._table}.{field}" - - field_in_db = field in model_cls.all_db_fields - out_query = query_class(table, pattern, field_in_db) + out_query = model_cls.field_query(key.lower(), pattern, query_class) # Apply negation. if negate: @@ -180,7 +174,7 @@ def construct_query_part( # TYPING ERROR def query_from_strings( query_cls: type[query.CollectionQuery], - model_cls: type[Model], + model_cls: type[LibModel], prefixes: Prefixes, query_parts: Collection[str], ) -> query.Query: @@ -197,7 +191,7 @@ def query_from_strings( def construct_sort_part( - model_cls: type[Model], + model_cls: type[LibModel], part: str, case_insensitive: bool = True, ) -> Sort: @@ -228,7 +222,7 @@ def construct_sort_part( def sort_from_strings( - model_cls: type[Model], + model_cls: type[LibModel], sort_parts: Sequence[str], case_insensitive: bool = True, ) -> Sort: @@ -247,7 +241,7 @@ def sort_from_strings( def parse_sorted_query( - model_cls: type[Model], + model_cls: type[LibModel], parts: list[str], prefixes: Prefixes = {}, case_insensitive: bool = True, diff --git a/beets/importer.py b/beets/importer.py index ab2382c9f..b30e6399b 100644 --- a/beets/importer.py +++ b/beets/importer.py @@ -707,9 +707,7 @@ class ImportTask(BaseImportTask): # use a temporary Album object to generate any computed fields. tmp_album = library.Album(lib, **info) keys = config["import"]["duplicate_keys"]["album"].as_str_seq() - dup_query = library.Album.all_fields_query( - {key: tmp_album.get(key) for key in keys} - ) + dup_query = tmp_album.duplicates_query(keys) # Don't count albums with the same files as duplicates. task_paths = {i.path for i in self.items if i} @@ -1025,9 +1023,7 @@ class SingletonImportTask(ImportTask): # temporary `Item` object to generate any computed fields. tmp_item = library.Item(lib, **info) keys = config["import"]["duplicate_keys"]["item"].as_str_seq() - dup_query = library.Album.all_fields_query( - {key: tmp_item.get(key) for key in keys} - ) + dup_query = tmp_item.duplicates_query(keys) found_items = [] for other_item in lib.items(dup_query): diff --git a/beets/library.py b/beets/library.py index 2430f7125..46c1e43aa 100644 --- a/beets/library.py +++ b/beets/library.py @@ -25,6 +25,7 @@ import time import unicodedata from functools import cached_property from pathlib import Path +from typing import TYPE_CHECKING import platformdirs from mediafile import MediaFile, UnreadableFileError @@ -42,6 +43,9 @@ from beets.util import ( ) from beets.util.functemplate import Template, template +if TYPE_CHECKING: + from .dbcore.query import FieldQuery, FieldQueryType + # To use the SQLite "blob" type, it doesn't suffice to provide a byte # string; SQLite treats that as encoded text. Wrapping it in a # `memoryview` tells it that we actually mean non-text data. @@ -375,6 +379,31 @@ class LibModel(dbcore.Model["Library"]): def __bytes__(self): return self.__str__().encode("utf-8") + # Convenient queries. + + @classmethod + def field_query( + cls, field: str, pattern: str, query_cls: FieldQueryType + ) -> FieldQuery: + """Get a `FieldQuery` for the given field on this model.""" + fast = field in cls.all_db_fields + if field in cls.shared_db_fields: + # This field exists in both tables, so SQLite will encounter + # an OperationalError if we try to use it in a query. + # Using an explicit table name resolves this. + field = f"{cls._table}.{field}" + + return query_cls(field, pattern, fast) + + def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery: + """Return a query for entities with same values in the given fields.""" + return dbcore.AndQuery( + [ + self.field_query(f, self.get(f), dbcore.MatchQuery) + for f in fields + ] + ) + class FormattedItemMapping(dbcore.db.FormattedMapping): """Add lookup for album-level fields. @@ -648,6 +677,12 @@ class Item(LibModel): getters["filesize"] = Item.try_filesize # In bytes. return getters + def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery: + """Return a query for entities with same values in the given fields.""" + return super().duplicates_query(fields) & dbcore.query.NoneQuery( + "album_id" + ) + @classmethod def from_path(cls, path): """Create a new item from the media file at the specified path.""" @@ -1866,7 +1901,6 @@ class DefaultTemplateFunctions: Item.all_keys(), # Do nothing for non singletons. lambda i: i.album_id is not None, - initial_subqueries=[dbcore.query.NoneQuery("album_id", True)], ) def _tmpl_unique_memokey(self, name, keys, disam, item_id): @@ -1885,7 +1919,6 @@ class DefaultTemplateFunctions: db_item, item_keys, skip_item, - initial_subqueries=None, ): """Generate a string that is guaranteed to be unique among all items of the same type as "db_item" who share the same set of keys. @@ -1932,15 +1965,7 @@ class DefaultTemplateFunctions: bracket_r = "" # Find matching items to disambiguate with. - subqueries = [] - if initial_subqueries is not None: - subqueries.extend(initial_subqueries) - for key in keys: - value = db_item.get(key, "") - # Use slow queries for flexible attributes. - fast = key in item_keys - subqueries.append(dbcore.MatchQuery(key, value, fast)) - query = dbcore.AndQuery(subqueries) + query = db_item.duplicates_query(keys) ambigous_items = ( self.lib.items(query) if isinstance(db_item, Item) diff --git a/test/test_dbcore.py b/test/test_dbcore.py index ba2b84ad2..03fe02d19 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -23,6 +23,7 @@ from tempfile import mkstemp import pytest from beets import dbcore +from beets.library import LibModel from beets.test import _common # Fixture: concrete database and model classes. For migration tests, we @@ -44,7 +45,7 @@ class QueryFixture(dbcore.query.FieldQuery): return True -class ModelFixture1(dbcore.Model): +class ModelFixture1(LibModel): _table = "test" _flex_table = "testflex" _fields = {