diff --git a/beets/util/enumeration.py b/beets/util/enumeration.py index 531b406bb..70c0ab818 100644 --- a/beets/util/enumeration.py +++ b/beets/util/enumeration.py @@ -42,7 +42,8 @@ class IndexableEnumMeta(EnumMeta): def __getitem__(obj, x): if isinstance(x, int): return obj._value2member_map_[x] - return super(IndexableEnumMeta, EnumMeta).__getitem__(obj, x) + #import code; code.interact(local=locals()) + return super(IndexableEnumMeta, obj).__getitem__(x) class IndexableEnum(Enum): """ diff --git a/test/test_util.py b/test/test_util.py new file mode 100644 index 000000000..9ec569e62 --- /dev/null +++ b/test/test_util.py @@ -0,0 +1,44 @@ +# This file is part of beets. +# Copyright 2013, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Tests for utils.""" + +import _common +from _common import unittest +from beets.util.enumeration import OrderedEnum, IndexableEnum + +class EnumTest(_common.TestCase): + """ + Test Enum Subclasses defined in beets.util.enumeration + """ + def test_ordered_enum(self): + OrderedEnumTest = OrderedEnum('OrderedEnumTest', ['a', 'b', 'c']) + self.assertLess(OrderedEnumTest.a, OrderedEnumTest.b) + self.assertLess(OrderedEnumTest.a, OrderedEnumTest.c) + self.assertLess(OrderedEnumTest.b, OrderedEnumTest.c) + + def test_indexable_enum(self): + values = ['a', 'b', 'c'] + IndexableEnumTest = IndexableEnum('IndexableEnumTest', values) + for v in values: + self.assertEqual(IndexableEnumTest[IndexableEnumTest[v].value], + IndexableEnumTest[v]) + + +def suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + +if __name__ == '__main__': + unittest.main(defaultTest='suite') +