refactor: everything is like a plugin query (#214)

The initial idea for this refactor was motivated by the need to make
PluginQuery.match() have the same method signature as the match() methods on
other queries. That is, it needed to take an *item*, not the pattern and
value. (The pattern is supplied when the query is constructed.) So it made
sense to move the value-to-pattern code to a class method.

But then I realized that all the other FieldQuery subclasses needed to do
essentially the same thing. So I eliminated PluginQuery altogether and
refactored FieldQuery to subsume its functionality. I then changed all the
other FieldQuery subclasses to conform to the same pattern.

This has the side effect of allowing different kinds of queries (even
non-field queries) down the road.
This commit is contained in:
Adrian Sampson 2013-03-13 22:57:20 -07:00
parent 40b49ac786
commit f005ec2de0
4 changed files with 113 additions and 160 deletions

View file

@ -469,58 +469,32 @@ class Query(object):
class FieldQuery(Query):
"""An abstract query that searches in a specific field for a
pattern.
pattern. Subclasses must provide a `value_match` class method, which
determines whether a certain pattern string matches a certain value
string. They may then either override the `clause` method to use
native SQLite functionality or get registered to use a callback into
Python.
"""
def __init__(self, field, pattern):
self.field = field
self.pattern = pattern
class MatchQuery(FieldQuery):
"""A query that looks for exact matches in an item field."""
def clause(self):
pattern = self.pattern
if self.field == 'path' and isinstance(pattern, str):
pattern = buffer(pattern)
return self.field + " = ?", [pattern]
@classmethod
def value_match(cls, pattern, value):
"""Determine whether the value matches the pattern. Both
arguments are strings.
"""
raise NotImplementedError()
@classmethod
def _raw_value_match(cls, pattern, value):
"""Determine whether the value matches the pattern. The value
may have any type.
"""
return cls.value_match(pattern, util.as_string(value))
def match(self, item):
return self.pattern == getattr(item, self.field)
class SubstringQuery(FieldQuery):
"""A query that matches a substring in a specific item field."""
def clause(self):
search = '%' + (self.pattern.replace('\\','\\\\').replace('%','\\%')
.replace('_','\\_')) + '%'
clause = self.field + " like ? escape '\\'"
subvals = [search]
return clause, subvals
def match(self, item):
value = util.as_string(getattr(item, self.field))
return self.pattern.lower() in value.lower()
class RegexpQuery(FieldQuery):
"""A query that matches a regular expression in a specific item field."""
def __init__(self, field, pattern):
super(RegexpQuery, self).__init__(field, pattern)
self.regexp = re.compile(pattern)
def clause(self):
clause = self.field + " REGEXP ?"
subvals = [self.pattern]
return clause, subvals
def match(self, item):
value = util.as_string(getattr(item, self.field))
return self.regexp.search(value) is not None
class PluginQuery(FieldQuery):
"""The base class to add queries using beets plugins. Plugins can add
special queries by defining a subclass of PluginQuery and overriding
the match method.
"""
def __init__(self, field, pattern):
super(PluginQuery, self).__init__(field, pattern)
return self._raw_value_match(self.pattern, getattr(item, self.field))
def clause(self):
# Invoke the registered SQLite function.
@ -531,9 +505,54 @@ class PluginQuery(FieldQuery):
@classmethod
def register(cls, conn):
"""Register this query's matching function with the SQLite
connection.
connection. This method should only be invoked when the query
type chooses not to override `clause`.
"""
conn.create_function(cls.__name__, 2, cls(None, None).match)
conn.create_function(cls.__name__, 2, cls._raw_value_match)
class MatchQuery(FieldQuery):
"""A query that looks for exact matches in an item field."""
def clause(self):
pattern = self.pattern
if self.field == 'path' and isinstance(pattern, str):
pattern = buffer(pattern)
return self.field + " = ?", [pattern]
# We override the "raw" version here as a special case because we
# want to compare objects before conversion.
@classmethod
def _raw_value_match(cls, pattern, value):
return pattern == value
class SubstringQuery(FieldQuery):
"""A query that matches a substring in a specific item field."""
def clause(self):
search = '%' + (self.pattern.replace('\\','\\\\').replace('%','\\%')
.replace('_','\\_')) + '%'
clause = self.field + " like ? escape '\\'"
subvals = [search]
return clause, subvals
@classmethod
def value_match(cls, pattern, value):
return pattern.lower() in value.lower()
class RegexpQuery(FieldQuery):
"""A query that matches a regular expression in a specific item
field.
"""
def __init__(self, field, pattern):
super(RegexpQuery, self).__init__(field, pattern)
self.regexp = re.compile(pattern)
def clause(self):
clause = self.field + " REGEXP ?"
subvals = [self.pattern]
return clause, subvals
@classmethod
def value_match(cls, pattern, value):
return re.search(pattern, value)
class BooleanQuery(MatchQuery):
"""Matches a boolean field. Pattern should either be a boolean or a
@ -605,9 +624,8 @@ class CollectionQuery(Query):
example, the colon prefix denotes a regular expression query.
The function returns a tuple of `(key, value, cls)`. `key` may
be None, indicating that any field may be matched. `cls` is
either a subclass of `PluginQuery` or `None` indicating a
"normal" query.
be None, indicating that any field may be matched. `cls` is a
subclass of `FieldQuery`.
For instance,
parse_query('stapler') == (None, 'stapler', None)
@ -631,7 +649,7 @@ class CollectionQuery(Query):
for pre, query_class in prefixes.items():
if term.startswith(pre):
return key, term[len(pre):], query_class
return key, term, None # None means a normal query.
return key, term, SubstringQuery # The default query type.
@classmethod
def from_strings(cls, query_parts, default_fields=None,
@ -656,11 +674,7 @@ class CollectionQuery(Query):
subqueries.append(PathQuery(pattern))
else:
# Match any field.
if query_class:
subq = AnyPluginQuery(pattern, default_fields,
cls=query_class)
else:
subq = AnySubstringQuery(pattern, default_fields)
subq = AnyFieldQuery(pattern, default_fields, query_class)
subqueries.append(subq)
# A boolean field.
@ -673,10 +687,7 @@ class CollectionQuery(Query):
# Other (recognized) field.
elif key.lower() in all_keys:
if query_class:
subqueries.append(query_class(key.lower(), pattern))
else:
subqueries.append(SubstringQuery(key.lower(), pattern))
subqueries.append(query_class(key.lower(), pattern))
# Singleton query (not a real field).
elif key.lower() == 'singleton':
@ -704,62 +715,28 @@ class CollectionQuery(Query):
return cls.from_strings(parts, default_fields=default_fields,
all_keys=all_keys)
class AnySubstringQuery(CollectionQuery):
"""A query that matches a substring in any of a list of metadata
fields.
class AnyFieldQuery(CollectionQuery):
"""A query that matches if a given FieldQuery subclass matches in
any field. The individual field query class is provided to the
constructor.
"""
def __init__(self, pattern, fields=None):
"""Create a query for pattern over the sequence of fields
given. If no fields are given, all available fields are
used.
"""
self.pattern = pattern
self.fields = fields or ITEM_KEYS_WRITABLE
subqueries = []
for field in self.fields:
subqueries.append(SubstringQuery(field, pattern))
super(AnySubstringQuery, self).__init__(subqueries)
def clause(self):
return self.clause_with_joiner('or')
def match(self, item):
for fld in self.fields:
try:
val = getattr(item, fld)
except KeyError:
continue
if isinstance(val, basestring) and \
self.pattern.lower() in val.lower():
return True
return False
class AnyPluginQuery(CollectionQuery):
"""A query that dispatch the matching function to the match method of
the cls provided to the contstructor using a list of metadata fields.
"""
def __init__(self, pattern, fields=None, cls=PluginQuery):
subqueries = []
def __init__(self, pattern, fields, cls):
self.pattern = pattern
self.fields = fields
self.query_class = cls
subqueries = []
for field in self.fields:
subqueries.append(cls(field, pattern))
super(AnyPluginQuery, self).__init__(subqueries)
super(AnyFieldQuery, self).__init__(subqueries)
def clause(self):
return self.clause_with_joiner('or')
def match(self, item):
for field in self.fields:
try:
val = getattr(item, field)
except KeyError:
continue
if isinstance(val, basestring):
for subq in self.subqueries:
if subq.match(self.pattern, val):
return True
for subq in self.subqueries:
if subq.match(item):
return True
return False
class MutableCollectionQuery(CollectionQuery):

View file

@ -16,40 +16,31 @@
"""
from beets.plugins import BeetsPlugin
from beets.library import PluginQuery
from beets import util
from beets.library import FieldQuery
import beets
from beets.util import confit
import difflib
class FuzzyQuery(PluginQuery):
def __init__(self, field, pattern):
super(FuzzyQuery, self).__init__(field, pattern)
try:
self.threshold = beets.config['fuzzy']['threshold'].as_number()
except confit.NotFoundError:
self.threshold = 0.7
def match(self, pattern, val):
if pattern is None:
return False
val = util.as_string(val)
class FuzzyQuery(FieldQuery):
@classmethod
def value_match(self, pattern, val):
# smartcase
if pattern.islower():
val = val.lower()
queryMatcher = difflib.SequenceMatcher(None, pattern, val)
return queryMatcher.quick_ratio() >= self.threshold
threshold = beets.config['fuzzy']['threshold'].as_number()
return queryMatcher.quick_ratio() >= threshold
class FuzzyPlugin(BeetsPlugin):
def __init__(self):
super(FuzzyPlugin, self).__init__()
self.config.add({
'prefix': '~',
'threshold': 0.7,
})
super(FuzzyPlugin, self).__init__(self)
def queries(self):
try:
prefix = beets.config['fuzzy']['prefix'].get(basestring)
except confit.NotFoundError:
prefix = '~'
prefix = beets.config['fuzzy']['prefix'].get(basestring)
return {prefix: FuzzyQuery}

View file

@ -334,9 +334,11 @@ You can add new kinds of queries to beets' :doc:`query syntax
supports regular expression queries, which are indicated by a colon
prefix---plugins can do the same.
To do so, define a subclass of the ``PluginQuery`` type from the
``beets.library`` module. Then, in the ``queries`` method of your plugin
class, return a dictionary mapping prefix strings to query classes.
To do so, define a subclass of the ``FieldQuery`` type from the
``beets.library`` module. In this subclass, you should override the
``value_match`` class method. (Remember the ``@classmethod`` decorator!) Then,
in the ``queries`` method of your plugin class, return a dictionary mapping
prefix strings to query classes.
The following example plugins declares a query using the ``@`` prefix. So the
plugin will be called if we issue a command like ``beet ls @something`` or
@ -346,7 +348,8 @@ plugin will be called if we issue a command like ``beet ls @something`` or
from beets.library import PluginQuery
class ExampleQuery(PluginQuery):
def match(self, pattern, val):
@classmethod
def value_match(self, pattern, val):
return True # This will just match everything.
class ExamplePlugin(BeetsPlugin):

View file

@ -27,17 +27,17 @@ some_item = _common.item()
class QueryParseTest(unittest.TestCase):
def test_one_basic_term(self):
q = 'test'
r = (None, 'test', None)
r = (None, 'test', beets.library.SubstringQuery)
self.assertEqual(pqp(q), r)
def test_one_keyed_term(self):
q = 'test:val'
r = ('test', 'val', None)
r = ('test', 'val', beets.library.SubstringQuery)
self.assertEqual(pqp(q), r)
def test_colon_at_end(self):
q = 'test:'
r = (None, 'test:', None)
r = (None, 'test:', beets.library.SubstringQuery)
self.assertEqual(pqp(q), r)
def test_one_basic_regexp(self):
@ -52,7 +52,7 @@ class QueryParseTest(unittest.TestCase):
def test_escaped_colon(self):
q = r'test\:val'
r = (None, 'test:val', None)
r = (None, 'test:val', beets.library.SubstringQuery)
self.assertEqual(pqp(q), r)
def test_escaped_colon_in_regexp(self):
@ -60,42 +60,24 @@ class QueryParseTest(unittest.TestCase):
r = (None, 'test:regexp', beets.library.RegexpQuery)
self.assertEqual(pqp(q), r)
class AnySubstringQueryTest(unittest.TestCase):
class AnyFieldQueryTest(unittest.TestCase):
def setUp(self):
self.lib = beets.library.Library(':memory:')
self.lib.add(some_item)
def test_no_restriction(self):
q = beets.library.AnySubstringQuery('title')
q = beets.library.AnyFieldQuery('title', beets.library.ITEM_KEYS,
beets.library.SubstringQuery)
self.assertEqual(self.lib.items(q).next().title, 'the title')
def test_restriction_completeness(self):
q = beets.library.AnySubstringQuery('title', ['title'])
q = beets.library.AnyFieldQuery('title', ['title'],
beets.library.SubstringQuery)
self.assertEqual(self.lib.items(q).next().title, 'the title')
def test_restriction_soundness(self):
q = beets.library.AnySubstringQuery('title', ['artist'])
self.assertRaises(StopIteration, self.lib.items(q).next)
class AnyRegexpQueryTest(unittest.TestCase):
def setUp(self):
self.lib = beets.library.Library(':memory:')
self.lib.add(some_item)
def test_no_restriction(self):
q = beets.library.AnyRegexpQuery(r'^the ti')
self.assertEqual(self.lib.items(q).next().title, 'the title')
def test_restriction_completeness(self):
q = beets.library.AnyRegexpQuery(r'^the ti', ['title'])
self.assertEqual(self.lib.items(q).next().title, 'the title')
def test_restriction_soundness(self):
q = beets.library.AnyRegexpQuery(r'^the ti', ['artist'])
self.assertRaises(StopIteration, self.lib.items(q).next)
def test_restriction_soundness_2(self):
q = beets.library.AnyRegexpQuery(r'the ti$', ['title'])
q = beets.library.AnyFieldQuery('title', ['artist'],
beets.library.SubstringQuery)
self.assertRaises(StopIteration, self.lib.items(q).next)
# Convenient asserts for matching items.