Updated class fields to allow for easier unit testing

This commit is contained in:
Austin Marino 2019-10-10 19:35:49 -07:00
parent 8ff875bded
commit c31b488e54

View file

@ -42,6 +42,9 @@ class ExportPlugin(BeetsPlugin):
def __init__(self): def __init__(self):
super(ExportPlugin, self).__init__() super(ExportPlugin, self).__init__()
# Used when testing export plugin
self.run_results = None
self.export_format = None
self.config.add({ self.config.add({
'default_format': 'json', 'default_format': 'json',
@ -68,7 +71,7 @@ class ExportPlugin(BeetsPlugin):
'formatting': { 'formatting': {
'ensure_ascii': False, 'ensure_ascii': False,
'indent': 4, 'indent': 4,
'separators': (''), 'separators': ('>'),
'sort_keys': True 'sort_keys': True
} }
} }
@ -105,13 +108,12 @@ class ExportPlugin(BeetsPlugin):
return [cmd] return [cmd]
def run(self, lib, opts, args): def run(self, lib, opts, args):
file_path = opts.output file_path = opts.output
file_mode = 'a' if opts.append else 'w' file_mode = 'a' if opts.append else 'w'
file_format = opts.format file_format = opts.format
format_options = self.config[file_format]['formatting'].get(dict) format_options = self.config[file_format]['formatting'].get(dict)
export_format = ExportFormat.factory( self.export_format = ExportFormat.factory(
file_type=file_format, file_type=file_format,
**{ **{
'file_path': file_path, 'file_path': file_path,
@ -135,7 +137,9 @@ class ExportPlugin(BeetsPlugin):
continue continue
data = key_filter(data) data = key_filter(data)
items += [data] items += [data]
export_format.export(items, **format_options)
self.run_results = items
self.export_format.export(self.run_results, **format_options)
class ExportFormat(object): class ExportFormat(object):
@ -144,6 +148,8 @@ class ExportFormat(object):
self.path = file_path self.path = file_path
self.mode = file_mode self.mode = file_mode
self.encoding = encoding self.encoding = encoding
# Used for testing
self.results = None
@classmethod @classmethod
def factory(cls, file_type, **kwargs): def factory(cls, file_type, **kwargs):
@ -167,11 +173,13 @@ class JsonFormat(ExportFormat):
self.export = self.export_to_file if self.path else self.export_to_terminal self.export = self.export_to_file if self.path else self.export_to_terminal
def export_to_terminal(self, data, **kwargs): def export_to_terminal(self, data, **kwargs):
json.dump(data, sys.stdout, cls=ExportEncoder, **kwargs) r = json.dump(data, sys.stdout, cls=ExportEncoder, **kwargs)
self.results = str(r)
def export_to_file(self, data, **kwargs): def export_to_file(self, data, **kwargs):
with codecs.open(self.path, self.mode, self.encoding) as f: with codecs.open(self.path, self.mode, self.encoding) as f:
json.dump(data, f, cls=ExportEncoder, **kwargs) r = json.dump(data, f, cls=ExportEncoder, **kwargs)
self.results = str(r)
class CSVFormat(ExportFormat): class CSVFormat(ExportFormat):
@ -192,12 +200,15 @@ class CSVFormat(ExportFormat):
writer = csv.DictWriter(sys.stdout, fieldnames=self.header) writer = csv.DictWriter(sys.stdout, fieldnames=self.header)
writer.writeheader() writer.writeheader()
writer.writerows(data) writer.writerows(data)
self.results = str(writer)
def export_to_file(self, data, **kwargs): def export_to_file(self, data, **kwargs):
with codecs.open(self.path, self.mode, self.encoding) as f: with codecs.open(self.path, self.mode, self.encoding) as f:
writer = csv.DictWriter(f, fieldnames=self.header) writer = csv.DictWriter(f, fieldnames=self.header)
writer.writeheader() writer.writeheader()
writer.writerows(data) writer.writerows(data)
self.results = str(writer)
class XMLFormat(ExportFormat): class XMLFormat(ExportFormat):
@ -233,7 +244,9 @@ class XMLFormat(ExportFormat):
def export_to_terminal(self, data, **kwargs): def export_to_terminal(self, data, **kwargs):
print(data) print(data)
self.results = str(data)
def export_to_file(self, data, **kwargs): def export_to_file(self, data, **kwargs):
with codecs.open(self.path, self.mode, self.encoding) as f: with codecs.open(self.path, self.mode, self.encoding) as f:
f.write(data) f.write(data)
self.results = str(data)