Overall refactor of random plugin. Added length property to albums.

This commit is contained in:
Sebastian Mohr 2025-08-10 22:20:28 +02:00
parent 3941fd31ae
commit fb529d7a73
3 changed files with 171 additions and 153 deletions

View file

@ -616,6 +616,11 @@ class Album(LibModel):
for item in self.items():
item.try_sync(write, move)
@property
def length(self) -> float:
"""Return the total length of all items in this album in seconds."""
return sum(item.length for item in self.items())
class Item(LibModel):
"""Represent a song or track."""

View file

@ -1,17 +1,3 @@
# This file is part of beets.
# Copyright 2016, Philippe Mongeau.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Get a random song or album from the library."""
from __future__ import annotations
@ -19,26 +5,31 @@ from __future__ import annotations
import random
from itertools import groupby, islice
from operator import attrgetter
from typing import Iterable, Sequence, TypeVar
from typing import TYPE_CHECKING, Any, Iterable, Sequence, TypeVar
from beets.library import Album, Item
from beets.plugins import BeetsPlugin
from beets.ui import Subcommand, print_
if TYPE_CHECKING:
import optparse
def random_func(lib, opts, args):
from beets.library import LibModel, Library
T = TypeVar("T", bound=LibModel)
def random_func(lib: Library, opts: optparse.Values, args: list[str]):
"""Select some random items or albums and print the results."""
# Fetch all the objects matching the query into a list.
if opts.album:
objs = list(lib.albums(args))
else:
objs = list(lib.items(args))
objs = lib.albums(args) if opts.album else lib.items(args)
# Print a random subset.
objs = random_objs(
objs, opts.album, opts.number, opts.time, opts.equal_chance
)
for obj in objs:
for obj in random_objs(
objs=list(objs),
number=opts.number,
time_minutes=opts.time,
equal_chance=opts.equal_chance,
):
print_(format(obj))
@ -73,105 +64,93 @@ class Random(BeetsPlugin):
return [random_cmd]
def _length(obj: Item | Album) -> float:
"""Get the duration of an item or album."""
if isinstance(obj, Album):
return sum(i.length for i in obj.items())
else:
return obj.length
NOT_FOUND_SENTINEL = object()
def _equal_chance_permutation(
objs: Sequence[Item | Album],
objs: Sequence[T],
field: str = "albumartist",
random_gen: random.Random | None = None,
) -> Iterable[Item | Album]:
) -> Iterable[T]:
"""Generate (lazily) a permutation of the objects where every group
with equal values for `field` have an equal chance of appearing in
any given position.
"""
rand = random_gen or random
rand: random.Random = random_gen or random.Random()
# Group the objects by artist so we can sample from them.
key = attrgetter(field)
objs = sorted(objs, key=key)
objs_by_artists = {}
for artist, v in groupby(objs, key):
objs_by_artists[artist] = list(v)
# While we still have artists with music to choose from, pick one
# randomly and pick a track from that artist.
while objs_by_artists:
# Choose an artist and an object for that artist, removing
# this choice from the pool.
artist = rand.choice(list(objs_by_artists.keys()))
objs_from_artist = objs_by_artists[artist]
i = rand.randint(0, len(objs_from_artist) - 1)
yield objs_from_artist.pop(i)
def get_attr(obj: T) -> Any:
try:
return key(obj)
except AttributeError:
return NOT_FOUND_SENTINEL
# Remove the artist if we've used up all of its objects.
if not objs_from_artist:
del objs_by_artists[artist]
groups: dict[Any, list[T]] = {
NOT_FOUND_SENTINEL: [],
}
for k, values in groupby(objs, key=get_attr):
groups[k] = list(values)
# shuffle in category
rand.shuffle(groups[k])
T = TypeVar("T")
def _take(
iter: Iterable[T],
num: int,
) -> list[T]:
"""Return a list containing the first `num` values in `iter` (or
fewer, if the iterable ends early).
"""
return list(islice(iter, num))
# Remove items without the field value.
del groups[NOT_FOUND_SENTINEL]
while groups:
group = rand.choice(list(groups.keys()))
yield groups[group].pop()
if not groups[group]:
del groups[group]
def _take_time(
iter: Iterable[Item | Album],
iter: Iterable[T],
secs: float,
) -> list[Item | Album]:
) -> Iterable[T]:
"""Return a list containing the first values in `iter`, which should
be Item or Album objects, that add up to the given amount of time in
seconds.
"""
out: list[Item | Album] = []
total_time = 0.0
for obj in iter:
length = _length(obj)
length = obj.length
if total_time + length <= secs:
out.append(obj)
yield obj
total_time += length
return out
def random_objs(
objs: Sequence[Item | Album],
number=1,
time: float | None = None,
objs: Sequence[T],
number: int = 1,
time_minutes: float | None = None,
equal_chance: bool = False,
random_gen: random.Random | None = None,
):
"""Get a random subset of the provided `objs`.
) -> Iterable[T]:
"""Get a random subset of items, optionally constrained by time or count.
If `number` is provided, produce that many matches. Otherwise, if
`time` is provided, instead select a list whose total time is close
to that number of minutes. If `equal_chance` is true, give each
artist an equal chance of being included so that artists with more
songs are not represented disproportionately.
Args:
- objs: The sequence of objects to choose from.
- number: The number of objects to select.
- time_minutes: If specified, the total length of selected objects
should not exceed this many minutes.
- equal_chance: If True, each artist has the same chance of being
selected, regardless of how many tracks they have.
- random_gen: An optional random generator to use for shuffling.
"""
rand = random_gen or random
rand: random.Random = random_gen or random.Random()
# Permute the objects either in a straightforward way or an
# artist-balanced way.
perm: Iterable[T]
if equal_chance:
perm = _equal_chance_permutation(objs)
perm = _equal_chance_permutation(objs, random_gen=rand)
else:
perm = list(objs)
rand.shuffle(perm)
# Select objects by time our count.
if time:
return _take_time(perm, time * 60)
if time_minutes:
return _take_time(perm, time_minutes * 60)
else:
return _take(perm, number)
return islice(perm, number)

View file

@ -15,7 +15,6 @@
"""Test the beets.random utilities associated with the random plugin."""
import math
import unittest
from random import Random
import pytest
@ -24,16 +23,30 @@ from beets.test.helper import TestHelper
from beetsplug import random
class RandomTest(TestHelper, unittest.TestCase):
def setUp(self):
self.lib = None
@pytest.fixture(scope="class")
def helper():
helper = TestHelper()
helper.setup_beets()
yield helper
helper.teardown_beets()
class TestEqualChancePermutation:
"""Test the _equal_chance_permutation function."""
@pytest.fixture(autouse=True)
def setup(self, helper):
"""Set up the test environment with items."""
self.lib = helper.lib
self.artist1 = "Artist 1"
self.artist2 = "Artist 2"
self.item1 = self.create_item(artist=self.artist1)
self.item2 = self.create_item(artist=self.artist2)
self.item1 = helper.create_item(artist=self.artist1)
self.item2 = helper.create_item(artist=self.artist2)
self.items = [self.item1, self.item2]
for _ in range(8):
self.items.append(self.create_item(artist=self.artist2))
self.items.append(helper.create_item(artist=self.artist2))
self.random_gen = Random()
self.random_gen.seed(12345)
@ -78,73 +91,94 @@ class RandomTest(TestHelper, unittest.TestCase):
assert len(self.items) // 2 == pytest.approx(median2, abs=1)
assert stdev2 > stdev1
def test_equal_permutation_empty_input(self):
@pytest.mark.parametrize(
"input_items, field, expected",
[
([], "artist", []),
([{"artist": "Artist 1"}], "artist", [{"artist": "Artist 1"}]),
# Missing field should not raise an error, but return empty
([{"artist": "Artist 1"}], "nonexistent", []),
# Multiple items with the same field value
(
[{"artist": "Artist 1"}, {"artist": "Artist 1"}],
"artist",
[{"artist": "Artist 1"}, {"artist": "Artist 1"}],
),
],
)
def test_equal_permutation_items(
self, input_items, field, expected, helper
):
"""Test _equal_chance_permutation with empty input."""
result = list(random._equal_chance_permutation([], "artist"))
assert result == []
def test_equal_permutation_single_item(self):
"""Test _equal_chance_permutation with single item."""
result = list(random._equal_chance_permutation([self.item1], "artist"))
assert result == [self.item1]
def test_equal_permutation_single_artist(self):
"""Test _equal_chance_permutation with items from one artist."""
items = [self.create_item(artist=self.artist1) for _ in range(5)]
result = list(random._equal_chance_permutation(items, "artist"))
assert set(result) == set(items)
assert len(result) == len(items)
def test_random_objs_count(self):
"""Test random_objs with count-based selection."""
result = random.random_objs(
self.items, number=3, random_gen=self.random_gen
result = list(
random._equal_chance_permutation(
[helper.create_item(**i) for i in input_items], field
)
)
assert len(result) == 3
assert all(item in self.items for item in result)
def test_random_objs_time(self):
"""Test random_objs with time-based selection."""
# Total length is 30 + 60 + 8*45 = 450 seconds
# Requesting 120 seconds should return 2-3 items
result = random.random_objs(
self.items,
time=2,
random_gen=self.random_gen, # 2 minutes = 120 sec
for item in expected:
for key, value in item.items():
assert any(getattr(r, key) == value for r in result)
assert len(result) == len(expected)
class TestRandomObjs:
"""Test the random_objs function."""
@pytest.fixture(autouse=True)
def setup(self, helper):
"""Set up the test environment with items."""
self.lib = helper.lib
self.artist1 = "Artist 1"
self.artist2 = "Artist 2"
self.items = [
helper.create_item(artist=self.artist1, length=180), # 3 minutes
helper.create_item(artist=self.artist2, length=240), # 4 minutes
helper.create_item(artist=self.artist2, length=300), # 5 minutes
]
self.random_gen = random.Random()
def test_random_selection_by_count(self):
"""Test selecting a specific number of items."""
selected = list(random.random_objs(self.items, number=2))
assert len(selected) == 2
assert all(item in self.items for item in selected)
def test_random_selection_by_time(self):
"""Test selecting items constrained by total time (minutes)."""
selected = list(
random.random_objs(self.items, time_minutes=6)
) # 6 minutes
total_time = (
sum(item.length for item in selected) / 60
) # Convert to minutes
assert total_time <= 6
def test_equal_chance_permutation(self, helper):
"""Test equal chance permutation ensures balanced artist selection."""
# Add more items to make the test meaningful
for _ in range(5):
self.items.append(
helper.create_item(artist=self.artist1, length=180)
)
selected = list(
random.random_objs(self.items, number=10, equal_chance=True)
)
total_time = sum(item.length for item in result)
assert total_time <= 120
# Check we got at least some items
assert len(result) > 0
artist_counts = {}
for item in selected:
artist_counts[item.artist] = artist_counts.get(item.artist, 0) + 1
def test_random_objs_equal_chance(self):
"""Test random_objs with equal_chance=True."""
# Ensure both artists are represented (not strictly equal due to randomness)
assert len(artist_counts) >= 2
# With equal_chance, artist1 should appear more often in results
def experiment():
"""Run the random_objs function multiple times and collect results."""
results = []
for _ in range(5000):
result = random.random_objs(
[self.item1, self.item2],
number=1,
equal_chance=True,
random_gen=self.random_gen,
)
results.append(result[0].artist)
def test_empty_input_list(self):
"""Test behavior with an empty input list."""
selected = list(random.random_objs([], number=1))
assert len(selected) == 0
# Return ratio
return results.count(self.artist1), results.count(self.artist2)
count_artist1, count_artist2 = experiment()
assert 1 - count_artist1 / count_artist2 < 0.1 # 10% deviation
def test_random_objs_empty_input(self):
"""Test random_objs with empty input."""
result = random.random_objs([], number=3)
assert result == []
def test_random_objs_zero_number(self):
"""Test random_objs with number=0."""
result = random.random_objs(self.items, number=0)
assert result == []
def test_no_constraints_returns_all(self):
"""Test that no constraints return all items in random order."""
selected = list(random.random_objs(self.items, 3))
assert len(selected) == len(self.items)
assert set(selected) == set(self.items)