mirror of
https://github.com/beetbox/beets.git
synced 2025-12-06 16:42:42 +01:00
In preparation for #660, where we will allow MediaFile to expose None values when tags are missing (and consume None to remove tags). This makes it possible to hide nullness in the rest of beets by translating None to a suitable zero-ish value on field assignment. Types can of course opt out of this to preserve a distinct null value. We do this now for the album_id field, which needs to be null to indicate singletons. Type.normalize() also enables more sophisticated translations (e.g., an integer field could round off float values assigned into it) in the future.
273 lines
8 KiB
Python
273 lines
8 KiB
Python
# This file is part of beets.
|
|
# Copyright 2014, Adrian Sampson.
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining
|
|
# a copy of this software and associated documentation files (the
|
|
# "Software"), to deal in the Software without restriction, including
|
|
# without limitation the rights to use, copy, modify, merge, publish,
|
|
# distribute, sublicense, and/or sell copies of the Software, and to
|
|
# permit persons to whom the Software is furnished to do so, subject to
|
|
# the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be
|
|
# included in all copies or substantial portions of the Software.
|
|
|
|
"""Tests for the DBCore database abstraction.
|
|
"""
|
|
import os
|
|
import sqlite3
|
|
|
|
import _common
|
|
from _common import unittest
|
|
from beets import dbcore
|
|
|
|
|
|
# Fixture: concrete database and model classes. For migration tests, we
|
|
# have multiple models with different numbers of fields.
|
|
|
|
class TestModel1(dbcore.Model):
|
|
_table = 'test'
|
|
_flex_table = 'testflex'
|
|
_fields = {
|
|
'id': dbcore.types.Id(),
|
|
'field_one': dbcore.types.Integer(),
|
|
}
|
|
|
|
@classmethod
|
|
def _getters(cls):
|
|
return {}
|
|
|
|
def _template_funcs(self):
|
|
return {}
|
|
|
|
class TestDatabase1(dbcore.Database):
|
|
_models = (TestModel1,)
|
|
pass
|
|
|
|
class TestModel2(TestModel1):
|
|
_fields = {
|
|
'id': dbcore.types.Id(),
|
|
'field_one': dbcore.types.Integer(),
|
|
'field_two': dbcore.types.Integer(),
|
|
}
|
|
|
|
class TestDatabase2(dbcore.Database):
|
|
_models = (TestModel2,)
|
|
pass
|
|
|
|
class TestModel3(TestModel1):
|
|
_fields = {
|
|
'id': dbcore.types.Id(),
|
|
'field_one': dbcore.types.Integer(),
|
|
'field_two': dbcore.types.Integer(),
|
|
'field_three': dbcore.types.Integer(),
|
|
}
|
|
|
|
class TestDatabase3(dbcore.Database):
|
|
_models = (TestModel3,)
|
|
pass
|
|
|
|
class TestModel4(TestModel1):
|
|
_fields = {
|
|
'id': dbcore.types.Id(),
|
|
'field_one': dbcore.types.Integer(),
|
|
'field_two': dbcore.types.Integer(),
|
|
'field_three': dbcore.types.Integer(),
|
|
'field_four': dbcore.types.Integer(),
|
|
}
|
|
|
|
class TestDatabase4(dbcore.Database):
|
|
_models = (TestModel4,)
|
|
pass
|
|
|
|
class AnotherTestModel(TestModel1):
|
|
_table = 'another'
|
|
_flex_table = 'anotherflex'
|
|
_fields = {
|
|
'id': dbcore.types.Id(),
|
|
'foo': dbcore.types.Integer(),
|
|
}
|
|
|
|
class TestDatabaseTwoModels(dbcore.Database):
|
|
_models = (TestModel2, AnotherTestModel)
|
|
pass
|
|
|
|
|
|
class MigrationTest(_common.TestCase):
|
|
"""Tests the ability to change the database schema between
|
|
versions.
|
|
"""
|
|
def setUp(self):
|
|
super(MigrationTest, self).setUp()
|
|
|
|
# Set up a database with the two-field schema.
|
|
self.libfile = os.path.join(self.temp_dir, 'temp.db')
|
|
old_lib = TestDatabase2(self.libfile)
|
|
|
|
# Add an item to the old library.
|
|
old_lib._connection().execute(
|
|
'insert into test (field_one, field_two) values (4, 2)'
|
|
)
|
|
old_lib._connection().commit()
|
|
del old_lib
|
|
|
|
def test_open_with_same_fields_leaves_untouched(self):
|
|
new_lib = TestDatabase2(self.libfile)
|
|
c = new_lib._connection().cursor()
|
|
c.execute("select * from test")
|
|
row = c.fetchone()
|
|
self.assertEqual(len(row.keys()), len(TestModel2._fields))
|
|
|
|
def test_open_with_new_field_adds_column(self):
|
|
new_lib = TestDatabase3(self.libfile)
|
|
c = new_lib._connection().cursor()
|
|
c.execute("select * from test")
|
|
row = c.fetchone()
|
|
self.assertEqual(len(row.keys()), len(TestModel3._fields))
|
|
|
|
def test_open_with_fewer_fields_leaves_untouched(self):
|
|
new_lib = TestDatabase1(self.libfile)
|
|
c = new_lib._connection().cursor()
|
|
c.execute("select * from test")
|
|
row = c.fetchone()
|
|
self.assertEqual(len(row.keys()), len(TestModel2._fields))
|
|
|
|
def test_open_with_multiple_new_fields(self):
|
|
new_lib = TestDatabase4(self.libfile)
|
|
c = new_lib._connection().cursor()
|
|
c.execute("select * from test")
|
|
row = c.fetchone()
|
|
self.assertEqual(len(row.keys()), len(TestModel4._fields))
|
|
|
|
def test_extra_model_adds_table(self):
|
|
new_lib = TestDatabaseTwoModels(self.libfile)
|
|
try:
|
|
new_lib._connection().execute("select * from another")
|
|
except sqlite3.OperationalError:
|
|
self.fail("select failed")
|
|
|
|
|
|
class ModelTest(_common.TestCase):
|
|
def setUp(self):
|
|
super(ModelTest, self).setUp()
|
|
dbfile = os.path.join(self.temp_dir, 'temp.db')
|
|
self.db = TestDatabase1(dbfile)
|
|
|
|
def test_add_model(self):
|
|
model = TestModel1()
|
|
model.add(self.db)
|
|
rows = self.db._connection().execute('select * from test').fetchall()
|
|
self.assertEqual(len(rows), 1)
|
|
|
|
def test_store_fixed_field(self):
|
|
model = TestModel1()
|
|
model.add(self.db)
|
|
model.field_one = 123
|
|
model.store()
|
|
row = self.db._connection().execute('select * from test').fetchone()
|
|
self.assertEqual(row['field_one'], 123)
|
|
|
|
def test_retrieve_by_id(self):
|
|
model = TestModel1()
|
|
model.add(self.db)
|
|
other_model = self.db._get(TestModel1, model.id)
|
|
self.assertEqual(model.id, other_model.id)
|
|
|
|
def test_store_and_retrieve_flexattr(self):
|
|
model = TestModel1()
|
|
model.add(self.db)
|
|
model.foo = 'bar'
|
|
model.store()
|
|
|
|
other_model = self.db._get(TestModel1, model.id)
|
|
self.assertEqual(other_model.foo, 'bar')
|
|
|
|
def test_delete_flexattr(self):
|
|
model = TestModel1()
|
|
model['foo'] = 'bar'
|
|
self.assertTrue('foo' in model)
|
|
del model['foo']
|
|
self.assertFalse('foo' in model)
|
|
|
|
def test_delete_flexattr_via_dot(self):
|
|
model = TestModel1()
|
|
model['foo'] = 'bar'
|
|
self.assertTrue('foo' in model)
|
|
del model.foo
|
|
self.assertFalse('foo' in model)
|
|
|
|
def test_delete_flexattr_persists(self):
|
|
model = TestModel1()
|
|
model.add(self.db)
|
|
model.foo = 'bar'
|
|
model.store()
|
|
|
|
model = self.db._get(TestModel1, model.id)
|
|
del model['foo']
|
|
model.store()
|
|
|
|
model = self.db._get(TestModel1, model.id)
|
|
self.assertFalse('foo' in model)
|
|
|
|
def test_delete_non_existent_attribute(self):
|
|
model = TestModel1()
|
|
with self.assertRaises(KeyError):
|
|
del model['foo']
|
|
|
|
def test_delete_fixed_attribute(self):
|
|
model = TestModel1()
|
|
with self.assertRaises(KeyError):
|
|
del model['field_one']
|
|
|
|
def test_null_value_normalization_by_type(self):
|
|
model = TestModel1()
|
|
model.field_one = None
|
|
self.assertEqual(model.field_one, 0)
|
|
|
|
def test_null_value_stays_none_for_untyped_field(self):
|
|
model = TestModel1()
|
|
model.foo = None
|
|
self.assertEqual(model.foo, None)
|
|
|
|
|
|
class FormatTest(_common.TestCase):
|
|
def test_format_fixed_field(self):
|
|
model = TestModel1()
|
|
model.field_one = u'caf\xe9'
|
|
value = model._get_formatted('field_one')
|
|
self.assertEqual(value, u'caf\xe9')
|
|
|
|
def test_format_flex_field(self):
|
|
model = TestModel1()
|
|
model.other_field = u'caf\xe9'
|
|
value = model._get_formatted('other_field')
|
|
self.assertEqual(value, u'caf\xe9')
|
|
|
|
def test_format_flex_field_bytes(self):
|
|
model = TestModel1()
|
|
model.other_field = u'caf\xe9'.encode('utf8')
|
|
value = model._get_formatted('other_field')
|
|
self.assertTrue(isinstance(value, unicode))
|
|
self.assertEqual(value, u'caf\xe9')
|
|
|
|
def test_format_unset_field(self):
|
|
model = TestModel1()
|
|
value = model._get_formatted('other_field')
|
|
self.assertEqual(value, u'')
|
|
|
|
|
|
class ParseTest(_common.TestCase):
|
|
def test_parse_fixed_field(self):
|
|
value = TestModel1._parse('field_one', u'2')
|
|
self.assertEqual(value, 2)
|
|
|
|
def test_parse_untyped_field(self):
|
|
value = TestModel1._parse('field_nine', u'2')
|
|
self.assertEqual(value, u'2')
|
|
|
|
|
|
def suite():
|
|
return unittest.TestLoader().loadTestsFromName(__name__)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(defaultTest='suite')
|