diff --git a/beets/importer.py b/beets/importer.py index 9d9c7e962..d31fd1a11 100644 --- a/beets/importer.py +++ b/beets/importer.py @@ -22,7 +22,6 @@ import logging import pickle import itertools from collections import defaultdict -from zipfile import is_zipfile, ZipFile from tempfile import mkdtemp import shutil @@ -553,9 +552,60 @@ class ArchiveImportTask(ImportTask): def __init__(self, toppath): super(ArchiveImportTask, self).__init__(toppath) self.sentinel = True + self.extracted = False + + @classmethod + def is_archive(cls, path): + """Returns true if the given path points to an archive that can + be handled. + """ + if not os.path.isfile(path): + return False + + for path_test, _ in cls.handlers(): + if path_test(path): + return True + return False + + @classmethod + def handlers(cls): + """Returns a list of archive handlers. + + Each handler is a `(path_test, ArchiveClass)` tuple. `path_test` + is a function that returns `True` if the given path can be + handled by `ArchiveClass`. `ArchiveClass` is a class that + implements the same interface as `tarfile.TarFile`. + """ + if not hasattr(cls, '_handlers'): + cls._handlers = [] + from zipfile import is_zipfile, ZipFile + cls._handlers.append((is_zipfile, ZipFile)) + from tarfile import is_tarfile, TarFile + cls._handlers.append((is_tarfile, TarFile)) + return cls._handlers def cleanup(self): - shutil.rmtree(self.toppath) + """Removes the temporary directory the archive was extracted to. + """ + if self.extracted: + shutil.rmtree(self.toppath) + + def extract(self): + """Extracts the archive to a temporary directory and sets + `toppath` to that directory. + """ + for path_test, handler_class in self.handlers(): + if path_test(self.toppath): + break + + try: + extract_to = mkdtemp() + archive = handler_class(self.toppath, mode='r') + archive.extractall(extract_to) + finally: + archive.close() + self.extracted = True + self.toppath = extract_to # Full-album pipeline stages. @@ -591,8 +641,9 @@ def read_tasks(session): history_dirs = history_get() for toppath in session.paths: - extracted = None - if is_zipfile(syspath(toppath)): + # Extract archives + archive_task = None + if ArchiveImportTask.is_archive(syspath(toppath)): if not (config['import']['move'] or config['import']['copy']): log.warn("Cannot import archive. Please set " "the 'move' or 'copy' option.") @@ -600,16 +651,13 @@ def read_tasks(session): log.debug('extracting archive {0}' .format(displayable_path(toppath))) + archive_task = ArchiveImportTask(toppath) try: - extracted = mkdtemp() - zip_file = ZipFile(toppath, mode='r') - zip_file.extractall(extracted) + archive_task.extract() except IOError as exc: log.error('extraction failed: {0}'.format(exc)) continue - finally: - zip_file.close() - toppath = extracted + toppath = archive_task.toppath # Check whether the path is to a file. if not os.path.isdir(syspath(toppath)): @@ -666,10 +714,10 @@ def read_tasks(session): # Indicate the directory is finished. # FIXME hack to delete extraced archives - if extracted is None: + if archive_task is None: yield ImportTask.done_sentinel(toppath) else: - yield ArchiveImportTask(extracted) + yield archive_task # Show skipped directories. if config['import']['incremental'] and incremental_skipped: diff --git a/test/test_importer.py b/test/test_importer.py index c8c4dde9b..854b21d4b 100644 --- a/test/test_importer.py +++ b/test/test_importer.py @@ -19,6 +19,7 @@ import shutil import StringIO from tempfile import mkstemp from zipfile import ZipFile +from tarfile import TarFile import _common from _common import unittest @@ -301,7 +302,7 @@ class NonAutotaggedImportTest(_common.TestCase, ImportHelper): self.assertNotExists(os.path.join(self.import_dir, 'the_album')) -class ImportArchiveTest(unittest.TestCase, ImportHelper): +class ImportZipTest(unittest.TestCase, ImportHelper): def setUp(self): self.setup_beets() @@ -310,7 +311,7 @@ class ImportArchiveTest(unittest.TestCase, ImportHelper): self.teardown_beets() def test_import_zip(self): - zip_path = self.create_zip_archive() + zip_path = self.create_archive() self.assertEqual(len(self.lib.items()), 0) self.assertEqual(len(self.lib.albums()), 0) @@ -319,14 +320,26 @@ class ImportArchiveTest(unittest.TestCase, ImportHelper): self.assertEqual(len(self.lib.items()), 1) self.assertEqual(len(self.lib.albums()), 1) - def create_zip_archive(self): - (handle, zip_path) = mkstemp('.zip', dir=self.temp_dir) + def create_archive(self): + (handle, path) = mkstemp(dir=self.temp_dir) os.close(handle) - zip_file = ZipFile(zip_path, mode='w') - zip_file.write(os.path.join(_common.RSRC, 'full.mp3'), - 'full.mp3') - zip_file.close() - return zip_path + archive = ZipFile(path, mode='w') + archive.write(os.path.join(_common.RSRC, 'full.mp3'), + 'full.mp3') + archive.close() + return path + + +class ImportTarTest(ImportZipTest): + + def create_archive(self): + (handle, path) = mkstemp(dir=self.temp_dir) + os.close(handle) + archive = TarFile(path, mode='w') + archive.add(os.path.join(_common.RSRC, 'full.mp3'), + 'full.mp3') + archive.close() + return path class ImportSingletonTest(_common.TestCase, ImportHelper):