From 1d965b30d1ef9d348acfb092fe7f01df02b8ed75 Mon Sep 17 00:00:00 2001 From: Adrian Sampson Date: Fri, 27 Mar 2015 22:06:46 -0400 Subject: [PATCH] Add redaction to Confit --- beets/util/confit.py | 183 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 159 insertions(+), 24 deletions(-) diff --git a/beets/util/confit.py b/beets/util/confit.py index 110d00001..8ba51cd2f 100644 --- a/beets/util/confit.py +++ b/beets/util/confit.py @@ -42,6 +42,8 @@ ROOT_NAME = 'root' YAML_TAB_PROBLEM = "found character '\\t' that cannot start any token" +REDACTED_TOMBSTONE = 'REDACTED' + # Utilities. @@ -213,7 +215,8 @@ class ConfigView(object): return '<{}: {}>'.format(self.__class__.__name__, self.name) def __iter__(self): - """Prevent list(config) from using __getitem__ and never halting""" + # Prevent list(config) from using __getitem__ and entering an + # infinite loop. raise TypeError(u"{!r} object is not " u"iterable".format(self.__class__.__name__)) @@ -246,14 +249,17 @@ class ConfigView(object): # just say ``bool(view)`` or use ``view`` in a conditional. def __str__(self): - """Gets the value for this view as a byte string.""" - return bytes(self.get()) + """Get the value for this view as a bytestring. + """ + if PY3: + return self.__unicode__() + else: + return bytes(self.get()) def __unicode__(self): - """Gets the value for this view as a unicode string. (Python 2 - only.) + """Get the value for this view as a Unicode string. """ - return unicode(self.get()) + return STRING(self.get()) def __nonzero__(self): """Gets the value for this view as a boolean. (Python 2 only.) @@ -333,17 +339,23 @@ class ConfigView(object): # Validation and conversion. - def flatten(self): + def flatten(self, redact=False): """Create a hierarchy of OrderedDicts containing the data from this view, recursively reifying all views to get their represented values. + + If `redact` is set, then sensitive values are replaced with + the string "REDACTED". """ od = OrderedDict() for key, view in self.items(): - try: - od[key] = view.flatten() - except ConfigTypeError: - od[key] = view.get() + if redact and view.redact: + od[key] = REDACTED_TOMBSTONE + else: + try: + od[key] = view.flatten(redact=True) + except ConfigTypeError: + od[key] = view.get() return od def get(self, template=None): @@ -375,6 +387,30 @@ class ConfigView(object): def as_str_seq(self): return self.get(StrSeq()) + # Redaction. + + @property + def redact(self): + """Whether the view contains sensitive information and should be + redacted from output. + """ + return () in self.get_redactions() + + @redact.setter + def redact(self, flag): + self.set_redaction((), flag) + + def set_redaction(self, path, flag): + """Add or remove a redaction for a key path, which should be an + iterable of keys. + """ + raise NotImplementedError() + + def get_redactions(self): + """Get the set of currently-redacted sub-key-paths at this view. + """ + raise NotImplementedError() + class RootView(ConfigView): """The base of a view hierarchy. This view keeps track of the @@ -387,6 +423,7 @@ class RootView(ConfigView): """ self.sources = list(sources) self.name = ROOT_NAME + self.redactions = set() def add(self, obj): self.sources.append(ConfigSource.of(obj)) @@ -404,6 +441,15 @@ class RootView(ConfigView): def root(self): return self + def set_redaction(self, path, flag): + if flag: + self.redactions.add(path) + elif path in self.redactions: + self.redactions.remove(path) + + def get_redactions(self): + return self.redactions + class Subview(ConfigView): """A subview accessed via a subscript of a parent view.""" @@ -423,9 +469,12 @@ class Subview(ConfigView): if isinstance(self.key, int): self.name += '#{0}'.format(self.key) elif isinstance(self.key, BASESTRING): - self.name += '{0}'.format(self.key.decode('utf8')) + if isinstance(self.key, bytes): + self.name += self.key.decode('utf8') + else: + self.name += self.key else: - self.name += '{0}'.format(repr(self.key)) + self.name += repr(self.key) def resolve(self): for collection, source in self.parent.resolve(): @@ -455,6 +504,13 @@ class Subview(ConfigView): def root(self): return self.parent.root() + def set_redaction(self, path, flag): + self.parent.set_redaction((self.key,) + path, flag) + + def get_redactions(self): + return (kp[1:] for kp in self.parent.get_redactions() + if kp and kp[0] == self.key) + # Config file paths, including platform-specific paths and in-package # defaults. @@ -469,7 +525,7 @@ def _package_path(name): if loader is None or name == b'__main__': return None - if hasattr(loader, b'get_filename'): + if hasattr(loader, 'get_filename'): filepath = loader.get_filename(name) else: # Fall back to importing the specified module. @@ -489,13 +545,13 @@ def config_dirs(): """ paths = [] - if platform.system() == b'Darwin': + if platform.system() == 'Darwin': paths.append(MAC_DIR) paths.append(UNIX_DIR_FALLBACK) if UNIX_DIR_VAR in os.environ: paths.append(os.environ[UNIX_DIR_VAR]) - elif platform.system() == b'Windows': + elif platform.system() == 'Windows': paths.append(WINDOWS_DIR_FALLBACK) if WINDOWS_DIR_VAR in os.environ: paths.append(os.environ[WINDOWS_DIR_VAR]) @@ -578,7 +634,7 @@ def load_yaml(filename): parsed, a ConfigReadError is raised. """ try: - with open(filename, b'r') as f: + with open(filename, 'r') as f: return yaml.load(f, Loader=Loader) except (IOError, yaml.error.YAMLError) as exc: raise ConfigReadError(filename, exc) @@ -783,7 +839,7 @@ class Configuration(RootView): filename = os.path.abspath(filename) self.set(ConfigSource(load_yaml(filename), filename)) - def dump(self, full=True): + def dump(self, full=True, redact=False): """Dump the Configuration object to a YAML file. The order of the keys is determined from the default @@ -795,13 +851,15 @@ class Configuration(RootView): :type filename: unicode :param full: Dump settings that don't differ from the defaults as well + :param redact: Remove sensitive information (views with the `redact` + flag set) from the output """ if full: - out_dict = self.flatten() + out_dict = self.flatten(redact=redact) else: # Exclude defaults when flattening. sources = [s for s in self.sources if not s.default] - out_dict = RootView(sources).flatten() + out_dict = RootView(sources).flatten(redact=redact) yaml_out = yaml.dump(out_dict, Dumper=Dumper, default_flow_style=None, indent=4, @@ -1012,6 +1070,17 @@ class String(Template): if pattern: self.regex = re.compile(pattern) + def __repr__(self): + args = [] + + if self.default is not REQUIRED: + args.append(repr(self.default)) + + if self.pattern is not None: + args.append('pattern=' + repr(self.pattern)) + + return 'String({0})'.format(', '.join(args)) + def convert(self, value, view): """Check that the value is a string and matches the pattern. """ @@ -1059,6 +1128,67 @@ class Choice(Template): return 'Choice({0!r})'.format(self.choices) +class OneOf(Template): + """A template that permits values complying to one of the given templates. + """ + def __init__(self, allowed, default=REQUIRED): + super(OneOf, self).__init__(default) + self.allowed = list(allowed) + + def __repr__(self): + args = [] + + if self.allowed is not None: + args.append('allowed=' + repr(self.allowed)) + + if self.default is not REQUIRED: + args.append(repr(self.default)) + + return 'OneOf({0})'.format(', '.join(args)) + + def value(self, view, template): + self.template = template + return super(OneOf, self).value(view, template) + + def convert(self, value, view): + """Ensure that the value follows at least one template. + """ + is_mapping = isinstance(self.template, MappingTemplate) + + for candidate in self.allowed: + try: + if is_mapping: + if isinstance(candidate, Filename) and \ + candidate.relative_to: + next_template = candidate.template_with_relatives( + view, + self.template + ) + + next_template.subtemplates[view.key] = as_template( + candidate + ) + else: + next_template = MappingTemplate({view.key: candidate}) + + return view.parent.get(next_template)[view.key] + else: + return view.get(candidate) + except ConfigTemplateError: + raise + except ConfigError: + pass + except ValueError as exc: + raise ConfigTemplateError(exc) + + self.fail( + 'must be one of {0}, not {1}'.format( + repr(self.allowed), repr(value) + ), + view + ) + + class StrSeq(Template): """A template for values that are lists of strings. @@ -1092,13 +1222,13 @@ class StrSeq(Template): view, True) def convert(x): - if isinstance(x, unicode): + if isinstance(x, STRING): return x - elif isinstance(x, BASESTRING): + elif isinstance(x, bytes): return x.decode('utf8', 'ignore') else: self.fail('must be a list of strings', view, True) - return map(convert, value) + return list(map(convert, value)) class Filename(Template): @@ -1114,7 +1244,7 @@ class Filename(Template): """ def __init__(self, default=REQUIRED, cwd=None, relative_to=None, in_app_dir=False): - """ `relative_to` is the name of a sibling value that is + """`relative_to` is the name of a sibling value that is being validated at the same time. `in_app_dir` indicates whether the path should be resolved @@ -1274,6 +1404,11 @@ def as_template(value): return String() elif isinstance(value, BASESTRING): return String(value) + elif isinstance(value, set): + # convert to list to avoid hash related problems + return Choice(list(value)) + elif isinstance(value, list): + return OneOf(value) elif value is float: return Number() elif value is None: