From d053f98e81f710d091adf436d4cb8235b13e1834 Mon Sep 17 00:00:00 2001 From: Adrian Sampson Date: Mon, 26 Dec 2016 16:43:47 -0500 Subject: [PATCH] random: Refactor equal chance logic --- beetsplug/random.py | 64 +++++++++++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 22 deletions(-) diff --git a/beetsplug/random.py b/beetsplug/random.py index c8e1f3fd4..5e7b038d7 100644 --- a/beetsplug/random.py +++ b/beetsplug/random.py @@ -33,6 +33,46 @@ def _length(obj, album): return obj.length +def _equal_chance_permutation(objs, field='albumartist'): + """Generate (lazily) a permutation of the objects where every group + with equal values for `field` have an equal chance of appearing in + any given position. + """ + # Group the objects by artist so we can sample from them. + key = attrgetter(field) + objs.sort(key=key) + objs_by_artists = {} + for artist, v in groupby(objs, key): + objs_by_artists[artist] = list(v) + + # While we still have artists with music to choose from, pick one + # randomly and pick a track from that artist. + while objs_by_artists: + # Choose an artist and an object for that artist, removing + # this choice from the pool. + artist = random.choice(list(objs_by_artists.keys())) + objs_from_artist = objs_by_artists[artist] + i = random.randint(0, len(objs_from_artist) - 1) + yield objs_from_artist.pop(i) + + # Remove the artist if we've used up all of its objects. + if not objs_from_artist: + del objs_by_artists[artist] + + +def _take(iter, num): + """Return a list containing the first `num` values in `iter` (or + fewer, if the iterable ends early). + """ + out = [] + for val in iter: + out.append(val) + num -= 1 + if num <= 0: + break + return out + + def random_objs(objs, album, number=1, time=None, equal_chance=False): """Get a random subset of the provided `objs`. @@ -44,7 +84,7 @@ def random_objs(objs, album, number=1, time=None, equal_chance=False): """ if time: total_time = 0.0 - time_sec = (time * 60) + time_sec = time * 60 objs_shuffled = objs random.shuffle(objs_shuffled) @@ -61,13 +101,6 @@ def random_objs(objs, album, number=1, time=None, equal_chance=False): pass if equal_chance: - # Group the objects by artist so we can sample from them. - key = attrgetter('albumartist') - objs.sort(key=key) - objs_by_artists = {} - for artist, v in groupby(objs, key): - objs_by_artists[artist] = list(v) - objs = [] if time: for obj in objs_shuffled: @@ -89,21 +122,8 @@ def random_objs(objs, album, number=1, time=None, equal_chance=False): pass else: - for _ in range(number): - # Terminate early if we're out of objects to select. - if not objs_by_artists: - break + return _take(_equal_chance_permutation(objs), number) - # Choose an artist and an object for that artist, removing - # this choice from the pool. - artist = random.choice(list(objs_by_artists.keys())) - objs_from_artist = objs_by_artists[artist] - i = random.randint(0, len(objs_from_artist) - 1) - objs.append(objs_from_artist.pop(i)) - - # Remove the artist if we've used up all of its objects. - if not objs_from_artist: - del objs_by_artists[artist] elif not time: number = min(len(objs), number)