diff --git a/beetsplug/advancedrewrite.py b/beetsplug/advancedrewrite.py index 20f2b7e03..9a5feaaff 100644 --- a/beetsplug/advancedrewrite.py +++ b/beetsplug/advancedrewrite.py @@ -27,37 +27,22 @@ from beets.plugins import BeetsPlugin from beets.ui import UserError -def simple_rewriter(field, rules): +def rewriter(field, simple_rules, advanced_rules): """Template field function factory. Create a template field function that rewrites the given field with the given rewriting rules. - ``rules`` must be a list of (pattern, replacement) pairs. + ``simple_rules`` must be a list of (pattern, replacement) pairs. + ``advanced_rules`` must be a list of (query, replacement) pairs. """ def fieldfunc(item): value = item._values_fixed[field] - for pattern, replacement in rules: + for pattern, replacement in simple_rules: if pattern.match(value.lower()): # Rewrite activated. return replacement - # Not activated; return original value. - return value - - return fieldfunc - - -def advanced_rewriter(field, rules): - """Template field function factory. - - Create a template field function that rewrites the given field - with the given rewriting rules. - ``rules`` must be a list of (query, replacement) pairs. - """ - - def fieldfunc(item): - value = item._values_fixed[field] - for query, replacement in rules: + for query, replacement in advanced_rules: if query.match(item): # Rewrite activated. return replacement @@ -97,8 +82,12 @@ class AdvancedRewritePlugin(BeetsPlugin): } # Gather all the rewrite rules for each field. - simple_rules = defaultdict(list) - advanced_rules = defaultdict(list) + class RulesContainer: + def __init__(self): + self.simple = [] + self.advanced = [] + + rules = defaultdict(RulesContainer) for rule in self.config.get(template): if "match" not in rule: # Simple syntax @@ -124,12 +113,12 @@ class AdvancedRewritePlugin(BeetsPlugin): f"for field {fieldname}" ) pattern = re.compile(pattern.lower()) - simple_rules[fieldname].append((pattern, value)) + rules[fieldname].simple.append((pattern, value)) # Apply the same rewrite to the corresponding album field. if fieldname in corresponding_album_fields: album_fieldname = corresponding_album_fields[fieldname] - simple_rules[album_fieldname].append((pattern, value)) + rules[album_fieldname].simple.append((pattern, value)) else: # Advanced syntax match = rule["match"] @@ -168,24 +157,18 @@ class AdvancedRewritePlugin(BeetsPlugin): f"for field {fieldname}" ) - advanced_rules[fieldname].append((query, replacement)) + rules[fieldname].advanced.append((query, replacement)) # Apply the same rewrite to the corresponding album field. if fieldname in corresponding_album_fields: album_fieldname = corresponding_album_fields[fieldname] - advanced_rules[album_fieldname].append( + rules[album_fieldname].advanced.append( (query, replacement) ) # Replace each template field with the new rewriter function. - for fieldname, fieldrules in simple_rules.items(): - getter = simple_rewriter(fieldname, fieldrules) - self.template_fields[fieldname] = getter - if fieldname in Album._fields: - self.album_template_fields[fieldname] = getter - - for fieldname, fieldrules in advanced_rules.items(): - getter = advanced_rewriter(fieldname, fieldrules) + for fieldname, fieldrules in rules.items(): + getter = rewriter(fieldname, fieldrules.simple, fieldrules.advanced) self.template_fields[fieldname] = getter if fieldname in Album._fields: self.album_template_fields[fieldname] = getter diff --git a/test/plugins/test_advancedrewrite.py b/test/plugins/test_advancedrewrite.py index d21660da6..71f92c4dd 100644 --- a/test/plugins/test_advancedrewrite.py +++ b/test/plugins/test_advancedrewrite.py @@ -133,6 +133,31 @@ class AdvancedRewritePluginTest(unittest.TestCase, TestHelper): ): self.load_plugins(PLUGIN_NAME) + def test_combined_rewrite_example(self): + self.config[PLUGIN_NAME] = [ + {"artist A": "B"}, + { + "match": "album:'C'", + "replacements": { + "artist": "D", + }, + }, + ] + self.load_plugins(PLUGIN_NAME) + + item = self.add_item( + artist="A", + albumartist="A", + ) + self.assertEqual(item.artist, "B") + + item = self.add_item( + artist="C", + albumartist="C", + album="C", + ) + self.assertEqual(item.artist, "D") + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)