fix code review comments

This commit is contained in:
alicezou 2022-03-29 21:24:13 -04:00
parent 67e778fec6
commit 2886296c86
5 changed files with 53 additions and 36 deletions

View file

@ -901,19 +901,8 @@ class Database:
data is written in a transaction.
"""
# Check whether parental directories exist.
def _path_checker(self, path):
if not isinstance(path, bytes) and path == ':memory:': # in memory db
return
newpath = os.path.dirname(path)
if not os.path.isdir(newpath):
from beets.ui.commands import database_dir_creation
if database_dir_creation(newpath):
os.makedirs(newpath)
def __init__(self, path, timeout=5.0):
self.path = path
self._path_checker(path)
self.timeout = timeout
self._connections = {}

View file

@ -1206,11 +1206,24 @@ def _configure(options):
util.displayable_path(config.config_dir()))
return config
# Check whether parental directories exist.
def _check_db_directory_exists(path):
if path == b':memory:': # in memory db
return
newpath = os.path.dirname(path)
if not os.path.isdir(newpath):
from beets.ui.commands import database_dir_creation
if database_dir_creation(newpath):
os.makedirs(newpath)
def _open_library(config):
"""Create a new library instance from the configuration.
"""
dbpath = util.bytestring_path(config['library'].as_filename())
_check_db_directory_exists(dbpath)
try:
lib = library.Library(
dbpath,

View file

@ -1404,7 +1404,8 @@ default_commands.append(version_cmd)
def database_dir_creation(path):
# Ask the user for a choice.
return ui.input_yn("The database directory {} does not exists, create it (Y/n)?"
return ui.input_yn("The database directory {} does not \
exists, create it (Y/n)?"
.format(displayable_path(path)))

View file

@ -19,11 +19,8 @@ import os
import shutil
import sqlite3
import unittest
from random import random
from unittest import mock
from test import _common
from test.helper import control_stdin
from beets import dbcore
from tempfile import mkstemp
@ -763,25 +760,6 @@ class ResultsIteratorTest(unittest.TestCase):
ModelFixture1, dbcore.query.FalseQuery()).get())
class ParentalDirCreation(_common.TestCase):
def test_create_yes(self):
non_exist_path = _common.util.py3_path(os.path.join(
self.temp_dir, b'nonexist', str(random()).encode()))
with control_stdin('y'):
dbcore.Database(non_exist_path)
def test_create_no(self):
non_exist_path_parent = _common.util.py3_path(
os.path.join(self.temp_dir, b'nonexist'))
non_exist_path = _common.util.py3_path(os.path.join(
non_exist_path_parent.encode(), str(random()).encode()))
with control_stdin('n'):
dbcore.Database(non_exist_path)
if os.path.exists(non_exist_path_parent):
shutil.rmtree(non_exist_path_parent)
raise OSError("Should not create dir")
def suite():
return unittest.TestLoader().loadTestsFromName(__name__)

View file

@ -15,11 +15,16 @@
"""Test module for file ui/__init__.py
"""
import os
import shutil
import unittest
from test import _common
from random import random
from copy import deepcopy
from beets import ui
from test import _common
from test.helper import control_stdin
from beets import config
class InputMethodsTest(_common.TestCase):
@ -121,8 +126,39 @@ class InitTest(_common.LibTestCase):
self.assertEqual(h, ui.human_seconds(i))
class ParentalDirCreation(_common.TestCase):
def test_create_yes(self):
non_exist_path = _common.util.py3_path(os.path.join(
self.temp_dir, b'nonexist', str(random()).encode()))
# Deepcopy instead of recovering because exceptions might
# occcur; wish I can use a golang defer here.
test_config = deepcopy(config)
test_config['library'] = non_exist_path
with control_stdin('y'):
ui._open_library(test_config)
def test_create_no(self):
non_exist_path_parent = _common.util.py3_path(
os.path.join(self.temp_dir, b'nonexist'))
non_exist_path = _common.util.py3_path(os.path.join(
non_exist_path_parent.encode(), str(random()).encode()))
test_config = deepcopy(config)
test_config['library'] = non_exist_path
with control_stdin('n'):
try:
ui._open_library(test_config)
except ui.UserError:
if os.path.exists(non_exist_path_parent):
shutil.rmtree(non_exist_path_parent)
raise OSError("Parent directories should not be created.")
else:
raise OSError("Parent directories should not be created.")
def suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == '__main__':
unittest.main(defaultTest='suite')