mirror of
https://github.com/beetbox/beets.git
synced 2025-12-29 20:12:33 +01:00
Merge pull request #5063 from Maxr1998/fix-advancedrewrite-simple-rules
advancedrewrite: Fix simple rules being overwritten by advanced rules
This commit is contained in:
commit
8720d6413b
2 changed files with 42 additions and 34 deletions
|
|
@ -27,37 +27,22 @@ from beets.plugins import BeetsPlugin
|
|||
from beets.ui import UserError
|
||||
|
||||
|
||||
def simple_rewriter(field, rules):
|
||||
def rewriter(field, simple_rules, advanced_rules):
|
||||
"""Template field function factory.
|
||||
|
||||
Create a template field function that rewrites the given field
|
||||
with the given rewriting rules.
|
||||
``rules`` must be a list of (pattern, replacement) pairs.
|
||||
``simple_rules`` must be a list of (pattern, replacement) pairs.
|
||||
``advanced_rules`` must be a list of (query, replacement) pairs.
|
||||
"""
|
||||
|
||||
def fieldfunc(item):
|
||||
value = item._values_fixed[field]
|
||||
for pattern, replacement in rules:
|
||||
for pattern, replacement in simple_rules:
|
||||
if pattern.match(value.lower()):
|
||||
# Rewrite activated.
|
||||
return replacement
|
||||
# Not activated; return original value.
|
||||
return value
|
||||
|
||||
return fieldfunc
|
||||
|
||||
|
||||
def advanced_rewriter(field, rules):
|
||||
"""Template field function factory.
|
||||
|
||||
Create a template field function that rewrites the given field
|
||||
with the given rewriting rules.
|
||||
``rules`` must be a list of (query, replacement) pairs.
|
||||
"""
|
||||
|
||||
def fieldfunc(item):
|
||||
value = item._values_fixed[field]
|
||||
for query, replacement in rules:
|
||||
for query, replacement in advanced_rules:
|
||||
if query.match(item):
|
||||
# Rewrite activated.
|
||||
return replacement
|
||||
|
|
@ -97,8 +82,12 @@ class AdvancedRewritePlugin(BeetsPlugin):
|
|||
}
|
||||
|
||||
# Gather all the rewrite rules for each field.
|
||||
simple_rules = defaultdict(list)
|
||||
advanced_rules = defaultdict(list)
|
||||
class RulesContainer:
|
||||
def __init__(self):
|
||||
self.simple = []
|
||||
self.advanced = []
|
||||
|
||||
rules = defaultdict(RulesContainer)
|
||||
for rule in self.config.get(template):
|
||||
if "match" not in rule:
|
||||
# Simple syntax
|
||||
|
|
@ -124,12 +113,12 @@ class AdvancedRewritePlugin(BeetsPlugin):
|
|||
f"for field {fieldname}"
|
||||
)
|
||||
pattern = re.compile(pattern.lower())
|
||||
simple_rules[fieldname].append((pattern, value))
|
||||
rules[fieldname].simple.append((pattern, value))
|
||||
|
||||
# Apply the same rewrite to the corresponding album field.
|
||||
if fieldname in corresponding_album_fields:
|
||||
album_fieldname = corresponding_album_fields[fieldname]
|
||||
simple_rules[album_fieldname].append((pattern, value))
|
||||
rules[album_fieldname].simple.append((pattern, value))
|
||||
else:
|
||||
# Advanced syntax
|
||||
match = rule["match"]
|
||||
|
|
@ -168,24 +157,18 @@ class AdvancedRewritePlugin(BeetsPlugin):
|
|||
f"for field {fieldname}"
|
||||
)
|
||||
|
||||
advanced_rules[fieldname].append((query, replacement))
|
||||
rules[fieldname].advanced.append((query, replacement))
|
||||
|
||||
# Apply the same rewrite to the corresponding album field.
|
||||
if fieldname in corresponding_album_fields:
|
||||
album_fieldname = corresponding_album_fields[fieldname]
|
||||
advanced_rules[album_fieldname].append(
|
||||
rules[album_fieldname].advanced.append(
|
||||
(query, replacement)
|
||||
)
|
||||
|
||||
# Replace each template field with the new rewriter function.
|
||||
for fieldname, fieldrules in simple_rules.items():
|
||||
getter = simple_rewriter(fieldname, fieldrules)
|
||||
self.template_fields[fieldname] = getter
|
||||
if fieldname in Album._fields:
|
||||
self.album_template_fields[fieldname] = getter
|
||||
|
||||
for fieldname, fieldrules in advanced_rules.items():
|
||||
getter = advanced_rewriter(fieldname, fieldrules)
|
||||
for fieldname, fieldrules in rules.items():
|
||||
getter = rewriter(fieldname, fieldrules.simple, fieldrules.advanced)
|
||||
self.template_fields[fieldname] = getter
|
||||
if fieldname in Album._fields:
|
||||
self.album_template_fields[fieldname] = getter
|
||||
|
|
|
|||
|
|
@ -133,6 +133,31 @@ class AdvancedRewritePluginTest(unittest.TestCase, TestHelper):
|
|||
):
|
||||
self.load_plugins(PLUGIN_NAME)
|
||||
|
||||
def test_combined_rewrite_example(self):
|
||||
self.config[PLUGIN_NAME] = [
|
||||
{"artist A": "B"},
|
||||
{
|
||||
"match": "album:'C'",
|
||||
"replacements": {
|
||||
"artist": "D",
|
||||
},
|
||||
},
|
||||
]
|
||||
self.load_plugins(PLUGIN_NAME)
|
||||
|
||||
item = self.add_item(
|
||||
artist="A",
|
||||
albumartist="A",
|
||||
)
|
||||
self.assertEqual(item.artist, "B")
|
||||
|
||||
item = self.add_item(
|
||||
artist="C",
|
||||
albumartist="C",
|
||||
album="C",
|
||||
)
|
||||
self.assertEqual(item.artist, "D")
|
||||
|
||||
|
||||
def suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
|
|
|||
Loading…
Reference in a new issue