Settings: implement saving of dict and sequence, add graceful crashing (#1981)
* settings: don't crash when loading an outdated host.yaml * settings: use temp file to not destroy host.yaml on error * settings: implement saving of dicts * settings: simplify dump of dict * settings: add support for sequences also a few more comments * settings: reformat a bit
This commit is contained in:
parent
60586aa284
commit
a77739ba18
107
settings.py
107
settings.py
|
@ -55,6 +55,7 @@ class Group:
|
|||
_dumping: bool = False
|
||||
_has_attr: bool = False
|
||||
_changed: bool = False
|
||||
_dumper: ClassVar[type]
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
try:
|
||||
|
@ -142,13 +143,21 @@ class Group:
|
|||
if k not in self.__dict__:
|
||||
attr = attr.__class__() # make a copy of default
|
||||
setattr(self, k, attr)
|
||||
attr.update(v)
|
||||
if isinstance(v, dict):
|
||||
attr.update(v)
|
||||
else:
|
||||
warnings.warn(f"{self.__class__.__name__}.{k} "
|
||||
f"tried to update Group from {type(v)}")
|
||||
elif isinstance(attr, dict):
|
||||
# update dict
|
||||
if k not in self.__dict__:
|
||||
attr = attr.copy() # make a copy of default
|
||||
setattr(self, k, attr)
|
||||
attr.update(v)
|
||||
if isinstance(v, dict):
|
||||
attr.update(v)
|
||||
else:
|
||||
warnings.warn(f"{self.__class__.__name__}.{k} "
|
||||
f"tried to update dict from {type(v)}")
|
||||
else:
|
||||
# assign value, try to upcast to type hint
|
||||
annotation = self.get_type_hints().get(k, None)
|
||||
|
@ -169,6 +178,10 @@ class Group:
|
|||
# upcast, i.e. int -> IntEnum, str -> Path
|
||||
setattr(self, k, cls.__call__(v))
|
||||
break
|
||||
if issubclass(cls, (tuple, set)) and isinstance(v, list):
|
||||
# convert or upcast from list
|
||||
setattr(self, k, cls.__call__(v))
|
||||
break
|
||||
else:
|
||||
# assign scalar and hope for the best
|
||||
setattr(self, k, v)
|
||||
|
@ -182,23 +195,68 @@ class Group:
|
|||
for name in self if not args or name in args
|
||||
}
|
||||
|
||||
def dump(self, f: TextIO, level: int = 0) -> None:
|
||||
@classmethod
|
||||
def _dump_value(cls, value: Any, f: TextIO, indent: str) -> None:
|
||||
"""Write a single yaml line to f"""
|
||||
from Utils import dump, Dumper as BaseDumper
|
||||
yaml_line: str = dump(value, Dumper=cast(BaseDumper, cls._dumper))
|
||||
assert yaml_line.count("\n") == 1, f"Unexpected input for yaml dumper: {value}"
|
||||
f.write(f"{indent}{yaml_line}")
|
||||
|
||||
@classmethod
|
||||
def _dump_item(cls, name: Optional[str], attr: object, f: TextIO, level: int) -> None:
|
||||
"""Write a group, dict or sequence item to f, where attr can be a scalar or a collection"""
|
||||
|
||||
# lazy construction of yaml Dumper to avoid loading Utils early
|
||||
from Utils import Dumper as BaseDumper
|
||||
from yaml import ScalarNode, MappingNode
|
||||
if not hasattr(cls, "_dumper"):
|
||||
if cls is Group or not hasattr(Group, "_dumper"):
|
||||
class Dumper(BaseDumper):
|
||||
def represent_mapping(self, tag: str, mapping: Any, flow_style: Any = None) -> MappingNode:
|
||||
from yaml import ScalarNode
|
||||
res: MappingNode = super().represent_mapping(tag, mapping, flow_style)
|
||||
pairs = cast(List[Tuple[ScalarNode, Any]], res.value)
|
||||
for k, v in pairs:
|
||||
k.style = None # remove quotes from keys
|
||||
return res
|
||||
|
||||
class Dumper(BaseDumper):
|
||||
def represent_mapping(self, tag: str, mapping: Any, flow_style: Any = None) -> MappingNode:
|
||||
res: MappingNode = super().represent_mapping(tag, mapping, flow_style)
|
||||
pairs = cast(List[Tuple[ScalarNode, Any]], res.value)
|
||||
for k, v in pairs:
|
||||
k.style = None # remove quotes from keys
|
||||
return res
|
||||
def represent_str(self, data: str) -> ScalarNode:
|
||||
# default double quote all strings
|
||||
return self.represent_scalar("tag:yaml.org,2002:str", data, style='"')
|
||||
|
||||
def represent_str(self, data: str) -> ScalarNode:
|
||||
# default double quote all strings
|
||||
return self.represent_scalar("tag:yaml.org,2002:str", data, style='"')
|
||||
Dumper.add_representer(str, Dumper.represent_str)
|
||||
Group._dumper = Dumper
|
||||
if cls is not Group:
|
||||
cls._dumper = Group._dumper
|
||||
|
||||
Dumper.add_representer(str, Dumper.represent_str)
|
||||
indent = " " * level
|
||||
start = f"{indent}-\n" if name is None else f"{indent}{name}:\n"
|
||||
if isinstance(attr, Group):
|
||||
# handle group
|
||||
f.write(start)
|
||||
attr.dump(f, level=level+1)
|
||||
elif isinstance(attr, (list, tuple, set)) and attr:
|
||||
# handle non-empty sequence; empty use one-line [] syntax
|
||||
f.write(start)
|
||||
for value in attr:
|
||||
cls._dump_item(None, value, f, level=level + 1)
|
||||
elif isinstance(attr, dict) and attr:
|
||||
# handle non-empty dict; empty use one-line {} syntax
|
||||
f.write(start)
|
||||
for dict_key, value in attr.items():
|
||||
# not dumping doc string here, since there is no way to upcast it after dumping
|
||||
assert dict_key is not None, "Key None is reserved for sequences"
|
||||
cls._dump_item(dict_key, value, f, level=level + 1)
|
||||
else:
|
||||
# dump scalar or empty sequence or mapping item
|
||||
line = [_to_builtin(attr)] if name is None else {name: _to_builtin(attr)}
|
||||
cls._dump_value(line, f, indent=indent)
|
||||
|
||||
def dump(self, f: TextIO, level: int = 0) -> None:
|
||||
"""Dump Group to stream f at given indentation level"""
|
||||
# There is no easy way to generate extra lines into default yaml output,
|
||||
# so we format part of it by hand using an odd recursion here and in _dump_*.
|
||||
|
||||
self._dumping = True
|
||||
try:
|
||||
|
@ -218,16 +276,7 @@ class Group:
|
|||
attr_cls_origin = typing.get_origin(attr_cls)
|
||||
if attr_cls.__doc__ and attr_cls.__module__ != "builtins":
|
||||
f.write(fmt_doc(attr_cls, level=level) + "\n")
|
||||
indent = ' ' * level
|
||||
if isinstance(attr, Group):
|
||||
f.write(f"{indent}{name}:\n")
|
||||
attr.dump(f, level=level+1)
|
||||
elif isinstance(attr, (dict, list, tuple, set)):
|
||||
# TODO: special handling for dicts and iterables
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
yaml_line = dump({name: _to_builtin(attr)}, Dumper=Dumper)
|
||||
f.write(f"{indent}{yaml_line}")
|
||||
self._dump_item(name, attr, f, level=level)
|
||||
self._changed = False
|
||||
finally:
|
||||
self._dumping = False
|
||||
|
@ -727,9 +776,17 @@ class Settings(Group):
|
|||
def save(self, location: Optional[str] = None) -> None: # as above
|
||||
location = location or self._filename
|
||||
assert location, "No file specified"
|
||||
temp_location = location + ".tmp" # not using tempfile to test expected file access
|
||||
# remove old temps
|
||||
if os.path.exists(temp_location):
|
||||
os.unlink(temp_location)
|
||||
# can't use utf-8-sig because it breaks backward compat: pyyaml on Windows with bytes does not strip the BOM
|
||||
with open(location, "w", encoding="utf-8") as f:
|
||||
with open(temp_location, "w", encoding="utf-8") as f:
|
||||
self.dump(f)
|
||||
# replace old with new
|
||||
if os.path.exists(location):
|
||||
os.unlink(location)
|
||||
os.rename(temp_location, location)
|
||||
self._filename = location
|
||||
|
||||
def dump(self, f: TextIO, level: int = 0) -> None:
|
||||
|
|
Loading…
Reference in New Issue