tests: temporarily modify global configuration

This commit is contained in:
Adrian Sampson 2012-12-14 13:39:14 -08:00
parent e84a41b550
commit 6f19f466fc
3 changed files with 25 additions and 6 deletions

View file

@ -17,6 +17,8 @@ import time
import sys
import os
import logging
import contextlib
import copy
# Use unittest2 on Python < 2.7.
try:
@ -28,6 +30,7 @@ except ImportError:
sys.path.insert(0, '..')
import beets.library
from beets import importer
import beets
# Suppress logging output.
log = logging.getLogger('beets')
@ -75,6 +78,19 @@ def item():
def import_session(lib, logfile=None, paths=[], query=[]):
return importer.ImportSession(lib, logfile, paths, query)
# Temporary config modifications.
@contextlib.contextmanager
def temp_config():
"""A context manager that saves and restores beets' global
configuration. This allows tests to make temporary modifications
that will then be automatically removed when the context exits.
"""
old_sources = copy.deepcopy(beets.config.sources)
old_overlay = copy.deepcopy(beets.config.overlay)
yield
beets.config.sources = old_sources
beets.config.overlay = old_overlay
# Mock timing.

View file

@ -265,13 +265,15 @@ class ArtImporterTest(unittest.TestCase, _common.ExtraAsserts):
self.assertExists(self.art_file)
def test_delete_original_file(self):
config['import']['delete'] = True
self._fetch_art(True)
with _common.temp_config():
config['import']['delete'] = True
self._fetch_art(True)
self.assertNotExists(self.art_file)
def test_move_original_file(self):
config['import']['move'] = True
self._fetch_art(True)
with _common.temp_config():
config['import']['move'] = True
self._fetch_art(True)
self.assertNotExists(self.art_file)
def test_do_not_delete_original_if_already_in_place(self):

View file

@ -589,8 +589,9 @@ class ShowdiffTest(unittest.TestCase):
self.assertTrue('field' in out)
def test_showdiff_ints_no_color(self):
config['color'] = False
commands._showdiff('field', 2, 3)
with _common.temp_config():
config['color'] = False
commands._showdiff('field', 2, 3)
out = self.io.getoutput()
self.assertTrue('field' in out)