Merge pull request #5210 from snejus/add-in-query-and-remove-named-query

Enforce the same interface across all `...Query` implementations

### Make `PlaylistQuery` a `FieldQuery`
While working on the DB optimization and looking at updates upstream I discovered one query which does not follow the `FieldQuery` interface —`PlaylistQuery`, so I looked into it in more detail and ended up integrating it as a `FieldQuery`.

One special thing about it is that it uses **IN** SQL operator, so I added implementation for this sort of query outside the playlist context, see `InQuery`.

Otherwise, it seems like `PlaylistQuery` is a field query with a special way of resolving values it wants to query. In the future, we may want to consider moving this kind of custom _initialization_ logic away from `init` methods to factory/@classmethod: this should make it more clear that the purpose of such logic is to resolve the data that is required to define a particular `FieldQuery` class fully.


### Remove `NamedQuery` since it is unused

This simplifies query parsing logic in `queryparse.py`. We know that this logic can only receive `FieldQuery` classes thus I adjusted types and removed the logic that handles other cases.

Effectively, this means that the query parsing logic does not need to care whether the query is named by the corresponding DB field. Instead, queries like `SingletonQuery` and `PlaylistQuery` are initialized with the same data as others and take things from there themselves: in this case they translate `singleton` and `playlist` queries to the underlying DB filters.
This commit is contained in:
Šarūnas Nejus 2024-05-01 16:36:52 +01:00 committed by GitHub
commit 34a59f98b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 57 additions and 73 deletions

View file

@ -1,5 +1,5 @@
[mypy]
allow_any_generics = false
# FIXME: Would be better to actually type the libraries (if under our control),
# or write our own stubs. For now, silence errors
ignore_missing_imports = True
ignore_missing_imports = true

View file

@ -22,7 +22,6 @@ from .query import (
FieldQuery,
InvalidQueryError,
MatchQuery,
NamedQuery,
OrQuery,
Query,
)

View file

@ -12,8 +12,7 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""The central Model and Database constructs for DBCore.
"""
"""The central Model and Database constructs for DBCore."""
from __future__ import annotations
@ -309,7 +308,7 @@ class Model(ABC):
are subclasses of `Sort`.
"""
_queries: Dict[str, Type[Query]] = {}
_queries: Dict[str, Type[FieldQuery]] = {}
"""Named queries that use a field-like `name:value` syntax but which
do not relate to any specific field.
"""
@ -599,8 +598,7 @@ class Model(ABC):
# Deleted flexible attributes.
for key in self._dirty:
tx.mutate(
"DELETE FROM {} "
"WHERE entity_id=? AND key=?".format(self._flex_table),
f"DELETE FROM {self._flex_table} WHERE entity_id=? AND key=?",
(self.id, key),
)

View file

@ -12,8 +12,7 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""The Query type hierarchy for DBCore.
"""
"""The Query type hierarchy for DBCore."""
from __future__ import annotations
@ -116,17 +115,8 @@ class Query(ABC):
return hash(type(self))
class NamedQuery(Query):
"""Non-field query, i.e. the query prefix is not a field but identifies the
query class.
"""
@abstractmethod
def __init__(self, pattern): ...
P = TypeVar("P")
SQLiteType = Union[str, float, int, memoryview]
SQLiteType = Union[str, bytes, float, int, memoryview]
AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType)
@ -155,9 +145,7 @@ class FieldQuery(Query, Generic[P]):
@classmethod
def value_match(cls, pattern: P, value: Any):
"""Determine whether the value matches the pattern. Both
arguments are strings.
"""
"""Determine whether the value matches the pattern."""
raise NotImplementedError()
def match(self, obj: Model) -> bool:
@ -428,6 +416,28 @@ class NumericQuery(FieldQuery):
return "1", ()
class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
"""Query which matches values in the given set."""
field: str
pattern: Sequence[AnySQLiteType]
fast: bool = True
@property
def subvals(self) -> Sequence[SQLiteType]:
return self.pattern
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
placeholders = ", ".join(["?"] * len(self.subvals))
return f"{self.field} IN ({placeholders})", self.subvals
@classmethod
def value_match(
cls, pattern: Sequence[AnySQLiteType], value: AnySQLiteType
) -> bool:
return value in pattern
class CollectionQuery(Query):
"""An abstract query class that aggregates other queries. Can be
indexed like a list to access the sub-queries.

View file

@ -12,15 +12,14 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Parsing of strings into DBCore queries.
"""
"""Parsing of strings into DBCore queries."""
import itertools
import re
from typing import Collection, Dict, List, Optional, Sequence, Tuple, Type
from . import Model, query
from .query import Query, Sort
from .query import Sort
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
@ -36,10 +35,10 @@ PARSE_QUERY_PART_REGEX = re.compile(
def parse_query_part(
part: str,
query_classes: Dict = {},
query_classes: Dict[str, Type[query.FieldQuery]] = {},
prefixes: Dict = {},
default_class: Type[query.SubstringQuery] = query.SubstringQuery,
) -> Tuple[Optional[str], str, Type[query.Query], bool]:
) -> Tuple[Optional[str], str, Type[query.FieldQuery], bool]:
"""Parse a single *query part*, which is a chunk of a complete query
string representing a single criterion.
@ -128,7 +127,7 @@ def construct_query_part(
# Use `model_cls` to build up a map from field (or query) names to
# `Query` classes.
query_classes: Dict[str, Type[Query]] = {}
query_classes: Dict[str, Type[query.FieldQuery]] = {}
for k, t in itertools.chain(
model_cls._fields.items(), model_cls._types.items()
):
@ -143,30 +142,17 @@ def construct_query_part(
# If there's no key (field name) specified, this is a "match
# anything" query.
if key is None:
if issubclass(query_class, query.FieldQuery):
# The query type matches a specific field, but none was
# specified. So we use a version of the query that matches
# any field.
out_query = query.AnyFieldQuery(
pattern, model_cls._search_fields, query_class
)
elif issubclass(query_class, query.NamedQuery):
# Non-field query type.
out_query = query_class(pattern)
else:
assert False, "Unexpected query type"
# The query type matches a specific field, but none was
# specified. So we use a version of the query that matches
# any field.
out_query = query.AnyFieldQuery(
pattern, model_cls._search_fields, query_class
)
# Field queries get constructed according to the name of the field
# they are querying.
elif issubclass(query_class, query.FieldQuery):
key = key.lower()
out_query = query_class(key.lower(), pattern, key in model_cls._fields)
# Non-field (named) query.
elif issubclass(query_class, query.NamedQuery):
out_query = query_class(pattern)
else:
assert False, "Unexpected query type"
out_query = query_class(key.lower(), pattern, key in model_cls._fields)
# Apply negation.
if negate:

View file

@ -427,7 +427,7 @@ def displayable_path(
return path.decode("utf-8", "ignore")
def syspath(path: bytes, prefix: bool = True) -> Bytes_or_String:
def syspath(path: Bytes_or_String, prefix: bool = True) -> Bytes_or_String:
"""Convert a path for use by the operating system. In particular,
paths on Windows must receive a magic prefix and must be converted
to Unicode before they are sent to the OS. To disable the magic

View file

@ -15,17 +15,22 @@
import fnmatch
import os
import tempfile
from typing import Any, Optional, Sequence, Tuple
from typing import Sequence
import beets
from beets.dbcore.query import InQuery
from beets.library import BLOB_TYPE
from beets.util import path_as_posix
class PlaylistQuery(beets.dbcore.NamedQuery):
class PlaylistQuery(InQuery[bytes]):
"""Matches files listed by a playlist file."""
def __init__(self, pattern):
self.pattern = pattern
@property
def subvals(self) -> Sequence[BLOB_TYPE]:
return [BLOB_TYPE(p) for p in self.pattern]
def __init__(self, _, pattern: str, __):
config = beets.config["playlist"]
# Get the full path to the playlist
@ -39,7 +44,7 @@ class PlaylistQuery(beets.dbcore.NamedQuery):
),
)
self.paths = []
paths = []
for playlist_path in playlist_paths:
if not fnmatch.fnmatch(playlist_path, "*.[mM]3[uU]"):
# This is not am M3U playlist, skip this candidate
@ -63,23 +68,14 @@ class PlaylistQuery(beets.dbcore.NamedQuery):
# ignore comments, and extm3u extension
continue
self.paths.append(
paths.append(
beets.util.normpath(
os.path.join(relative_to, line.rstrip())
)
)
f.close()
break
def clause(self) -> Tuple[Optional[str], Sequence[Any]]:
if not self.paths:
# Playlist is empty
return "0", ()
clause = "path IN ({})".format(", ".join("?" for path in self.paths))
return clause, (beets.library.BLOB_TYPE(p) for p in self.paths)
def match(self, item):
return item.path in self.paths
super().__init__("path", paths)
class PlaylistPlugin(beets.plugins.BeetsPlugin):

View file

@ -12,8 +12,7 @@
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Tests for the DBCore database abstraction.
"""
"""Tests for the DBCore database abstraction."""
import os
import shutil
@ -32,7 +31,7 @@ class SortFixture(dbcore.query.FieldSort):
pass
class QueryFixture(dbcore.query.NamedQuery):
class QueryFixture(dbcore.query.FieldQuery):
def __init__(self, pattern):
self.pattern = pattern
@ -605,10 +604,6 @@ class QueryFromStringsTest(unittest.TestCase):
q = self.qfs([""])
self.assertIsInstance(q.subqueries[0], dbcore.query.TrueQuery)
def test_parse_named_query(self):
q = self.qfs(["some_query:foo"])
self.assertIsInstance(q.subqueries[0], QueryFixture)
class SortFromStringsTest(unittest.TestCase):
def sfs(self, strings):