Core: clean up Utils.py
* fix import order * lazy import shutil * lazy import jellyfish (also speed-up by 0.8%, probably because of inlining) * yaml: * explicitely call Loader UnsafeLoader * use CDumper, twice as fast * stop leaking leak imported names load and load_all * open_file: use absolute path * replace quotes in touched code * add some typing in touched code * stringify type hinting for non-imports * %s/.format -> f * freeze safe_builtins * remove double-caching in get_options() * get rid of some warnings
This commit is contained in:
		
							parent
							
								
									b8ca41b45f
								
							
						
					
					
						commit
						b702ae482b
					
				
							
								
								
									
										141
									
								
								Utils.py
								
								
								
								
							
							
						
						
									
										141
									
								
								Utils.py
								
								
								
								
							| 
						 | 
				
			
			@ -1,6 +1,5 @@
 | 
			
		|||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import shutil
 | 
			
		||||
import typing
 | 
			
		||||
import builtins
 | 
			
		||||
import os
 | 
			
		||||
| 
						 | 
				
			
			@ -13,11 +12,18 @@ import collections
 | 
			
		|||
import importlib
 | 
			
		||||
import logging
 | 
			
		||||
import decimal
 | 
			
		||||
from yaml import load, load_all, dump, SafeLoader
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from yaml import CLoader as UnsafeLoader
 | 
			
		||||
    from yaml import CDumper as Dumper
 | 
			
		||||
except ImportError:
 | 
			
		||||
    from yaml import Loader as UnsafeLoader
 | 
			
		||||
    from yaml import Dumper
 | 
			
		||||
 | 
			
		||||
if typing.TYPE_CHECKING:
 | 
			
		||||
    from tkinter import Tk
 | 
			
		||||
else:
 | 
			
		||||
    Tk = typing.Any
 | 
			
		||||
    import tkinter
 | 
			
		||||
    import pathlib
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tuplize_version(version: str) -> Version:
 | 
			
		||||
| 
						 | 
				
			
			@ -33,18 +39,10 @@ class Version(typing.NamedTuple):
 | 
			
		|||
__version__ = "0.3.4"
 | 
			
		||||
version_tuple = tuplize_version(__version__)
 | 
			
		||||
 | 
			
		||||
is_linux = sys.platform.startswith('linux')
 | 
			
		||||
is_macos = sys.platform == 'darwin'
 | 
			
		||||
is_linux = sys.platform.startswith("linux")
 | 
			
		||||
is_macos = sys.platform == "darwin"
 | 
			
		||||
is_windows = sys.platform in ("win32", "cygwin", "msys")
 | 
			
		||||
 | 
			
		||||
import jellyfish
 | 
			
		||||
from yaml import load, load_all, dump, SafeLoader
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from yaml import CLoader as Loader
 | 
			
		||||
except ImportError:
 | 
			
		||||
    from yaml import Loader
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def int16_as_bytes(value: int) -> typing.List[int]:
 | 
			
		||||
    value = value & 0xFFFF
 | 
			
		||||
| 
						 | 
				
			
			@ -125,17 +123,18 @@ def home_path(*path: str) -> str:
 | 
			
		|||
 | 
			
		||||
def user_path(*path: str) -> str:
 | 
			
		||||
    """Returns either local_path or home_path based on write permissions."""
 | 
			
		||||
    if hasattr(user_path, 'cached_path'):
 | 
			
		||||
    if hasattr(user_path, "cached_path"):
 | 
			
		||||
        pass
 | 
			
		||||
    elif os.access(local_path(), os.W_OK):
 | 
			
		||||
        user_path.cached_path = local_path()
 | 
			
		||||
    else:
 | 
			
		||||
        user_path.cached_path = home_path()
 | 
			
		||||
        # populate home from local - TODO: upgrade feature
 | 
			
		||||
        if user_path.cached_path != local_path() and not os.path.exists(user_path('host.yaml')):
 | 
			
		||||
            for dn in ('Players', 'data/sprites'):
 | 
			
		||||
        if user_path.cached_path != local_path() and not os.path.exists(user_path("host.yaml")):
 | 
			
		||||
            import shutil
 | 
			
		||||
            for dn in ("Players", "data/sprites"):
 | 
			
		||||
                shutil.copytree(local_path(dn), user_path(dn), dirs_exist_ok=True)
 | 
			
		||||
            for fn in ('manifest.json', 'host.yaml'):
 | 
			
		||||
            for fn in ("manifest.json", "host.yaml"):
 | 
			
		||||
                shutil.copy2(local_path(fn), user_path(fn))
 | 
			
		||||
 | 
			
		||||
    return os.path.join(user_path.cached_path, *path)
 | 
			
		||||
| 
						 | 
				
			
			@ -150,11 +149,12 @@ def output_path(*path: str):
 | 
			
		|||
    return path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def open_file(filename):
 | 
			
		||||
    if sys.platform == 'win32':
 | 
			
		||||
def open_file(filename: typing.Union[str, "pathlib.Path"]) -> None:
 | 
			
		||||
    if is_windows:
 | 
			
		||||
        os.startfile(filename)
 | 
			
		||||
    else:
 | 
			
		||||
        open_command = 'open' if sys.platform == 'darwin' else 'xdg-open'
 | 
			
		||||
        from shutil import which
 | 
			
		||||
        open_command = which("open") if is_macos else (which("xdg-open") or which("gnome-open") or which("kde-open"))
 | 
			
		||||
        subprocess.call([open_command, filename])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -173,7 +173,9 @@ class UniqueKeyLoader(SafeLoader):
 | 
			
		|||
 | 
			
		||||
parse_yaml = functools.partial(load, Loader=UniqueKeyLoader)
 | 
			
		||||
parse_yamls = functools.partial(load_all, Loader=UniqueKeyLoader)
 | 
			
		||||
unsafe_parse_yaml = functools.partial(load, Loader=Loader)
 | 
			
		||||
unsafe_parse_yaml = functools.partial(load, Loader=UnsafeLoader)
 | 
			
		||||
 | 
			
		||||
del load, load_all  # should not be used. don't leak their names
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_cert_none_ssl_context():
 | 
			
		||||
| 
						 | 
				
			
			@ -191,11 +193,12 @@ def get_public_ipv4() -> str:
 | 
			
		|||
    ip = socket.gethostbyname(socket.gethostname())
 | 
			
		||||
    ctx = get_cert_none_ssl_context()
 | 
			
		||||
    try:
 | 
			
		||||
        ip = urllib.request.urlopen('https://checkip.amazonaws.com/', context=ctx).read().decode('utf8').strip()
 | 
			
		||||
        ip = urllib.request.urlopen("https://checkip.amazonaws.com/", context=ctx).read().decode("utf8").strip()
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        # noinspection PyBroadException
 | 
			
		||||
        try:
 | 
			
		||||
            ip = urllib.request.urlopen('https://v4.ident.me', context=ctx).read().decode('utf8').strip()
 | 
			
		||||
        except:
 | 
			
		||||
            ip = urllib.request.urlopen("https://v4.ident.me", context=ctx).read().decode("utf8").strip()
 | 
			
		||||
        except Exception:
 | 
			
		||||
            logging.exception(e)
 | 
			
		||||
            pass  # we could be offline, in a local game, so no point in erroring out
 | 
			
		||||
    return ip
 | 
			
		||||
| 
						 | 
				
			
			@ -208,7 +211,7 @@ def get_public_ipv6() -> str:
 | 
			
		|||
    ip = socket.gethostbyname(socket.gethostname())
 | 
			
		||||
    ctx = get_cert_none_ssl_context()
 | 
			
		||||
    try:
 | 
			
		||||
        ip = urllib.request.urlopen('https://v6.ident.me', context=ctx).read().decode('utf8').strip()
 | 
			
		||||
        ip = urllib.request.urlopen("https://v6.ident.me", context=ctx).read().decode("utf8").strip()
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logging.exception(e)
 | 
			
		||||
        pass  # we could be offline, in a local game, or ipv6 may not be available
 | 
			
		||||
| 
						 | 
				
			
			@ -309,23 +312,19 @@ def update_options(src: dict, dest: dict, filename: str, keys: list) -> dict:
 | 
			
		|||
 | 
			
		||||
@cache_argsless
 | 
			
		||||
def get_options() -> dict:
 | 
			
		||||
    if not hasattr(get_options, "options"):
 | 
			
		||||
        filenames = ("options.yaml", "host.yaml")
 | 
			
		||||
        locations = []
 | 
			
		||||
        if os.path.join(os.getcwd()) != local_path():
 | 
			
		||||
            locations += filenames  # use files from cwd only if it's not the local_path
 | 
			
		||||
        locations += [user_path(filename) for filename in filenames]
 | 
			
		||||
    filenames = ("options.yaml", "host.yaml")
 | 
			
		||||
    locations = []
 | 
			
		||||
    if os.path.join(os.getcwd()) != local_path():
 | 
			
		||||
        locations += filenames  # use files from cwd only if it's not the local_path
 | 
			
		||||
    locations += [user_path(filename) for filename in filenames]
 | 
			
		||||
 | 
			
		||||
        for location in locations:
 | 
			
		||||
            if os.path.exists(location):
 | 
			
		||||
                with open(location) as f:
 | 
			
		||||
                    options = parse_yaml(f.read())
 | 
			
		||||
    for location in locations:
 | 
			
		||||
        if os.path.exists(location):
 | 
			
		||||
            with open(location) as f:
 | 
			
		||||
                options = parse_yaml(f.read())
 | 
			
		||||
            return update_options(get_default_options(), options, location, list())
 | 
			
		||||
 | 
			
		||||
                get_options.options = update_options(get_default_options(), options, location, list())
 | 
			
		||||
                break
 | 
			
		||||
        else:
 | 
			
		||||
            raise FileNotFoundError(f"Could not find {filenames[1]} to load options.")
 | 
			
		||||
    return get_options.options
 | 
			
		||||
    raise FileNotFoundError(f"Could not find {filenames[1]} to load options.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def persistent_store(category: str, key: typing.Any, value: typing.Any):
 | 
			
		||||
| 
						 | 
				
			
			@ -334,10 +333,10 @@ def persistent_store(category: str, key: typing.Any, value: typing.Any):
 | 
			
		|||
    category = storage.setdefault(category, {})
 | 
			
		||||
    category[key] = value
 | 
			
		||||
    with open(path, "wt") as f:
 | 
			
		||||
        f.write(dump(storage))
 | 
			
		||||
        f.write(dump(storage, Dumper=Dumper))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def persistent_load() -> typing.Dict[dict]:
 | 
			
		||||
def persistent_load() -> typing.Dict[str, dict]:
 | 
			
		||||
    storage = getattr(persistent_load, "storage", None)
 | 
			
		||||
    if storage:
 | 
			
		||||
        return storage
 | 
			
		||||
| 
						 | 
				
			
			@ -355,8 +354,8 @@ def persistent_load() -> typing.Dict[dict]:
 | 
			
		|||
    return storage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_adjuster_settings(gameName: str):
 | 
			
		||||
    adjuster_settings = persistent_load().get("adjuster", {}).get(gameName, {})
 | 
			
		||||
def get_adjuster_settings(game_name: str):
 | 
			
		||||
    adjuster_settings = persistent_load().get("adjuster", {}).get(game_name, {})
 | 
			
		||||
    return adjuster_settings
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -372,10 +371,10 @@ def get_unique_identifier():
 | 
			
		|||
    return uuid
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
safe_builtins = {
 | 
			
		||||
safe_builtins = frozenset((
 | 
			
		||||
    'set',
 | 
			
		||||
    'frozenset',
 | 
			
		||||
}
 | 
			
		||||
))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RestrictedUnpickler(pickle.Unpickler):
 | 
			
		||||
| 
						 | 
				
			
			@ -403,8 +402,7 @@ class RestrictedUnpickler(pickle.Unpickler):
 | 
			
		|||
            if issubclass(obj, self.options_module.Option):
 | 
			
		||||
                return obj
 | 
			
		||||
        # Forbid everything else.
 | 
			
		||||
        raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
 | 
			
		||||
                                     (module, name))
 | 
			
		||||
        raise pickle.UnpicklingError(f"global '{module}.{name}' is forbidden")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def restricted_loads(s):
 | 
			
		||||
| 
						 | 
				
			
			@ -483,11 +481,11 @@ def stream_input(stream, queue):
 | 
			
		|||
    return thread
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tkinter_center_window(window: Tk):
 | 
			
		||||
def tkinter_center_window(window: "tkinter.Tk") -> None:
 | 
			
		||||
    window.update()
 | 
			
		||||
    xPos = int(window.winfo_screenwidth() / 2 - window.winfo_reqwidth() / 2)
 | 
			
		||||
    yPos = int(window.winfo_screenheight() / 2 - window.winfo_reqheight() / 2)
 | 
			
		||||
    window.geometry("+{}+{}".format(xPos, yPos))
 | 
			
		||||
    x = int(window.winfo_screenwidth() / 2 - window.winfo_reqwidth() / 2)
 | 
			
		||||
    y = int(window.winfo_screenheight() / 2 - window.winfo_reqheight() / 2)
 | 
			
		||||
    window.geometry(f"+{x}+{y}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VersionException(Exception):
 | 
			
		||||
| 
						 | 
				
			
			@ -516,13 +514,14 @@ def format_SI_prefix(value, power=1000, power_labels=("", "k", "M", "G", "T", "P
 | 
			
		|||
    return f"{value.quantize(decimal.Decimal('1.00'))} {chaining_prefix(n, power_labels)}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_fuzzy_ratio(word1: str, word2: str) -> float:
 | 
			
		||||
    return (1 - jellyfish.damerau_levenshtein_distance(word1.lower(), word2.lower())
 | 
			
		||||
            / max(len(word1), len(word2)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_fuzzy_results(input_word: str, wordlist: typing.Sequence[str], limit: typing.Optional[int] = None) \
 | 
			
		||||
        -> typing.List[typing.Tuple[str, int]]:
 | 
			
		||||
    import jellyfish
 | 
			
		||||
 | 
			
		||||
    def get_fuzzy_ratio(word1: str, word2: str) -> float:
 | 
			
		||||
        return (1 - jellyfish.damerau_levenshtein_distance(word1.lower(), word2.lower())
 | 
			
		||||
                / max(len(word1), len(word2)))
 | 
			
		||||
 | 
			
		||||
    limit: int = limit if limit else len(wordlist)
 | 
			
		||||
    return list(
 | 
			
		||||
        map(
 | 
			
		||||
| 
						 | 
				
			
			@ -540,18 +539,19 @@ def get_fuzzy_results(input_word: str, wordlist: typing.Sequence[str], limit: ty
 | 
			
		|||
def open_filename(title: str, filetypes: typing.Sequence[typing.Tuple[str, typing.Sequence[str]]]) \
 | 
			
		||||
        -> typing.Optional[str]:
 | 
			
		||||
    def run(*args: str):
 | 
			
		||||
        return subprocess.run(args, capture_output=True, text=True).stdout.split('\n', 1)[0] or None
 | 
			
		||||
        return subprocess.run(args, capture_output=True, text=True).stdout.split("\n", 1)[0] or None
 | 
			
		||||
 | 
			
		||||
    if is_linux:
 | 
			
		||||
        # prefer native dialog
 | 
			
		||||
        kdialog = shutil.which('kdialog')
 | 
			
		||||
        from shutil import which
 | 
			
		||||
        kdialog = which("kdialog")
 | 
			
		||||
        if kdialog:
 | 
			
		||||
            k_filters = '|'.join((f'{text} (*{" *".join(ext)})' for (text, ext) in filetypes))
 | 
			
		||||
            return run(kdialog, f'--title={title}', '--getopenfilename', '.', k_filters)
 | 
			
		||||
        zenity = shutil.which('zenity')
 | 
			
		||||
            return run(kdialog, f"--title={title}", "--getopenfilename", ".", k_filters)
 | 
			
		||||
        zenity = which("zenity")
 | 
			
		||||
        if zenity:
 | 
			
		||||
            z_filters = (f'--file-filter={text} ({", ".join(ext)}) | *{" *".join(ext)}' for (text, ext) in filetypes)
 | 
			
		||||
            return run(zenity, f'--title={title}', '--file-selection', *z_filters)
 | 
			
		||||
            return run(zenity, f"--title={title}", "--file-selection", *z_filters)
 | 
			
		||||
 | 
			
		||||
    # fall back to tk
 | 
			
		||||
    try:
 | 
			
		||||
| 
						 | 
				
			
			@ -569,10 +569,10 @@ def open_filename(title: str, filetypes: typing.Sequence[typing.Tuple[str, typin
 | 
			
		|||
 | 
			
		||||
def messagebox(title: str, text: str, error: bool = False) -> None:
 | 
			
		||||
    def run(*args: str):
 | 
			
		||||
        return subprocess.run(args, capture_output=True, text=True).stdout.split('\n', 1)[0] or None
 | 
			
		||||
        return subprocess.run(args, capture_output=True, text=True).stdout.split("\n", 1)[0] or None
 | 
			
		||||
 | 
			
		||||
    def is_kivy_running():
 | 
			
		||||
        if 'kivy' in sys.modules:
 | 
			
		||||
        if "kivy" in sys.modules:
 | 
			
		||||
            from kivy.app import App
 | 
			
		||||
            return App.get_running_app() is not None
 | 
			
		||||
        return False
 | 
			
		||||
| 
						 | 
				
			
			@ -582,14 +582,15 @@ def messagebox(title: str, text: str, error: bool = False) -> None:
 | 
			
		|||
        MessageBox(title, text, error).open()
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    if is_linux and not 'tkinter' in sys.modules:
 | 
			
		||||
    if is_linux and "tkinter" not in sys.modules:
 | 
			
		||||
        # prefer native dialog
 | 
			
		||||
        kdialog = shutil.which('kdialog')
 | 
			
		||||
        from shutil import which
 | 
			
		||||
        kdialog = which("kdialog")
 | 
			
		||||
        if kdialog:
 | 
			
		||||
            return run(kdialog, f'--title={title}', '--error' if error else '--msgbox', text)
 | 
			
		||||
        zenity = shutil.which('zenity')
 | 
			
		||||
            return run(kdialog, f"--title={title}", "--error" if error else "--msgbox", text)
 | 
			
		||||
        zenity = which("zenity")
 | 
			
		||||
        if zenity:
 | 
			
		||||
            return run(zenity, f'--title={title}', f'--text={text}', '--error' if error else '--info')
 | 
			
		||||
            return run(zenity, f"--title={title}", f"--text={text}", "--error" if error else "--info")
 | 
			
		||||
 | 
			
		||||
    # fall back to tk
 | 
			
		||||
    try:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue