settings: safer writing (#3644)

* settings: clean up imports

* settings: try to use atomic rename

* settings: flush, sync and validate new yaml

before replacing the old one

* settings: add test for Settings.save
This commit is contained in:
black-sliver 2024-07-25 09:10:36 +02:00 committed by GitHub
parent deae524e9b
commit 8949e21565
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 7 deletions

View File

@ -3,6 +3,7 @@ Application settings / host.yaml interface using type hints.
This is different from player options. This is different from player options.
""" """
import os
import os.path import os.path
import shutil import shutil
import sys import sys
@ -11,7 +12,6 @@ import warnings
from enum import IntEnum from enum import IntEnum
from threading import Lock from threading import Lock
from typing import cast, Any, BinaryIO, ClassVar, Dict, Iterator, List, Optional, TextIO, Tuple, Union, TypeVar from typing import cast, Any, BinaryIO, ClassVar, Dict, Iterator, List, Optional, TextIO, Tuple, Union, TypeVar
import os
__all__ = [ __all__ = [
"get_settings", "fmt_doc", "no_gui", "get_settings", "fmt_doc", "no_gui",
@ -798,6 +798,7 @@ class Settings(Group):
atexit.register(autosave) atexit.register(autosave)
def save(self, location: Optional[str] = None) -> None: # as above def save(self, location: Optional[str] = None) -> None: # as above
from Utils import parse_yaml
location = location or self._filename location = location or self._filename
assert location, "No file specified" assert location, "No file specified"
temp_location = location + ".tmp" # not using tempfile to test expected file access temp_location = location + ".tmp" # not using tempfile to test expected file access
@ -807,10 +808,18 @@ class Settings(Group):
# can't use utf-8-sig because it breaks backward compat: pyyaml on Windows with bytes does not strip the BOM # can't use utf-8-sig because it breaks backward compat: pyyaml on Windows with bytes does not strip the BOM
with open(temp_location, "w", encoding="utf-8") as f: with open(temp_location, "w", encoding="utf-8") as f:
self.dump(f) self.dump(f)
# replace old with new f.flush()
if os.path.exists(location): if hasattr(os, "fsync"):
os.fsync(f.fileno())
# validate new file is valid yaml
with open(temp_location, encoding="utf-8") as f:
parse_yaml(f.read())
# replace old with new, try atomic operation first
try:
os.rename(temp_location, location)
except (OSError, FileExistsError):
os.unlink(location) os.unlink(location)
os.rename(temp_location, location) os.rename(temp_location, location)
self._filename = location self._filename = location
def dump(self, f: TextIO, level: int = 0) -> None: def dump(self, f: TextIO, level: int = 0) -> None:
@ -832,7 +841,6 @@ def get_settings() -> Settings:
with _lock: # make sure we only have one instance with _lock: # make sure we only have one instance
res = getattr(get_settings, "_cache", None) res = getattr(get_settings, "_cache", None)
if not res: if not res:
import os
from Utils import user_path, local_path from Utils import user_path, local_path
filenames = ("options.yaml", "host.yaml") filenames = ("options.yaml", "host.yaml")
locations: List[str] = [] locations: List[str] = []

View File

@ -1,11 +1,12 @@
import os import os
import os.path
import unittest import unittest
from io import StringIO from io import StringIO
from tempfile import TemporaryFile from tempfile import TemporaryDirectory, TemporaryFile
from typing import Any, Dict, List, cast from typing import Any, Dict, List, cast
import Utils import Utils
from settings import Settings, Group from settings import Group, Settings, ServerOptions
class TestIDs(unittest.TestCase): class TestIDs(unittest.TestCase):
@ -80,3 +81,27 @@ class TestSettingsDumper(unittest.TestCase):
self.assertEqual(value_spaces[2], value_spaces[0]) # start of sub-list self.assertEqual(value_spaces[2], value_spaces[0]) # start of sub-list
self.assertGreater(value_spaces[3], value_spaces[0], self.assertGreater(value_spaces[3], value_spaces[0],
f"{value_lines[3]} should have more indentation than {value_lines[0]} in {lines}") f"{value_lines[3]} should have more indentation than {value_lines[0]} in {lines}")
class TestSettingsSave(unittest.TestCase):
def test_save(self) -> None:
"""Test that saving and updating works"""
with TemporaryDirectory() as d:
filename = os.path.join(d, "host.yaml")
new_release_mode = ServerOptions.ReleaseMode("enabled")
# create default host.yaml
settings = Settings(None)
settings.save(filename)
self.assertTrue(os.path.exists(filename),
"Default settings could not be saved")
self.assertNotEqual(settings.server_options.release_mode, new_release_mode,
"Unexpected default release mode")
# update host.yaml
settings.server_options.release_mode = new_release_mode
settings.save(filename)
self.assertFalse(os.path.exists(filename + ".tmp"),
"Temp file was not removed during save")
# read back host.yaml
settings = Settings(filename)
self.assertEqual(settings.server_options.release_mode, new_release_mode,
"Settings were not overwritten")