mirror of
https://github.com/beetbox/beets.git
synced 2026-01-09 17:33:51 +01:00
Unify query construction logic
Unify query creation logic from - queryparse.py:construct_query_part, - Model.field_query, - DefaultTemplateFunctions._tmpl_unique to a single implementation under `LibModel.field_query` class method. This method should be used for query resolution for model fields.
This commit is contained in:
parent
f4097410eb
commit
69faa58bab
6 changed files with 52 additions and 62 deletions
|
|
@ -35,8 +35,6 @@ import beets
|
|||
from ..util import cached_classproperty, functemplate
|
||||
from . import types
|
||||
from .query import (
|
||||
AndQuery,
|
||||
FieldQuery,
|
||||
FieldQueryType,
|
||||
FieldSort,
|
||||
MatchQuery,
|
||||
|
|
@ -718,33 +716,6 @@ class Model(ABC, Generic[D]):
|
|||
"""Set the object's key to a value represented by a string."""
|
||||
self[key] = self._parse(key, string)
|
||||
|
||||
# Convenient queries.
|
||||
|
||||
@classmethod
|
||||
def field_query(
|
||||
cls,
|
||||
field,
|
||||
pattern,
|
||||
query_cls: FieldQueryType = MatchQuery,
|
||||
) -> FieldQuery:
|
||||
"""Get a `FieldQuery` for this model."""
|
||||
return query_cls(field, pattern, field in cls._fields)
|
||||
|
||||
@classmethod
|
||||
def all_fields_query(
|
||||
cls: type[Model],
|
||||
pats: Mapping[str, str],
|
||||
query_cls: FieldQueryType = MatchQuery,
|
||||
):
|
||||
"""Get a query that matches many fields with different patterns.
|
||||
|
||||
`pats` should be a mapping from field names to patterns. The
|
||||
resulting query is a conjunction ("and") of per-field queries
|
||||
for all of these field/pattern pairs.
|
||||
"""
|
||||
subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()]
|
||||
return AndQuery(subqueries)
|
||||
|
||||
|
||||
# Database controller and supporting interfaces.
|
||||
|
||||
|
|
|
|||
|
|
@ -97,6 +97,9 @@ class Query(ABC):
|
|||
"""
|
||||
...
|
||||
|
||||
def __and__(self, other: Query) -> AndQuery:
|
||||
return AndQuery([self, other])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
|
|
|||
|
|
@ -20,15 +20,17 @@ import itertools
|
|||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import Model, query
|
||||
from . import query
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Collection, Sequence
|
||||
|
||||
from ..library import LibModel
|
||||
from .query import FieldQueryType, Sort
|
||||
|
||||
Prefixes = dict[str, FieldQueryType]
|
||||
|
||||
|
||||
PARSE_QUERY_PART_REGEX = re.compile(
|
||||
# Non-capturing optional segment for the keyword.
|
||||
r"(-|\^)?" # Negation prefixes.
|
||||
|
|
@ -112,7 +114,7 @@ def parse_query_part(
|
|||
|
||||
|
||||
def construct_query_part(
|
||||
model_cls: type[Model],
|
||||
model_cls: type[LibModel],
|
||||
prefixes: Prefixes,
|
||||
query_part: str,
|
||||
) -> query.Query:
|
||||
|
|
@ -160,15 +162,7 @@ def construct_query_part(
|
|||
# Field queries get constructed according to the name of the field
|
||||
# they are querying.
|
||||
else:
|
||||
field = table = key.lower()
|
||||
if field in model_cls.shared_db_fields:
|
||||
# This field exists in both tables, so SQLite will encounter
|
||||
# an OperationalError if we try to query it in a join.
|
||||
# Using an explicit table name resolves this.
|
||||
table = f"{model_cls._table}.{field}"
|
||||
|
||||
field_in_db = field in model_cls.all_db_fields
|
||||
out_query = query_class(table, pattern, field_in_db)
|
||||
out_query = model_cls.field_query(key.lower(), pattern, query_class)
|
||||
|
||||
# Apply negation.
|
||||
if negate:
|
||||
|
|
@ -180,7 +174,7 @@ def construct_query_part(
|
|||
# TYPING ERROR
|
||||
def query_from_strings(
|
||||
query_cls: type[query.CollectionQuery],
|
||||
model_cls: type[Model],
|
||||
model_cls: type[LibModel],
|
||||
prefixes: Prefixes,
|
||||
query_parts: Collection[str],
|
||||
) -> query.Query:
|
||||
|
|
@ -197,7 +191,7 @@ def query_from_strings(
|
|||
|
||||
|
||||
def construct_sort_part(
|
||||
model_cls: type[Model],
|
||||
model_cls: type[LibModel],
|
||||
part: str,
|
||||
case_insensitive: bool = True,
|
||||
) -> Sort:
|
||||
|
|
@ -228,7 +222,7 @@ def construct_sort_part(
|
|||
|
||||
|
||||
def sort_from_strings(
|
||||
model_cls: type[Model],
|
||||
model_cls: type[LibModel],
|
||||
sort_parts: Sequence[str],
|
||||
case_insensitive: bool = True,
|
||||
) -> Sort:
|
||||
|
|
@ -247,7 +241,7 @@ def sort_from_strings(
|
|||
|
||||
|
||||
def parse_sorted_query(
|
||||
model_cls: type[Model],
|
||||
model_cls: type[LibModel],
|
||||
parts: list[str],
|
||||
prefixes: Prefixes = {},
|
||||
case_insensitive: bool = True,
|
||||
|
|
|
|||
|
|
@ -707,9 +707,7 @@ class ImportTask(BaseImportTask):
|
|||
# use a temporary Album object to generate any computed fields.
|
||||
tmp_album = library.Album(lib, **info)
|
||||
keys = config["import"]["duplicate_keys"]["album"].as_str_seq()
|
||||
dup_query = library.Album.all_fields_query(
|
||||
{key: tmp_album.get(key) for key in keys}
|
||||
)
|
||||
dup_query = tmp_album.duplicates_query(keys)
|
||||
|
||||
# Don't count albums with the same files as duplicates.
|
||||
task_paths = {i.path for i in self.items if i}
|
||||
|
|
@ -1025,9 +1023,7 @@ class SingletonImportTask(ImportTask):
|
|||
# temporary `Item` object to generate any computed fields.
|
||||
tmp_item = library.Item(lib, **info)
|
||||
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
|
||||
dup_query = library.Album.all_fields_query(
|
||||
{key: tmp_item.get(key) for key in keys}
|
||||
)
|
||||
dup_query = tmp_item.duplicates_query(keys)
|
||||
|
||||
found_items = []
|
||||
for other_item in lib.items(dup_query):
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import time
|
|||
import unicodedata
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import platformdirs
|
||||
from mediafile import MediaFile, UnreadableFileError
|
||||
|
|
@ -42,6 +43,9 @@ from beets.util import (
|
|||
)
|
||||
from beets.util.functemplate import Template, template
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .dbcore.query import FieldQuery, FieldQueryType
|
||||
|
||||
# To use the SQLite "blob" type, it doesn't suffice to provide a byte
|
||||
# string; SQLite treats that as encoded text. Wrapping it in a
|
||||
# `memoryview` tells it that we actually mean non-text data.
|
||||
|
|
@ -375,6 +379,31 @@ class LibModel(dbcore.Model["Library"]):
|
|||
def __bytes__(self):
|
||||
return self.__str__().encode("utf-8")
|
||||
|
||||
# Convenient queries.
|
||||
|
||||
@classmethod
|
||||
def field_query(
|
||||
cls, field: str, pattern: str, query_cls: FieldQueryType
|
||||
) -> FieldQuery:
|
||||
"""Get a `FieldQuery` for the given field on this model."""
|
||||
fast = field in cls.all_db_fields
|
||||
if field in cls.shared_db_fields:
|
||||
# This field exists in both tables, so SQLite will encounter
|
||||
# an OperationalError if we try to use it in a query.
|
||||
# Using an explicit table name resolves this.
|
||||
field = f"{cls._table}.{field}"
|
||||
|
||||
return query_cls(field, pattern, fast)
|
||||
|
||||
def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery:
|
||||
"""Return a query for entities with same values in the given fields."""
|
||||
return dbcore.AndQuery(
|
||||
[
|
||||
self.field_query(f, self.get(f), dbcore.MatchQuery)
|
||||
for f in fields
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FormattedItemMapping(dbcore.db.FormattedMapping):
|
||||
"""Add lookup for album-level fields.
|
||||
|
|
@ -648,6 +677,12 @@ class Item(LibModel):
|
|||
getters["filesize"] = Item.try_filesize # In bytes.
|
||||
return getters
|
||||
|
||||
def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery:
|
||||
"""Return a query for entities with same values in the given fields."""
|
||||
return super().duplicates_query(fields) & dbcore.query.NoneQuery(
|
||||
"album_id"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, path):
|
||||
"""Create a new item from the media file at the specified path."""
|
||||
|
|
@ -1866,7 +1901,6 @@ class DefaultTemplateFunctions:
|
|||
Item.all_keys(),
|
||||
# Do nothing for non singletons.
|
||||
lambda i: i.album_id is not None,
|
||||
initial_subqueries=[dbcore.query.NoneQuery("album_id", True)],
|
||||
)
|
||||
|
||||
def _tmpl_unique_memokey(self, name, keys, disam, item_id):
|
||||
|
|
@ -1885,7 +1919,6 @@ class DefaultTemplateFunctions:
|
|||
db_item,
|
||||
item_keys,
|
||||
skip_item,
|
||||
initial_subqueries=None,
|
||||
):
|
||||
"""Generate a string that is guaranteed to be unique among all items of
|
||||
the same type as "db_item" who share the same set of keys.
|
||||
|
|
@ -1932,15 +1965,7 @@ class DefaultTemplateFunctions:
|
|||
bracket_r = ""
|
||||
|
||||
# Find matching items to disambiguate with.
|
||||
subqueries = []
|
||||
if initial_subqueries is not None:
|
||||
subqueries.extend(initial_subqueries)
|
||||
for key in keys:
|
||||
value = db_item.get(key, "")
|
||||
# Use slow queries for flexible attributes.
|
||||
fast = key in item_keys
|
||||
subqueries.append(dbcore.MatchQuery(key, value, fast))
|
||||
query = dbcore.AndQuery(subqueries)
|
||||
query = db_item.duplicates_query(keys)
|
||||
ambigous_items = (
|
||||
self.lib.items(query)
|
||||
if isinstance(db_item, Item)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from tempfile import mkstemp
|
|||
import pytest
|
||||
|
||||
from beets import dbcore
|
||||
from beets.library import LibModel
|
||||
from beets.test import _common
|
||||
|
||||
# Fixture: concrete database and model classes. For migration tests, we
|
||||
|
|
@ -44,7 +45,7 @@ class QueryFixture(dbcore.query.FieldQuery):
|
|||
return True
|
||||
|
||||
|
||||
class ModelFixture1(dbcore.Model):
|
||||
class ModelFixture1(LibModel):
|
||||
_table = "test"
|
||||
_flex_table = "testflex"
|
||||
_fields = {
|
||||
|
|
|
|||
Loading…
Reference in a new issue