diff --git a/beets/library.py b/beets/library.py index 20d22dbbf..0209e6600 100644 --- a/beets/library.py +++ b/beets/library.py @@ -73,8 +73,6 @@ log.addHandler(logging.StreamHandler()) #### exceptions #### -class LibraryError(Exception): - pass class InvalidFieldError(Exception): pass @@ -616,7 +614,8 @@ class Library(BaseLibrary): """A music library using an SQLite database as a metadata store.""" def __init__(self, path='library.blb', directory='~/Music', - path_format='$artist/$album/$track $title'): + path_format='$artist/$album/$track $title', + fields=item_fields): self.path = path self.directory = directory self.path_format = path_format @@ -625,13 +624,39 @@ class Library(BaseLibrary): self.conn.row_factory = sqlite3.Row # this way we can access our SELECT results like dictionaries - self._setup() + self._setup(fields) - def _setup(self): - """Set up the schema of the library file.""" - setup_sql = 'CREATE TABLE IF NOT EXISTS items (' - setup_sql += ', '.join([' '.join(f) for f in item_fields]) - setup_sql += ');' + def _setup(self, fields): + """Set up the schema of the library file. fields is a list + of (name, type) pairs indicating all the fields that should + be present in the table. Columns are added if necessary. + """ + # Get current schema. + cur = self.conn.cursor() + cur.execute('PRAGMA table_info(items)') + current_fields = set([row[1] for row in cur]) + + field_names = set([f[0] for f in fields]) + if current_fields.issuperset(field_names): + # Table exists and has all the required columns. + return + + if not current_fields: + # No table exists. + setup_sql = 'CREATE TABLE items (' + setup_sql += ', '.join([' '.join(f) for f in fields]) + setup_sql += ');' + + else: + # Table exists but is missing fields. + for fname in field_names - current_fields: + for field in fields: + if field[0] == fname: + break + else: + assert False + setup_sql = 'ALTER TABLE items ADD COLUMN ' + \ + ' '.join(field) + ';' self.conn.executescript(setup_sql) self.conn.commit() diff --git a/test/test_db.py b/test/test_db.py index 3c8a8039a..3751ce7f2 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -111,7 +111,6 @@ class AddTest(unittest.TestCase): 'where composer="the composer"').fetchone()['grouping'] self.assertEqual(new_grouping, self.i.grouping) - class RemoveTest(unittest.TestCase): def setUp(self): self.lib = lib() @@ -207,7 +206,52 @@ class DestinationTest(unittest.TestCase): self.i.path = 'something.extn' dest = self.lib.destination(self.i) self.assertEqual(dest[-5:], '.extn') + +class MigrationTest(unittest.TestCase): + """Tests the ability to change the database schema between + versions. + """ + def setUp(self): + # Three different "schema versions". + self.older_fields = [('field_one', 'int')] + self.old_fields = self.older_fields + [('field_two', 'int')] + self.new_fields = self.old_fields + [('field_three', 'int')] + # Set up a library with old_fields. + self.libfile = os.path.join('rsrc', 'templib.blb') + old_lib = beets.library.Library(self.libfile, fields=self.old_fields) + # Add an item to the old library. + old_lib.conn.execute( + 'insert into items (field_one, field_two) values (4, 2)' + ) + old_lib.save() + del old_lib + + def tearDown(self): + os.unlink(self.libfile) + + def test_open_with_same_fields_leaves_untouched(self): + new_lib = beets.library.Library(self.libfile, fields=self.old_fields) + c = new_lib.conn.cursor() + c.execute("select * from items") + row = c.fetchone() + self.assertEqual(len(row), len(self.old_fields)) + + def test_open_with_new_field_adds_column(self): + new_lib = beets.library.Library(self.libfile, fields=self.new_fields) + c = new_lib.conn.cursor() + c.execute("select * from items") + row = c.fetchone() + self.assertEqual(len(row), len(self.new_fields)) + + def test_open_with_fewer_fields_leaves_untouched(self): + new_lib = beets.library.Library(self.libfile, fields=self.older_fields) + c = new_lib.conn.cursor() + c.execute("select * from items") + row = c.fetchone() + self.assertEqual(len(row), len(self.old_fields)) + + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)