Fix most types in beets.util.__init__ (#5223)

This PR is Part 1 of the work #5215 that fixes typing issues in
`beets.util.__init__` module.

It addresses simple-to-fix / most of the issues in this module.
This commit is contained in:
Šarūnas Nejus 2024-09-18 13:40:11 +01:00 committed by GitHub
commit d3c62968d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -31,40 +31,40 @@ from collections import Counter
from contextlib import suppress
from enum import Enum
from importlib import import_module
from logging import Logger
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Callable,
Generator,
Iterable,
List,
MutableSequence,
Iterator,
NamedTuple,
Optional,
Pattern,
Sequence,
Tuple,
TypeVar,
Union,
)
from unidecode import unidecode
from beets.util import hidden
if TYPE_CHECKING:
from logging import Logger
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from unidecode import unidecode
from beets.util import hidden
MAX_FILENAME_LENGTH = 200
WINDOWS_MAGIC_PREFIX = "\\\\?\\"
T = TypeVar("T")
Bytes_or_String: TypeAlias = Union[str, bytes]
PathLike = Union[str, bytes, Path]
BytesOrStr = Union[str, bytes]
PathLike = Union[BytesOrStr, Path]
Replacements: TypeAlias = "Sequence[tuple[Pattern[str], str]]"
class HumanReadableException(Exception):
@ -168,12 +168,12 @@ def normpath(path: bytes) -> bytes:
"""Provide the canonical form of the path suitable for storing in
the database.
"""
path = syspath(path, prefix=False)
path = os.path.normpath(os.path.abspath(os.path.expanduser(path)))
return bytestring_path(path)
str_path = syspath(path, prefix=False)
str_path = os.path.normpath(os.path.abspath(os.path.expanduser(str_path)))
return bytestring_path(str_path)
def ancestry(path: bytes) -> List[str]:
def ancestry(path: AnyStr) -> list[AnyStr]:
"""Return a list consisting of path's parent directory, its
grandparent, and so on. For instance:
@ -182,7 +182,7 @@ def ancestry(path: bytes) -> List[str]:
The argument should *not* be the result of a call to `syspath`.
"""
out = []
out: list[AnyStr] = []
last_path = None
while path:
path = os.path.dirname(path)
@ -199,34 +199,34 @@ def ancestry(path: bytes) -> List[str]:
def sorted_walk(
path: AnyStr,
ignore: Sequence = (),
ignore: Sequence[bytes] = (),
ignore_hidden: bool = False,
logger: Optional[Logger] = None,
) -> Generator[Tuple, None, None]:
logger: Logger | None = None,
) -> Iterator[tuple[bytes, Sequence[bytes], Sequence[bytes]]]:
"""Like `os.walk`, but yields things in case-insensitive sorted,
breadth-first order. Directory and file names matching any glob
pattern in `ignore` are skipped. If `logger` is provided, then
warning messages are logged there when a directory cannot be listed.
"""
# Make sure the paths aren't Unicode strings.
path = bytestring_path(path)
bytes_path = bytestring_path(path)
ignore = [bytestring_path(i) for i in ignore]
# Get all the directories and files at this level.
try:
contents = os.listdir(syspath(path))
contents = os.listdir(syspath(bytes_path))
except OSError as exc:
if logger:
logger.warning(
"could not list directory {}: {}".format(
displayable_path(path), exc.strerror
displayable_path(bytes_path), exc.strerror
)
)
return
dirs = []
files = []
for base in contents:
base = bytestring_path(base)
for str_base in contents:
base = bytestring_path(str_base)
# Skip ignored filenames.
skip = False
@ -234,7 +234,7 @@ def sorted_walk(
if fnmatch.fnmatch(base, pat):
if logger:
logger.debug(
"ignoring {} due to ignore rule {}".format(base, pat)
"ignoring '{}' due to ignore rule '{}'", base, pat
)
skip = True
break
@ -242,7 +242,7 @@ def sorted_walk(
continue
# Add to output as either a file or a directory.
cur = os.path.join(path, base)
cur = os.path.join(bytes_path, base)
if (ignore_hidden and not hidden.is_hidden(cur)) or not ignore_hidden:
if os.path.isdir(syspath(cur)):
dirs.append(base)
@ -252,12 +252,11 @@ def sorted_walk(
# Sort lists (case-insensitive) and yield the current level.
dirs.sort(key=bytes.lower)
files.sort(key=bytes.lower)
yield (path, dirs, files)
yield (bytes_path, dirs, files)
# Recurse into directories.
for base in dirs:
cur = os.path.join(path, base)
# yield from sorted_walk(...)
cur = os.path.join(bytes_path, base)
yield from sorted_walk(cur, ignore, ignore_hidden, logger)
@ -298,8 +297,8 @@ def fnmatch_all(names: Sequence[bytes], patterns: Sequence[bytes]) -> bool:
def prune_dirs(
path: str,
root: Optional[Bytes_or_String] = None,
path: bytes,
root: bytes | None = None,
clutter: Sequence[str] = (".DS_Store", "Thumbs.db"),
):
"""If path is an empty directory, then remove it. Recursively remove
@ -310,41 +309,41 @@ def prune_dirs(
(i.e., no recursive removal).
"""
path = normpath(path)
if root is not None:
root = normpath(root)
root = normpath(root) if root else None
ancestors = ancestry(path)
if root is None:
# Only remove the top directory.
ancestors = []
elif root in ancestors:
# Only remove directories below the root.
# Only remove directories below the root_bytes.
ancestors = ancestors[ancestors.index(root) + 1 :]
else:
# Remove nothing.
return
bytes_clutter = [bytestring_path(c) for c in clutter]
# Traverse upward from path.
ancestors.append(path)
ancestors.reverse()
for directory in ancestors:
directory = syspath(directory)
str_directory = syspath(directory)
if not os.path.exists(directory):
# Directory gone already.
continue
clutter: List[bytes] = [bytestring_path(c) for c in clutter]
match_paths = [bytestring_path(d) for d in os.listdir(directory)]
match_paths = [bytestring_path(d) for d in os.listdir(str_directory)]
try:
if fnmatch_all(match_paths, clutter):
if fnmatch_all(match_paths, bytes_clutter):
# Directory contains only clutter (or nothing).
shutil.rmtree(directory)
shutil.rmtree(str_directory)
else:
break
except OSError:
break
def components(path: AnyStr) -> MutableSequence[AnyStr]:
def components(path: AnyStr) -> list[AnyStr]:
"""Return a list of the path components in path. For instance:
>>> components(b'/a/b/c')
@ -420,8 +419,7 @@ PATH_SEP: bytes = bytestring_path(os.sep)
def displayable_path(
path: Union[bytes, str, Tuple[Union[bytes, str], ...]],
separator: str = "; ",
path: BytesOrStr | tuple[BytesOrStr, ...], separator: str = "; "
) -> str:
"""Attempts to decode a bytestring path to a unicode object for the
purpose of displaying it to the user. If the `path` argument is a
@ -468,20 +466,25 @@ def samefile(p1: bytes, p2: bytes) -> bool:
"""Safer equality for paths."""
if p1 == p2:
return True
return shutil._samefile(syspath(p1), syspath(p2))
with suppress(OSError):
return os.path.samefile(syspath(p1), syspath(p2))
return False
def remove(path: Optional[bytes], soft: bool = True):
def remove(path: bytes, soft: bool = True):
"""Remove the file. If `soft`, then no error will be raised if the
file does not exist.
"""
path = syspath(path)
if not path or (soft and not os.path.exists(path)):
str_path = syspath(path)
if not str_path or (soft and not os.path.exists(str_path)):
return
try:
os.remove(path)
os.remove(str_path)
except OSError as exc:
raise FilesystemError(exc, "delete", (path,), traceback.format_exc())
raise FilesystemError(
exc, "delete", (str_path,), traceback.format_exc()
)
def copy(path: bytes, dest: bytes, replace: bool = False):
@ -492,14 +495,16 @@ def copy(path: bytes, dest: bytes, replace: bool = False):
"""
if samefile(path, dest):
return
path = syspath(path)
dest = syspath(dest)
if not replace and os.path.exists(dest):
raise FilesystemError("file exists", "copy", (path, dest))
str_path = syspath(path)
str_dest = syspath(dest)
if not replace and os.path.exists(str_dest):
raise FilesystemError("file exists", "copy", (str_path, str_dest))
try:
shutil.copyfile(path, dest)
shutil.copyfile(str_path, str_dest)
except OSError as exc:
raise FilesystemError(exc, "copy", (path, dest), traceback.format_exc())
raise FilesystemError(
exc, "copy", (str_path, str_dest), traceback.format_exc()
)
def move(path: bytes, dest: bytes, replace: bool = False):
@ -534,22 +539,28 @@ def move(path: bytes, dest: bytes, replace: bool = False):
)
try:
with open(syspath(path), "rb") as f:
shutil.copyfileobj(f, tmp)
# mypy bug:
# - https://github.com/python/mypy/issues/15031
# - https://github.com/python/mypy/issues/14943
# Fix not yet released:
# - https://github.com/python/mypy/pull/14975
shutil.copyfileobj(f, tmp) # type: ignore[misc]
finally:
tmp.close()
# Move the copied file into place.
tmp_filename = tmp.name
try:
os.replace(tmp.name, syspath(dest))
tmp = None
os.replace(tmp_filename, syspath(dest))
tmp_filename = ""
os.remove(syspath(path))
except OSError as exc:
raise FilesystemError(
exc, "move", (path, dest), traceback.format_exc()
)
finally:
if tmp is not None:
os.remove(tmp)
if tmp_filename:
os.remove(tmp_filename)
def link(path: bytes, dest: bytes, replace: bool = False):
@ -673,7 +684,7 @@ def unique_path(path: bytes) -> bytes:
# Unix. They are forbidden here because they cause problems on Samba
# shares, which are sufficiently common as to cause frequent problems.
# https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
CHAR_REPLACE: List[Tuple[Pattern, str]] = [
CHAR_REPLACE = [
(re.compile(r"[\\/]"), "_"), # / and \ -- forbidden everywhere.
(re.compile(r"^\."), "_"), # Leading dot (hidden files on Unix).
(re.compile(r"[\x00-\x1f]"), ""), # Control characters.
@ -683,10 +694,7 @@ CHAR_REPLACE: List[Tuple[Pattern, str]] = [
]
def sanitize_path(
path: str,
replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]] = None,
) -> str:
def sanitize_path(path: str, replacements: Replacements | None = None) -> str:
"""Takes a path (as a Unicode string) and makes sure that it is
legal. Returns a new path. Only works with fragments; won't work
reliably on Windows when a path begins with a drive letter. Path
@ -726,11 +734,11 @@ def truncate_path(path: AnyStr, length: int = MAX_FILENAME_LENGTH) -> AnyStr:
def _legalize_stage(
path: str,
replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]],
replacements: Replacements | None,
length: int,
extension: str,
fragment: bool,
) -> Tuple[Bytes_or_String, bool]:
) -> tuple[BytesOrStr, bool]:
"""Perform a single round of path legalization steps
(sanitation/replacement, encoding from Unicode to bytes,
extension-appending, and truncation). Return the path (Unicode if
@ -756,11 +764,11 @@ def _legalize_stage(
def legalize_path(
path: str,
replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]],
replacements: Replacements | None,
length: int,
extension: bytes,
fragment: bool,
) -> Tuple[Union[Bytes_or_String, bool]]:
) -> tuple[BytesOrStr, bool]:
"""Given a path-like Unicode string, produce a legal path. Return
the path and a flag indicating whether some replacements had to be
ignored (see below).
@ -827,7 +835,7 @@ def as_string(value: Any) -> str:
return str(value)
def plurality(objs: Sequence[T]) -> T:
def plurality(objs: Sequence[T]) -> tuple[T, int]:
"""Given a sequence of hashble objects, returns the object that
is most common in the set and the its number of appearance. The
sequence must contain at least one object.
@ -838,7 +846,7 @@ def plurality(objs: Sequence[T]) -> T:
return c.most_common(1)[0]
def convert_command_args(args: List[bytes]) -> List[str]:
def convert_command_args(args: list[BytesOrStr]) -> list[str]:
"""Convert command arguments, which may either be `bytes` or `str`
objects, to uniformly surrogate-escaped strings."""
assert isinstance(args, list)
@ -857,10 +865,7 @@ class CommandOutput(NamedTuple):
stderr: bytes
def command_output(
cmd: List[Bytes_or_String],
shell: bool = False,
) -> CommandOutput:
def command_output(cmd: list[BytesOrStr], shell: bool = False) -> CommandOutput:
"""Runs the command and returns its output after it has exited.
Returns a CommandOutput. The attributes ``stdout`` and ``stderr`` contain
@ -878,7 +883,7 @@ def command_output(
This replaces `subprocess.check_output` which can have problems if lots of
output is sent to stderr.
"""
cmd = convert_command_args(cmd)
converted_cmd = convert_command_args(cmd)
devnull = subprocess.DEVNULL
@ -894,13 +899,13 @@ def command_output(
if proc.returncode:
raise subprocess.CalledProcessError(
returncode=proc.returncode,
cmd=" ".join(map(str, cmd)),
cmd=" ".join(converted_cmd),
output=stdout + stderr,
)
return CommandOutput(stdout, stderr)
def max_filename_length(path: AnyStr, limit=MAX_FILENAME_LENGTH) -> int:
def max_filename_length(path: BytesOrStr, limit=MAX_FILENAME_LENGTH) -> int:
"""Attempt to determine the maximum filename length for the
filesystem containing `path`. If the value is greater than `limit`,
then `limit` is used instead (to prevent errors when a filesystem
@ -1040,7 +1045,7 @@ def asciify_path(path: str, sep_replace: str) -> str:
# if this platform has an os.altsep, change it to os.sep.
if os.altsep:
path = path.replace(os.altsep, os.sep)
path_components: List[Bytes_or_String] = path.split(os.sep)
path_components: list[str] = path.split(os.sep)
for index, item in enumerate(path_components):
path_components[index] = unidecode(item).replace(os.sep, sep_replace)
if os.altsep:
@ -1050,7 +1055,7 @@ def asciify_path(path: str, sep_replace: str) -> str:
return os.sep.join(path_components)
def par_map(transform: Callable, items: Iterable):
def par_map(transform: Callable[[T], Any], items: Sequence[T]) -> None:
"""Apply the function `transform` to all the elements in the
iterable `items`, like `map(transform, items)` but with no return
value.