diff --git a/beetsplug/random.py b/beetsplug/random.py index 853e9b14a..b8778eb08 100644 --- a/beetsplug/random.py +++ b/beetsplug/random.py @@ -70,14 +70,11 @@ NOT_FOUND_SENTINEL = object() def _equal_chance_permutation( objs: Sequence[T], field: str = "albumartist", - random_gen: random.Random | None = None, ) -> Iterable[T]: """Generate (lazily) a permutation of the objects where every group with equal values for `field` have an equal chance of appearing in any given position. """ - rand: random.Random = random_gen or random.Random() - # Group the objects by artist so we can sample from them. key = attrgetter(field) @@ -95,12 +92,12 @@ def _equal_chance_permutation( for k, values in groupby(objs, key=get_attr): groups[k] = list(values) # shuffle in category - rand.shuffle(groups[k]) + random.shuffle(groups[k]) # Remove items without the field value. del groups[NOT_FOUND_SENTINEL] while groups: - group = rand.choice(list(groups.keys())) + group = random.choice(list(groups.keys())) yield groups[group].pop() if not groups[group]: del groups[group] @@ -127,7 +124,6 @@ def random_objs( number: int = 1, time_minutes: float | None = None, equal_chance: bool = False, - random_gen: random.Random | None = None, ) -> Iterable[T]: """Get a random subset of items, optionally constrained by time or count. @@ -140,16 +136,14 @@ def random_objs( selected, regardless of how many tracks they have. - random_gen: An optional random generator to use for shuffling. """ - rand: random.Random = random_gen or random.Random() - # Permute the objects either in a straightforward way or an # artist-balanced way. perm: Iterable[T] if equal_chance: - perm = _equal_chance_permutation(objs, random_gen=rand) + perm = _equal_chance_permutation(objs) else: perm = list(objs) - rand.shuffle(perm) + random.shuffle(perm) # Select objects by time our count. if time_minutes: diff --git a/test/plugins/test_random.py b/test/plugins/test_random.py index 80d12379b..f11326d53 100644 --- a/test/plugins/test_random.py +++ b/test/plugins/test_random.py @@ -15,12 +15,12 @@ """Test the beets.random utilities associated with the random plugin.""" import math -from random import Random +import random import pytest from beets.test.helper import TestHelper -from beetsplug import random +from beetsplug.random import _equal_chance_permutation, random_objs @pytest.fixture(scope="class") @@ -33,6 +33,11 @@ def helper(): helper.teardown_beets() +@pytest.fixture(scope="module", autouse=True) +def seed_random(): + random.seed(12345) + + class TestEqualChancePermutation: """Test the _equal_chance_permutation function.""" @@ -47,8 +52,6 @@ class TestEqualChancePermutation: self.items = [self.item1, self.item2] for _ in range(8): self.items.append(helper.create_item(artist=self.artist2)) - self.random_gen = Random() - self.random_gen.seed(12345) def _stats(self, data): mean = sum(data) / len(data) @@ -74,9 +77,7 @@ class TestEqualChancePermutation: positions = [] for _ in range(500): shuffled = list( - random._equal_chance_permutation( - self.items, field=field, random_gen=self.random_gen - ) + _equal_chance_permutation(self.items, field=field) ) positions.append(shuffled.index(self.item1)) # Print a histogram (useful for debugging). @@ -111,7 +112,7 @@ class TestEqualChancePermutation: ): """Test _equal_chance_permutation with empty input.""" result = list( - random._equal_chance_permutation( + _equal_chance_permutation( [helper.create_item(**i) for i in input_items], field ) ) @@ -136,19 +137,16 @@ class TestRandomObjs: helper.create_item(artist=self.artist2, length=240), # 4 minutes helper.create_item(artist=self.artist2, length=300), # 5 minutes ] - self.random_gen = random.Random() def test_random_selection_by_count(self): """Test selecting a specific number of items.""" - selected = list(random.random_objs(self.items, number=2)) + selected = list(random_objs(self.items, number=2)) assert len(selected) == 2 assert all(item in self.items for item in selected) def test_random_selection_by_time(self): """Test selecting items constrained by total time (minutes).""" - selected = list( - random.random_objs(self.items, time_minutes=6) - ) # 6 minutes + selected = list(random_objs(self.items, time_minutes=6)) # 6 minutes total_time = ( sum(item.length for item in selected) / 60 ) # Convert to minutes @@ -162,9 +160,7 @@ class TestRandomObjs: helper.create_item(artist=self.artist1, length=180) ) - selected = list( - random.random_objs(self.items, number=10, equal_chance=True) - ) + selected = list(random_objs(self.items, number=10, equal_chance=True)) artist_counts = {} for item in selected: artist_counts[item.artist] = artist_counts.get(item.artist, 0) + 1 @@ -174,11 +170,11 @@ class TestRandomObjs: def test_empty_input_list(self): """Test behavior with an empty input list.""" - selected = list(random.random_objs([], number=1)) + selected = list(random_objs([], number=1)) assert len(selected) == 0 def test_no_constraints_returns_all(self): """Test that no constraints return all items in random order.""" - selected = list(random.random_objs(self.items, 3)) + selected = list(random_objs(self.items, 3)) assert len(selected) == len(self.items) assert set(selected) == set(self.items)