Updated class fields to allow for easier unit testing

This commit is contained in:
Austin Marino 2019-10-10 19:35:49 -07:00
parent 8ff875bded
commit c31b488e54

View file

@ -42,6 +42,9 @@ class ExportPlugin(BeetsPlugin):
def __init__(self):
super(ExportPlugin, self).__init__()
# Used when testing export plugin
self.run_results = None
self.export_format = None
self.config.add({
'default_format': 'json',
@ -68,7 +71,7 @@ class ExportPlugin(BeetsPlugin):
'formatting': {
'ensure_ascii': False,
'indent': 4,
'separators': (''),
'separators': ('>'),
'sort_keys': True
}
}
@ -105,13 +108,12 @@ class ExportPlugin(BeetsPlugin):
return [cmd]
def run(self, lib, opts, args):
file_path = opts.output
file_mode = 'a' if opts.append else 'w'
file_format = opts.format
format_options = self.config[file_format]['formatting'].get(dict)
export_format = ExportFormat.factory(
self.export_format = ExportFormat.factory(
file_type=file_format,
**{
'file_path': file_path,
@ -135,7 +137,9 @@ class ExportPlugin(BeetsPlugin):
continue
data = key_filter(data)
items += [data]
export_format.export(items, **format_options)
self.run_results = items
self.export_format.export(self.run_results, **format_options)
class ExportFormat(object):
@ -144,6 +148,8 @@ class ExportFormat(object):
self.path = file_path
self.mode = file_mode
self.encoding = encoding
# Used for testing
self.results = None
@classmethod
def factory(cls, file_type, **kwargs):
@ -167,11 +173,13 @@ class JsonFormat(ExportFormat):
self.export = self.export_to_file if self.path else self.export_to_terminal
def export_to_terminal(self, data, **kwargs):
json.dump(data, sys.stdout, cls=ExportEncoder, **kwargs)
r = json.dump(data, sys.stdout, cls=ExportEncoder, **kwargs)
self.results = str(r)
def export_to_file(self, data, **kwargs):
with codecs.open(self.path, self.mode, self.encoding) as f:
json.dump(data, f, cls=ExportEncoder, **kwargs)
r = json.dump(data, f, cls=ExportEncoder, **kwargs)
self.results = str(r)
class CSVFormat(ExportFormat):
@ -192,12 +200,15 @@ class CSVFormat(ExportFormat):
writer = csv.DictWriter(sys.stdout, fieldnames=self.header)
writer.writeheader()
writer.writerows(data)
self.results = str(writer)
def export_to_file(self, data, **kwargs):
with codecs.open(self.path, self.mode, self.encoding) as f:
writer = csv.DictWriter(f, fieldnames=self.header)
writer.writeheader()
writer.writerows(data)
self.results = str(writer)
class XMLFormat(ExportFormat):
@ -233,7 +244,9 @@ class XMLFormat(ExportFormat):
def export_to_terminal(self, data, **kwargs):
print(data)
self.results = str(data)
def export_to_file(self, data, **kwargs):
with codecs.open(self.path, self.mode, self.encoding) as f:
f.write(data)
f.write(data)
self.results = str(data)