diff --git a/beetsplug/advancedrewrite.py b/beetsplug/advancedrewrite.py index fbb455314..6b7fad1a2 100644 --- a/beetsplug/advancedrewrite.py +++ b/beetsplug/advancedrewrite.py @@ -14,18 +14,40 @@ """Plugin to rewrite fields based on a given query.""" +import re import shlex from collections import defaultdict import confuse -from beets import ui from beets.dbcore import AndQuery, query_from_strings +from beets.dbcore.types import MULTI_VALUE_DSV from beets.library import Album, Item from beets.plugins import BeetsPlugin +from beets.ui import UserError -def rewriter(field, rules): +def simple_rewriter(field, rules): + """Template field function factory. + + Create a template field function that rewrites the given field + with the given rewriting rules. + ``rules`` must be a list of (pattern, replacement) pairs. + """ + + def fieldfunc(item): + value = item._values_fixed[field] + for pattern, replacement in rules: + if pattern.match(value.lower()): + # Rewrite activated. + return replacement + # Not activated; return original value. + return value + + return fieldfunc + + +def advanced_rewriter(field, rules): """Template field function factory. Create a template field function that rewrites the given field @@ -53,40 +75,115 @@ class AdvancedRewritePlugin(BeetsPlugin): super().__init__() template = confuse.Sequence( - { - "match": str, - "field": str, - "replacement": str, - } + confuse.OneOf( + [ + confuse.MappingValues(str), + { + "match": str, + "replacements": confuse.MappingValues( + confuse.OneOf([str, confuse.Sequence(str)]), + ), + }, + ] + ) ) # Gather all the rewrite rules for each field. - rules = defaultdict(list) + simple_rules = defaultdict(list) + advanced_rules = defaultdict(list) for rule in self.config.get(template): - query = query_from_strings( - AndQuery, - Item, - prefixes={}, - query_parts=shlex.split(rule["match"]), - ) - fieldname = rule["field"] - replacement = rule["replacement"] - if fieldname not in Item._fields: - raise ui.UserError( - "invalid field name (%s) in rewriter" % fieldname + if "match" not in rule: + # Simple syntax + if len(rule) != 1: + raise UserError( + "Simple rewrites must have only one rule, " + "but found multiple entries. " + "Did you forget to prepend a dash (-)?" + ) + key, value = next(iter(rule.items())) + try: + fieldname, pattern = key.split(None, 1) + except ValueError: + raise UserError( + f"Invalid simple rewrite specification {key}" + ) + if fieldname not in Item._fields: + raise UserError( + f"invalid field name {fieldname} in rewriter" + ) + self._log.debug( + f"adding simple rewrite '{pattern}' → '{value}' " + f"for field {fieldname}" ) - self._log.debug( - "adding template field {0} → {1}", fieldname, replacement - ) - rules[fieldname].append((query, replacement)) - if fieldname == "artist": - # Special case for the artist field: apply the same - # rewrite for "albumartist" as well. - rules["albumartist"].append((query, replacement)) + pattern = re.compile(pattern.lower()) + simple_rules[fieldname].append((pattern, value)) + if fieldname == "artist": + # Special case for the artist field: apply the same + # rewrite for "albumartist" as well. + simple_rules["albumartist"].append((pattern, value)) + else: + # Advanced syntax + match = rule["match"] + replacements = rule["replacements"] + if len(replacements) == 0: + raise UserError( + "Advanced rewrites must have at least one replacement" + ) + query = query_from_strings( + AndQuery, + Item, + prefixes={}, + query_parts=shlex.split(match), + ) + for fieldname, replacement in replacements.items(): + if fieldname not in Item._fields: + raise UserError( + f"Invalid field name {fieldname} in rewriter" + ) + self._log.debug( + f"adding advanced rewrite to '{replacement}' " + f"for field {fieldname}" + ) + if isinstance(replacement, list): + if Item._fields[fieldname] is not MULTI_VALUE_DSV: + raise UserError( + f"Field {fieldname} is not a multi-valued field " + f"but a list was given: {', '.join(replacement)}" + ) + elif isinstance(replacement, str): + if Item._fields[fieldname] is MULTI_VALUE_DSV: + replacement = list(replacement) + else: + raise UserError( + f"Invalid type of replacement {replacement} " + f"for field {fieldname}" + ) + + advanced_rules[fieldname].append((query, replacement)) + # Special case for the artist(s) field: + # apply the same rewrite for "albumartist(s)" as well. + if fieldname == "artist": + advanced_rules["albumartist"].append( + (query, replacement) + ) + elif fieldname == "artists": + advanced_rules["albumartists"].append( + (query, replacement) + ) + elif fieldname == "artist_sort": + advanced_rules["albumartist_sort"].append( + (query, replacement) + ) # Replace each template field with the new rewriter function. - for fieldname, fieldrules in rules.items(): - getter = rewriter(fieldname, fieldrules) + for fieldname, fieldrules in simple_rules.items(): + getter = simple_rewriter(fieldname, fieldrules) + self.template_fields[fieldname] = getter + if fieldname in Album._fields: + self.album_template_fields[fieldname] = getter + + for fieldname, fieldrules in advanced_rules.items(): + getter = advanced_rewriter(fieldname, fieldrules) self.template_fields[fieldname] = getter if fieldname in Album._fields: self.album_template_fields[fieldname] = getter diff --git a/test/plugins/test_advancedrewrite.py b/test/plugins/test_advancedrewrite.py new file mode 100644 index 000000000..74d2e5db0 --- /dev/null +++ b/test/plugins/test_advancedrewrite.py @@ -0,0 +1,142 @@ +# This file is part of beets. +# Copyright 2023, Max Rumpf. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Test the advancedrewrite plugin for various configurations. +""" + +import unittest +from test.helper import TestHelper + +from beets.ui import UserError + +PLUGIN_NAME = "advancedrewrite" + + +class AdvancedRewritePluginTest(unittest.TestCase, TestHelper): + def setUp(self): + self.setup_beets() + + def tearDown(self): + self.unload_plugins() + self.teardown_beets() + + def test_simple_rewrite_example(self): + self.config[PLUGIN_NAME] = [ + {"artist ODD EYE CIRCLE": "이달의 소녀 오드아이써클"}, + ] + self.load_plugins(PLUGIN_NAME) + + item = self.add_item( + title="Uncover", + artist="ODD EYE CIRCLE", + albumartist="ODD EYE CIRCLE", + album="Mix & Match", + ) + + self.assertEqual(item.artist, "이달의 소녀 오드아이써클") + + def test_advanced_rewrite_example(self): + self.config[PLUGIN_NAME] = [ + { + "match": "mb_artistid:dec0f331-cb08-4c8e-9c9f-aeb1f0f6d88c year:..2022", + "replacements": { + "artist": "이달의 소녀 오드아이써클", + "artist_sort": "LOONA / ODD EYE CIRCLE", + }, + }, + ] + self.load_plugins(PLUGIN_NAME) + + item_a = self.add_item( + title="Uncover", + artist="ODD EYE CIRCLE", + albumartist="ODD EYE CIRCLE", + artist_sort="ODD EYE CIRCLE", + albumartist_sort="ODD EYE CIRCLE", + album="Mix & Match", + mb_artistid="dec0f331-cb08-4c8e-9c9f-aeb1f0f6d88c", + year=2017, + ) + item_b = self.add_item( + title="Air Force One", + artist="ODD EYE CIRCLE", + albumartist="ODD EYE CIRCLE", + artist_sort="ODD EYE CIRCLE", + albumartist_sort="ODD EYE CIRCLE", + album="ODD EYE CIRCLE ", + mb_artistid="dec0f331-cb08-4c8e-9c9f-aeb1f0f6d88c", + year=2023, + ) + + # Assert that all replacements were applied to item_a + self.assertEqual("이달의 소녀 오드아이써클", item_a.artist) + self.assertEqual("LOONA / ODD EYE CIRCLE", item_a.artist_sort) + self.assertEqual("LOONA / ODD EYE CIRCLE", item_a.albumartist_sort) + + # Assert that no replacements were applied to item_b + self.assertEqual("ODD EYE CIRCLE", item_b.artist) + + def test_advanced_rewrite_example_with_multi_valued_field(self): + self.config[PLUGIN_NAME] = [ + { + "match": "artist:배유빈 feat. 김미현", + "replacements": { + "artists": ["유빈", "미미"], + }, + }, + ] + self.load_plugins(PLUGIN_NAME) + + item = self.add_item( + artist="배유빈 feat. 김미현", + artists=["배유빈", "김미현"], + ) + + self.assertEqual(item.artists, ["유빈", "미미"]) + + def test_fail_when_replacements_empty(self): + self.config[PLUGIN_NAME] = [ + { + "match": "artist:A", + "replacements": {}, + }, + ] + with self.assertRaises( + UserError, + msg="Advanced rewrites must have at least one replacement", + ): + self.load_plugins(PLUGIN_NAME) + + def test_fail_when_rewriting_single_valued_field_with_list(self): + self.config[PLUGIN_NAME] = [ + { + "match": "artist:'A & B'", + "replacements": { + "artist": ["C", "D"], + }, + }, + ] + with self.assertRaises( + UserError, + msg="Field artist is not a multi-valued field but a list was given: C, D", + ): + self.load_plugins(PLUGIN_NAME) + + +def suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + + +if __name__ == "__main__": + unittest.main(defaultTest="suite")