Refactor PathQuery and add docs

This commit is contained in:
Šarūnas Nejus 2025-05-31 18:57:09 +01:00
parent 45f92ac641
commit 8937978d5f
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
2 changed files with 55 additions and 50 deletions

View file

@ -22,7 +22,7 @@ import unicodedata
from abc import ABC, abstractmethod
from collections.abc import Iterator, MutableSequence, Sequence
from datetime import datetime, timedelta
from functools import reduce
from functools import cached_property, reduce
from operator import mul, or_
from re import Pattern
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union
@ -30,8 +30,7 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union
from beets import util
if TYPE_CHECKING:
from beets.dbcore import Model
from beets.dbcore.db import AnyModel
from beets.dbcore.db import AnyModel, Model
P = TypeVar("P", default=Any)
else:
@ -283,13 +282,11 @@ class PathQuery(FieldQuery[bytes]):
and case-sensitive otherwise.
"""
def __init__(self, field, pattern, fast=True):
def __init__(self, field: str, pattern: bytes, fast: bool = True) -> None:
"""Create a path query.
`pattern` must be a path, either to a file or a directory.
"""
super().__init__(field, pattern, fast)
path = util.normpath(pattern)
# Case sensitivity depends on the filesystem that the query path is located on.
@ -304,50 +301,57 @@ class PathQuery(FieldQuery[bytes]):
# from `col_clause()` do the same thing.
path = path.lower()
# Match the path as a single file.
self.file_path = path
# As a directory (prefix).
self.dir_path = os.path.join(path, b"")
super().__init__(field, path, fast)
@classmethod
def is_path_query(cls, query_part):
@cached_property
def dir_path(self) -> bytes:
return os.path.join(self.pattern, b"")
@staticmethod
def is_path_query(query_part: str) -> bool:
"""Try to guess whether a unicode query part is a path query.
Condition: separator precedes colon and the file exists.
The path query must
1. precede the colon in the query, if a colon is present
2. contain either ``os.sep`` or ``os.altsep`` (Windows)
3. this path must exist on the filesystem.
"""
colon = query_part.find(":")
if colon != -1:
query_part = query_part[:colon]
query_part = query_part.split(":")[0]
# Test both `sep` and `altsep` (i.e., both slash and backslash on
# Windows).
if not (
os.sep in query_part or (os.altsep and os.altsep in query_part)
):
return False
return os.path.exists(util.syspath(util.normpath(query_part)))
def match(self, item):
path = item.path if self.case_sensitive else item.path.lower()
return (path == self.file_path) or path.startswith(self.dir_path)
def col_clause(self):
file_blob = BLOB_TYPE(self.file_path)
dir_blob = BLOB_TYPE(self.dir_path)
if self.case_sensitive:
query_part = "({0} = ?) || (substr({0}, 1, ?) = ?)"
else:
query_part = "(BYTELOWER({0}) = BYTELOWER(?)) || \
(substr(BYTELOWER({0}), 1, ?) = BYTELOWER(?))"
return query_part.format(self.field), (
file_blob,
len(dir_blob),
dir_blob,
return (
# make sure the query part contains a path separator
bool(set(query_part) & {os.sep, os.altsep})
and os.path.exists(util.normpath(query_part))
)
def match(self, obj: Model) -> bool:
"""Check whether a model object's path matches this query.
Performs either an exact match against the pattern or checks if the path
starts with the given directory path. Case sensitivity depends on the object's
filesystem as determined during initialization.
"""
path = obj.path if self.case_sensitive else obj.path.lower()
return (path == self.pattern) or path.startswith(self.dir_path)
def col_clause(self) -> tuple[str, Sequence[SQLiteType]]:
"""Generate an SQL clause that implements path matching in the database.
Returns a tuple of SQL clause string and parameter values list that matches
paths either exactly or by directory prefix. Handles case sensitivity
appropriately using BYTELOWER for case-insensitive matches.
"""
if self.case_sensitive:
left, right = self.field, "?"
else:
left, right = f"BYTELOWER({self.field})", "BYTELOWER(?)"
return f"({left} = {right}) || (substr({left}, 1, ?) = {right})", [
BLOB_TYPE(self.pattern),
len(dir_blob := BLOB_TYPE(self.dir_path)),
dir_blob,
]
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.field!r}, {self.pattern!r}, "

View file

@ -880,7 +880,7 @@ class TestPathQuery:
@pytest.fixture(scope="class")
def lib(self, helper):
helper.add_item(path=b"/a/b/c.mp3", title="path item")
helper.add_item(path=b"/aaa/bb/c.mp3", title="path item")
helper.add_item(path=b"/x/y/z.mp3", title="another item")
helper.add_item(path=b"/c/_/title.mp3", title="with underscore")
helper.add_item(path=b"/c/%/title.mp3", title="with percent")
@ -892,12 +892,13 @@ class TestPathQuery:
@pytest.mark.parametrize(
"q, expected_titles",
[
_p("path:/a/b/c.mp3", ["path item"], id="exact-match"),
_p("path:/a", ["path item"], id="parent-dir-no-slash"),
_p("path:/a/", ["path item"], id="parent-dir-with-slash"),
_p("path:/aaa/bb/c.mp3", ["path item"], id="exact-match"),
_p("path:/aaa", ["path item"], id="parent-dir-no-slash"),
_p("path:/aaa/", ["path item"], id="parent-dir-with-slash"),
_p("path:/aa", [], id="no-match-does-not-match-parent-dir"),
_p("path:/xyzzy/", [], id="no-match"),
_p("path:/b/", [], id="fragment-no-match"),
_p("path:/x/../a/b", ["path item"], id="non-normalized"),
_p("path:/x/../aaa/bb", ["path item"], id="non-normalized"),
_p("path::c\\.mp3$", ["path item"], id="regex"),
_p("path:/c/_", ["with underscore"], id="underscore-escaped"),
_p("path:/c/%", ["with percent"], id="percent-escaped"),
@ -913,8 +914,8 @@ class TestPathQuery:
@pytest.mark.parametrize(
"q, expected_titles",
[
_p("/a/b", ["path item"], id="slashed-query"),
_p("/a/b , /a/b", ["path item"], id="path-in-or-query"),
_p("/aaa/bb", ["path item"], id="slashed-query"),
_p("/aaa/bb , /aaa", ["path item"], id="path-in-or-query"),
_p("c.mp3", [], id="no-slash-no-match"),
_p("title:/a/b", [], id="slash-with-explicit-field-no-match"),
],