Core: clean up Utils.py

* fix import order
* lazy import shutil
* lazy import jellyfish (also speed-up by 0.8%, probably because of inlining)
* yaml:
  * explicitely call Loader UnsafeLoader
  * use CDumper, twice as fast
  * stop leaking leak imported names load and load_all
* open_file: use absolute path
* replace quotes in touched code
* add some typing in touched code
* stringify type hinting for non-imports
* %s/.format -> f
* freeze safe_builtins
* remove double-caching in get_options()
* get rid of some warnings
This commit is contained in:
black-sliver 2022-08-12 00:32:37 +02:00
parent b8ca41b45f
commit b702ae482b
1 changed files with 71 additions and 70 deletions

141
Utils.py
View File

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import shutil
import typing import typing
import builtins import builtins
import os import os
@ -13,11 +12,18 @@ import collections
import importlib import importlib
import logging import logging
import decimal import decimal
from yaml import load, load_all, dump, SafeLoader
try:
from yaml import CLoader as UnsafeLoader
from yaml import CDumper as Dumper
except ImportError:
from yaml import Loader as UnsafeLoader
from yaml import Dumper
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from tkinter import Tk import tkinter
else: import pathlib
Tk = typing.Any
def tuplize_version(version: str) -> Version: def tuplize_version(version: str) -> Version:
@ -33,18 +39,10 @@ class Version(typing.NamedTuple):
__version__ = "0.3.4" __version__ = "0.3.4"
version_tuple = tuplize_version(__version__) version_tuple = tuplize_version(__version__)
is_linux = sys.platform.startswith('linux') is_linux = sys.platform.startswith("linux")
is_macos = sys.platform == 'darwin' is_macos = sys.platform == "darwin"
is_windows = sys.platform in ("win32", "cygwin", "msys") is_windows = sys.platform in ("win32", "cygwin", "msys")
import jellyfish
from yaml import load, load_all, dump, SafeLoader
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
def int16_as_bytes(value: int) -> typing.List[int]: def int16_as_bytes(value: int) -> typing.List[int]:
value = value & 0xFFFF value = value & 0xFFFF
@ -125,17 +123,18 @@ def home_path(*path: str) -> str:
def user_path(*path: str) -> str: def user_path(*path: str) -> str:
"""Returns either local_path or home_path based on write permissions.""" """Returns either local_path or home_path based on write permissions."""
if hasattr(user_path, 'cached_path'): if hasattr(user_path, "cached_path"):
pass pass
elif os.access(local_path(), os.W_OK): elif os.access(local_path(), os.W_OK):
user_path.cached_path = local_path() user_path.cached_path = local_path()
else: else:
user_path.cached_path = home_path() user_path.cached_path = home_path()
# populate home from local - TODO: upgrade feature # populate home from local - TODO: upgrade feature
if user_path.cached_path != local_path() and not os.path.exists(user_path('host.yaml')): if user_path.cached_path != local_path() and not os.path.exists(user_path("host.yaml")):
for dn in ('Players', 'data/sprites'): import shutil
for dn in ("Players", "data/sprites"):
shutil.copytree(local_path(dn), user_path(dn), dirs_exist_ok=True) shutil.copytree(local_path(dn), user_path(dn), dirs_exist_ok=True)
for fn in ('manifest.json', 'host.yaml'): for fn in ("manifest.json", "host.yaml"):
shutil.copy2(local_path(fn), user_path(fn)) shutil.copy2(local_path(fn), user_path(fn))
return os.path.join(user_path.cached_path, *path) return os.path.join(user_path.cached_path, *path)
@ -150,11 +149,12 @@ def output_path(*path: str):
return path return path
def open_file(filename): def open_file(filename: typing.Union[str, "pathlib.Path"]) -> None:
if sys.platform == 'win32': if is_windows:
os.startfile(filename) os.startfile(filename)
else: else:
open_command = 'open' if sys.platform == 'darwin' else 'xdg-open' from shutil import which
open_command = which("open") if is_macos else (which("xdg-open") or which("gnome-open") or which("kde-open"))
subprocess.call([open_command, filename]) subprocess.call([open_command, filename])
@ -173,7 +173,9 @@ class UniqueKeyLoader(SafeLoader):
parse_yaml = functools.partial(load, Loader=UniqueKeyLoader) parse_yaml = functools.partial(load, Loader=UniqueKeyLoader)
parse_yamls = functools.partial(load_all, Loader=UniqueKeyLoader) parse_yamls = functools.partial(load_all, Loader=UniqueKeyLoader)
unsafe_parse_yaml = functools.partial(load, Loader=Loader) unsafe_parse_yaml = functools.partial(load, Loader=UnsafeLoader)
del load, load_all # should not be used. don't leak their names
def get_cert_none_ssl_context(): def get_cert_none_ssl_context():
@ -191,11 +193,12 @@ def get_public_ipv4() -> str:
ip = socket.gethostbyname(socket.gethostname()) ip = socket.gethostbyname(socket.gethostname())
ctx = get_cert_none_ssl_context() ctx = get_cert_none_ssl_context()
try: try:
ip = urllib.request.urlopen('https://checkip.amazonaws.com/', context=ctx).read().decode('utf8').strip() ip = urllib.request.urlopen("https://checkip.amazonaws.com/", context=ctx).read().decode("utf8").strip()
except Exception as e: except Exception as e:
# noinspection PyBroadException
try: try:
ip = urllib.request.urlopen('https://v4.ident.me', context=ctx).read().decode('utf8').strip() ip = urllib.request.urlopen("https://v4.ident.me", context=ctx).read().decode("utf8").strip()
except: except Exception:
logging.exception(e) logging.exception(e)
pass # we could be offline, in a local game, so no point in erroring out pass # we could be offline, in a local game, so no point in erroring out
return ip return ip
@ -208,7 +211,7 @@ def get_public_ipv6() -> str:
ip = socket.gethostbyname(socket.gethostname()) ip = socket.gethostbyname(socket.gethostname())
ctx = get_cert_none_ssl_context() ctx = get_cert_none_ssl_context()
try: try:
ip = urllib.request.urlopen('https://v6.ident.me', context=ctx).read().decode('utf8').strip() ip = urllib.request.urlopen("https://v6.ident.me", context=ctx).read().decode("utf8").strip()
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
pass # we could be offline, in a local game, or ipv6 may not be available pass # we could be offline, in a local game, or ipv6 may not be available
@ -309,23 +312,19 @@ def update_options(src: dict, dest: dict, filename: str, keys: list) -> dict:
@cache_argsless @cache_argsless
def get_options() -> dict: def get_options() -> dict:
if not hasattr(get_options, "options"): filenames = ("options.yaml", "host.yaml")
filenames = ("options.yaml", "host.yaml") locations = []
locations = [] if os.path.join(os.getcwd()) != local_path():
if os.path.join(os.getcwd()) != local_path(): locations += filenames # use files from cwd only if it's not the local_path
locations += filenames # use files from cwd only if it's not the local_path locations += [user_path(filename) for filename in filenames]
locations += [user_path(filename) for filename in filenames]
for location in locations: for location in locations:
if os.path.exists(location): if os.path.exists(location):
with open(location) as f: with open(location) as f:
options = parse_yaml(f.read()) options = parse_yaml(f.read())
return update_options(get_default_options(), options, location, list())
get_options.options = update_options(get_default_options(), options, location, list()) raise FileNotFoundError(f"Could not find {filenames[1]} to load options.")
break
else:
raise FileNotFoundError(f"Could not find {filenames[1]} to load options.")
return get_options.options
def persistent_store(category: str, key: typing.Any, value: typing.Any): def persistent_store(category: str, key: typing.Any, value: typing.Any):
@ -334,10 +333,10 @@ def persistent_store(category: str, key: typing.Any, value: typing.Any):
category = storage.setdefault(category, {}) category = storage.setdefault(category, {})
category[key] = value category[key] = value
with open(path, "wt") as f: with open(path, "wt") as f:
f.write(dump(storage)) f.write(dump(storage, Dumper=Dumper))
def persistent_load() -> typing.Dict[dict]: def persistent_load() -> typing.Dict[str, dict]:
storage = getattr(persistent_load, "storage", None) storage = getattr(persistent_load, "storage", None)
if storage: if storage:
return storage return storage
@ -355,8 +354,8 @@ def persistent_load() -> typing.Dict[dict]:
return storage return storage
def get_adjuster_settings(gameName: str): def get_adjuster_settings(game_name: str):
adjuster_settings = persistent_load().get("adjuster", {}).get(gameName, {}) adjuster_settings = persistent_load().get("adjuster", {}).get(game_name, {})
return adjuster_settings return adjuster_settings
@ -372,10 +371,10 @@ def get_unique_identifier():
return uuid return uuid
safe_builtins = { safe_builtins = frozenset((
'set', 'set',
'frozenset', 'frozenset',
} ))
class RestrictedUnpickler(pickle.Unpickler): class RestrictedUnpickler(pickle.Unpickler):
@ -403,8 +402,7 @@ class RestrictedUnpickler(pickle.Unpickler):
if issubclass(obj, self.options_module.Option): if issubclass(obj, self.options_module.Option):
return obj return obj
# Forbid everything else. # Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" % raise pickle.UnpicklingError(f"global '{module}.{name}' is forbidden")
(module, name))
def restricted_loads(s): def restricted_loads(s):
@ -483,11 +481,11 @@ def stream_input(stream, queue):
return thread return thread
def tkinter_center_window(window: Tk): def tkinter_center_window(window: "tkinter.Tk") -> None:
window.update() window.update()
xPos = int(window.winfo_screenwidth() / 2 - window.winfo_reqwidth() / 2) x = int(window.winfo_screenwidth() / 2 - window.winfo_reqwidth() / 2)
yPos = int(window.winfo_screenheight() / 2 - window.winfo_reqheight() / 2) y = int(window.winfo_screenheight() / 2 - window.winfo_reqheight() / 2)
window.geometry("+{}+{}".format(xPos, yPos)) window.geometry(f"+{x}+{y}")
class VersionException(Exception): class VersionException(Exception):
@ -516,13 +514,14 @@ def format_SI_prefix(value, power=1000, power_labels=("", "k", "M", "G", "T", "P
return f"{value.quantize(decimal.Decimal('1.00'))} {chaining_prefix(n, power_labels)}" return f"{value.quantize(decimal.Decimal('1.00'))} {chaining_prefix(n, power_labels)}"
def get_fuzzy_ratio(word1: str, word2: str) -> float:
return (1 - jellyfish.damerau_levenshtein_distance(word1.lower(), word2.lower())
/ max(len(word1), len(word2)))
def get_fuzzy_results(input_word: str, wordlist: typing.Sequence[str], limit: typing.Optional[int] = None) \ def get_fuzzy_results(input_word: str, wordlist: typing.Sequence[str], limit: typing.Optional[int] = None) \
-> typing.List[typing.Tuple[str, int]]: -> typing.List[typing.Tuple[str, int]]:
import jellyfish
def get_fuzzy_ratio(word1: str, word2: str) -> float:
return (1 - jellyfish.damerau_levenshtein_distance(word1.lower(), word2.lower())
/ max(len(word1), len(word2)))
limit: int = limit if limit else len(wordlist) limit: int = limit if limit else len(wordlist)
return list( return list(
map( map(
@ -540,18 +539,19 @@ def get_fuzzy_results(input_word: str, wordlist: typing.Sequence[str], limit: ty
def open_filename(title: str, filetypes: typing.Sequence[typing.Tuple[str, typing.Sequence[str]]]) \ def open_filename(title: str, filetypes: typing.Sequence[typing.Tuple[str, typing.Sequence[str]]]) \
-> typing.Optional[str]: -> typing.Optional[str]:
def run(*args: str): def run(*args: str):
return subprocess.run(args, capture_output=True, text=True).stdout.split('\n', 1)[0] or None return subprocess.run(args, capture_output=True, text=True).stdout.split("\n", 1)[0] or None
if is_linux: if is_linux:
# prefer native dialog # prefer native dialog
kdialog = shutil.which('kdialog') from shutil import which
kdialog = which("kdialog")
if kdialog: if kdialog:
k_filters = '|'.join((f'{text} (*{" *".join(ext)})' for (text, ext) in filetypes)) k_filters = '|'.join((f'{text} (*{" *".join(ext)})' for (text, ext) in filetypes))
return run(kdialog, f'--title={title}', '--getopenfilename', '.', k_filters) return run(kdialog, f"--title={title}", "--getopenfilename", ".", k_filters)
zenity = shutil.which('zenity') zenity = which("zenity")
if zenity: if zenity:
z_filters = (f'--file-filter={text} ({", ".join(ext)}) | *{" *".join(ext)}' for (text, ext) in filetypes) z_filters = (f'--file-filter={text} ({", ".join(ext)}) | *{" *".join(ext)}' for (text, ext) in filetypes)
return run(zenity, f'--title={title}', '--file-selection', *z_filters) return run(zenity, f"--title={title}", "--file-selection", *z_filters)
# fall back to tk # fall back to tk
try: try:
@ -569,10 +569,10 @@ def open_filename(title: str, filetypes: typing.Sequence[typing.Tuple[str, typin
def messagebox(title: str, text: str, error: bool = False) -> None: def messagebox(title: str, text: str, error: bool = False) -> None:
def run(*args: str): def run(*args: str):
return subprocess.run(args, capture_output=True, text=True).stdout.split('\n', 1)[0] or None return subprocess.run(args, capture_output=True, text=True).stdout.split("\n", 1)[0] or None
def is_kivy_running(): def is_kivy_running():
if 'kivy' in sys.modules: if "kivy" in sys.modules:
from kivy.app import App from kivy.app import App
return App.get_running_app() is not None return App.get_running_app() is not None
return False return False
@ -582,14 +582,15 @@ def messagebox(title: str, text: str, error: bool = False) -> None:
MessageBox(title, text, error).open() MessageBox(title, text, error).open()
return return
if is_linux and not 'tkinter' in sys.modules: if is_linux and "tkinter" not in sys.modules:
# prefer native dialog # prefer native dialog
kdialog = shutil.which('kdialog') from shutil import which
kdialog = which("kdialog")
if kdialog: if kdialog:
return run(kdialog, f'--title={title}', '--error' if error else '--msgbox', text) return run(kdialog, f"--title={title}", "--error" if error else "--msgbox", text)
zenity = shutil.which('zenity') zenity = which("zenity")
if zenity: if zenity:
return run(zenity, f'--title={title}', f'--text={text}', '--error' if error else '--info') return run(zenity, f"--title={title}", f"--text={text}", "--error" if error else "--info")
# fall back to tk # fall back to tk
try: try: