Replace control_stdin with io.addinput

This commit is contained in:
Šarūnas Nejus 2026-01-16 01:25:14 +00:00
parent ed43387778
commit cbfec8de66
No known key found for this signature in database
11 changed files with 134 additions and 149 deletions

View file

@ -15,8 +15,8 @@
"""This module includes various helpers that provide fixtures, capture
information or mock the environment.
- The `control_stdin` and `capture_stdout` context managers allow one to
interact with the user interface.
- `capture_stdout` context managers allow one to interact with the user
interface.
- `has_program` checks the presence of a command on the system.
@ -84,22 +84,6 @@ def capture_log(logger="beets"):
log.removeHandler(capture)
@contextmanager
def control_stdin(input=None):
"""Sends ``input`` to stdin.
>>> with control_stdin('yes'):
... input()
'yes'
"""
org = sys.stdin
sys.stdin = StringIO(input)
try:
yield sys.stdin
finally:
sys.stdin = org
@contextmanager
def capture_stdout():
"""Save stdout in a StringIO.

View file

@ -32,7 +32,7 @@ import yaml
from beets.test.helper import PluginTestCase
from beets.util import bluelet
bpd = pytest.importorskip("beetsplug.bpd")
bpd = pytest.importorskip("beetsplug.bpd", exc_type=ImportError)
class CommandParseTest(unittest.TestCase):

View file

@ -30,8 +30,8 @@ from beets.test.helper import (
AsIsImporterMixin,
ImportHelper,
PluginTestCase,
IOMixin,
capture_log,
control_stdin,
)
from beetsplug import convert
@ -66,7 +66,7 @@ class ConvertMixin:
return path.read_bytes().endswith(tag.encode("utf-8"))
class ConvertTestCase(ConvertMixin, PluginTestCase):
class ConvertTestCase(IOMixin, ConvertMixin, PluginTestCase):
db_on_disk = True
plugin = "convert"
@ -157,8 +157,8 @@ class ConvertCliTest(ConvertTestCase, ConvertCommand):
}
def test_convert(self):
with control_stdin("y"):
self.run_convert()
self.io.addinput("y")
self.run_convert()
assert self.file_endswith(self.converted_mp3, "mp3")
def test_convert_with_auto_confirmation(self):
@ -166,22 +166,22 @@ class ConvertCliTest(ConvertTestCase, ConvertCommand):
assert self.file_endswith(self.converted_mp3, "mp3")
def test_reject_confirmation(self):
with control_stdin("n"):
self.run_convert()
self.io.addinput("n")
self.run_convert()
assert not self.converted_mp3.exists()
def test_convert_keep_new(self):
assert os.path.splitext(self.item.path)[1] == b".ogg"
with control_stdin("y"):
self.run_convert("--keep-new")
self.io.addinput("y")
self.run_convert("--keep-new")
self.item.load()
assert os.path.splitext(self.item.path)[1] == b".mp3"
def test_format_option(self):
with control_stdin("y"):
self.run_convert("--format", "opus")
self.io.addinput("y")
self.run_convert("--format", "opus")
assert self.file_endswith(self.convert_dest / "converted.ops", "opus")
def test_embed_album_art(self):
@ -192,8 +192,8 @@ class ConvertCliTest(ConvertTestCase, ConvertCommand):
with open(os.path.join(image_path), "rb") as f:
image_data = f.read()
with control_stdin("y"):
self.run_convert()
self.io.addinput("y")
self.run_convert()
mediafile = MediaFile(self.converted_mp3)
assert mediafile.images[0].data == image_data
@ -215,26 +215,26 @@ class ConvertCliTest(ConvertTestCase, ConvertCommand):
def test_no_transcode_when_maxbr_set_high_and_different_formats(self):
self.config["convert"]["max_bitrate"] = 5000
with control_stdin("y"):
self.run_convert()
self.io.addinput("y")
self.run_convert()
assert self.file_endswith(self.converted_mp3, "mp3")
def test_transcode_when_maxbr_set_low_and_different_formats(self):
self.config["convert"]["max_bitrate"] = 5
with control_stdin("y"):
self.run_convert()
self.io.addinput("y")
self.run_convert()
assert self.file_endswith(self.converted_mp3, "mp3")
def test_transcode_when_maxbr_set_to_none_and_different_formats(self):
with control_stdin("y"):
self.run_convert()
self.io.addinput("y")
self.run_convert()
assert self.file_endswith(self.converted_mp3, "mp3")
def test_no_transcode_when_maxbr_set_high_and_same_formats(self):
self.config["convert"]["max_bitrate"] = 5000
self.config["convert"]["format"] = "ogg"
with control_stdin("y"):
self.run_convert()
self.io.addinput("y")
self.run_convert()
assert not self.file_endswith(
self.convert_dest / "converted.ogg", "ogg"
)
@ -243,8 +243,8 @@ class ConvertCliTest(ConvertTestCase, ConvertCommand):
self.config["convert"]["max_bitrate"] = 5000
self.config["convert"]["format"] = "ogg"
with control_stdin("y"):
self.run_convert("--force")
self.io.addinput("y")
self.run_convert("--force")
converted = self.convert_dest / "converted.ogg"
assert self.file_endswith(converted, "ogg")
@ -252,21 +252,21 @@ class ConvertCliTest(ConvertTestCase, ConvertCommand):
def test_transcode_when_maxbr_set_low_and_same_formats(self):
self.config["convert"]["max_bitrate"] = 5
self.config["convert"]["format"] = "ogg"
with control_stdin("y"):
self.run_convert()
self.io.addinput("y")
self.run_convert()
assert self.file_endswith(self.convert_dest / "converted.ogg", "ogg")
def test_transcode_when_maxbr_set_to_none_and_same_formats(self):
self.config["convert"]["format"] = "ogg"
with control_stdin("y"):
self.run_convert()
self.io.addinput("y")
self.run_convert()
assert not self.file_endswith(
self.convert_dest / "converted.ogg", "ogg"
)
def test_playlist(self):
with control_stdin("y"):
self.run_convert("--playlist", "playlist.m3u8")
self.io.addinput("y")
self.run_convert("--playlist", "playlist.m3u8")
assert (self.convert_dest / "playlist.m3u8").exists()
def test_playlist_pretend(self):
@ -282,8 +282,8 @@ class ConvertCliTest(ConvertTestCase, ConvertCommand):
[item] = self.add_item_fixtures(ext="ogg")
with control_stdin("y"):
self.run_convert_path(item, "--format", "opus", "--force")
self.io.addinput("y")
self.run_convert_path(item, "--format", "opus", "--force")
converted = self.convert_dest / "converted.ops"
assert self.file_endswith(converted, "opus")
@ -309,23 +309,23 @@ class NeverConvertLossyFilesTest(ConvertTestCase, ConvertCommand):
def test_transcode_from_lossless(self):
[item] = self.add_item_fixtures(ext="flac")
with control_stdin("y"):
self.run_convert_path(item)
self.io.addinput("y")
self.run_convert_path(item)
converted = self.convert_dest / "converted.mp3"
assert self.file_endswith(converted, "mp3")
def test_transcode_from_lossy(self):
self.config["convert"]["never_convert_lossy_files"] = False
[item] = self.add_item_fixtures(ext="ogg")
with control_stdin("y"):
self.run_convert_path(item)
self.io.addinput("y")
self.run_convert_path(item)
converted = self.convert_dest / "converted.mp3"
assert self.file_endswith(converted, "mp3")
def test_transcode_from_lossy_prevented(self):
[item] = self.add_item_fixtures(ext="ogg")
with control_stdin("y"):
self.run_convert_path(item)
self.io.addinput("y")
self.run_convert_path(item)
converted = self.convert_dest / "converted.ogg"
assert not self.file_endswith(converted, "mp3")
@ -336,8 +336,8 @@ class NeverConvertLossyFilesTest(ConvertTestCase, ConvertCommand):
}
[item] = self.add_item_fixtures(ext="ogg")
with control_stdin("y"):
self.run_convert_path(item, "--format", "opus", "--force")
self.io.addinput("y")
self.run_convert_path(item, "--format", "opus", "--force")
converted = self.convert_dest / "converted.ops"
assert self.file_endswith(converted, "opus")

View file

@ -17,15 +17,16 @@ from typing import ClassVar
from unittest.mock import patch
from beets.dbcore.query import TrueQuery
from beets.importer import Action
from beets.library import Item
from beets.test import _common
from beets.test.helper import (
AutotagImportTestCase,
AutotagStub,
BeetsTestCase,
IOMixin,
PluginMixin,
TerminalImportMixin,
control_stdin,
)
@ -73,7 +74,7 @@ class ModifyFileMocker:
f.write(contents)
class EditMixin(PluginMixin):
class EditMixin(IOMixin, PluginMixin):
"""Helper containing some common functionality used for the Edit tests."""
plugin = "edit"
@ -103,24 +104,26 @@ class EditMixin(PluginMixin):
"""
m = ModifyFileMocker(**modify_file_args)
with patch("beetsplug.edit.edit", side_effect=m.action):
with control_stdin("\n".join(stdin)):
self.importer.run()
for char in stdin:
self.importer.add_choice(char)
self.importer.run()
def run_mocked_command(self, modify_file_args={}, stdin=[], args=[]):
"""Run the edit command, with mocked stdin and yaml writing, and
passing `args` to `run_command`."""
m = ModifyFileMocker(**modify_file_args)
with patch("beetsplug.edit.edit", side_effect=m.action):
with control_stdin("\n".join(stdin)):
self.run_command("edit", *args)
for char in stdin:
self.io.addinput(char)
self.run_command("edit", *args)
@_common.slow_test()
@patch("beets.library.Item.write")
class EditCommandTest(EditMixin, BeetsTestCase):
"""Black box tests for `beetsplug.edit`. Command line interaction is
simulated using `test.helper.control_stdin()`, and yaml editing via an
external editor is simulated using `ModifyFileMocker`.
simulated using mocked stdin, and yaml editing via an external editor is
simulated using `ModifyFileMocker`.
"""
ALBUM_COUNT = 1
@ -412,7 +415,7 @@ class EditDuringImporterNonSingletonTest(EditDuringImporterTestCase):
self.run_mocked_interpreter(
{},
# 1, Apply changes.
["1", "a"],
["1", Action.APPLY],
)
# Retag and edit track titles. On retag, the importer will reset items

View file

@ -19,7 +19,7 @@ import os
import time
from beets import importer, plugins
from beets.test.helper import AutotagImportTestCase, PluginMixin, control_stdin
from beets.test.helper import AutotagImportTestCase, IOMixin, PluginMixin
from beets.util import syspath
from beetsplug.importsource import ImportSourcePlugin
@ -34,7 +34,7 @@ def preserve_plugin_listeners():
ImportSourcePlugin.listeners = _listeners
class ImportSourceTest(PluginMixin, AutotagImportTestCase):
class ImportSourceTest(IOMixin, PluginMixin, AutotagImportTestCase):
plugin = "importsource"
preload_plugin = False
@ -50,31 +50,29 @@ class ImportSourceTest(PluginMixin, AutotagImportTestCase):
self.all_items = self.lib.albums().get().items()
self.item_to_remove = self.all_items[0]
def interact(self, stdin_input: str):
with control_stdin(stdin_input):
self.run_command(
"remove",
f"path:{syspath(self.item_to_remove.path)}",
)
def interact(self, stdin: list[str]):
for char in stdin:
self.io.addinput(char)
self.run_command("remove", f"path:{syspath(self.item_to_remove.path)}")
def test_do_nothing(self):
self.interact("N")
self.interact(["N"])
assert os.path.exists(self.item_to_remove.source_path)
def test_remove_single(self):
self.interact("y\nD")
self.interact(["y", "D"])
assert not os.path.exists(self.item_to_remove.source_path)
def test_remove_all_from_single(self):
self.interact("y\nR\ny")
self.interact(["y", "R", "y"])
for item in self.all_items:
assert not os.path.exists(item.source_path)
def test_stop_suggesting(self):
self.interact("y\nS")
self.interact(["y", "S"])
for item in self.all_items:
assert os.path.exists(item.source_path)

View file

@ -18,7 +18,6 @@ from beets.test.helper import (
PluginMixin,
TerminalImportMixin,
capture_stdout,
control_stdin,
)
@ -35,9 +34,10 @@ class MBSubmitPluginTest(
def test_print_tracks_output(self):
"""Test the output of the "print tracks" choice."""
with capture_stdout() as output:
with control_stdin("\n".join(["p", "s"])):
# Print tracks; Skip
self.importer.run()
self.io.addinput("p")
self.io.addinput("s")
# Print tracks; Skip
self.importer.run()
# Manually build the string for comparing the output.
tracklist = (
@ -50,9 +50,12 @@ class MBSubmitPluginTest(
def test_print_tracks_output_as_tracks(self):
"""Test the output of the "print tracks" choice, as singletons."""
with capture_stdout() as output:
with control_stdin("\n".join(["t", "s", "p", "s"])):
# as Tracks; Skip; Print tracks; Skip
self.importer.run()
self.io.addinput("t")
self.io.addinput("s")
self.io.addinput("p")
self.io.addinput("s")
# as Tracks; Skip; Print tracks; Skip
self.importer.run()
# Manually build the string for comparing the output.
tracklist = (

View file

@ -21,14 +21,14 @@ from unittest.mock import ANY, patch
import pytest
from beets.test.helper import CleanupModulesMixin, PluginTestCase, control_stdin
from beets.test.helper import CleanupModulesMixin, PluginTestCase, IOMixin
from beets.ui import UserError
from beets.util import open_anything
from beetsplug.play import PlayPlugin
@patch("beetsplug.play.util.interactive_open")
class PlayPluginTest(CleanupModulesMixin, PluginTestCase):
class PlayPluginTest(IOMixin, CleanupModulesMixin, PluginTestCase):
modules = (PlayPlugin.__module__,)
plugin = "play"
@ -127,8 +127,8 @@ class PlayPluginTest(CleanupModulesMixin, PluginTestCase):
self.config["play"]["warning_threshold"] = 1
self.add_item(title="another NiceTitle")
with control_stdin("a"):
self.run_command("play", "nice")
self.io.addinput("a")
self.run_command("play", "nice")
open_mock.assert_not_called()
@ -138,12 +138,12 @@ class PlayPluginTest(CleanupModulesMixin, PluginTestCase):
expected_playlist = f"{self.item.filepath}\n{self.other_item.filepath}"
with control_stdin("a"):
self.run_and_assert(
open_mock,
["-y", "NiceTitle"],
expected_playlist=expected_playlist,
)
self.io.addinput("a")
self.run_and_assert(
open_mock,
["-y", "NiceTitle"],
expected_playlist=expected_playlist,
)
def test_command_failed(self, open_mock):
open_mock.side_effect = OSError("some reason")

View file

@ -3,12 +3,12 @@
from mediafile import MediaFile
from beets.library import Item
from beets.test.helper import PluginTestCase, control_stdin
from beets.test.helper import IOMixin, PluginTestCase
from beets.util import syspath
from beetsplug.zero import ZeroPlugin
class ZeroPluginTest(PluginTestCase):
class ZeroPluginTest(IOMixin, PluginTestCase):
plugin = "zero"
preload_plugin = False
@ -102,12 +102,10 @@ class ZeroPluginTest(PluginTestCase):
item.write()
item_id = item.id
with (
self.configure_plugin(
{"fields": ["comments"], "update_database": True, "auto": False}
),
control_stdin("y"),
with self.configure_plugin(
{"fields": ["comments"], "update_database": True, "auto": False}
):
self.io.addinput("y")
self.run_command("zero")
mf = MediaFile(syspath(item.path))
@ -125,16 +123,14 @@ class ZeroPluginTest(PluginTestCase):
item.write()
item_id = item.id
with (
self.configure_plugin(
{
"fields": ["comments"],
"update_database": False,
"auto": False,
}
),
control_stdin("y"),
with self.configure_plugin(
{
"fields": ["comments"],
"update_database": False,
"auto": False,
}
):
self.io.addinput("y")
self.run_command("zero")
mf = MediaFile(syspath(item.path))
@ -187,7 +183,8 @@ class ZeroPluginTest(PluginTestCase):
item_id = item.id
with self.configure_plugin({"fields": []}), control_stdin("y"):
with self.configure_plugin({"fields": []}):
self.io.addinput("y")
self.run_command("zero")
item = self.lib.get_item(item_id)
@ -203,12 +200,10 @@ class ZeroPluginTest(PluginTestCase):
item_id = item.id
with (
self.configure_plugin(
{"fields": ["year"], "keep_fields": ["comments"]}
),
control_stdin("y"),
with self.configure_plugin(
{"fields": ["year"], "keep_fields": ["comments"]}
):
self.io.addinput("y")
self.run_command("zero")
item = self.lib.get_item(item_id)
@ -303,12 +298,10 @@ class ZeroPluginTest(PluginTestCase):
)
item.write()
item_id = item.id
with (
self.configure_plugin(
{"fields": ["comments"], "update_database": True, "auto": False}
),
control_stdin("n"),
with self.configure_plugin(
{"fields": ["comments"], "update_database": True, "auto": False}
):
self.io.addinput("n")
self.run_command("zero")
mf = MediaFile(syspath(item.path))

View file

@ -429,8 +429,9 @@ class PromptChoicesTest(TerminalImportMixin, PluginImportTestCase):
# DummyPlugin.foo() should be called once
with patch.object(DummyPlugin, "foo", autospec=True) as mock_foo:
with helper.control_stdin("\n".join(["f", "s"])):
self.importer.run()
self.io.addinput("f")
self.io.addinput("n")
self.importer.run()
assert mock_foo.call_count == 1
# input_options should be called twice, as foo() returns None
@ -471,8 +472,8 @@ class PromptChoicesTest(TerminalImportMixin, PluginImportTestCase):
)
# DummyPlugin.foo() should be called once
with helper.control_stdin("f\n"):
self.importer.run()
self.io.addinput("f")
self.importer.run()
# input_options should be called once, as foo() returns SKIP
self.mock_input_options.assert_called_once_with(

View file

@ -2,23 +2,24 @@ import unittest
from mediafile import MediaFile
from beets.test.helper import BeetsTestCase, control_stdin
from beets.test.helper import BeetsTestCase, IOMixin
from beets.ui.commands.modify import modify_parse_args
from beets.util import syspath
class ModifyTest(BeetsTestCase):
class ModifyTest(IOMixin, BeetsTestCase):
def setUp(self):
super().setUp()
self.album = self.add_album_fixture()
[self.item] = self.album.items()
def modify_inp(self, inp, *args):
with control_stdin(inp):
self.run_command("modify", *args)
def modify_inp(self, inp: list[str], *args):
for chat in inp:
self.io.addinput(chat)
self.run_command("modify", *args)
def modify(self, *args):
self.modify_inp("y", *args)
self.modify_inp(["y"], *args)
# Item tests
@ -30,14 +31,14 @@ class ModifyTest(BeetsTestCase):
def test_modify_item_abort(self):
item = self.lib.items().get()
title = item.title
self.modify_inp("n", "title=newTitle")
self.modify_inp(["n"], "title=newTitle")
item = self.lib.items().get()
assert item.title == title
def test_modify_item_no_change(self):
title = "Tracktitle"
item = self.add_item_fixture(title=title)
self.modify_inp("y", "title", f"title={title}")
self.modify_inp(["y"], "title", f"title={title}")
item = self.lib.items(title).get()
assert item.title == title
@ -96,7 +97,9 @@ class ModifyTest(BeetsTestCase):
title=f"{title}{i}", artist=original_artist, album=album
)
self.modify_inp(
"s\ny\ny\ny\nn\nn\ny\ny\ny\ny\nn", title, f"artist={new_artist}"
["s", "y", "y", "y", "n", "n", "y", "y", "y", "y", "n"],
title,
f"artist={new_artist}",
)
original_items = self.lib.items(f"artist:{original_artist}")
new_items = self.lib.items(f"artist:{new_artist}")

View file

@ -22,7 +22,7 @@ from random import random
from beets import config, ui
from beets.test import _common
from beets.test.helper import BeetsTestCase, IOMixin, control_stdin
from beets.test.helper import BeetsTestCase, IOMixin
class InputMethodsTest(IOMixin, unittest.TestCase):
@ -85,7 +85,7 @@ class InputMethodsTest(IOMixin, unittest.TestCase):
assert items == ["1", "3"]
class ParentalDirCreation(BeetsTestCase):
class ParentalDirCreation(IOMixin, BeetsTestCase):
def test_create_yes(self):
non_exist_path = _common.os.fsdecode(
os.path.join(self.temp_dir, b"nonexist", str(random()).encode())
@ -94,8 +94,8 @@ class ParentalDirCreation(BeetsTestCase):
# occur; wish I can use a golang defer here.
test_config = deepcopy(config)
test_config["library"] = non_exist_path
with control_stdin("y"):
lib = ui._open_library(test_config)
self.io.addinput("y")
lib = ui._open_library(test_config)
lib._close()
def test_create_no(self):
@ -108,14 +108,14 @@ class ParentalDirCreation(BeetsTestCase):
test_config = deepcopy(config)
test_config["library"] = non_exist_path
with control_stdin("n"):
try:
lib = ui._open_library(test_config)
except ui.UserError:
if os.path.exists(non_exist_path_parent):
shutil.rmtree(non_exist_path_parent)
raise OSError("Parent directories should not be created.")
else:
if lib:
lib._close()
self.io.addinput("n")
try:
lib = ui._open_library(test_config)
except ui.UserError:
if os.path.exists(non_exist_path_parent):
shutil.rmtree(non_exist_path_parent)
raise OSError("Parent directories should not be created.")
else:
if lib:
lib._close()
raise OSError("Parent directories should not be created.")