diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index cd891148e..330913706 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -18,7 +18,7 @@ from __future__ import (division, absolute_import, print_function, unicode_literals) import re -from operator import attrgetter +from operator import attrgetter, mul from beets import util from datetime import datetime, timedelta @@ -73,6 +73,12 @@ class Query(object): """ raise NotImplementedError + def __eq__(self, other): + return type(self) == type(other) + + def __hash__(self): + return 0 + class FieldQuery(Query): """An abstract query that searches in a specific field for a @@ -106,6 +112,13 @@ class FieldQuery(Query): def match(self, item): return self.value_match(self.pattern, item.get(self.field)) + def __eq__(self, other): + return super(FieldQuery, self).__eq__(other) and \ + self.field == other.field and self.pattern == other.pattern + + def __hash__(self): + return hash((self.field, hash(self.pattern))) + class MatchQuery(FieldQuery): """A query that looks for exact matches in an item field.""" @@ -120,8 +133,7 @@ class MatchQuery(FieldQuery): class NoneQuery(FieldQuery): def __init__(self, field, fast=True): - self.field = field - self.fast = fast + super(NoneQuery, self).__init__(field, None, fast) def col_clause(self): return self.field + " IS NULL", () @@ -177,8 +189,8 @@ class RegexpQuery(StringFieldQuery): Raises InvalidQueryError when the pattern is not a valid regular expression. """ - def __init__(self, field, pattern, false=True): - super(RegexpQuery, self).__init__(field, pattern, false) + def __init__(self, field, pattern, fast=True): + super(RegexpQuery, self).__init__(field, pattern, fast) try: self.pattern = re.compile(self.pattern) except re.error as exc: @@ -337,6 +349,16 @@ class CollectionQuery(Query): clause = (' ' + joiner + ' ').join(clause_parts) return clause, subvals + def __eq__(self, other): + return super(CollectionQuery, self).__eq__(other) and \ + self.subqueries == other.subqueries + + def __hash__(self): + """Since subqueries are mutable, this object should not be hashable. + However and for conveniencies purposes, it can be hashed. + """ + return reduce(mul, map(hash, self.subqueries), 1) + class AnyFieldQuery(CollectionQuery): """A query that matches if a given FieldQuery subclass matches in @@ -362,6 +384,13 @@ class AnyFieldQuery(CollectionQuery): return True return False + def __eq__(self, other): + return super(AnyFieldQuery, self).__eq__(other) and \ + self.query_class == other.query_class + + def __hash__(self): + return hash((self.pattern, tuple(self.fields), self.query_class)) + class MutableCollectionQuery(CollectionQuery): """A collection query whose subqueries may be modified after the @@ -596,6 +625,12 @@ class Sort(object): """ return False + def __hash__(self): + return 0 + + def __eq__(self, other): + return type(self) == type(other) + class MultipleSort(Sort): """Sort that encapsulates multiple sub-sorts. @@ -657,6 +692,13 @@ class MultipleSort(Sort): def __repr__(self): return u'MultipleSort({0})'.format(repr(self.sorts)) + def __hash__(self): + return hash(tuple(self.sorts)) + + def __eq__(self, other): + return super(MultipleSort, self).__eq__(other) and \ + self.sorts == other.sorts + class FieldSort(Sort): """An abstract sort criterion that orders by a specific field (of @@ -680,6 +722,14 @@ class FieldSort(Sort): '+' if self.ascending else '-', ) + def __hash__(self): + return hash((self.field, self.ascending)) + + def __eq__(self, other): + return super(FieldSort, self).__eq__(other) and \ + self.field == other.field and \ + self.ascending == other.ascending + class FixedFieldSort(FieldSort): """Sort object to sort on a fixed field. @@ -701,3 +751,15 @@ class NullSort(Sort): """No sorting. Leave results unsorted.""" def sort(items): return items + + def __nonzero__(self): + return self.__bool__() + + def __bool__(self): + return False + + def __eq__(self, other): + return type(self) == type(other) or other is None + + def __hash__(self): + return 0 diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index 6628bebf0..1dcf9c4b3 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -152,12 +152,14 @@ def sort_from_strings(model_cls, sort_parts): """Create a `Sort` from a list of sort criteria (strings). """ if not sort_parts: - return query.NullSort() + sort = query.NullSort() + elif len(sort_parts) == 1: + sort = construct_sort_part(model_cls, sort_parts[0]) else: sort = query.MultipleSort() for part in sort_parts: sort.add_sort(construct_sort_part(model_cls, part)) - return sort + return sort def parse_sorted_query(model_cls, parts, prefixes={}, diff --git a/beets/library.py b/beets/library.py index 139cdfec0..132df4f52 100644 --- a/beets/library.py +++ b/beets/library.py @@ -270,15 +270,15 @@ class LibModel(dbcore.Model): def store(self): super(LibModel, self).store() - plugins.send('database_change', lib=self._db) + plugins.send('database_change', lib=self._db, model=self) def remove(self): super(LibModel, self).remove() - plugins.send('database_change', lib=self._db) + plugins.send('database_change', lib=self._db, model=self) def add(self, lib=None): super(LibModel, self).add(lib) - plugins.send('database_change', lib=self._db) + plugins.send('database_change', lib=self._db, model=self) def __format__(self, spec): if not spec: diff --git a/beets/plugins.py b/beets/plugins.py index bee4d9f32..c642e9318 100755 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -17,6 +17,7 @@ from __future__ import (division, absolute_import, print_function, unicode_literals) +import inspect import traceback import re from collections import defaultdict @@ -100,26 +101,37 @@ class BeetsPlugin(object): `self.import_stages`. Wrapping provides some bookkeeping for the plugin: specifically, the logging level is adjusted to WARNING. """ - return [self._set_log_level(logging.WARNING, import_stage) + return [self._set_log_level_and_params(logging.WARNING, import_stage) for import_stage in self.import_stages] - def _set_log_level(self, base_log_level, func): + def _set_log_level_and_params(self, base_log_level, func): """Wrap `func` to temporarily set this plugin's logger level to `base_log_level` + config options (and restore it to its previous - value after the function returns). + value after the function returns). Also determines which params may not + be sent for backwards-compatibility. - Note that that value may not be NOTSET, e.g. if a plugin import stage - triggers an event that is listened this very same plugin + Note that the log level value may not be NOTSET, e.g. if a plugin + import stage triggers an event that is listened this very same plugin. """ + argspec = inspect.getargspec(func) + @wraps(func) def wrapper(*args, **kwargs): old_log_level = self._log.level verbosity = beets.config['verbose'].get(int) log_level = max(logging.DEBUG, base_log_level - 10 * verbosity) self._log.setLevel(log_level) - try: - return func(*args, **kwargs) + try: + return func(*args, **kwargs) + except TypeError as exc: + if exc.args[0].startswith(func.__name__): + # caused by 'func' and not stuff internal to 'func' + kwargs = dict((arg, val) for arg, val in kwargs.items() + if arg in argspec.args) + return func(*args, **kwargs) + else: + raise finally: self._log.setLevel(old_log_level) return wrapper @@ -186,7 +198,7 @@ class BeetsPlugin(object): def register_listener(self, event, func): """Add a function as a listener for the specified event. """ - wrapped_func = self._set_log_level(logging.WARNING, func) + wrapped_func = self._set_log_level_and_params(logging.WARNING, func) cls = self.__class__ if cls.listeners is None or cls._raw_listeners is None: diff --git a/beetsplug/mpdupdate.py b/beetsplug/mpdupdate.py index 96141c567..dfe402497 100644 --- a/beetsplug/mpdupdate.py +++ b/beetsplug/mpdupdate.py @@ -80,7 +80,7 @@ class MPDUpdatePlugin(BeetsPlugin): self.register_listener('database_change', self.db_change) - def db_change(self, lib): + def db_change(self, lib, model): self.register_listener('cli_exit', self.update) def update(self, lib): diff --git a/beetsplug/plexupdate.py b/beetsplug/plexupdate.py index 9781e300e..5aa096486 100644 --- a/beetsplug/plexupdate.py +++ b/beetsplug/plexupdate.py @@ -55,7 +55,7 @@ class PlexUpdate(BeetsPlugin): self.register_listener('database_change', self.listen_for_db_change) - def listen_for_db_change(self, lib): + def listen_for_db_change(self, lib, model): """Listens for beets db change and register the update for the end""" self.register_listener('cli_exit', self.update) diff --git a/beetsplug/smartplaylist.py b/beetsplug/smartplaylist.py index e2c256b2b..34fa94979 100644 --- a/beetsplug/smartplaylist.py +++ b/beetsplug/smartplaylist.py @@ -21,30 +21,14 @@ from __future__ import (division, absolute_import, print_function, from beets.plugins import BeetsPlugin from beets import ui from beets.util import mkdirall, normpath, syspath +from beets.library import Item, Album, parse_query_string +from beets.dbcore import OrQuery +from beets.dbcore.query import MultipleSort import os -def _items_for_query(lib, queries, album): - """Get the matching items for a list of queries. - - `queries` can either be a single string or a list of strings. In the - latter case, the results from each query are concatenated. `album` - indicates whether the queries are item-level or album-level. - """ - if isinstance(queries, basestring): - queries = [queries] - if album: - for query in queries: - for album in lib.albums(query): - for item in album.items(): - yield item - else: - for query in queries: - for item in lib.items(query): - yield item - - class SmartPlaylistPlugin(BeetsPlugin): + def __init__(self): super(SmartPlaylistPlugin, self).__init__() self.config.add({ @@ -54,42 +38,139 @@ class SmartPlaylistPlugin(BeetsPlugin): 'playlists': [] }) + self._matched_playlists = None + self._unmatched_playlists = None + if self.config['auto']: self.register_listener('database_change', self.db_change) def commands(self): - def update(lib, opts, args): - self.update_playlists(lib) spl_update = ui.Subcommand('splupdate', - help='update the smart playlists') - spl_update.func = update + help='update the smart playlists. Playlist ' + 'names may be passed as arguments.') + spl_update.func = self.update_cmd return [spl_update] - def db_change(self, lib): - self.register_listener('cli_exit', self.update_playlists) + def update_cmd(self, lib, opts, args): + self.build_queries() + if args: + args = set(ui.decargs(args)) + for a in list(args): + if not a.endswith(".m3u"): + args.add("{0}.m3u".format(a)) + + playlists = set((name, q, a_q) + for name, q, a_q in self._unmatched_playlists + if name in args) + if not playlists: + raise ui.UserError('No playlist matching any of {0} ' + 'found'.format([name for name, _, _ in + self._unmatched_playlists])) + + self._matched_playlists = playlists + self._unmatched_playlists -= playlists + else: + self._matched_playlists = self._unmatched_playlists + + self.update_playlists(lib) + + def build_queries(self): + """ + Instanciate queries for the playlists. + + Each playlist has 2 queries: one or items one for albums, each with a + sort. We must also remember its name. _unmatched_playlists is a set of + tuples (name, (q, q_sort), (album_q, album_q_sort)). + + sort may be any sort, or NullSort, or None. None and NullSort are + equivalent and both eval to False. + More precisely + - it will be NullSort when a playlist query ('query' or 'album_query') + is a single item or a list with 1 element + - it will be None when there are multiple items i a query + """ + self._unmatched_playlists = set() + self._matched_playlists = set() + + for playlist in self.config['playlists'].get(list): + playlist_data = (playlist['name'],) + for key, Model in (('query', Item), ('album_query', Album)): + qs = playlist.get(key) + if qs is None: + query_and_sort = None, None + elif isinstance(qs, basestring): + query_and_sort = parse_query_string(qs, Model) + elif len(qs) == 1: + query_and_sort = parse_query_string(qs[0], Model) + else: + # multiple queries and sorts + queries, sorts = zip(*(parse_query_string(q, Model) + for q in qs)) + query = OrQuery(queries) + final_sorts = [] + for s in sorts: + if s: + if isinstance(s, MultipleSort): + final_sorts += s.sorts + else: + final_sorts.append(s) + if not final_sorts: + sort = None + elif len(final_sorts) == 1: + sort, = final_sorts + else: + sort = MultipleSort(final_sorts) + query_and_sort = query, sort + + playlist_data += (query_and_sort,) + + self._unmatched_playlists.add(playlist_data) + + def db_change(self, lib, model): + if self._unmatched_playlists is None: + self.build_queries() + + for playlist in self._unmatched_playlists: + n, (q, _), (a_q, _) = playlist + if a_q and isinstance(model, Album): + matches = a_q.match(model) + elif q and isinstance(model, Item): + matches = q.match(model) or q.match(model.get_album()) + else: + matches = False + + if matches: + self._log.debug("{0} will be updated because of {1}", n, model) + self._matched_playlists.add(playlist) + self.register_listener('cli_exit', self.update_playlists) + + self._unmatched_playlists -= self._matched_playlists def update_playlists(self, lib): - self._log.info("Updating smart playlists...") - playlists = self.config['playlists'].get(list) + self._log.info("Updating {0} smart playlists...", + len(self._matched_playlists)) + playlist_dir = self.config['playlist_dir'].as_filename() relative_to = self.config['relative_to'].get() if relative_to: relative_to = normpath(relative_to) - for playlist in playlists: - self._log.debug(u"Creating playlist {0[name]}", playlist) + for playlist in self._matched_playlists: + name, (query, q_sort), (album_query, a_q_sort) = playlist + self._log.debug(u"Creating playlist {0}", name) items = [] - if 'album_query' in playlist: - items.extend(_items_for_query(lib, playlist['album_query'], - True)) - if 'query' in playlist: - items.extend(_items_for_query(lib, playlist['query'], False)) + + if query: + items.extend(lib.items(query, q_sort)) + if album_query: + for album in lib.albums(album_query, a_q_sort): + items.extend(album.items()) m3us = {} # As we allow tags in the m3u names, we'll need to iterate through # the items and generate the correct m3u file names. for item in items: - m3u_name = item.evaluate_template(playlist['name'], True) + m3u_name = item.evaluate_template(name, True) if m3u_name not in m3us: m3us[m3u_name] = [] item_path = item.path @@ -104,4 +185,4 @@ class SmartPlaylistPlugin(BeetsPlugin): with open(syspath(m3u_path), 'w') as f: for path in m3us[m3u]: f.write(path + b'\n') - self._log.info("{0} playlists updated", len(playlists)) + self._log.info("{0} playlists updated", len(self._matched_playlists)) diff --git a/docs/changelog.rst b/docs/changelog.rst index b222e3ad8..f3e71e9dd 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,9 +6,14 @@ Changelog Features: +* :doc:`/plugins/smartplaylist`: detect for each playlist if it needs to be + regenated, instead of systematically regenerating all of them after a + database modification. +* :doc:`/plugins/smartplaylist`: the ``splupdate`` command can now take + additinal parameters: names of the playlists to regenerate. * Beets now accept top-level options ``--format-item`` and ``--format-album`` before any subcommand to control how items and albums are displayed. - :bug:`1271`: + :bug:`1271` * :doc:`/plugins/replaygain`: There is a new backend for the `bs1770gain`_ tool. Thanks to :user:`jmwatte`. :bug:`1343` * There are now multiple levels of verbosity. On the command line, you can @@ -128,6 +133,8 @@ Fixes: For developers: +* The ``database_change`` event now sends the item or album that is subject to + a change in the db. * the ``OptionParser`` is now a ``CommonOptionsParser`` that offers facilities for adding usual options (``--album``, ``--path`` and ``--format``). See :ref:`add_subcommands`. :bug:`1271` diff --git a/docs/dev/plugins.rst b/docs/dev/plugins.rst index 0c1f7017f..1d610f53b 100644 --- a/docs/dev/plugins.rst +++ b/docs/dev/plugins.rst @@ -203,7 +203,7 @@ The events currently available are: Library object. Parameter: ``lib``. * *database_change*: a modification has been made to the library database. The - change might not be committed yet. Parameter: ``lib``. + change might not be committed yet. Parameters: ``lib`` and ``model``. * *cli_exit*: called just before the ``beet`` command-line program exits. Parameter: ``lib``. diff --git a/docs/plugins/smartplaylist.rst b/docs/plugins/smartplaylist.rst index bc39e581e..2f691c4fe 100644 --- a/docs/plugins/smartplaylist.rst +++ b/docs/plugins/smartplaylist.rst @@ -44,6 +44,18 @@ You can also gather the results of several queries by putting them in a list. - name: 'BeatlesUniverse.m3u' query: ['artist:beatles', 'genre:"beatles cover"'] +Note that since beets query syntax is in effect, you can also use sorting +directives:: + + - name: 'Chronological Beatles' + query: 'artist:Beatles year+' + - name: 'Mixed Rock' + query: ['artist:Beatles year+', 'artist:"Led Zeppelin" bitrate+'] + +The former case behaves as expected, however please note that in the latter the +sorts will be merged: ``year+ bitrate+`` will apply to both the Beatles and Led +Zeppelin. If that bothers you, please get in touch. + For querying albums instead of items (mainly useful with extensible fields), use the ``album_query`` field. ``query`` and ``album_query`` can be used at the same time. The following example gathers single items but also items belonging @@ -53,13 +65,16 @@ to albums that have a ``for_travel`` extensible field set to 1:: album_query: 'for_travel:1' query: 'for_travel:1' -By default, all playlists are automatically regenerated at the end of the -session if the library database was changed. To force regeneration, you can -invoke it manually from the command line:: +By default, each playlist is automatically regenerated at the end of the +session if an item or album it matches changed in the library database. To +force regeneration, you can invoke it manually from the command line:: $ beet splupdate -which will generate your new smart playlists. +This will regenerate all smart playlists. You can also specify which ones you +want to regenerate:: + + $ beet splupdate BeatlesUniverse.m3u MyTravelPlaylist You can also use this plugin together with the :doc:`mpdupdate`, in order to automatically notify MPD of the playlist change, by adding ``mpdupdate`` to diff --git a/test/test_dbcore.py b/test/test_dbcore.py index dffe9ae75..39867ceb0 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -449,6 +449,7 @@ class SortFromStringsTest(unittest.TestCase): def test_zero_parts(self): s = self.sfs([]) self.assertIsInstance(s, dbcore.query.NullSort) + self.assertEqual(s, dbcore.query.NullSort()) def test_one_parts(self): s = self.sfs(['field+']) @@ -461,17 +462,17 @@ class SortFromStringsTest(unittest.TestCase): def test_fixed_field_sort(self): s = self.sfs(['field_one+']) - self.assertIsInstance(s, dbcore.query.MultipleSort) - self.assertIsInstance(s.sorts[0], dbcore.query.FixedFieldSort) + self.assertIsInstance(s, dbcore.query.FixedFieldSort) + self.assertEqual(s, dbcore.query.FixedFieldSort('field_one')) def test_flex_field_sort(self): s = self.sfs(['flex_field+']) - self.assertIsInstance(s, dbcore.query.MultipleSort) - self.assertIsInstance(s.sorts[0], dbcore.query.SlowFieldSort) + self.assertIsInstance(s, dbcore.query.SlowFieldSort) + self.assertEqual(s, dbcore.query.SlowFieldSort('flex_field')) def test_special_sort(self): s = self.sfs(['some_sort+']) - self.assertIsInstance(s.sorts[0], TestSort) + self.assertIsInstance(s, TestSort) class ResultsIteratorTest(unittest.TestCase): diff --git a/test/test_plugins.py b/test/test_plugins.py index c9c5be502..2e8bedca1 100644 --- a/test/test_plugins.py +++ b/test/test_plugins.py @@ -16,8 +16,9 @@ from __future__ import (division, absolute_import, print_function, unicode_literals) import os -from mock import patch +from mock import patch, Mock import shutil +import itertools from beets.importer import SingletonImportTask, SentinelImportTask, \ ArchiveImportTask @@ -57,7 +58,6 @@ class ItemTypesTest(unittest.TestCase, TestHelper): def setUp(self): self.setup_plugin_loader() - self.setup_beets() def tearDown(self): self.teardown_plugin_loader() @@ -309,6 +309,96 @@ class ListenersTest(unittest.TestCase, TestHelper): self.assertEqual(DummyPlugin._raw_listeners['cli_exit'], [d.dummy, d2.dummy]) + @patch('beets.plugins.find_plugins') + @patch('beets.plugins.inspect') + def test_events_called(self, mock_inspect, mock_find_plugins): + mock_inspect.getargspec.return_value = None + + class DummyPlugin(plugins.BeetsPlugin): + def __init__(self): + super(DummyPlugin, self).__init__() + self.foo = Mock(__name__=b'foo') + self.register_listener('event_foo', self.foo) + self.bar = Mock(__name__=b'bar') + self.register_listener('event_bar', self.bar) + + d = DummyPlugin() + mock_find_plugins.return_value = d, + + plugins.send('event') + d.foo.assert_has_calls([]) + d.bar.assert_has_calls([]) + + plugins.send('event_foo', var="tagada") + d.foo.assert_called_once_with(var="tagada") + d.bar.assert_has_calls([]) + + @patch('beets.plugins.find_plugins') + def test_listener_params(self, mock_find_plugins): + test = self + + class DummyPlugin(plugins.BeetsPlugin): + def __init__(self): + super(DummyPlugin, self).__init__() + for i in itertools.count(1): + try: + meth = getattr(self, 'dummy{0}'.format(i)) + except AttributeError: + break + self.register_listener('event{0}'.format(i), meth) + + def dummy1(self, foo): + test.assertEqual(foo, 5) + + def dummy2(self, foo=None): + test.assertEqual(foo, 5) + + def dummy3(self): + # argument cut off + pass + + def dummy4(self, bar=None): + # argument cut off + pass + + def dummy5(self, bar): + test.assertFalse(True) + + # more complex exmaples + + def dummy6(self, foo, bar=None): + test.assertEqual(foo, 5) + test.assertEqual(bar, None) + + def dummy7(self, foo, **kwargs): + test.assertEqual(foo, 5) + test.assertEqual(kwargs, {}) + + def dummy8(self, foo, bar, **kwargs): + test.assertFalse(True) + + def dummy9(self, **kwargs): + test.assertEqual(kwargs, {"foo": 5}) + + d = DummyPlugin() + mock_find_plugins.return_value = d, + + plugins.send('event1', foo=5) + plugins.send('event2', foo=5) + plugins.send('event3', foo=5) + plugins.send('event4', foo=5) + + with self.assertRaises(TypeError): + plugins.send('event5', foo=5) + + plugins.send('event6', foo=5) + plugins.send('event7', foo=5) + + with self.assertRaises(TypeError): + plugins.send('event8', foo=5) + + plugins.send('event9', foo=5) + def suite(): return unittest.TestLoader().loadTestsFromName(__name__) diff --git a/test/test_query.py b/test/test_query.py index d512e02b8..6d8d744fe 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -60,6 +60,16 @@ class AnyFieldQueryTest(_common.LibTestCase): dbcore.query.SubstringQuery) self.assertEqual(self.lib.items(q).get(), None) + def test_eq(self): + q1 = dbcore.query.AnyFieldQuery('foo', ['bar'], + dbcore.query.SubstringQuery) + q2 = dbcore.query.AnyFieldQuery('foo', ['bar'], + dbcore.query.SubstringQuery) + self.assertEqual(q1, q2) + + q2.query_class = None + self.assertNotEqual(q1, q2) + class AssertsMixin(object): def assert_items_matched(self, results, titles): @@ -344,6 +354,16 @@ class MatchTest(_common.TestCase): def test_open_range(self): dbcore.query.NumericQuery('bitrate', '100000..') + def test_eq(self): + q1 = dbcore.query.MatchQuery('foo', 'bar') + q2 = dbcore.query.MatchQuery('foo', 'bar') + q3 = dbcore.query.MatchQuery('foo', 'baz') + q4 = dbcore.query.StringFieldQuery('foo', 'bar') + self.assertEqual(q1, q2) + self.assertNotEqual(q1, q3) + self.assertNotEqual(q1, q4) + self.assertNotEqual(q3, q4) + class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin): def setUp(self): diff --git a/test/test_smartplaylist.py b/test/test_smartplaylist.py new file mode 100644 index 000000000..4af1cfeb6 --- /dev/null +++ b/test/test_smartplaylist.py @@ -0,0 +1,214 @@ +# This file is part of beets. +# Copyright 2015, Bruno Cauet. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +from __future__ import (division, absolute_import, print_function, + unicode_literals) + +from os import path, remove +from tempfile import mkdtemp +from shutil import rmtree + +from mock import Mock, MagicMock + +from beetsplug.smartplaylist import SmartPlaylistPlugin +from beets.library import Item, Album, parse_query_string +from beets.dbcore import OrQuery +from beets.dbcore.query import NullSort, MultipleSort, FixedFieldSort +from beets.util import syspath +from beets.ui import UserError +from beets import config + +from test._common import unittest +from test.helper import TestHelper + + +class SmartPlaylistTest(unittest.TestCase): + def test_build_queries(self): + spl = SmartPlaylistPlugin() + self.assertEqual(spl._matched_playlists, None) + self.assertEqual(spl._unmatched_playlists, None) + + config['smartplaylist']['playlists'].set([]) + spl.build_queries() + self.assertEqual(spl._matched_playlists, set()) + self.assertEqual(spl._unmatched_playlists, set()) + + config['smartplaylist']['playlists'].set([ + {'name': 'foo', + 'query': 'FOO foo'}, + {'name': 'bar', + 'album_query': ['BAR bar1', 'BAR bar2']}, + {'name': 'baz', + 'query': 'BAZ baz', + 'album_query': 'BAZ baz'} + ]) + spl.build_queries() + self.assertEqual(spl._matched_playlists, set()) + foo_foo = parse_query_string('FOO foo', Item) + baz_baz = parse_query_string('BAZ baz', Item) + baz_baz2 = parse_query_string('BAZ baz', Album) + bar_bar = OrQuery((parse_query_string('BAR bar1', Album)[0], + parse_query_string('BAR bar2', Album)[0])) + self.assertEqual(spl._unmatched_playlists, set([ + ('foo', foo_foo, (None, None)), + ('baz', baz_baz, baz_baz2), + ('bar', (None, None), (bar_bar, None)), + ])) + + def test_build_queries_with_sorts(self): + spl = SmartPlaylistPlugin() + config['smartplaylist']['playlists'].set([ + {'name': 'no_sort', 'query': 'foo'}, + {'name': 'one_sort', 'query': 'foo year+'}, + {'name': 'only_empty_sorts', 'query': ['foo', 'bar']}, + {'name': 'one_non_empty_sort', 'query': ['foo year+', 'bar']}, + {'name': 'multiple_sorts', 'query': ['foo year+', 'bar genre-']}, + {'name': 'mixed', 'query': ['foo year+', 'bar', 'baz genre+ id-']} + ]) + + spl.build_queries() + sorts = dict((name, sort) + for name, (_, sort), _ in spl._unmatched_playlists) + + asseq = self.assertEqual # less cluttered code + S = FixedFieldSort # short cut since we're only dealing with this + asseq(sorts["no_sort"], NullSort()) + asseq(sorts["one_sort"], S('year')) + asseq(sorts["only_empty_sorts"], None) + asseq(sorts["one_non_empty_sort"], S('year')) + asseq(sorts["multiple_sorts"], + MultipleSort([S('year'), S('genre', False)])) + asseq(sorts["mixed"], + MultipleSort([S('year'), S('genre'), S('id', False)])) + + def test_db_changes(self): + spl = SmartPlaylistPlugin() + + i1 = MagicMock(Item) + i2 = MagicMock(Item) + a = MagicMock(Album) + i1.get_album.return_value = a + + q1 = Mock() + q1.matches.side_effect = {i1: False, i2: False}.__getitem__ + a_q1 = Mock() + a_q1.matches.side_effect = {a: True}.__getitem__ + q2 = Mock() + q2.matches.side_effect = {i1: False, i2: True}.__getitem__ + + pl1 = '1', (q1, None), (a_q1, None) + pl2 = '2', (None, None), (a_q1, None) + pl3 = '3', (q2, None), (None, None) + + spl._unmatched_playlists = set([pl1, pl2, pl3]) + spl._matched_playlists = set() + spl.db_change(None, i1) + self.assertEqual(spl._unmatched_playlists, set([pl2])) + self.assertEqual(spl._matched_playlists, set([pl1, pl3])) + + spl._unmatched_playlists = set([pl1, pl2, pl3]) + spl._matched_playlists = set() + spl.db_change(None, i2) + self.assertEqual(spl._unmatched_playlists, set([pl2])) + self.assertEqual(spl._matched_playlists, set([pl1, pl3])) + + spl._unmatched_playlists = set([pl1, pl2, pl3]) + spl._matched_playlists = set() + spl.db_change(None, a) + self.assertEqual(spl._unmatched_playlists, set([pl3])) + self.assertEqual(spl._matched_playlists, set([pl1, pl2])) + spl.db_change(None, i2) + self.assertEqual(spl._unmatched_playlists, set()) + self.assertEqual(spl._matched_playlists, set([pl1, pl2, pl3])) + + def test_playlist_update(self): + spl = SmartPlaylistPlugin() + + i = Mock(path='/tagada.mp3') + i.evaluate_template.side_effect = lambda x, _: x + q = Mock() + a_q = Mock() + lib = Mock() + lib.items.return_value = [i] + lib.albums.return_value = [] + pl = 'my_playlist.m3u', (q, None), (a_q, None) + spl._matched_playlists = [pl] + + dir = mkdtemp() + config['smartplaylist']['relative_to'] = False + config['smartplaylist']['playlist_dir'] = dir + try: + spl.update_playlists(lib) + except Exception: + rmtree(dir) + raise + + lib.items.assert_called_once_with(q, None) + lib.albums.assert_called_once_with(a_q, None) + + m3u_filepath = path.join(dir, pl[0]) + self.assertTrue(path.exists(m3u_filepath)) + with open(syspath(m3u_filepath), 'r') as f: + content = f.read() + rmtree(dir) + + self.assertEqual(content, "/tagada.mp3\n") + + +class SmartPlaylistCLITest(unittest.TestCase, TestHelper): + def setUp(self): + self.setup_beets() + + self.item = self.add_item() + config['smartplaylist']['playlists'].set([ + {'name': 'my_playlist.m3u', + 'query': self.item.title}, + {'name': 'all.m3u', + 'query': ''} + ]) + config['smartplaylist']['playlist_dir'].set(self.temp_dir) + self.load_plugins('smartplaylist') + + def tearDown(self): + self.unload_plugins() + self.teardown_beets() + + def test_splupdate(self): + with self.assertRaises(UserError): + self.run_with_output('splupdate', 'tagada') + + self.run_with_output('splupdate', 'my_playlist') + m3u_path = path.join(self.temp_dir, 'my_playlist.m3u') + self.assertTrue(path.exists(m3u_path)) + with open(m3u_path, 'r') as f: + self.assertEqual(f.read(), self.item.path + b"\n") + remove(m3u_path) + + self.run_with_output('splupdate', 'my_playlist.m3u') + with open(m3u_path, 'r') as f: + self.assertEqual(f.read(), self.item.path + b"\n") + remove(m3u_path) + + self.run_with_output('splupdate') + for name in ('my_playlist.m3u', 'all.m3u'): + with open(path.join(self.temp_dir, name), 'r') as f: + self.assertEqual(f.read(), self.item.path + b"\n") + + +def suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + + +if __name__ == b'__main__': + unittest.main(defaultTest='suite')