Removed import state functions in favor of an import state dataclass.

Makes this more readable in my opinion, we also now have typehints for
the import state.
This commit is contained in:
Sebastian Mohr 2025-02-01 13:16:04 +01:00
parent a1c0ebdeef
commit 435864cb50

View file

@ -1,18 +1,3 @@
# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Provides the basic, interface-agnostic workflow for importing and
autotagging music files.
"""
@ -23,17 +8,20 @@ import pickle
import re
import shutil
import time
from abc import ABC, abstractmethod
from bisect import bisect_left, insort
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from tempfile import mkdtemp
from typing import Iterable, Sequence
import mediafile
from beets import autotag, config, dbcore, library, logging, plugins, util
from beets.util import (
MoveOperation,
PathLike,
ancestry,
displayable_path,
normpath,
@ -49,8 +37,7 @@ action = Enum("action", ["SKIP", "ASIS", "TRACKS", "APPLY", "ALBUMS", "RETAG"])
QUEUE_SIZE = 128
SINGLE_ARTIST_THRESH = 0.25
PROGRESS_KEY = "tagprogress"
HISTORY_KEY = "taghistory"
# Usually flexible attributes are preserved (i.e., not updated) during
# reimports. The following two lists (globally) change this behaviour for
# certain fields. To alter these lists only when a specific plugin is in use,
@ -80,142 +67,163 @@ class ImportAbortError(Exception):
pass
# Utilities.
@dataclass
class ImportState:
"""Representing the progress of an import task.
Opens the state file on creation of the class. If you want
to ensure the state is written to disk, you should use the
context manager protocol.
def _open_state():
"""Reads the state file, returning a dictionary."""
try:
with open(config["statefile"].as_filename(), "rb") as f:
return pickle.load(f)
except Exception as exc:
# The `pickle` module can emit all sorts of exceptions during
# unpickling, including ImportError. We use a catch-all
# exception to avoid enumerating them all (the docs don't even have a
# full list!).
log.debug("state file could not be read: {0}", exc)
return {}
Tagprogress allows long tagging tasks to be resumed when they pause.
Taghistory is a utility for manipulating the "incremental" import log.
This keeps track of all directories that were ever imported, which
allows the importer to only import new stuff.
def _save_state(state):
"""Writes the state dictionary out to disk."""
try:
with open(config["statefile"].as_filename(), "wb") as f:
pickle.dump(state, f)
except OSError as exc:
log.error("state file could not be written: {0}", exc)
Usage
-----
```
# Readonly
progress = ImportState().tagprogress
# Utilities for reading and writing the beets progress file, which
# allows long tagging tasks to be resumed when they pause (or crash).
def progress_read():
state = _open_state()
return state.setdefault(PROGRESS_KEY, {})
@contextmanager
def progress_write():
state = _open_state()
progress = state.setdefault(PROGRESS_KEY, {})
yield progress
_save_state(state)
def progress_add(toppath, *paths):
"""Record that the files under all of the `paths` have been imported
under `toppath`.
# Read and write
with ImportState() as state:
state["key"] = "value"
```
"""
with progress_write() as state:
imported = state.setdefault(toppath, [])
for path in paths:
# Normally `progress_add` will be called with the path
# argument increasing. This is because of the ordering in
# `albums_in_dir`. We take advantage of that to make the
# code faster
if imported and imported[len(imported) - 1] <= path:
imported.append(path)
else:
insort(imported, path)
tagprogress: dict
taghistory: set
path: PathLike
def __init__(self, readonly=False, path: PathLike | None = None):
self.path = path or config["statefile"].as_filename()
self._open()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._save()
def _open(
self,
):
try:
with open(self.path, "rb") as f:
state = pickle.load(f)
# Read the states
self.tagprogress = state.get("tagprogress", {})
self.taghistory = state.get("taghistory", set())
except Exception as exc:
# The `pickle` module can emit all sorts of exceptions during
# unpickling, including ImportError. We use a catch-all
# exception to avoid enumerating them all (the docs don't even have a
# full list!).
log.debug("state file could not be read: {0}", exc)
def _save(self):
try:
with open(self.path, "wb") as f:
pickle.dump(
{
"tagprogress": self.tagprogress,
"taghistory": self.taghistory,
},
f,
)
except OSError as exc:
log.error("state file could not be written: {0}", exc)
# -------------------------------- Tagprogress ------------------------------- #
def progress_add(self, toppath: PathLike, *paths: list[PathLike]):
"""Record that the files under all of the `paths` have been imported
under `toppath`.
"""
with self as state:
imported = state.tagprogress.setdefault(toppath, [])
for path in paths:
if imported and imported[-1] <= path:
imported.append(path)
else:
insort(imported, path)
def progress_has_element(self, toppath: PathLike, path: PathLike) -> bool:
"""Return whether `path` has been imported in `toppath`."""
imported = self.tagprogress.get(toppath, [])
i = bisect_left(imported, path)
return i != len(imported) and imported[i] == path
def progress_has(self, toppath: PathLike) -> bool:
"""Return `True` if there exist paths that have already been
imported under `toppath`.
"""
return toppath in self.tagprogress
def progress_reset(self, toppath: PathLike):
"""Reset the progress for `toppath`."""
with self as state:
if toppath in state.tagprogress:
del state.tagprogress[toppath]
# -------------------------------- Taghistory -------------------------------- #
def history_add(self, paths: list[PathLike]):
"""Add the paths to the history."""
with self as state:
state.taghistory.add(tuple(paths))
def progress_element(toppath, path):
"""Return whether `path` has been imported in `toppath`."""
state = progress_read()
if toppath not in state:
return False
imported = state[toppath]
i = bisect_left(imported, path)
return i != len(imported) and imported[i] == path
def has_progress(toppath):
"""Return `True` if there exist paths that have already been
imported under `toppath`.
"""
state = progress_read()
return toppath in state
def progress_reset(toppath):
with progress_write() as state:
if toppath in state:
del state[toppath]
# Similarly, utilities for manipulating the "incremental" import log.
# This keeps track of all directories that were ever imported, which
# allows the importer to only import new stuff.
def history_add(paths):
"""Indicate that the import of the album in `paths` is completed and
should not be repeated in incremental imports.
"""
state = _open_state()
if HISTORY_KEY not in state:
state[HISTORY_KEY] = set()
state[HISTORY_KEY].add(tuple(paths))
_save_state(state)
def history_get():
"""Get the set of completed path tuples in incremental imports."""
state = _open_state()
if HISTORY_KEY not in state:
return set()
return state[HISTORY_KEY]
# Abstract session class.
class ImportSession:
class ImportSession(ABC):
"""Controls an import action. Subclasses should implement methods to
communicate with the user or otherwise make decisions.
"""
def __init__(self, lib, loghandler, paths, query):
"""Create a session. `lib` is a Library object. `loghandler` is a
logging.Handler. Either `paths` or `query` is non-null and indicates
the source of files to be imported.
logger: logging.Logger
paths: list[bytes] | None
lib: library.Library
_is_resuming: dict[bytes, bool]
_merged_items: set
_merged_dirs: set
def __init__(
self,
lib: library.Library,
loghandler: logging.Handler | None,
paths: Iterable[PathLike] | None,
query: dbcore.Query | None,
):
"""Create a session.
Parameters
----------
lib : library.Library
The library instance to which items will be imported.
loghandler : logging.Handler or None
A logging handler to use for the session's logger. If None, a
NullHandler will be used.
paths : os.PathLike or None
The paths to be imported. If None, no paths are specified.
query : dbcore.Query or None
A query to filter items for import. If None, no query is applied.
"""
self.lib = lib
self.logger = self._setup_logging(loghandler)
self.paths = paths
self.query = query
self._is_resuming = {}
self._merged_items = set()
self._merged_dirs = set()
# Normalize the paths.
if self.paths:
self.paths = list(map(normpath, self.paths))
if paths is not None:
self.paths = list(map(normpath, paths))
else:
self.paths = None
def _setup_logging(self, loghandler):
def _setup_logging(self, loghandler: logging.Handler | None):
logger = logging.getLogger(__name__)
logger.propagate = False
if not loghandler:
@ -243,9 +251,7 @@ class ImportSession:
iconfig["incremental"] = False
if iconfig["reflink"]:
iconfig["reflink"] = iconfig["reflink"].as_choice(
["auto", True, False]
)
iconfig["reflink"] = iconfig["reflink"].as_choice(["auto", True, False])
# Copy, move, reflink, link, and hardlink are mutually exclusive.
if iconfig["move"]:
@ -302,17 +308,21 @@ class ImportSession:
elif task.choice_flag is action.SKIP:
self.tag_log("skip", paths)
@abstractmethod
def should_resume(self, path):
raise NotImplementedError
raise NotImplementedError("Inheriting class must implement `should_resume`")
@abstractmethod
def choose_match(self, task):
raise NotImplementedError
raise NotImplementedError("Inheriting class must implement `choose_match`")
@abstractmethod
def resolve_duplicate(self, task, found_duplicates):
raise NotImplementedError
raise NotImplementedError("Inheriting class must implement `resolve_duplicate`")
@abstractmethod
def choose_item(self, task):
raise NotImplementedError
raise NotImplementedError("Inheriting class must implement `choose_item`")
def run(self):
"""Run the import task."""
@ -366,12 +376,13 @@ class ImportSession:
# Incremental and resumed imports
def already_imported(self, toppath, paths):
def already_imported(self, toppath: PathLike, paths: Sequence[PathLike]):
"""Returns true if the files belonging to this task have already
been imported in a previous session.
"""
state = ImportState()
if self.is_resuming(toppath) and all(
[progress_element(toppath, p) for p in paths]
[state.progress_has_element(toppath, p) for p in paths]
):
return True
if self.config["incremental"] and tuple(paths) in self.history_dirs:
@ -379,13 +390,15 @@ class ImportSession:
return False
_history_dirs = None
@property
def history_dirs(self):
if not hasattr(self, "_history_dirs"):
self._history_dirs = history_get()
if self._history_dirs is None:
self._history_dirs = ImportState().taghistory
return self._history_dirs
def already_merged(self, paths):
def already_merged(self, paths: Sequence[PathLike]):
"""Returns true if all the paths being imported were part of a merge
during previous tasks.
"""
@ -394,7 +407,7 @@ class ImportSession:
return False
return True
def mark_merged(self, paths):
def mark_merged(self, paths: Sequence[PathLike]):
"""Mark paths and directories as merged for future reimport tasks."""
self._merged_items.update(paths)
dirs = {
@ -403,30 +416,31 @@ class ImportSession:
}
self._merged_dirs.update(dirs)
def is_resuming(self, toppath):
def is_resuming(self, toppath: PathLike):
"""Return `True` if user wants to resume import of this path.
You have to call `ask_resume` first to determine the return value.
"""
return self._is_resuming.get(toppath, False)
return self._is_resuming.get(normpath(toppath), False)
def ask_resume(self, toppath):
def ask_resume(self, toppath: PathLike):
"""If import of `toppath` was aborted in an earlier session, ask
user if they want to resume the import.
Determines the return value of `is_resuming(toppath)`.
"""
if self.want_resume and has_progress(toppath):
state = ImportState()
if self.want_resume and state.progress_has(toppath):
# Either accept immediately or prompt for input to decide.
if self.want_resume is True or self.should_resume(toppath):
log.warning(
"Resuming interrupted import of {0}",
util.displayable_path(toppath),
util.displayable_path(normpath(toppath)),
)
self._is_resuming[toppath] = True
self._is_resuming[normpath(toppath)] = True
else:
# Clear progress; we're starting from the top.
progress_reset(toppath)
state.progress_reset(toppath)
# The importer task class.
@ -528,12 +542,12 @@ class ImportTask(BaseImportTask):
finished.
"""
if self.toppath:
progress_add(self.toppath, *self.paths)
ImportState().progress_add(self.toppath, *self.paths)
def save_history(self):
"""Save the directory in the history for incremental imports."""
if self.paths:
history_add(self.paths)
ImportState().history_add(self.paths)
# Logical decisions.
@ -593,9 +607,7 @@ class ImportTask(BaseImportTask):
for item in duplicate_items:
item.remove()
if lib.directory in util.ancestry(item.path):
log.debug(
"deleting duplicate {0}", util.displayable_path(item.path)
)
log.debug("deleting duplicate {0}", util.displayable_path(item.path))
util.remove(item.path)
util.prune_dirs(os.path.dirname(item.path), lib.directory)
@ -627,7 +639,8 @@ class ImportTask(BaseImportTask):
self.save_progress()
if session.config["incremental"] and not (
# Should we skip recording to incremental list?
self.skip and session.config["incremental_skip_later"]
self.skip
and session.config["incremental_skip_later"]
):
self.save_history()
@ -684,9 +697,7 @@ class ImportTask(BaseImportTask):
candidate IDs are stored in self.search_ids: if present, the
initial lookup is restricted to only those IDs.
"""
artist, album, prop = autotag.tag_album(
self.items, search_ids=self.search_ids
)
artist, album, prop = autotag.tag_album(self.items, search_ids=self.search_ids)
self.cur_artist = artist
self.cur_album = album
self.candidates = prop.candidates
@ -737,8 +748,7 @@ class ImportTask(BaseImportTask):
[i.albumartist or i.artist for i in self.items]
)
if freq == len(self.items) or (
freq > 1
and float(freq) / len(self.items) >= SINGLE_ARTIST_THRESH
freq > 1 and float(freq) / len(self.items) >= SINGLE_ARTIST_THRESH
):
# Single-artist album.
changes["albumartist"] = plur_albumartist
@ -832,15 +842,10 @@ class ImportTask(BaseImportTask):
self.replaced_albums = defaultdict(list)
replaced_album_ids = set()
for item in self.imported_items():
dup_items = list(
lib.items(dbcore.query.BytesQuery("path", item.path))
)
dup_items = list(lib.items(dbcore.query.BytesQuery("path", item.path)))
self.replaced_items[item] = dup_items
for dup_item in dup_items:
if (
not dup_item.album_id
or dup_item.album_id in replaced_album_ids
):
if not dup_item.album_id or dup_item.album_id in replaced_album_ids:
continue
replaced_album = dup_item._cached_album
if replaced_album:
@ -893,8 +898,7 @@ class ImportTask(BaseImportTask):
self.album.artpath = replaced_album.artpath
self.album.store()
log.debug(
"Reimported album {}. Preserving attribute ['added']. "
"Path: {}",
"Reimported album {}. Preserving attribute ['added']. " "Path: {}",
self.album.id,
displayable_path(self.album.path),
)
@ -1094,10 +1098,10 @@ class SentinelImportTask(ImportTask):
def save_progress(self):
if self.paths is None:
# "Done" sentinel.
progress_reset(self.toppath)
ImportState().progress_reset(self.toppath)
else:
# "Directory progress" sentinel for singletons
progress_add(self.toppath, *self.paths)
ImportState().progress_add(self.toppath, *self.paths)
def skip(self):
return True
@ -1308,9 +1312,7 @@ class ImportTaskFactory:
def singleton(self, path):
"""Return a `SingletonImportTask` for the music file."""
if self.session.already_imported(self.toppath, [path]):
log.debug(
"Skipping previously-imported path: {0}", displayable_path(path)
)
log.debug("Skipping previously-imported path: {0}", displayable_path(path))
self.skipped += 1
return None
@ -1333,9 +1335,7 @@ class ImportTaskFactory:
dirs = list({os.path.dirname(p) for p in paths})
if self.session.already_imported(self.toppath, dirs):
log.debug(
"Skipping previously-imported path: {0}", displayable_path(dirs)
)
log.debug("Skipping previously-imported path: {0}", displayable_path(dirs))
self.skipped += 1
return None
@ -1364,8 +1364,7 @@ class ImportTaskFactory:
if not (self.session.config["move"] or self.session.config["copy"]):
log.warning(
"Archive importing requires either "
"'copy' or 'move' to be enabled."
"Archive importing requires either " "'copy' or 'move' to be enabled."
)
return
@ -1578,9 +1577,7 @@ def resolve_duplicates(session, task):
if task.choice_flag in (action.ASIS, action.APPLY, action.RETAG):
found_duplicates = task.find_duplicates(session.lib)
if found_duplicates:
log.debug(
"found duplicates: {}".format([o.id for o in found_duplicates])
)
log.debug("found duplicates: {}".format([o.id for o in found_duplicates]))
# Get the default action to follow from config.
duplicate_action = config["import"]["duplicate_action"].as_choice(