diff --git a/beets/importer.py b/beets/importer.py index b30e6399b..2bdb16669 100644 --- a/beets/importer.py +++ b/beets/importer.py @@ -12,11 +12,12 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. - """Provides the basic, interface-agnostic workflow for importing and autotagging music files. """ +from __future__ import annotations + import itertools import os import pickle @@ -25,9 +26,10 @@ import shutil import time from bisect import bisect_left, insort from collections import defaultdict -from contextlib import contextmanager +from dataclasses import dataclass from enum import Enum from tempfile import mkdtemp +from typing import Callable, Iterable, Sequence import mediafile @@ -49,8 +51,7 @@ action = Enum("action", ["SKIP", "ASIS", "TRACKS", "APPLY", "ALBUMS", "RETAG"]) QUEUE_SIZE = 128 SINGLE_ARTIST_THRESH = 0.25 -PROGRESS_KEY = "tagprogress" -HISTORY_KEY = "taghistory" + # Usually flexible attributes are preserved (i.e., not updated) during # reimports. The following two lists (globally) change this behaviour for # certain fields. To alter these lists only when a specific plugin is in use, @@ -73,6 +74,10 @@ REIMPORT_FRESH_FIELDS_ITEM = list(REIMPORT_FRESH_FIELDS_ALBUM) # Global logger. log = logging.getLogger("beets") +# Here for now to allow for a easy replace later on +# once we can move to a PathLike +PathBytes = bytes + class ImportAbortError(Exception): """Raised when the user aborts the tagging operation.""" @@ -80,117 +85,115 @@ class ImportAbortError(Exception): pass -# Utilities. +@dataclass +class ImportState: + """Representing the progress of an import task. + Opens the state file on creation of the class. If you want + to ensure the state is written to disk, you should use the + context manager protocol. -def _open_state(): - """Reads the state file, returning a dictionary.""" - try: - with open(config["statefile"].as_filename(), "rb") as f: - return pickle.load(f) - except Exception as exc: - # The `pickle` module can emit all sorts of exceptions during - # unpickling, including ImportError. We use a catch-all - # exception to avoid enumerating them all (the docs don't even have a - # full list!). - log.debug("state file could not be read: {0}", exc) - return {} + Tagprogress allows long tagging tasks to be resumed when they pause. + Taghistory is a utility for manipulating the "incremental" import log. + This keeps track of all directories that were ever imported, which + allows the importer to only import new stuff. -def _save_state(state): - """Writes the state dictionary out to disk.""" - try: - with open(config["statefile"].as_filename(), "wb") as f: - pickle.dump(state, f) - except OSError as exc: - log.error("state file could not be written: {0}", exc) + Usage + ----- + ``` + # Readonly + progress = ImportState().tagprogress - -# Utilities for reading and writing the beets progress file, which -# allows long tagging tasks to be resumed when they pause (or crash). - - -def progress_read(): - state = _open_state() - return state.setdefault(PROGRESS_KEY, {}) - - -@contextmanager -def progress_write(): - state = _open_state() - progress = state.setdefault(PROGRESS_KEY, {}) - yield progress - _save_state(state) - - -def progress_add(toppath, *paths): - """Record that the files under all of the `paths` have been imported - under `toppath`. + # Read and write + with ImportState() as state: + state["key"] = "value" + ``` """ - with progress_write() as state: - imported = state.setdefault(toppath, []) - for path in paths: - # Normally `progress_add` will be called with the path - # argument increasing. This is because of the ordering in - # `albums_in_dir`. We take advantage of that to make the - # code faster - if imported and imported[len(imported) - 1] <= path: - imported.append(path) - else: - insort(imported, path) + tagprogress: dict[PathBytes, list[PathBytes]] + taghistory: set[tuple[PathBytes, ...]] + path: PathBytes -def progress_element(toppath, path): - """Return whether `path` has been imported in `toppath`.""" - state = progress_read() - if toppath not in state: - return False - imported = state[toppath] - i = bisect_left(imported, path) - return i != len(imported) and imported[i] == path + def __init__(self, readonly=False, path: PathBytes | None = None): + self.path = path or os.fsencode(config["statefile"].as_filename()) + self.tagprogress = {} + self.taghistory = set() + self._open() + def __enter__(self): + return self -def has_progress(toppath): - """Return `True` if there exist paths that have already been - imported under `toppath`. - """ - state = progress_read() - return toppath in state + def __exit__(self, exc_type, exc_val, exc_tb): + self._save() + def _open( + self, + ): + try: + with open(self.path, "rb") as f: + state = pickle.load(f) + # Read the states + self.tagprogress = state.get("tagprogress", {}) + self.taghistory = state.get("taghistory", set()) + except Exception as exc: + # The `pickle` module can emit all sorts of exceptions during + # unpickling, including ImportError. We use a catch-all + # exception to avoid enumerating them all (the docs don't even have a + # full list!). + log.debug("state file could not be read: {0}", exc) -def progress_reset(toppath): - with progress_write() as state: - if toppath in state: - del state[toppath] + def _save(self): + try: + with open(self.path, "wb") as f: + pickle.dump( + { + "tagprogress": self.tagprogress, + "taghistory": self.taghistory, + }, + f, + ) + except OSError as exc: + log.error("state file could not be written: {0}", exc) + # -------------------------------- Tagprogress ------------------------------- # -# Similarly, utilities for manipulating the "incremental" import log. -# This keeps track of all directories that were ever imported, which -# allows the importer to only import new stuff. + def progress_add(self, toppath: PathBytes, *paths: PathBytes): + """Record that the files under all of the `paths` have been imported + under `toppath`. + """ + with self as state: + imported = state.tagprogress.setdefault(toppath, []) + for path in paths: + if imported and imported[-1] <= path: + imported.append(path) + else: + insort(imported, path) + def progress_has_element(self, toppath: PathBytes, path: PathBytes) -> bool: + """Return whether `path` has been imported in `toppath`.""" + imported = self.tagprogress.get(toppath, []) + i = bisect_left(imported, path) + return i != len(imported) and imported[i] == path -def history_add(paths): - """Indicate that the import of the album in `paths` is completed and - should not be repeated in incremental imports. - """ - state = _open_state() - if HISTORY_KEY not in state: - state[HISTORY_KEY] = set() + def progress_has(self, toppath: PathBytes) -> bool: + """Return `True` if there exist paths that have already been + imported under `toppath`. + """ + return toppath in self.tagprogress - state[HISTORY_KEY].add(tuple(paths)) + def progress_reset(self, toppath: PathBytes | None): + """Reset the progress for `toppath`.""" + with self as state: + if toppath in state.tagprogress: + del state.tagprogress[toppath] - _save_state(state) + # -------------------------------- Taghistory -------------------------------- # - -def history_get(): - """Get the set of completed path tuples in incremental imports.""" - state = _open_state() - if HISTORY_KEY not in state: - return set() - return state[HISTORY_KEY] - - -# Abstract session class. + def history_add(self, paths: list[PathBytes]): + """Add the paths to the history.""" + with self as state: + state.taghistory.add(tuple(paths)) class ImportSession: @@ -198,24 +201,46 @@ class ImportSession: communicate with the user or otherwise make decisions. """ - def __init__(self, lib, loghandler, paths, query): - """Create a session. `lib` is a Library object. `loghandler` is a - logging.Handler. Either `paths` or `query` is non-null and indicates - the source of files to be imported. + logger: logging.Logger + paths: list[PathBytes] + lib: library.Library + + _is_resuming: dict[bytes, bool] + _merged_items: set[PathBytes] + _merged_dirs: set[PathBytes] + + def __init__( + self, + lib: library.Library, + loghandler: logging.Handler | None, + paths: Sequence[PathBytes] | None, + query: dbcore.Query | None, + ): + """Create a session. + + Parameters + ---------- + lib : library.Library + The library instance to which items will be imported. + loghandler : logging.Handler or None + A logging handler to use for the session's logger. If None, a + NullHandler will be used. + paths : os.PathLike or None + The paths to be imported. + query : dbcore.Query or None + A query to filter items for import. """ self.lib = lib self.logger = self._setup_logging(loghandler) - self.paths = paths self.query = query self._is_resuming = {} self._merged_items = set() self._merged_dirs = set() # Normalize the paths. - if self.paths: - self.paths = list(map(normpath, self.paths)) + self.paths = list(map(normpath, paths or [])) - def _setup_logging(self, loghandler): + def _setup_logging(self, loghandler: logging.Handler | None): logger = logging.getLogger(__name__) logger.propagate = False if not loghandler: @@ -275,13 +300,13 @@ class ImportSession: self.want_resume = config["resume"].as_choice([True, False, "ask"]) - def tag_log(self, status, paths): + def tag_log(self, status, paths: Sequence[PathBytes]): """Log a message about a given album to the importer log. The status should reflect the reason the album couldn't be tagged. """ self.logger.info("{0} {1}", status, displayable_path(paths)) - def log_choice(self, task, duplicate=False): + def log_choice(self, task: ImportTask, duplicate=False): """Logs the task's current choice if it should be logged. If ``duplicate``, then this is a secondary choice after a duplicate was detected and a decision was made. @@ -302,16 +327,16 @@ class ImportSession: elif task.choice_flag is action.SKIP: self.tag_log("skip", paths) - def should_resume(self, path): + def should_resume(self, path: PathBytes): raise NotImplementedError - def choose_match(self, task): + def choose_match(self, task: ImportTask): raise NotImplementedError - def resolve_duplicate(self, task, found_duplicates): + def resolve_duplicate(self, task: ImportTask, found_duplicates): raise NotImplementedError - def choose_item(self, task): + def choose_item(self, task: ImportTask): raise NotImplementedError def run(self): @@ -366,12 +391,12 @@ class ImportSession: # Incremental and resumed imports - def already_imported(self, toppath, paths): + def already_imported(self, toppath: PathBytes, paths: Sequence[PathBytes]): """Returns true if the files belonging to this task have already been imported in a previous session. """ if self.is_resuming(toppath) and all( - [progress_element(toppath, p) for p in paths] + [ImportState().progress_has_element(toppath, p) for p in paths] ): return True if self.config["incremental"] and tuple(paths) in self.history_dirs: @@ -379,13 +404,16 @@ class ImportSession: return False + _history_dirs = None + @property - def history_dirs(self): - if not hasattr(self, "_history_dirs"): - self._history_dirs = history_get() + def history_dirs(self) -> set[tuple[PathBytes, ...]]: + # FIXME: This could be simplified to a cached property + if self._history_dirs is None: + self._history_dirs = ImportState().taghistory return self._history_dirs - def already_merged(self, paths): + def already_merged(self, paths: Sequence[PathBytes]): """Returns true if all the paths being imported were part of a merge during previous tasks. """ @@ -394,7 +422,7 @@ class ImportSession: return False return True - def mark_merged(self, paths): + def mark_merged(self, paths: Sequence[PathBytes]): """Mark paths and directories as merged for future reimport tasks.""" self._merged_items.update(paths) dirs = { @@ -403,20 +431,20 @@ class ImportSession: } self._merged_dirs.update(dirs) - def is_resuming(self, toppath): + def is_resuming(self, toppath: PathBytes): """Return `True` if user wants to resume import of this path. You have to call `ask_resume` first to determine the return value. """ return self._is_resuming.get(toppath, False) - def ask_resume(self, toppath): + def ask_resume(self, toppath: PathBytes): """If import of `toppath` was aborted in an earlier session, ask user if they want to resume the import. Determines the return value of `is_resuming(toppath)`. """ - if self.want_resume and has_progress(toppath): + if self.want_resume and ImportState().progress_has(toppath): # Either accept immediately or prompt for input to decide. if self.want_resume is True or self.should_resume(toppath): log.warning( @@ -426,7 +454,7 @@ class ImportSession: self._is_resuming[toppath] = True else: # Clear progress; we're starting from the top. - progress_reset(toppath) + ImportState().progress_reset(toppath) # The importer task class. @@ -438,7 +466,16 @@ class BaseImportTask: Tasks flow through the importer pipeline. Each stage can update them.""" - def __init__(self, toppath, paths, items): + toppath: PathBytes | None + paths: list[PathBytes] + items: list[library.Item] + + def __init__( + self, + toppath: PathBytes | None, + paths: Iterable[PathBytes] | None, + items: Iterable[library.Item] | None, + ): """Create a task. The primary fields that define a task are: * `toppath`: The user-specified base directory that contains the @@ -456,8 +493,8 @@ class BaseImportTask: These fields should not change after initialization. """ self.toppath = toppath - self.paths = paths - self.items = items + self.paths = list(paths) if paths is not None else [] + self.items = list(items) if items is not None else [] class ImportTask(BaseImportTask): @@ -492,24 +529,39 @@ class ImportTask(BaseImportTask): system. """ - def __init__(self, toppath, paths, items): + choice_flag: action | None = None + match: autotag.AlbumMatch | autotag.TrackMatch | None = None + + # Keep track of the current task item + cur_album: str | None = None + cur_artist: str | None = None + candidates: Sequence[autotag.AlbumMatch | autotag.TrackMatch] = [] + + def __init__( + self, + toppath: PathBytes | None, + paths: Iterable[PathBytes] | None, + items: Iterable[library.Item] | None, + ): super().__init__(toppath, paths, items) - self.choice_flag = None - self.cur_album = None - self.cur_artist = None - self.candidates = [] self.rec = None self.should_remove_duplicates = False self.should_merge_duplicates = False self.is_album = True self.search_ids = [] # user-supplied candidate IDs. - def set_choice(self, choice): + def set_choice( + self, choice: action | autotag.AlbumMatch | autotag.TrackMatch + ): """Given an AlbumMatch or TrackMatch object or an action constant, indicates that an action has been selected for this task. + + Album and trackmatch are implemented as tuples, so we can't + use isinstance to check for them. """ # Not part of the task structure: assert choice != action.APPLY # Only used internally. + if choice in ( action.SKIP, action.ASIS, @@ -517,23 +569,23 @@ class ImportTask(BaseImportTask): action.ALBUMS, action.RETAG, ): - self.choice_flag = choice + # TODO: redesign to stricten the type + self.choice_flag = choice # type: ignore[assignment] self.match = None else: self.choice_flag = action.APPLY # Implicit choice. - self.match = choice + self.match = choice # type: ignore[assignment] def save_progress(self): """Updates the progress state to indicate that this album has finished. """ if self.toppath: - progress_add(self.toppath, *self.paths) + ImportState().progress_add(self.toppath, *self.paths) def save_history(self): """Save the directory in the history for incremental imports.""" - if self.paths: - history_add(self.paths) + ImportState().history_add(self.paths) # Logical decisions. @@ -556,7 +608,7 @@ class ImportTask(BaseImportTask): if self.choice_flag in (action.ASIS, action.RETAG): likelies, consensus = autotag.current_metadata(self.items) return likelies - elif self.choice_flag is action.APPLY: + elif self.choice_flag is action.APPLY and self.match: return self.match.info.copy() assert False @@ -568,7 +620,9 @@ class ImportTask(BaseImportTask): """ if self.choice_flag in (action.ASIS, action.RETAG): return list(self.items) - elif self.choice_flag == action.APPLY: + elif self.choice_flag == action.APPLY and isinstance( + self.match, autotag.AlbumMatch + ): return list(self.match.mapping.keys()) else: assert False @@ -581,13 +635,13 @@ class ImportTask(BaseImportTask): autotag.apply_metadata(self.match.info, self.match.mapping) - def duplicate_items(self, lib): + def duplicate_items(self, lib: library.Library): duplicate_items = [] for album in self.find_duplicates(lib): duplicate_items += album.items() return duplicate_items - def remove_duplicates(self, lib): + def remove_duplicates(self, lib: library.Library): duplicate_items = self.duplicate_items(lib) log.debug("removing {0} old duplicated items", len(duplicate_items)) for item in duplicate_items: @@ -599,7 +653,7 @@ class ImportTask(BaseImportTask): util.remove(item.path) util.prune_dirs(os.path.dirname(item.path), lib.directory) - def set_fields(self, lib): + def set_fields(self, lib: library.Library): """Sets the fields given at CLI or configuration to the specified values, for both the album and all its items. """ @@ -620,7 +674,7 @@ class ImportTask(BaseImportTask): item.store() self.album.store() - def finalize(self, session): + def finalize(self, session: ImportSession): """Save progress, clean up files, and emit plugin event.""" # Update progress. if session.want_resume: @@ -654,7 +708,7 @@ class ImportTask(BaseImportTask): for old_path in self.old_paths: # Only delete files that were actually copied. if old_path not in new_paths: - util.remove(syspath(old_path), False) + util.remove(old_path, False) self.prune(old_path) # When moving, prune empty directories containing the original files. @@ -662,10 +716,10 @@ class ImportTask(BaseImportTask): for old_path in self.old_paths: self.prune(old_path) - def _emit_imported(self, lib): + def _emit_imported(self, lib: library.Library): plugins.send("album_imported", lib=lib, album=self.album) - def handle_created(self, session): + def handle_created(self, session: ImportSession): """Send the `import_task_created` event for this task. Return a list of tasks that should continue through the pipeline. By default, this is a list containing only the task itself, but plugins can replace the task @@ -692,7 +746,7 @@ class ImportTask(BaseImportTask): self.candidates = prop.candidates self.rec = prop.recommendation - def find_duplicates(self, lib): + def find_duplicates(self, lib: library.Library): """Return a list of albums from `lib` with the same artist and album name as the task. """ @@ -706,7 +760,9 @@ class ImportTask(BaseImportTask): # Construct a query to find duplicates with this metadata. We # use a temporary Album object to generate any computed fields. tmp_album = library.Album(lib, **info) - keys = config["import"]["duplicate_keys"]["album"].as_str_seq() + keys: list[str] = config["import"]["duplicate_keys"][ + "album" + ].as_str_seq() dup_query = tmp_album.duplicates_query(keys) # Don't count albums with the same files as duplicates. @@ -764,19 +820,25 @@ class ImportTask(BaseImportTask): for item in self.items: item.update(changes) - def manipulate_files(self, operation=None, write=False, session=None): + def manipulate_files( + self, + session: ImportSession, + operation: MoveOperation | None = None, + write=False, + ): """Copy, move, link, hardlink or reflink (depending on `operation`) the files as well as write metadata. `operation` should be an instance of `util.MoveOperation`. If `write` is `True` metadata is written to the files. + # TODO: Introduce a MoveOperation.NONE or SKIP """ items = self.imported_items() # Save the original paths of all items for deletion and pruning # in the next step (finalization). - self.old_paths = [item.path for item in items] + self.old_paths: list[PathBytes] = [item.path for item in items] for item in items: if operation is not None: # In copy and link modes, treat re-imports specially: @@ -806,7 +868,7 @@ class ImportTask(BaseImportTask): plugins.send("import_task_files", session=session, task=self) - def add(self, lib): + def add(self, lib: library.Library): """Add the items as an album to the library and remove replaced items.""" self.align_album_level_fields() with lib.transaction(): @@ -814,7 +876,9 @@ class ImportTask(BaseImportTask): self.remove_replaced(lib) self.album = lib.add_album(self.imported_items()) - if self.choice_flag == action.APPLY: + if self.choice_flag == action.APPLY and isinstance( + self.match, autotag.AlbumMatch + ): # Copy album flexible fields to the DB # TODO: change the flow so we create the `Album` object earlier, # and we can move this into `self.apply_metadata`, just like @@ -824,12 +888,12 @@ class ImportTask(BaseImportTask): self.reimport_metadata(lib) - def record_replaced(self, lib): + def record_replaced(self, lib: library.Library): """Records the replaced items and albums in the `replaced_items` and `replaced_albums` dictionaries. """ self.replaced_items = defaultdict(list) - self.replaced_albums = defaultdict(list) + self.replaced_albums: dict[PathBytes, library.Album] = defaultdict() replaced_album_ids = set() for item in self.imported_items(): dup_items = list( @@ -847,7 +911,7 @@ class ImportTask(BaseImportTask): replaced_album_ids.add(dup_item.album_id) self.replaced_albums[replaced_album.path] = replaced_album - def reimport_metadata(self, lib): + def reimport_metadata(self, lib: library.Library): """For reimports, preserves metadata for reimported items and albums. """ @@ -980,7 +1044,7 @@ class ImportTask(BaseImportTask): class SingletonImportTask(ImportTask): """ImportTask for a single track that is not associated to an album.""" - def __init__(self, toppath, item): + def __init__(self, toppath: PathBytes | None, item: library.Item): super().__init__(toppath, [item.path], [item]) self.item = item self.is_album = False @@ -1022,7 +1086,9 @@ class SingletonImportTask(ImportTask): # Query for existing items using the same metadata. We use a # temporary `Item` object to generate any computed fields. tmp_item = library.Item(lib, **info) - keys = config["import"]["duplicate_keys"]["item"].as_str_seq() + keys: list[str] = config["import"]["duplicate_keys"][ + "item" + ].as_str_seq() dup_query = tmp_item.duplicates_query(keys) found_items = [] @@ -1044,7 +1110,7 @@ class SingletonImportTask(ImportTask): def infer_album_fields(self): raise NotImplementedError - def choose_match(self, session): + def choose_match(self, session: ImportSession): """Ask the session which match should apply and apply it.""" choice = session.choose_item(self) self.set_choice(choice) @@ -1092,23 +1158,24 @@ class SentinelImportTask(ImportTask): pass def save_progress(self): - if self.paths is None: + if not self.paths: # "Done" sentinel. - progress_reset(self.toppath) - else: + ImportState().progress_reset(self.toppath) + elif self.toppath: # "Directory progress" sentinel for singletons - progress_add(self.toppath, *self.paths) + super().save_progress() - def skip(self): + @property + def skip(self) -> bool: return True def set_choice(self, choice): raise NotImplementedError - def cleanup(self, **kwargs): + def cleanup(self, copy=False, delete=False, move=False): pass - def _emit_imported(self, session): + def _emit_imported(self, lib): pass @@ -1152,7 +1219,7 @@ class ArchiveImportTask(SentinelImportTask): implements the same interface as `tarfile.TarFile`. """ if not hasattr(cls, "_handlers"): - cls._handlers = [] + cls._handlers: list[tuple[Callable, ...]] = [] from zipfile import ZipFile, is_zipfile cls._handlers.append((is_zipfile, ZipFile)) @@ -1174,9 +1241,9 @@ class ArchiveImportTask(SentinelImportTask): return cls._handlers - def cleanup(self, **kwargs): + def cleanup(self, copy=False, delete=False, move=False): """Removes the temporary directory the archive was extracted to.""" - if self.extracted: + if self.extracted and self.toppath: log.debug( "Removing extracted directory: {0}", displayable_path(self.toppath), @@ -1187,10 +1254,13 @@ class ArchiveImportTask(SentinelImportTask): """Extracts the archive to a temporary directory and sets `toppath` to that directory. """ + assert self.toppath is not None, "toppath must be set" + for path_test, handler_class in self.handlers(): if path_test(os.fsdecode(self.toppath)): break - + else: + raise ValueError(f"No handler found for archive: {self.toppath}") extract_to = mkdtemp() archive = handler_class(os.fsdecode(self.toppath), mode="r") try: @@ -1219,7 +1289,7 @@ class ImportTaskFactory: indicated by a path. """ - def __init__(self, toppath, session): + def __init__(self, toppath: PathBytes, session: ImportSession): """Create a new task factory. `toppath` is the user-specified path to search for music to @@ -1246,6 +1316,7 @@ class ImportTaskFactory: extracted data. """ # Check whether this is an archive. + archive_task: ArchiveImportTask | None = None if self.is_archive: archive_task = self.unarchive() if not archive_task: @@ -1267,12 +1338,9 @@ class ImportTaskFactory: # it is finished. This is usually just a SentinelImportTask, but # for archive imports, send the archive task instead (to remove # the extracted directory). - if self.is_archive: - yield archive_task - else: - yield self.sentinel() + yield archive_task or self.sentinel() - def _create(self, task): + def _create(self, task: ImportTask | None): """Handle a new task to be emitted by the factory. Emit the `import_task_created` event and increment the @@ -1305,7 +1373,7 @@ class ImportTaskFactory: for dirs, paths in albums_in_dir(self.toppath): yield dirs, paths - def singleton(self, path): + def singleton(self, path: PathBytes): """Return a `SingletonImportTask` for the music file.""" if self.session.already_imported(self.toppath, [path]): log.debug( @@ -1320,14 +1388,12 @@ class ImportTaskFactory: else: return None - def album(self, paths, dirs=None): + def album(self, paths: Iterable[PathBytes], dirs=None): """Return a `ImportTask` with all media files from paths. `dirs` is a list of parent directories used to record already imported albums. """ - if not paths: - return None if dirs is None: dirs = list({os.path.dirname(p) for p in paths}) @@ -1339,15 +1405,16 @@ class ImportTaskFactory: self.skipped += 1 return None - items = map(self.read_item, paths) - items = [item for item in items if item] + items: list[library.Item] = [ + item for item in map(self.read_item, paths) if item + ] - if items: + if len(items) > 0: return ImportTask(self.toppath, dirs, items) else: return None - def sentinel(self, paths=None): + def sentinel(self, paths: Iterable[PathBytes] | None = None): """Return a `SentinelImportTask` indicating the end of a top-level directory import. """ @@ -1382,7 +1449,7 @@ class ImportTaskFactory: log.debug("Archive extracted to: {0}", self.toppath) return archive_task - def read_item(self, path): + def read_item(self, path: PathBytes): """Return an `Item` read from the path. If an item cannot be read, return `None` instead and log an @@ -1425,12 +1492,13 @@ def _extend_pipeline(tasks, *stages): # Full-album pipeline stages. -def read_tasks(session): +def read_tasks(session: ImportSession): """A generator yielding all the albums (as ImportTask objects) found in the user-specified list of paths. In the case of a singleton import, yields single-item tasks instead. """ skipped = 0 + for toppath in session.paths: # Check whether we need to resume the import. session.ask_resume(toppath) @@ -1448,7 +1516,7 @@ def read_tasks(session): log.info("Skipped {0} paths.", skipped) -def query_tasks(session): +def query_tasks(session: ImportSession): """A generator that works as a drop-in-replacement for read_tasks. Instead of finding files from the filesystem, a query is used to match items from the library. @@ -1478,7 +1546,7 @@ def query_tasks(session): @pipeline.mutator_stage -def lookup_candidates(session, task): +def lookup_candidates(session: ImportSession, task: ImportTask): """A coroutine for performing the initial MusicBrainz lookup for an album. It accepts lists of Items and yields (items, cur_artist, cur_album, candidates, rec) tuples. If no match @@ -1500,7 +1568,7 @@ def lookup_candidates(session, task): @pipeline.stage -def user_query(session, task): +def user_query(session: ImportSession, task: ImportTask): """A coroutine for interfacing with the user about the tagging process. @@ -1571,7 +1639,7 @@ def user_query(session, task): return task -def resolve_duplicates(session, task): +def resolve_duplicates(session: ImportSession, task: ImportTask): """Check if a task conflicts with items or albums already imported and ask the session to resolve this. """ @@ -1614,7 +1682,7 @@ def resolve_duplicates(session, task): @pipeline.mutator_stage -def import_asis(session, task): +def import_asis(session: ImportSession, task: ImportTask): """Select the `action.ASIS` choice for all tasks. This stage replaces the initial_lookup and user_query stages @@ -1628,7 +1696,7 @@ def import_asis(session, task): apply_choice(session, task) -def apply_choice(session, task): +def apply_choice(session: ImportSession, task: ImportTask): """Apply the task's choice to the Album or Item it contains and add it to the library. """ @@ -1652,7 +1720,11 @@ def apply_choice(session, task): @pipeline.mutator_stage -def plugin_stage(session, func, task): +def plugin_stage( + session: ImportSession, + func: Callable[[ImportSession, ImportTask], None], + task: ImportTask, +): """A coroutine (pipeline stage) that calls the given function with each non-skipped import task. These stages occur between applying metadata changes and moving/copying/writing files. @@ -1669,7 +1741,7 @@ def plugin_stage(session, func, task): @pipeline.stage -def manipulate_files(session, task): +def manipulate_files(session: ImportSession, task: ImportTask): """A coroutine (pipeline stage) that performs necessary file manipulations *after* items have been added to the library and finalizes each task. @@ -1694,9 +1766,9 @@ def manipulate_files(session, task): operation = None task.manipulate_files( - operation, - write=session.config["write"], session=session, + operation=operation, + write=session.config["write"], ) # Progress, cleanup, and event. @@ -1704,7 +1776,7 @@ def manipulate_files(session, task): @pipeline.stage -def log_files(session, task): +def log_files(session: ImportSession, task: ImportTask): """A coroutine (pipeline stage) to log each file to be imported.""" if isinstance(task, SingletonImportTask): log.info("Singleton: {0}", displayable_path(task.item["path"])) @@ -1714,7 +1786,7 @@ def log_files(session, task): log.info(" {0}", displayable_path(item["path"])) -def group_albums(session): +def group_albums(session: ImportSession): """A pipeline stage that groups the items of each task into albums using their metadata. @@ -1731,10 +1803,10 @@ def group_albums(session): if task.skip: continue tasks = [] - sorted_items = sorted(task.items, key=group) + sorted_items: list[library.Item] = sorted(task.items, key=group) for _, items in itertools.groupby(sorted_items, group): - items = list(items) - task = ImportTask(task.toppath, [i.path for i in items], items) + l_items = list(items) + task = ImportTask(task.toppath, [i.path for i in l_items], l_items) tasks += task.handle_created(session) tasks.append(SentinelImportTask(task.toppath, task.paths)) @@ -1753,15 +1825,15 @@ def is_subdir_of_any_in_list(path, dirs): return any(d in ancestors for d in dirs) -def albums_in_dir(path): +def albums_in_dir(path: PathBytes): """Recursively searches the given directory and returns an iterable of (paths, items) where paths is a list of directories and items is a list of Items that is probably an album. Specifically, any folder containing any media files is an album. """ collapse_pat = collapse_paths = collapse_items = None - ignore = config["ignore"].as_str_seq() - ignore_hidden = config["ignore_hidden"].get(bool) + ignore: list[str] = config["ignore"].as_str_seq() + ignore_hidden: bool = config["ignore_hidden"].get(bool) for root, dirs, files in sorted_walk( path, ignore=ignore, ignore_hidden=ignore_hidden, logger=log diff --git a/beets/util/__init__.py b/beets/util/__init__.py index ac4e3bc3c..b882ed626 100644 --- a/beets/util/__init__.py +++ b/beets/util/__init__.py @@ -198,8 +198,8 @@ def ancestry(path: AnyStr) -> list[AnyStr]: def sorted_walk( - path: AnyStr, - ignore: Sequence[bytes] = (), + path: PathLike, + ignore: Sequence[PathLike] = (), ignore_hidden: bool = False, logger: Logger | None = None, ) -> Iterator[tuple[bytes, Sequence[bytes], Sequence[bytes]]]: @@ -210,7 +210,9 @@ def sorted_walk( """ # Make sure the paths aren't Unicode strings. bytes_path = bytestring_path(path) - ignore = [bytestring_path(i) for i in ignore] + ignore_bytes = [ # rename prevents mypy variable shadowing issue + bytestring_path(i) for i in ignore + ] # Get all the directories and files at this level. try: @@ -230,7 +232,7 @@ def sorted_walk( # Skip ignored filenames. skip = False - for pat in ignore: + for pat in ignore_bytes: if fnmatch.fnmatch(base, pat): if logger: logger.debug( @@ -257,7 +259,7 @@ def sorted_walk( # Recurse into directories. for base in dirs: cur = os.path.join(bytes_path, base) - yield from sorted_walk(cur, ignore, ignore_hidden, logger) + yield from sorted_walk(cur, ignore_bytes, ignore_hidden, logger) def path_as_posix(path: bytes) -> bytes: @@ -297,8 +299,8 @@ def fnmatch_all(names: Sequence[bytes], patterns: Sequence[bytes]) -> bool: def prune_dirs( - path: bytes, - root: bytes | None = None, + path: PathLike, + root: PathLike | None = None, clutter: Sequence[str] = (".DS_Store", "Thumbs.db"), ): """If path is an empty directory, then remove it. Recursively remove @@ -419,12 +421,13 @@ PATH_SEP: bytes = bytestring_path(os.sep) def displayable_path( - path: BytesOrStr | tuple[BytesOrStr, ...], separator: str = "; " + path: PathLike | Iterable[PathLike], separator: str = "; " ) -> str: """Attempts to decode a bytestring path to a unicode object for the purpose of displaying it to the user. If the `path` argument is a list or a tuple, the elements are joined with `separator`. """ + if isinstance(path, (list, tuple)): return separator.join(displayable_path(p) for p in path) elif isinstance(path, str): @@ -472,7 +475,7 @@ def samefile(p1: bytes, p2: bytes) -> bool: return False -def remove(path: bytes, soft: bool = True): +def remove(path: PathLike, soft: bool = True): """Remove the file. If `soft`, then no error will be raised if the file does not exist. """ diff --git a/beets/util/pipeline.py b/beets/util/pipeline.py index d23b1bd10..8f7be8194 100644 --- a/beets/util/pipeline.py +++ b/beets/util/pipeline.py @@ -31,9 +31,14 @@ To do so, pass an iterable of coroutines to the Pipeline constructor in place of any single coroutine. """ +from __future__ import annotations + import queue import sys from threading import Lock, Thread +from typing import Callable, Generator + +from typing_extensions import TypeVar, TypeVarTuple, Unpack BUBBLE = "__PIPELINE_BUBBLE__" POISON = "__PIPELINE_POISON__" @@ -149,7 +154,22 @@ def multiple(messages): return MultiMessage(messages) -def stage(func): +A = TypeVarTuple("A") # Arguments of a function (omitting the task) +T = TypeVar("T") # Type of the task +# Normally these are concatenated i.e. (*args, task) + +# Return type of the function (should normally be task but sadly +# we cant enforce this with the current stage functions without +# a refactor) +R = TypeVar("R") + + +def stage( + func: Callable[ + [Unpack[A], T], + R | None, + ], +): """Decorate a function to become a simple stage. >>> @stage @@ -163,8 +183,8 @@ def stage(func): [3, 4, 5] """ - def coro(*args): - task = None + def coro(*args: Unpack[A]) -> Generator[R | T | None, T, None]: + task: R | T | None = None while True: task = yield task task = func(*(args + (task,))) @@ -172,7 +192,7 @@ def stage(func): return coro -def mutator_stage(func): +def mutator_stage(func: Callable[[Unpack[A], T], R]): """Decorate a function that manipulates items in a coroutine to become a simple stage. @@ -187,7 +207,7 @@ def mutator_stage(func): [{'x': True}, {'a': False, 'x': True}] """ - def coro(*args): + def coro(*args: Unpack[A]) -> Generator[T | None, T, None]: task = None while True: task = yield task diff --git a/docs/changelog.rst b/docs/changelog.rst index feb73bc58..ecf1c01d3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -86,6 +86,8 @@ Other changes: wrong (outdated) commit. Now the tag is created in the same workflow step right after committing the version update. :bug:`5539` +* Added some typehints: ImportSession and Pipeline have typehints now. Should + improve useability for new developers. * :doc:`/plugins/smartplaylist`: URL-encode additional item `fields` within generated EXTM3U playlists instead of JSON-encoding them. diff --git a/test/plugins/test_lyrics.py b/test/plugins/test_lyrics.py index c6d48c3bd..a3c640109 100644 --- a/test/plugins/test_lyrics.py +++ b/test/plugins/test_lyrics.py @@ -14,6 +14,8 @@ """Tests for the 'lyrics' plugin.""" +import importlib.util +import os import re from functools import partial from http import HTTPStatus @@ -26,6 +28,11 @@ from beetsplug import lyrics from .lyrics_pages import LyricsPage, lyrics_pages +github_ci = os.environ.get("GITHUB_ACTIONS") == "true" +if not github_ci and not importlib.util.find_spec("langdetect"): + pytest.skip("langdetect isn't available", allow_module_level=True) + + PHRASE_BY_TITLE = { "Lady Madonna": "friday night arrives without a suitcase", "Jazz'n'blues": "as i check my balance i kiss the screen",