MultiServer: remove promp_toolkit

This commit is contained in:
Fabian Dill 2021-11-28 04:06:30 +01:00
parent d768379a8a
commit 7b0b243607
7 changed files with 45 additions and 43 deletions

View File

@ -15,7 +15,7 @@ if __name__ == "__main__":
from MultiServer import CommandProcessor from MultiServer import CommandProcessor
from NetUtils import Endpoint, decode, NetworkItem, encode, JSONtoTextParser, ClientStatus, Permission from NetUtils import Endpoint, decode, NetworkItem, encode, JSONtoTextParser, ClientStatus, Permission
from Utils import Version from Utils import Version, stream_input
from worlds import network_data_package, AutoWorldRegister from worlds import network_data_package, AutoWorldRegister
logger = logging.getLogger("Client") logger = logging.getLogger("Client")
@ -540,18 +540,6 @@ async def process_server_cmd(ctx: CommonContext, args: dict):
ctx.on_package(cmd, args) ctx.on_package(cmd, args)
def stream_input(stream, queue):
def queuer():
text = stream.readline().strip()
if text:
queue.put_nowait(text)
from threading import Thread
thread = Thread(target=queuer, name=f"Stream handler for {stream.name}", daemon=True)
thread.start()
return thread
async def console_loop(ctx: CommonContext): async def console_loop(ctx: CommonContext):
import sys import sys
commandprocessor = ctx.command_processor(ctx) commandprocessor = ctx.command_processor(ctx)
@ -560,7 +548,7 @@ async def console_loop(ctx: CommonContext):
while not ctx.exit_event.is_set(): while not ctx.exit_event.is_set():
try: try:
input_text = await queue.get() input_text = await queue.get()
input_text = input_text.strip() queue.task_done()
if ctx.input_requests > 0: if ctx.input_requests > 0:
ctx.input_requests -= 1 ctx.input_requests -= 1

View File

@ -191,6 +191,7 @@ async def factorio_server_watcher(ctx: FactorioContext):
while not factorio_queue.empty(): while not factorio_queue.empty():
msg = factorio_queue.get() msg = factorio_queue.get()
factorio_queue.task_done()
factorio_server_logger.info(msg) factorio_server_logger.info(msg)
if not ctx.rcon_client and "Starting RCON interface at IP ADDR:" in msg: if not ctx.rcon_client and "Starting RCON interface at IP ADDR:" in msg:
ctx.rcon_client = factorio_rcon.RCONClient("localhost", rcon_port, rcon_password) ctx.rcon_client = factorio_rcon.RCONClient("localhost", rcon_port, rcon_password)

View File

@ -119,7 +119,7 @@ class Context:
self.remaining_mode: str = remaining_mode self.remaining_mode: str = remaining_mode
self.collect_mode: str = collect_mode self.collect_mode: str = collect_mode
self.item_cheat = item_cheat self.item_cheat = item_cheat
self.running = True self.exit_event = asyncio.Event()
self.client_activity_timers: typing.Dict[ self.client_activity_timers: typing.Dict[
team_slot, datetime.datetime] = {} # datetime of last new item check team_slot, datetime.datetime] = {} # datetime of last new item check
self.client_connection_timers: typing.Dict[ self.client_connection_timers: typing.Dict[
@ -336,7 +336,7 @@ class Context:
if not self.auto_saver_thread: if not self.auto_saver_thread:
def save_regularly(): def save_regularly():
import time import time
while self.running: while not self.exit_event.is_set():
time.sleep(self.auto_save_interval) time.sleep(self.auto_save_interval)
if self.save_dirty: if self.save_dirty:
logging.debug("Saving via thread.") logging.debug("Saving via thread.")
@ -1409,7 +1409,7 @@ class ServerCommandProcessor(CommonCommandProcessor):
asyncio.create_task(self.ctx.server.ws_server._close()) asyncio.create_task(self.ctx.server.ws_server._close())
if self.ctx.shutdown_task: if self.ctx.shutdown_task:
self.ctx.shutdown_task.cancel() self.ctx.shutdown_task.cancel()
self.ctx.running = False self.ctx.exit_event.set()
return True return True
@mark_raw @mark_raw
@ -1566,11 +1566,17 @@ class ServerCommandProcessor(CommonCommandProcessor):
async def console(ctx: Context): async def console(ctx: Context):
session = prompt_toolkit.PromptSession() import sys
while ctx.running: queue = asyncio.Queue()
with patch_stdout(): Utils.stream_input(sys.stdin, queue)
input_text = await session.prompt_async() while not ctx.exit_event.is_set():
try: try:
# I don't get why this while loop is needed. Works fine without it on clients,
# but the queue.get() for server never fulfills if the queue is empty when entering the await.
while queue.qsize() == 0:
await asyncio.sleep(0.05)
input_text = await queue.get()
queue.task_done()
ctx.commandprocessor(input_text) ctx.commandprocessor(input_text)
except: except:
import traceback import traceback
@ -1636,10 +1642,10 @@ def parse_args() -> argparse.Namespace:
async def auto_shutdown(ctx, to_cancel=None): async def auto_shutdown(ctx, to_cancel=None):
await asyncio.sleep(ctx.auto_shutdown) await asyncio.sleep(ctx.auto_shutdown)
while ctx.running: while not ctx.exit_event.is_set():
if not ctx.client_activity_timers.values(): if not ctx.client_activity_timers.values():
asyncio.create_task(ctx.server.ws_server._close()) asyncio.create_task(ctx.server.ws_server._close())
ctx.running = False ctx.exit_event.set()
if to_cancel: if to_cancel:
for task in to_cancel: for task in to_cancel:
task.cancel() task.cancel()
@ -1650,7 +1656,7 @@ async def auto_shutdown(ctx, to_cancel=None):
seconds = ctx.auto_shutdown - delta.total_seconds() seconds = ctx.auto_shutdown - delta.total_seconds()
if seconds < 0: if seconds < 0:
asyncio.create_task(ctx.server.ws_server._close()) asyncio.create_task(ctx.server.ws_server._close())
ctx.running = False ctx.exit_event.set()
if to_cancel: if to_cancel:
for task in to_cancel: for task in to_cancel:
task.cancel() task.cancel()
@ -1694,7 +1700,8 @@ async def main(args: argparse.Namespace):
console_task = asyncio.create_task(console(ctx)) console_task = asyncio.create_task(console(ctx))
if ctx.auto_shutdown: if ctx.auto_shutdown:
ctx.shutdown_task = asyncio.create_task(auto_shutdown(ctx, [console_task])) ctx.shutdown_task = asyncio.create_task(auto_shutdown(ctx, [console_task]))
await console_task await ctx.exit_event.wait()
console_task.cancel()
if ctx.shutdown_task: if ctx.shutdown_task:
await ctx.shutdown_task await ctx.shutdown_task

View File

@ -680,11 +680,6 @@ async def snes_disconnect(ctx: Context):
async def snes_autoreconnect(ctx: Context): async def snes_autoreconnect(ctx: Context):
# unfortunately currently broken. See: https://github.com/prompt-toolkit/python-prompt-toolkit/issues/1033
# with prompt_toolkit.shortcuts.ProgressBar() as pb:
# for _ in pb(range(100)):
# await asyncio.sleep(RECONNECT_DELAY/100)
await asyncio.sleep(SNES_RECONNECT_DELAY) await asyncio.sleep(SNES_RECONNECT_DELAY)
if ctx.snes_reconnect_address and ctx.snes_socket is None: if ctx.snes_reconnect_address and ctx.snes_socket is None:
await snes_connect(ctx, ctx.snes_reconnect_address) await snes_connect(ctx, ctx.snes_reconnect_address)

View File

@ -1,6 +1,16 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
import builtins
import os
import subprocess
import sys
import pickle
import functools
import io
import collections
import importlib
import logging
def tuplize_version(version: str) -> Version: def tuplize_version(version: str) -> Version:
@ -16,17 +26,6 @@ class Version(typing.NamedTuple):
__version__ = "0.2.0" __version__ = "0.2.0"
version_tuple = tuplize_version(__version__) version_tuple = tuplize_version(__version__)
import builtins
import os
import subprocess
import sys
import pickle
import functools
import io
import collections
import importlib
import logging
from yaml import load, dump, safe_load from yaml import load, dump, safe_load
try: try:
@ -462,3 +461,16 @@ def init_logging(name: str, loglevel: typing.Union[str, int] = logging.INFO, wri
handle_exception._wrapped = True handle_exception._wrapped = True
sys.excepthook = handle_exception sys.excepthook = handle_exception
def stream_input(stream, queue):
def queuer():
while 1:
text = stream.readline().strip()
if text:
queue.put_nowait(text)
from threading import Thread
thread = Thread(target=queuer, name=f"Stream handler for {stream.name}", daemon=True)
thread.start()
return thread

View File

@ -56,7 +56,7 @@ class WebHostContext(Context):
def listen_to_db_commands(self): def listen_to_db_commands(self):
cmdprocessor = DBCommandProcessor(self) cmdprocessor = DBCommandProcessor(self)
while self.running: while not self.exit_event.is_set():
with db_session: with db_session:
commands = select(command for command in Command if command.room.id == self.room_id) commands = select(command for command in Command if command.room.id == self.room_id)
if commands: if commands:

View File

@ -2,7 +2,6 @@ colorama>=0.4.4
websockets>=10.1 websockets>=10.1
PyYAML>=6.0 PyYAML>=6.0
fuzzywuzzy>=0.18.0 fuzzywuzzy>=0.18.0
prompt_toolkit>=3.0.23
appdirs>=1.4.4 appdirs>=1.4.4
jinja2>=3.0.3 jinja2>=3.0.3
schema>=0.7.4 schema>=0.7.4