diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 397dbcedf..acd131be2 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -901,19 +901,8 @@ class Database: data is written in a transaction. """ - # Check whether parental directories exist. - def _path_checker(self, path): - if not isinstance(path, bytes) and path == ':memory:': # in memory db - return - newpath = os.path.dirname(path) - if not os.path.isdir(newpath): - from beets.ui.commands import database_dir_creation - if database_dir_creation(newpath): - os.makedirs(newpath) - def __init__(self, path, timeout=5.0): self.path = path - self._path_checker(path) self.timeout = timeout self._connections = {} diff --git a/beets/ui/__init__.py b/beets/ui/__init__.py index 121cb5dc0..b724a963a 100644 --- a/beets/ui/__init__.py +++ b/beets/ui/__init__.py @@ -1206,11 +1206,24 @@ def _configure(options): util.displayable_path(config.config_dir())) return config +# Check whether parental directories exist. + + +def _check_db_directory_exists(path): + if path == b':memory:': # in memory db + return + newpath = os.path.dirname(path) + if not os.path.isdir(newpath): + from beets.ui.commands import database_dir_creation + if database_dir_creation(newpath): + os.makedirs(newpath) + def _open_library(config): """Create a new library instance from the configuration. """ dbpath = util.bytestring_path(config['library'].as_filename()) + _check_db_directory_exists(dbpath) try: lib = library.Library( dbpath, diff --git a/beets/ui/commands.py b/beets/ui/commands.py index 394a2831a..1261b1776 100755 --- a/beets/ui/commands.py +++ b/beets/ui/commands.py @@ -1404,7 +1404,8 @@ default_commands.append(version_cmd) def database_dir_creation(path): # Ask the user for a choice. - return ui.input_yn("The database directory {} does not exists, create it (Y/n)?" + return ui.input_yn("The database directory {} does not \ + exists, create it (Y/n)?" .format(displayable_path(path))) diff --git a/test/test_dbcore.py b/test/test_dbcore.py index 95c52196f..80d85c3bb 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -19,11 +19,8 @@ import os import shutil import sqlite3 import unittest -from random import random -from unittest import mock from test import _common -from test.helper import control_stdin from beets import dbcore from tempfile import mkstemp @@ -763,25 +760,6 @@ class ResultsIteratorTest(unittest.TestCase): ModelFixture1, dbcore.query.FalseQuery()).get()) -class ParentalDirCreation(_common.TestCase): - def test_create_yes(self): - non_exist_path = _common.util.py3_path(os.path.join( - self.temp_dir, b'nonexist', str(random()).encode())) - with control_stdin('y'): - dbcore.Database(non_exist_path) - - def test_create_no(self): - non_exist_path_parent = _common.util.py3_path( - os.path.join(self.temp_dir, b'nonexist')) - non_exist_path = _common.util.py3_path(os.path.join( - non_exist_path_parent.encode(), str(random()).encode())) - with control_stdin('n'): - dbcore.Database(non_exist_path) - if os.path.exists(non_exist_path_parent): - shutil.rmtree(non_exist_path_parent) - raise OSError("Should not create dir") - - def suite(): return unittest.TestLoader().loadTestsFromName(__name__) diff --git a/test/test_ui_init.py b/test/test_ui_init.py index bb9a922a5..9f9487a6a 100644 --- a/test/test_ui_init.py +++ b/test/test_ui_init.py @@ -15,11 +15,16 @@ """Test module for file ui/__init__.py """ - +import os +import shutil import unittest -from test import _common +from random import random +from copy import deepcopy from beets import ui +from test import _common +from test.helper import control_stdin +from beets import config class InputMethodsTest(_common.TestCase): @@ -121,8 +126,39 @@ class InitTest(_common.LibTestCase): self.assertEqual(h, ui.human_seconds(i)) +class ParentalDirCreation(_common.TestCase): + def test_create_yes(self): + non_exist_path = _common.util.py3_path(os.path.join( + self.temp_dir, b'nonexist', str(random()).encode())) + # Deepcopy instead of recovering because exceptions might + # occcur; wish I can use a golang defer here. + test_config = deepcopy(config) + test_config['library'] = non_exist_path + with control_stdin('y'): + ui._open_library(test_config) + + def test_create_no(self): + non_exist_path_parent = _common.util.py3_path( + os.path.join(self.temp_dir, b'nonexist')) + non_exist_path = _common.util.py3_path(os.path.join( + non_exist_path_parent.encode(), str(random()).encode())) + test_config = deepcopy(config) + test_config['library'] = non_exist_path + + with control_stdin('n'): + try: + ui._open_library(test_config) + except ui.UserError: + if os.path.exists(non_exist_path_parent): + shutil.rmtree(non_exist_path_parent) + raise OSError("Parent directories should not be created.") + else: + raise OSError("Parent directories should not be created.") + + def suite(): return unittest.TestLoader().loadTestsFromName(__name__) + if __name__ == '__main__': unittest.main(defaultTest='suite')