Opt in beatport plugin. Also enhanced type hints and minor cleanup for

the beatport plugin.
This commit is contained in:
Sebastian Mohr 2025-07-07 13:58:40 +02:00
parent a97633dbf6
commit 3eadf17e8f

View file

@ -14,9 +14,19 @@
"""Adds Beatport release and track search support to the autotagger"""
from __future__ import annotations
import json
import re
from datetime import datetime, timedelta
from typing import (
TYPE_CHECKING,
Iterable,
Iterator,
Literal,
Sequence,
overload,
)
import confuse
from requests_oauthlib import OAuth1Session
@ -29,7 +39,13 @@ from requests_oauthlib.oauth1_session import (
import beets
import beets.ui
from beets.autotag.hooks import AlbumInfo, TrackInfo
from beets.plugins import BeetsPlugin, MetadataSourcePlugin, get_distance
from beets.metadata_plugins import MetadataSourcePlugin
if TYPE_CHECKING:
from beets.importer import ImportSession
from beets.library import Item
from ._typing import JSONDict
AUTH_ERRORS = (TokenRequestDenied, TokenMissing, VerifierMissing)
USER_AGENT = f"beets/{beets.__version__} +https://beets.io/"
@ -39,20 +55,6 @@ class BeatportAPIError(Exception):
pass
class BeatportObject:
def __init__(self, data):
self.beatport_id = data["id"]
self.name = str(data["name"])
if "releaseDate" in data:
self.release_date = datetime.strptime(
data["releaseDate"], "%Y-%m-%d"
)
if "artists" in data:
self.artists = [(x["id"], str(x["name"])) for x in data["artists"]]
if "genres" in data:
self.genres = [str(x["name"]) for x in data["genres"]]
class BeatportClient:
_api_base = "https://oauth-api.beatport.com"
@ -77,7 +79,7 @@ class BeatportClient:
)
self.api.headers = {"User-Agent": USER_AGENT}
def get_authorize_url(self):
def get_authorize_url(self) -> str:
"""Generate the URL for the user to authorize the application.
Retrieves a request token from the Beatport API and returns the
@ -99,15 +101,13 @@ class BeatportClient:
self._make_url("/identity/1/oauth/authorize")
)
def get_access_token(self, auth_data):
def get_access_token(self, auth_data: str) -> tuple[str, str]:
"""Obtain the final access token and secret for the API.
:param auth_data: URL-encoded authorization data as displayed at
the authorization url (obtained via
:py:meth:`get_authorize_url`) after signing in
:type auth_data: unicode
:returns: OAuth resource owner key and secret
:rtype: (unicode, unicode) tuple
:returns: OAuth resource owner key and secret as unicode
"""
self.api.parse_authorization_response(
"https://beets.io/auth?" + auth_data
@ -117,20 +117,37 @@ class BeatportClient:
)
return access_data["oauth_token"], access_data["oauth_token_secret"]
def search(self, query, release_type="release", details=True):
@overload
def search(
self,
query: str,
release_type: Literal["release"],
details: bool = True,
) -> Iterator[BeatportRelease]: ...
@overload
def search(
self,
query: str,
release_type: Literal["track"],
details: bool = True,
) -> Iterator[BeatportTrack]: ...
def search(
self,
query: str,
release_type: Literal["release", "track"],
details=True,
) -> Iterator[BeatportRelease | BeatportTrack]:
"""Perform a search of the Beatport catalogue.
:param query: Query string
:param release_type: Type of releases to search for, can be
'release' or 'track'
:param release_type: Type of releases to search for.
:param details: Retrieve additional information about the
search results. Currently this will fetch
the tracklist for releases and do nothing for
tracks
:returns: Search results
:rtype: generator that yields
py:class:`BeatportRelease` or
:py:class:`BeatportTrack`
"""
response = self._get(
"catalog/3/search",
@ -140,20 +157,18 @@ class BeatportClient:
)
for item in response:
if release_type == "release":
release = BeatportRelease(item)
if details:
release = self.get_release(item["id"])
else:
release = BeatportRelease(item)
release.tracks = self.get_release_tracks(item["id"])
yield release
elif release_type == "track":
yield BeatportTrack(item)
def get_release(self, beatport_id):
def get_release(self, beatport_id: str) -> BeatportRelease | None:
"""Get information about a single release.
:param beatport_id: Beatport ID of the release
:returns: The matching release
:rtype: :py:class:`BeatportRelease`
"""
response = self._get("/catalog/3/releases", id=beatport_id)
if response:
@ -162,35 +177,33 @@ class BeatportClient:
return release
return None
def get_release_tracks(self, beatport_id):
def get_release_tracks(self, beatport_id: str) -> list[BeatportTrack]:
"""Get all tracks for a given release.
:param beatport_id: Beatport ID of the release
:returns: Tracks in the matching release
:rtype: list of :py:class:`BeatportTrack`
"""
response = self._get(
"/catalog/3/tracks", releaseId=beatport_id, perPage=100
)
return [BeatportTrack(t) for t in response]
def get_track(self, beatport_id):
def get_track(self, beatport_id: str) -> BeatportTrack:
"""Get information about a single track.
:param beatport_id: Beatport ID of the track
:returns: The matching track
:rtype: :py:class:`BeatportTrack`
"""
response = self._get("/catalog/3/tracks", id=beatport_id)
return BeatportTrack(response[0])
def _make_url(self, endpoint):
def _make_url(self, endpoint: str) -> str:
"""Get complete URL for a given API endpoint."""
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
return self._api_base + endpoint
def _get(self, endpoint, **kwargs):
def _get(self, endpoint: str, **kwargs) -> list[JSONDict]:
"""Perform a GET request on a given API endpoint.
Automatically extracts result data from the response and converts HTTP
@ -211,48 +224,81 @@ class BeatportClient:
return response.json()["results"]
class BeatportRelease(BeatportObject):
def __str__(self):
if len(self.artists) < 4:
artist_str = ", ".join(x[1] for x in self.artists)
class BeatportObject:
beatport_id: str
name: str
release_date: datetime | None = None
artists: list[tuple[str, str]] | None = None
# tuple of artist id and artist name
def __init__(self, data: JSONDict):
self.beatport_id = str(data["id"]) # given as int in the response
self.name = str(data["name"])
if "releaseDate" in data:
self.release_date = datetime.strptime(
data["releaseDate"], "%Y-%m-%d"
)
if "artists" in data:
self.artists = [(x["id"], str(x["name"])) for x in data["artists"]]
if "genres" in data:
self.genres = [str(x["name"]) for x in data["genres"]]
def artists_str(self) -> str | None:
if self.artists is not None:
if len(self.artists) < 4:
artist_str = ", ".join(x[1] for x in self.artists)
else:
artist_str = "Various Artists"
else:
artist_str = "Various Artists"
return "<BeatportRelease: {} - {} ({})>".format(
artist_str,
self.name,
self.catalog_number,
)
artist_str = None
def __repr__(self):
return str(self).encode("utf-8")
return artist_str
class BeatportRelease(BeatportObject):
catalog_number: str | None
label_name: str | None
category: str | None
url: str | None
genre: str | None
tracks: list[BeatportTrack] | None = None
def __init__(self, data: JSONDict):
super().__init__(data)
self.catalog_number = data.get("catalogNumber")
self.label_name = data.get("label", {}).get("name")
self.category = data.get("category")
self.genre = data.get("genre")
def __init__(self, data):
BeatportObject.__init__(self, data)
if "catalogNumber" in data:
self.catalog_number = data["catalogNumber"]
if "label" in data:
self.label_name = data["label"]["name"]
if "category" in data:
self.category = data["category"]
if "slug" in data:
self.url = "https://beatport.com/release/{}/{}".format(
data["slug"], data["id"]
)
self.genre = data.get("genre")
def __str__(self) -> str:
return "<BeatportRelease: {} - {} ({})>".format(
self.artists_str(),
self.name,
self.catalog_number,
)
class BeatportTrack(BeatportObject):
def __str__(self):
artist_str = ", ".join(x[1] for x in self.artists)
return "<BeatportTrack: {} - {} ({})>".format(
artist_str, self.name, self.mix_name
)
title: str | None
mix_name: str | None
length: timedelta
url: str | None
track_number: int | None
bpm: str | None
initial_key: str | None
genre: str | None
def __repr__(self):
return str(self).encode("utf-8")
def __init__(self, data):
BeatportObject.__init__(self, data)
def __init__(self, data: JSONDict):
super().__init__(data)
if "title" in data:
self.title = str(data["title"])
if "mixName" in data:
@ -279,8 +325,8 @@ class BeatportTrack(BeatportObject):
self.genre = str(data["genres"][0].get("name"))
class BeatportPlugin(BeetsPlugin):
data_source = "Beatport"
class BeatportPlugin(MetadataSourcePlugin):
_client: BeatportClient | None = None
def __init__(self):
super().__init__()
@ -294,12 +340,19 @@ class BeatportPlugin(BeetsPlugin):
)
self.config["apikey"].redact = True
self.config["apisecret"].redact = True
self.client = None
self.register_listener("import_begin", self.setup)
def setup(self, session=None):
c_key = self.config["apikey"].as_str()
c_secret = self.config["apisecret"].as_str()
@property
def client(self) -> BeatportClient:
if self._client is None:
raise ValueError(
"Beatport client not initialized. Call setup() first."
)
return self._client
def setup(self, session: ImportSession):
c_key: str = self.config["apikey"].as_str()
c_secret: str = self.config["apisecret"].as_str()
# Get the OAuth token from a file or log in.
try:
@ -312,9 +365,9 @@ class BeatportPlugin(BeetsPlugin):
token = tokendata["token"]
secret = tokendata["secret"]
self.client = BeatportClient(c_key, c_secret, token, secret)
self._client = BeatportClient(c_key, c_secret, token, secret)
def authenticate(self, c_key, c_secret):
def authenticate(self, c_key: str, c_secret: str) -> tuple[str, str]:
# Get the link for the OAuth page.
auth_client = BeatportClient(c_key, c_secret)
try:
@ -341,44 +394,30 @@ class BeatportPlugin(BeetsPlugin):
return token, secret
def _tokenfile(self):
def _tokenfile(self) -> str:
"""Get the path to the JSON file for storing the OAuth token."""
return self.config["tokenfile"].get(confuse.Filename(in_app_dir=True))
def album_distance(self, items, album_info, mapping):
"""Returns the Beatport source weight and the maximum source weight
for albums.
"""
return get_distance(
data_source=self.data_source, info=album_info, config=self.config
)
def track_distance(self, item, track_info):
"""Returns the Beatport source weight and the maximum source weight
for individual tracks.
"""
return get_distance(
data_source=self.data_source, info=track_info, config=self.config
)
def candidates(self, items, artist, release, va_likely):
"""Returns a list of AlbumInfo objects for beatport search results
matching release and artist (if not various).
"""
def candidates(
self,
items: Sequence[Item],
artist: str,
album: str,
va_likely: bool,
) -> Iterator[AlbumInfo]:
if va_likely:
query = release
query = album
else:
query = f"{artist} {release}"
query = f"{artist} {album}"
try:
return self._get_releases(query)
yield from self._get_releases(query)
except BeatportAPIError as e:
self._log.debug("API Error: {0} (query: {1})", e, query)
return []
return
def item_candidates(self, item, artist, title):
"""Returns a list of TrackInfo objects for beatport search results
matching title and artist.
"""
def item_candidates(
self, item: Item, artist: str, title: str
) -> Iterable[TrackInfo]:
query = f"{artist} {title}"
try:
return self._get_tracks(query)
@ -386,13 +425,13 @@ class BeatportPlugin(BeetsPlugin):
self._log.debug("API Error: {0} (query: {1})", e, query)
return []
def album_for_id(self, release_id):
def album_for_id(self, album_id: str):
"""Fetches a release by its Beatport ID and returns an AlbumInfo object
or None if the query is not a valid ID or release is not found.
"""
self._log.debug("Searching for release {0}", release_id)
self._log.debug("Searching for release {0}", album_id)
if not (release_id := self._get_id(release_id)):
if not (release_id := self.extract_release_id(album_id)):
self._log.debug("Not a valid Beatport release ID.")
return None
@ -401,11 +440,12 @@ class BeatportPlugin(BeetsPlugin):
return self._get_album_info(release)
return None
def track_for_id(self, track_id):
def track_for_id(self, track_id: str):
"""Fetches a track by its Beatport ID and returns a TrackInfo object
or None if the track is not a valid Beatport ID or track is not found.
"""
self._log.debug("Searching for track {0}", track_id)
# TODO: move to extractor
match = re.search(r"(^|beatport\.com/track/.+/)(\d+)$", track_id)
if not match:
self._log.debug("Not a valid Beatport track ID.")
@ -415,7 +455,7 @@ class BeatportPlugin(BeetsPlugin):
return self._get_track_info(bp_track)
return None
def _get_releases(self, query):
def _get_releases(self, query: str) -> Iterator[AlbumInfo]:
"""Returns a list of AlbumInfo objects for a beatport search query."""
# Strip non-word characters from query. Things like "!" and "-" can
# cause a query to return no results, even if they match the artist or
@ -425,16 +465,22 @@ class BeatportPlugin(BeetsPlugin):
# Strip medium information from query, Things like "CD1" and "disk 1"
# can also negate an otherwise positive result.
query = re.sub(r"\b(CD|disc)\s*\d+", "", query, flags=re.I)
albums = [self._get_album_info(x) for x in self.client.search(query)]
return albums
for beatport_release in self.client.search(query, "release"):
if beatport_release is None:
continue
yield self._get_album_info(beatport_release)
def _get_album_info(self, release):
def _get_album_info(self, release: BeatportRelease) -> AlbumInfo:
"""Returns an AlbumInfo object for a Beatport Release object."""
va = len(release.artists) > 3
va = release.artists is not None and len(release.artists) > 3
artist, artist_id = self._get_artist(release.artists)
if va:
artist = "Various Artists"
tracks = [self._get_track_info(x) for x in release.tracks]
tracks: list[TrackInfo] = []
if release.tracks is not None:
tracks = [self._get_track_info(x) for x in release.tracks]
release_date = release.release_date
return AlbumInfo(
album=release.name,
@ -445,18 +491,18 @@ class BeatportPlugin(BeetsPlugin):
tracks=tracks,
albumtype=release.category,
va=va,
year=release.release_date.year,
month=release.release_date.month,
day=release.release_date.day,
label=release.label_name,
catalognum=release.catalog_number,
media="Digital",
data_source=self.data_source,
data_url=release.url,
genre=release.genre,
year=release_date.year if release_date else None,
month=release_date.month if release_date else None,
day=release_date.day if release_date else None,
)
def _get_track_info(self, track):
def _get_track_info(self, track: BeatportTrack) -> TrackInfo:
"""Returns a TrackInfo object for a Beatport Track object."""
title = track.name
if track.mix_name != "Original Mix":
@ -482,9 +528,7 @@ class BeatportPlugin(BeetsPlugin):
"""Returns an artist string (all artists) and an artist_id (the main
artist) for a list of Beatport release or track artists.
"""
return MetadataSourcePlugin.get_artist(
artists=artists, id_key=0, name_key=1
)
return self.get_artist(artists=artists, id_key=0, name_key=1)
def _get_tracks(self, query):
"""Returns a list of TrackInfo objects for a Beatport query."""