diff --git a/beets/plugins.py b/beets/plugins.py index bee4d9f32..c642e9318 100755 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -17,6 +17,7 @@ from __future__ import (division, absolute_import, print_function, unicode_literals) +import inspect import traceback import re from collections import defaultdict @@ -100,26 +101,37 @@ class BeetsPlugin(object): `self.import_stages`. Wrapping provides some bookkeeping for the plugin: specifically, the logging level is adjusted to WARNING. """ - return [self._set_log_level(logging.WARNING, import_stage) + return [self._set_log_level_and_params(logging.WARNING, import_stage) for import_stage in self.import_stages] - def _set_log_level(self, base_log_level, func): + def _set_log_level_and_params(self, base_log_level, func): """Wrap `func` to temporarily set this plugin's logger level to `base_log_level` + config options (and restore it to its previous - value after the function returns). + value after the function returns). Also determines which params may not + be sent for backwards-compatibility. - Note that that value may not be NOTSET, e.g. if a plugin import stage - triggers an event that is listened this very same plugin + Note that the log level value may not be NOTSET, e.g. if a plugin + import stage triggers an event that is listened this very same plugin. """ + argspec = inspect.getargspec(func) + @wraps(func) def wrapper(*args, **kwargs): old_log_level = self._log.level verbosity = beets.config['verbose'].get(int) log_level = max(logging.DEBUG, base_log_level - 10 * verbosity) self._log.setLevel(log_level) - try: - return func(*args, **kwargs) + try: + return func(*args, **kwargs) + except TypeError as exc: + if exc.args[0].startswith(func.__name__): + # caused by 'func' and not stuff internal to 'func' + kwargs = dict((arg, val) for arg, val in kwargs.items() + if arg in argspec.args) + return func(*args, **kwargs) + else: + raise finally: self._log.setLevel(old_log_level) return wrapper @@ -186,7 +198,7 @@ class BeetsPlugin(object): def register_listener(self, event, func): """Add a function as a listener for the specified event. """ - wrapped_func = self._set_log_level(logging.WARNING, func) + wrapped_func = self._set_log_level_and_params(logging.WARNING, func) cls = self.__class__ if cls.listeners is None or cls._raw_listeners is None: diff --git a/test/test_plugins.py b/test/test_plugins.py index c9c5be502..2e8bedca1 100644 --- a/test/test_plugins.py +++ b/test/test_plugins.py @@ -16,8 +16,9 @@ from __future__ import (division, absolute_import, print_function, unicode_literals) import os -from mock import patch +from mock import patch, Mock import shutil +import itertools from beets.importer import SingletonImportTask, SentinelImportTask, \ ArchiveImportTask @@ -57,7 +58,6 @@ class ItemTypesTest(unittest.TestCase, TestHelper): def setUp(self): self.setup_plugin_loader() - self.setup_beets() def tearDown(self): self.teardown_plugin_loader() @@ -309,6 +309,96 @@ class ListenersTest(unittest.TestCase, TestHelper): self.assertEqual(DummyPlugin._raw_listeners['cli_exit'], [d.dummy, d2.dummy]) + @patch('beets.plugins.find_plugins') + @patch('beets.plugins.inspect') + def test_events_called(self, mock_inspect, mock_find_plugins): + mock_inspect.getargspec.return_value = None + + class DummyPlugin(plugins.BeetsPlugin): + def __init__(self): + super(DummyPlugin, self).__init__() + self.foo = Mock(__name__=b'foo') + self.register_listener('event_foo', self.foo) + self.bar = Mock(__name__=b'bar') + self.register_listener('event_bar', self.bar) + + d = DummyPlugin() + mock_find_plugins.return_value = d, + + plugins.send('event') + d.foo.assert_has_calls([]) + d.bar.assert_has_calls([]) + + plugins.send('event_foo', var="tagada") + d.foo.assert_called_once_with(var="tagada") + d.bar.assert_has_calls([]) + + @patch('beets.plugins.find_plugins') + def test_listener_params(self, mock_find_plugins): + test = self + + class DummyPlugin(plugins.BeetsPlugin): + def __init__(self): + super(DummyPlugin, self).__init__() + for i in itertools.count(1): + try: + meth = getattr(self, 'dummy{0}'.format(i)) + except AttributeError: + break + self.register_listener('event{0}'.format(i), meth) + + def dummy1(self, foo): + test.assertEqual(foo, 5) + + def dummy2(self, foo=None): + test.assertEqual(foo, 5) + + def dummy3(self): + # argument cut off + pass + + def dummy4(self, bar=None): + # argument cut off + pass + + def dummy5(self, bar): + test.assertFalse(True) + + # more complex exmaples + + def dummy6(self, foo, bar=None): + test.assertEqual(foo, 5) + test.assertEqual(bar, None) + + def dummy7(self, foo, **kwargs): + test.assertEqual(foo, 5) + test.assertEqual(kwargs, {}) + + def dummy8(self, foo, bar, **kwargs): + test.assertFalse(True) + + def dummy9(self, **kwargs): + test.assertEqual(kwargs, {"foo": 5}) + + d = DummyPlugin() + mock_find_plugins.return_value = d, + + plugins.send('event1', foo=5) + plugins.send('event2', foo=5) + plugins.send('event3', foo=5) + plugins.send('event4', foo=5) + + with self.assertRaises(TypeError): + plugins.send('event5', foo=5) + + plugins.send('event6', foo=5) + plugins.send('event7', foo=5) + + with self.assertRaises(TypeError): + plugins.send('event8', foo=5) + + plugins.send('event9', foo=5) + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)