Add ability to set temporary music dir context for ipfs

This commit is contained in:
Šarūnas Nejus 2026-03-23 14:12:31 +00:00
parent 5eee28bb5c
commit 2d776a8a22
No known key found for this signature in database
4 changed files with 62 additions and 33 deletions

View file

@ -1,3 +1,4 @@
from contextlib import contextmanager
from contextvars import ContextVar
# Holds the music dir context
@ -12,3 +13,13 @@ def get_music_dir() -> bytes:
def set_music_dir(value: bytes) -> None:
"""Set the current music directory context."""
_music_dir_var.set(value)
@contextmanager
def music_dir(value: bytes):
"""Temporarily bind the active music directory for query parsing."""
token = _music_dir_var.set(value)
try:
yield
finally:
_music_dir_var.reset(token)

View file

@ -1,5 +1,6 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING
import platformdirs
@ -32,10 +33,12 @@ class Library(dbcore.Database):
directory: str | None = None,
path_formats=((PF_KEY_DEFAULT, "$artist/$album/$track $title"),),
replacements=None,
set_music_dir: bool = True,
):
timeout = beets.config["timeout"].as_number()
self.directory = normpath(directory or platformdirs.user_music_path())
context.set_music_dir(self.directory)
if set_music_dir:
context.set_music_dir(self.directory)
super().__init__(path, timeout=timeout)
@ -45,6 +48,12 @@ class Library(dbcore.Database):
# Used for template substitution performance.
self._memotable: dict[tuple[str, ...], str] = {}
@contextmanager
def music_dir_context(self):
"""Temporarily bind this library's directory to path conversion."""
with context.music_dir(self.directory):
yield self
# Adding objects to the database.
def add(self, obj):
@ -95,10 +104,13 @@ class Library(dbcore.Database):
# Parse the query, if necessary.
try:
parsed_sort = None
if isinstance(query, str):
query, parsed_sort = parse_query_string(query, model_cls)
elif isinstance(query, (list, tuple)):
query, parsed_sort = parse_query_parts(query, model_cls)
# Query parsing needs the library root, but keeping it scoped here
# avoids leaking one Library's directory into another's work.
with context.music_dir(self.directory):
if isinstance(query, str):
query, parsed_sort = parse_query_string(query, model_cls)
elif isinstance(query, (list, tuple)):
query, parsed_sort = parse_query_parts(query, model_cls)
except dbcore.query.InvalidQueryArgumentValueError as exc:
raise dbcore.InvalidQueryError(query, exc)

View file

@ -281,13 +281,16 @@ class IPFSPlugin(BeetsPlugin):
def ipfs_added_albums(self, rlib, tmpname):
"""Returns a new library with only albums/items added to ipfs"""
tmplib = library.Library(tmpname, directory="/ipfs/")
for album in rlib.albums():
try:
if album.ipfs:
self.create_new_album(album, tmplib)
except AttributeError:
pass
tmplib = library.Library(
tmpname, directory="/ipfs/", set_music_dir=False
)
with tmplib.music_dir_context():
for album in rlib.albums():
try:
if album.ipfs:
self.create_new_album(album, tmplib)
except AttributeError:
pass
return tmplib
def create_new_album(self, album, tmplib):

View file

@ -29,27 +29,30 @@ class IPFSPluginTest(PluginTestCase):
test_album = self.mk_test_album()
ipfs = IPFSPlugin()
added_albums = ipfs.ipfs_added_albums(self.lib, self.lib.path)
added_album = added_albums.get_album(1)
assert added_album.ipfs == test_album.ipfs
found = False
want_item = test_album.items()[2]
for check_item in added_album.items():
try:
if check_item.get("ipfs", with_album=False):
ipfs_item = os.fsdecode(os.path.basename(want_item.path))
want_path = util.normpath(
os.path.join("/ipfs", test_album.ipfs, ipfs_item)
)
assert check_item.path == want_path
assert (
check_item.get("ipfs", with_album=False)
== want_item.ipfs
)
assert check_item.title == want_item.title
found = True
except AttributeError:
pass
assert found
with added_albums.music_dir_context():
added_album = added_albums.get_album(1)
assert added_album.ipfs == test_album.ipfs
found = False
want_item = test_album.items()[2]
for check_item in added_album.items():
try:
if check_item.get("ipfs", with_album=False):
ipfs_item = os.fsdecode(
os.path.basename(want_item.path)
)
want_path = util.normpath(
os.path.join("/ipfs", test_album.ipfs, ipfs_item)
)
assert check_item.path == want_path
assert (
check_item.get("ipfs", with_album=False)
== want_item.ipfs
)
assert check_item.title == want_item.title
found = True
except AttributeError:
pass
assert found
def mk_test_album(self):
items = [_common.item() for _ in range(3)]