diff --git a/beetsplug/beatport.py b/beetsplug/beatport.py index 20147b5cc..72828a96a 100644 --- a/beetsplug/beatport.py +++ b/beetsplug/beatport.py @@ -14,9 +14,19 @@ """Adds Beatport release and track search support to the autotagger""" +from __future__ import annotations + import json import re from datetime import datetime, timedelta +from typing import ( + TYPE_CHECKING, + Iterable, + Iterator, + Literal, + Sequence, + overload, +) import confuse from requests_oauthlib import OAuth1Session @@ -29,7 +39,13 @@ from requests_oauthlib.oauth1_session import ( import beets import beets.ui from beets.autotag.hooks import AlbumInfo, TrackInfo -from beets.plugins import BeetsPlugin, MetadataSourcePlugin, get_distance +from beets.metadata_plugins import MetadataSourcePlugin + +if TYPE_CHECKING: + from beets.importer import ImportSession + from beets.library import Item + + from ._typing import JSONDict AUTH_ERRORS = (TokenRequestDenied, TokenMissing, VerifierMissing) USER_AGENT = f"beets/{beets.__version__} +https://beets.io/" @@ -39,20 +55,6 @@ class BeatportAPIError(Exception): pass -class BeatportObject: - def __init__(self, data): - self.beatport_id = data["id"] - self.name = str(data["name"]) - if "releaseDate" in data: - self.release_date = datetime.strptime( - data["releaseDate"], "%Y-%m-%d" - ) - if "artists" in data: - self.artists = [(x["id"], str(x["name"])) for x in data["artists"]] - if "genres" in data: - self.genres = [str(x["name"]) for x in data["genres"]] - - class BeatportClient: _api_base = "https://oauth-api.beatport.com" @@ -77,7 +79,7 @@ class BeatportClient: ) self.api.headers = {"User-Agent": USER_AGENT} - def get_authorize_url(self): + def get_authorize_url(self) -> str: """Generate the URL for the user to authorize the application. Retrieves a request token from the Beatport API and returns the @@ -99,15 +101,13 @@ class BeatportClient: self._make_url("/identity/1/oauth/authorize") ) - def get_access_token(self, auth_data): + def get_access_token(self, auth_data: str) -> tuple[str, str]: """Obtain the final access token and secret for the API. :param auth_data: URL-encoded authorization data as displayed at the authorization url (obtained via :py:meth:`get_authorize_url`) after signing in - :type auth_data: unicode - :returns: OAuth resource owner key and secret - :rtype: (unicode, unicode) tuple + :returns: OAuth resource owner key and secret as unicode """ self.api.parse_authorization_response( "https://beets.io/auth?" + auth_data @@ -117,20 +117,37 @@ class BeatportClient: ) return access_data["oauth_token"], access_data["oauth_token_secret"] - def search(self, query, release_type="release", details=True): + @overload + def search( + self, + query: str, + release_type: Literal["release"], + details: bool = True, + ) -> Iterator[BeatportRelease]: ... + + @overload + def search( + self, + query: str, + release_type: Literal["track"], + details: bool = True, + ) -> Iterator[BeatportTrack]: ... + + def search( + self, + query: str, + release_type: Literal["release", "track"], + details=True, + ) -> Iterator[BeatportRelease | BeatportTrack]: """Perform a search of the Beatport catalogue. :param query: Query string - :param release_type: Type of releases to search for, can be - 'release' or 'track' + :param release_type: Type of releases to search for. :param details: Retrieve additional information about the search results. Currently this will fetch the tracklist for releases and do nothing for tracks :returns: Search results - :rtype: generator that yields - py:class:`BeatportRelease` or - :py:class:`BeatportTrack` """ response = self._get( "catalog/3/search", @@ -140,20 +157,18 @@ class BeatportClient: ) for item in response: if release_type == "release": + release = BeatportRelease(item) if details: - release = self.get_release(item["id"]) - else: - release = BeatportRelease(item) + release.tracks = self.get_release_tracks(item["id"]) yield release elif release_type == "track": yield BeatportTrack(item) - def get_release(self, beatport_id): + def get_release(self, beatport_id: str) -> BeatportRelease | None: """Get information about a single release. :param beatport_id: Beatport ID of the release :returns: The matching release - :rtype: :py:class:`BeatportRelease` """ response = self._get("/catalog/3/releases", id=beatport_id) if response: @@ -162,35 +177,33 @@ class BeatportClient: return release return None - def get_release_tracks(self, beatport_id): + def get_release_tracks(self, beatport_id: str) -> list[BeatportTrack]: """Get all tracks for a given release. :param beatport_id: Beatport ID of the release :returns: Tracks in the matching release - :rtype: list of :py:class:`BeatportTrack` """ response = self._get( "/catalog/3/tracks", releaseId=beatport_id, perPage=100 ) return [BeatportTrack(t) for t in response] - def get_track(self, beatport_id): + def get_track(self, beatport_id: str) -> BeatportTrack: """Get information about a single track. :param beatport_id: Beatport ID of the track :returns: The matching track - :rtype: :py:class:`BeatportTrack` """ response = self._get("/catalog/3/tracks", id=beatport_id) return BeatportTrack(response[0]) - def _make_url(self, endpoint): + def _make_url(self, endpoint: str) -> str: """Get complete URL for a given API endpoint.""" if not endpoint.startswith("/"): endpoint = "/" + endpoint return self._api_base + endpoint - def _get(self, endpoint, **kwargs): + def _get(self, endpoint: str, **kwargs) -> list[JSONDict]: """Perform a GET request on a given API endpoint. Automatically extracts result data from the response and converts HTTP @@ -211,48 +224,81 @@ class BeatportClient: return response.json()["results"] -class BeatportRelease(BeatportObject): - def __str__(self): - if len(self.artists) < 4: - artist_str = ", ".join(x[1] for x in self.artists) +class BeatportObject: + beatport_id: str + name: str + + release_date: datetime | None = None + + artists: list[tuple[str, str]] | None = None + # tuple of artist id and artist name + + def __init__(self, data: JSONDict): + self.beatport_id = str(data["id"]) # given as int in the response + self.name = str(data["name"]) + if "releaseDate" in data: + self.release_date = datetime.strptime( + data["releaseDate"], "%Y-%m-%d" + ) + if "artists" in data: + self.artists = [(x["id"], str(x["name"])) for x in data["artists"]] + if "genres" in data: + self.genres = [str(x["name"]) for x in data["genres"]] + + def artists_str(self) -> str | None: + if self.artists is not None: + if len(self.artists) < 4: + artist_str = ", ".join(x[1] for x in self.artists) + else: + artist_str = "Various Artists" else: - artist_str = "Various Artists" - return "".format( - artist_str, - self.name, - self.catalog_number, - ) + artist_str = None - def __repr__(self): - return str(self).encode("utf-8") + return artist_str + + +class BeatportRelease(BeatportObject): + catalog_number: str | None + label_name: str | None + category: str | None + url: str | None + genre: str | None + + tracks: list[BeatportTrack] | None = None + + def __init__(self, data: JSONDict): + super().__init__(data) + + self.catalog_number = data.get("catalogNumber") + self.label_name = data.get("label", {}).get("name") + self.category = data.get("category") + self.genre = data.get("genre") - def __init__(self, data): - BeatportObject.__init__(self, data) - if "catalogNumber" in data: - self.catalog_number = data["catalogNumber"] - if "label" in data: - self.label_name = data["label"]["name"] - if "category" in data: - self.category = data["category"] if "slug" in data: self.url = "https://beatport.com/release/{}/{}".format( data["slug"], data["id"] ) - self.genre = data.get("genre") + + def __str__(self) -> str: + return "".format( + self.artists_str(), + self.name, + self.catalog_number, + ) class BeatportTrack(BeatportObject): - def __str__(self): - artist_str = ", ".join(x[1] for x in self.artists) - return "".format( - artist_str, self.name, self.mix_name - ) + title: str | None + mix_name: str | None + length: timedelta + url: str | None + track_number: int | None + bpm: str | None + initial_key: str | None + genre: str | None - def __repr__(self): - return str(self).encode("utf-8") - - def __init__(self, data): - BeatportObject.__init__(self, data) + def __init__(self, data: JSONDict): + super().__init__(data) if "title" in data: self.title = str(data["title"]) if "mixName" in data: @@ -279,8 +325,8 @@ class BeatportTrack(BeatportObject): self.genre = str(data["genres"][0].get("name")) -class BeatportPlugin(BeetsPlugin): - data_source = "Beatport" +class BeatportPlugin(MetadataSourcePlugin): + _client: BeatportClient | None = None def __init__(self): super().__init__() @@ -294,12 +340,19 @@ class BeatportPlugin(BeetsPlugin): ) self.config["apikey"].redact = True self.config["apisecret"].redact = True - self.client = None self.register_listener("import_begin", self.setup) - def setup(self, session=None): - c_key = self.config["apikey"].as_str() - c_secret = self.config["apisecret"].as_str() + @property + def client(self) -> BeatportClient: + if self._client is None: + raise ValueError( + "Beatport client not initialized. Call setup() first." + ) + return self._client + + def setup(self, session: ImportSession): + c_key: str = self.config["apikey"].as_str() + c_secret: str = self.config["apisecret"].as_str() # Get the OAuth token from a file or log in. try: @@ -312,9 +365,9 @@ class BeatportPlugin(BeetsPlugin): token = tokendata["token"] secret = tokendata["secret"] - self.client = BeatportClient(c_key, c_secret, token, secret) + self._client = BeatportClient(c_key, c_secret, token, secret) - def authenticate(self, c_key, c_secret): + def authenticate(self, c_key: str, c_secret: str) -> tuple[str, str]: # Get the link for the OAuth page. auth_client = BeatportClient(c_key, c_secret) try: @@ -341,44 +394,30 @@ class BeatportPlugin(BeetsPlugin): return token, secret - def _tokenfile(self): + def _tokenfile(self) -> str: """Get the path to the JSON file for storing the OAuth token.""" return self.config["tokenfile"].get(confuse.Filename(in_app_dir=True)) - def album_distance(self, items, album_info, mapping): - """Returns the Beatport source weight and the maximum source weight - for albums. - """ - return get_distance( - data_source=self.data_source, info=album_info, config=self.config - ) - - def track_distance(self, item, track_info): - """Returns the Beatport source weight and the maximum source weight - for individual tracks. - """ - return get_distance( - data_source=self.data_source, info=track_info, config=self.config - ) - - def candidates(self, items, artist, release, va_likely): - """Returns a list of AlbumInfo objects for beatport search results - matching release and artist (if not various). - """ + def candidates( + self, + items: Sequence[Item], + artist: str, + album: str, + va_likely: bool, + ) -> Iterator[AlbumInfo]: if va_likely: - query = release + query = album else: - query = f"{artist} {release}" + query = f"{artist} {album}" try: - return self._get_releases(query) + yield from self._get_releases(query) except BeatportAPIError as e: self._log.debug("API Error: {0} (query: {1})", e, query) - return [] + return - def item_candidates(self, item, artist, title): - """Returns a list of TrackInfo objects for beatport search results - matching title and artist. - """ + def item_candidates( + self, item: Item, artist: str, title: str + ) -> Iterable[TrackInfo]: query = f"{artist} {title}" try: return self._get_tracks(query) @@ -386,13 +425,13 @@ class BeatportPlugin(BeetsPlugin): self._log.debug("API Error: {0} (query: {1})", e, query) return [] - def album_for_id(self, release_id): + def album_for_id(self, album_id: str): """Fetches a release by its Beatport ID and returns an AlbumInfo object or None if the query is not a valid ID or release is not found. """ - self._log.debug("Searching for release {0}", release_id) + self._log.debug("Searching for release {0}", album_id) - if not (release_id := self._get_id(release_id)): + if not (release_id := self.extract_release_id(album_id)): self._log.debug("Not a valid Beatport release ID.") return None @@ -401,11 +440,12 @@ class BeatportPlugin(BeetsPlugin): return self._get_album_info(release) return None - def track_for_id(self, track_id): + def track_for_id(self, track_id: str): """Fetches a track by its Beatport ID and returns a TrackInfo object or None if the track is not a valid Beatport ID or track is not found. """ self._log.debug("Searching for track {0}", track_id) + # TODO: move to extractor match = re.search(r"(^|beatport\.com/track/.+/)(\d+)$", track_id) if not match: self._log.debug("Not a valid Beatport track ID.") @@ -415,7 +455,7 @@ class BeatportPlugin(BeetsPlugin): return self._get_track_info(bp_track) return None - def _get_releases(self, query): + def _get_releases(self, query: str) -> Iterator[AlbumInfo]: """Returns a list of AlbumInfo objects for a beatport search query.""" # Strip non-word characters from query. Things like "!" and "-" can # cause a query to return no results, even if they match the artist or @@ -425,16 +465,22 @@ class BeatportPlugin(BeetsPlugin): # Strip medium information from query, Things like "CD1" and "disk 1" # can also negate an otherwise positive result. query = re.sub(r"\b(CD|disc)\s*\d+", "", query, flags=re.I) - albums = [self._get_album_info(x) for x in self.client.search(query)] - return albums + for beatport_release in self.client.search(query, "release"): + if beatport_release is None: + continue + yield self._get_album_info(beatport_release) - def _get_album_info(self, release): + def _get_album_info(self, release: BeatportRelease) -> AlbumInfo: """Returns an AlbumInfo object for a Beatport Release object.""" - va = len(release.artists) > 3 + va = release.artists is not None and len(release.artists) > 3 artist, artist_id = self._get_artist(release.artists) if va: artist = "Various Artists" - tracks = [self._get_track_info(x) for x in release.tracks] + tracks: list[TrackInfo] = [] + if release.tracks is not None: + tracks = [self._get_track_info(x) for x in release.tracks] + + release_date = release.release_date return AlbumInfo( album=release.name, @@ -445,18 +491,18 @@ class BeatportPlugin(BeetsPlugin): tracks=tracks, albumtype=release.category, va=va, - year=release.release_date.year, - month=release.release_date.month, - day=release.release_date.day, label=release.label_name, catalognum=release.catalog_number, media="Digital", data_source=self.data_source, data_url=release.url, genre=release.genre, + year=release_date.year if release_date else None, + month=release_date.month if release_date else None, + day=release_date.day if release_date else None, ) - def _get_track_info(self, track): + def _get_track_info(self, track: BeatportTrack) -> TrackInfo: """Returns a TrackInfo object for a Beatport Track object.""" title = track.name if track.mix_name != "Original Mix": @@ -482,9 +528,7 @@ class BeatportPlugin(BeetsPlugin): """Returns an artist string (all artists) and an artist_id (the main artist) for a list of Beatport release or track artists. """ - return MetadataSourcePlugin.get_artist( - artists=artists, id_key=0, name_key=1 - ) + return self.get_artist(artists=artists, id_key=0, name_key=1) def _get_tracks(self, query): """Returns a list of TrackInfo objects for a Beatport query."""