Import tar archives

Also refactors the importer code to make better use of ArchiveImportTask.
This commit is contained in:
Thomas Scholtes 2014-04-15 17:23:50 +02:00
parent b783097329
commit e3acdd0cc8
2 changed files with 82 additions and 21 deletions

View file

@ -22,7 +22,6 @@ import logging
import pickle
import itertools
from collections import defaultdict
from zipfile import is_zipfile, ZipFile
from tempfile import mkdtemp
import shutil
@ -553,9 +552,60 @@ class ArchiveImportTask(ImportTask):
def __init__(self, toppath):
super(ArchiveImportTask, self).__init__(toppath)
self.sentinel = True
self.extracted = False
@classmethod
def is_archive(cls, path):
"""Returns true if the given path points to an archive that can
be handled.
"""
if not os.path.isfile(path):
return False
for path_test, _ in cls.handlers():
if path_test(path):
return True
return False
@classmethod
def handlers(cls):
"""Returns a list of archive handlers.
Each handler is a `(path_test, ArchiveClass)` tuple. `path_test`
is a function that returns `True` if the given path can be
handled by `ArchiveClass`. `ArchiveClass` is a class that
implements the same interface as `tarfile.TarFile`.
"""
if not hasattr(cls, '_handlers'):
cls._handlers = []
from zipfile import is_zipfile, ZipFile
cls._handlers.append((is_zipfile, ZipFile))
from tarfile import is_tarfile, TarFile
cls._handlers.append((is_tarfile, TarFile))
return cls._handlers
def cleanup(self):
shutil.rmtree(self.toppath)
"""Removes the temporary directory the archive was extracted to.
"""
if self.extracted:
shutil.rmtree(self.toppath)
def extract(self):
"""Extracts the archive to a temporary directory and sets
`toppath` to that directory.
"""
for path_test, handler_class in self.handlers():
if path_test(self.toppath):
break
try:
extract_to = mkdtemp()
archive = handler_class(self.toppath, mode='r')
archive.extractall(extract_to)
finally:
archive.close()
self.extracted = True
self.toppath = extract_to
# Full-album pipeline stages.
@ -591,8 +641,9 @@ def read_tasks(session):
history_dirs = history_get()
for toppath in session.paths:
extracted = None
if is_zipfile(syspath(toppath)):
# Extract archives
archive_task = None
if ArchiveImportTask.is_archive(syspath(toppath)):
if not (config['import']['move'] or config['import']['copy']):
log.warn("Cannot import archive. Please set "
"the 'move' or 'copy' option.")
@ -600,16 +651,13 @@ def read_tasks(session):
log.debug('extracting archive {0}'
.format(displayable_path(toppath)))
archive_task = ArchiveImportTask(toppath)
try:
extracted = mkdtemp()
zip_file = ZipFile(toppath, mode='r')
zip_file.extractall(extracted)
archive_task.extract()
except IOError as exc:
log.error('extraction failed: {0}'.format(exc))
continue
finally:
zip_file.close()
toppath = extracted
toppath = archive_task.toppath
# Check whether the path is to a file.
if not os.path.isdir(syspath(toppath)):
@ -666,10 +714,10 @@ def read_tasks(session):
# Indicate the directory is finished.
# FIXME hack to delete extraced archives
if extracted is None:
if archive_task is None:
yield ImportTask.done_sentinel(toppath)
else:
yield ArchiveImportTask(extracted)
yield archive_task
# Show skipped directories.
if config['import']['incremental'] and incremental_skipped:

View file

@ -19,6 +19,7 @@ import shutil
import StringIO
from tempfile import mkstemp
from zipfile import ZipFile
from tarfile import TarFile
import _common
from _common import unittest
@ -301,7 +302,7 @@ class NonAutotaggedImportTest(_common.TestCase, ImportHelper):
self.assertNotExists(os.path.join(self.import_dir, 'the_album'))
class ImportArchiveTest(unittest.TestCase, ImportHelper):
class ImportZipTest(unittest.TestCase, ImportHelper):
def setUp(self):
self.setup_beets()
@ -310,7 +311,7 @@ class ImportArchiveTest(unittest.TestCase, ImportHelper):
self.teardown_beets()
def test_import_zip(self):
zip_path = self.create_zip_archive()
zip_path = self.create_archive()
self.assertEqual(len(self.lib.items()), 0)
self.assertEqual(len(self.lib.albums()), 0)
@ -319,14 +320,26 @@ class ImportArchiveTest(unittest.TestCase, ImportHelper):
self.assertEqual(len(self.lib.items()), 1)
self.assertEqual(len(self.lib.albums()), 1)
def create_zip_archive(self):
(handle, zip_path) = mkstemp('.zip', dir=self.temp_dir)
def create_archive(self):
(handle, path) = mkstemp(dir=self.temp_dir)
os.close(handle)
zip_file = ZipFile(zip_path, mode='w')
zip_file.write(os.path.join(_common.RSRC, 'full.mp3'),
'full.mp3')
zip_file.close()
return zip_path
archive = ZipFile(path, mode='w')
archive.write(os.path.join(_common.RSRC, 'full.mp3'),
'full.mp3')
archive.close()
return path
class ImportTarTest(ImportZipTest):
def create_archive(self):
(handle, path) = mkstemp(dir=self.temp_dir)
os.close(handle)
archive = TarFile(path, mode='w')
archive.add(os.path.join(_common.RSRC, 'full.mp3'),
'full.mp3')
archive.close()
return path
class ImportSingletonTest(_common.TestCase, ImportHelper):