From 8937978d5f607885609d6809aca23d84cee063db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sat, 31 May 2025 18:57:09 +0100 Subject: [PATCH] Refactor PathQuery and add docs --- beets/dbcore/query.py | 90 ++++++++++++++++++++++--------------------- test/test_query.py | 15 ++++---- 2 files changed, 55 insertions(+), 50 deletions(-) diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index c814c5966..9cff082a3 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -22,7 +22,7 @@ import unicodedata from abc import ABC, abstractmethod from collections.abc import Iterator, MutableSequence, Sequence from datetime import datetime, timedelta -from functools import reduce +from functools import cached_property, reduce from operator import mul, or_ from re import Pattern from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union @@ -30,8 +30,7 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union from beets import util if TYPE_CHECKING: - from beets.dbcore import Model - from beets.dbcore.db import AnyModel + from beets.dbcore.db import AnyModel, Model P = TypeVar("P", default=Any) else: @@ -283,13 +282,11 @@ class PathQuery(FieldQuery[bytes]): and case-sensitive otherwise. """ - def __init__(self, field, pattern, fast=True): + def __init__(self, field: str, pattern: bytes, fast: bool = True) -> None: """Create a path query. `pattern` must be a path, either to a file or a directory. """ - super().__init__(field, pattern, fast) - path = util.normpath(pattern) # Case sensitivity depends on the filesystem that the query path is located on. @@ -304,50 +301,57 @@ class PathQuery(FieldQuery[bytes]): # from `col_clause()` do the same thing. path = path.lower() - # Match the path as a single file. - self.file_path = path - # As a directory (prefix). - self.dir_path = os.path.join(path, b"") + super().__init__(field, path, fast) - @classmethod - def is_path_query(cls, query_part): + @cached_property + def dir_path(self) -> bytes: + return os.path.join(self.pattern, b"") + + @staticmethod + def is_path_query(query_part: str) -> bool: """Try to guess whether a unicode query part is a path query. - Condition: separator precedes colon and the file exists. + The path query must + 1. precede the colon in the query, if a colon is present + 2. contain either ``os.sep`` or ``os.altsep`` (Windows) + 3. this path must exist on the filesystem. """ - colon = query_part.find(":") - if colon != -1: - query_part = query_part[:colon] + query_part = query_part.split(":")[0] - # Test both `sep` and `altsep` (i.e., both slash and backslash on - # Windows). - if not ( - os.sep in query_part or (os.altsep and os.altsep in query_part) - ): - return False - - return os.path.exists(util.syspath(util.normpath(query_part))) - - def match(self, item): - path = item.path if self.case_sensitive else item.path.lower() - return (path == self.file_path) or path.startswith(self.dir_path) - - def col_clause(self): - file_blob = BLOB_TYPE(self.file_path) - dir_blob = BLOB_TYPE(self.dir_path) - - if self.case_sensitive: - query_part = "({0} = ?) || (substr({0}, 1, ?) = ?)" - else: - query_part = "(BYTELOWER({0}) = BYTELOWER(?)) || \ - (substr(BYTELOWER({0}), 1, ?) = BYTELOWER(?))" - - return query_part.format(self.field), ( - file_blob, - len(dir_blob), - dir_blob, + return ( + # make sure the query part contains a path separator + bool(set(query_part) & {os.sep, os.altsep}) + and os.path.exists(util.normpath(query_part)) ) + def match(self, obj: Model) -> bool: + """Check whether a model object's path matches this query. + + Performs either an exact match against the pattern or checks if the path + starts with the given directory path. Case sensitivity depends on the object's + filesystem as determined during initialization. + """ + path = obj.path if self.case_sensitive else obj.path.lower() + return (path == self.pattern) or path.startswith(self.dir_path) + + def col_clause(self) -> tuple[str, Sequence[SQLiteType]]: + """Generate an SQL clause that implements path matching in the database. + + Returns a tuple of SQL clause string and parameter values list that matches + paths either exactly or by directory prefix. Handles case sensitivity + appropriately using BYTELOWER for case-insensitive matches. + """ + if self.case_sensitive: + left, right = self.field, "?" + else: + left, right = f"BYTELOWER({self.field})", "BYTELOWER(?)" + + return f"({left} = {right}) || (substr({left}, 1, ?) = {right})", [ + BLOB_TYPE(self.pattern), + len(dir_blob := BLOB_TYPE(self.dir_path)), + dir_blob, + ] + def __repr__(self) -> str: return ( f"{self.__class__.__name__}({self.field!r}, {self.pattern!r}, " diff --git a/test/test_query.py b/test/test_query.py index a8646f1bb..776bfd6f6 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -880,7 +880,7 @@ class TestPathQuery: @pytest.fixture(scope="class") def lib(self, helper): - helper.add_item(path=b"/a/b/c.mp3", title="path item") + helper.add_item(path=b"/aaa/bb/c.mp3", title="path item") helper.add_item(path=b"/x/y/z.mp3", title="another item") helper.add_item(path=b"/c/_/title.mp3", title="with underscore") helper.add_item(path=b"/c/%/title.mp3", title="with percent") @@ -892,12 +892,13 @@ class TestPathQuery: @pytest.mark.parametrize( "q, expected_titles", [ - _p("path:/a/b/c.mp3", ["path item"], id="exact-match"), - _p("path:/a", ["path item"], id="parent-dir-no-slash"), - _p("path:/a/", ["path item"], id="parent-dir-with-slash"), + _p("path:/aaa/bb/c.mp3", ["path item"], id="exact-match"), + _p("path:/aaa", ["path item"], id="parent-dir-no-slash"), + _p("path:/aaa/", ["path item"], id="parent-dir-with-slash"), + _p("path:/aa", [], id="no-match-does-not-match-parent-dir"), _p("path:/xyzzy/", [], id="no-match"), _p("path:/b/", [], id="fragment-no-match"), - _p("path:/x/../a/b", ["path item"], id="non-normalized"), + _p("path:/x/../aaa/bb", ["path item"], id="non-normalized"), _p("path::c\\.mp3$", ["path item"], id="regex"), _p("path:/c/_", ["with underscore"], id="underscore-escaped"), _p("path:/c/%", ["with percent"], id="percent-escaped"), @@ -913,8 +914,8 @@ class TestPathQuery: @pytest.mark.parametrize( "q, expected_titles", [ - _p("/a/b", ["path item"], id="slashed-query"), - _p("/a/b , /a/b", ["path item"], id="path-in-or-query"), + _p("/aaa/bb", ["path item"], id="slashed-query"), + _p("/aaa/bb , /aaa", ["path item"], id="path-in-or-query"), _p("c.mp3", [], id="no-slash-no-match"), _p("title:/a/b", [], id="slash-with-explicit-field-no-match"), ],