233 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			233 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
| import asyncio
 | |
| import Utils
 | |
| import websockets
 | |
| import functools
 | |
| from copy import deepcopy
 | |
| from typing import List, Any, Iterable
 | |
| from NetUtils import decode, encode, JSONtoTextParser, JSONMessagePart, NetworkItem
 | |
| from MultiServer import Endpoint
 | |
| from CommonClient import CommonContext, gui_enabled, ClientCommandProcessor, logger, get_base_parser
 | |
| 
 | |
| DEBUG = False
 | |
| 
 | |
| 
 | |
| class AHITJSONToTextParser(JSONtoTextParser):
 | |
|     def _handle_color(self, node: JSONMessagePart):
 | |
|         return self._handle_text(node)  # No colors for the in-game text
 | |
| 
 | |
| 
 | |
| class AHITCommandProcessor(ClientCommandProcessor):
 | |
|     def _cmd_ahit(self):
 | |
|         """Check AHIT Connection State"""
 | |
|         if isinstance(self.ctx, AHITContext):
 | |
|             logger.info(f"AHIT Status: {self.ctx.get_ahit_status()}")
 | |
| 
 | |
| 
 | |
| class AHITContext(CommonContext):
 | |
|     command_processor = AHITCommandProcessor
 | |
|     game = "A Hat in Time"
 | |
| 
 | |
|     def __init__(self, server_address, password):
 | |
|         super().__init__(server_address, password)
 | |
|         self.proxy = None
 | |
|         self.proxy_task = None
 | |
|         self.gamejsontotext = AHITJSONToTextParser(self)
 | |
|         self.autoreconnect_task = None
 | |
|         self.endpoint = None
 | |
|         self.items_handling = 0b111
 | |
|         self.room_info = None
 | |
|         self.connected_msg = None
 | |
|         self.game_connected = False
 | |
|         self.awaiting_info = False
 | |
|         self.full_inventory: List[Any] = []
 | |
|         self.server_msgs: List[Any] = []
 | |
| 
 | |
|     async def server_auth(self, password_requested: bool = False):
 | |
|         if password_requested and not self.password:
 | |
|             await super(AHITContext, self).server_auth(password_requested)
 | |
| 
 | |
|         await self.get_username()
 | |
|         await self.send_connect()
 | |
| 
 | |
|     def get_ahit_status(self) -> str:
 | |
|         if not self.is_proxy_connected():
 | |
|             return "Not connected to A Hat in Time"
 | |
| 
 | |
|         return "Connected to A Hat in Time"
 | |
| 
 | |
|     async def send_msgs_proxy(self, msgs: Iterable[dict]) -> bool:
 | |
|         """ `msgs` JSON serializable """
 | |
|         if not self.endpoint or not self.endpoint.socket.open or self.endpoint.socket.closed:
 | |
|             return False
 | |
| 
 | |
|         if DEBUG:
 | |
|             logger.info(f"Outgoing message: {msgs}")
 | |
| 
 | |
|         await self.endpoint.socket.send(msgs)
 | |
|         return True
 | |
| 
 | |
|     async def disconnect(self, allow_autoreconnect: bool = False):
 | |
|         await super().disconnect(allow_autoreconnect)
 | |
| 
 | |
|     async def disconnect_proxy(self):
 | |
|         if self.endpoint and not self.endpoint.socket.closed:
 | |
|             await self.endpoint.socket.close()
 | |
|         if self.proxy_task is not None:
 | |
|             await self.proxy_task
 | |
| 
 | |
|     def is_connected(self) -> bool:
 | |
|         return self.server and self.server.socket.open
 | |
| 
 | |
|     def is_proxy_connected(self) -> bool:
 | |
|         return self.endpoint and self.endpoint.socket.open
 | |
| 
 | |
|     def on_print_json(self, args: dict):
 | |
|         text = self.gamejsontotext(deepcopy(args["data"]))
 | |
|         msg = {"cmd": "PrintJSON", "data": [{"text": text}], "type": "Chat"}
 | |
|         self.server_msgs.append(encode([msg]))
 | |
| 
 | |
|         if self.ui:
 | |
|             self.ui.print_json(args["data"])
 | |
|         else:
 | |
|             text = self.jsontotextparser(args["data"])
 | |
|             logger.info(text)
 | |
| 
 | |
|     def update_items(self):
 | |
|         # just to be safe - we might still have an inventory from a different room
 | |
|         if not self.is_connected():
 | |
|             return
 | |
| 
 | |
|         self.server_msgs.append(encode([{"cmd": "ReceivedItems", "index": 0, "items": self.full_inventory}]))
 | |
| 
 | |
|     def on_package(self, cmd: str, args: dict):
 | |
|         if cmd == "Connected":
 | |
|             self.connected_msg = encode([args])
 | |
|             if self.awaiting_info:
 | |
|                 self.server_msgs.append(self.room_info)
 | |
|                 self.update_items()
 | |
|                 self.awaiting_info = False
 | |
| 
 | |
|         elif cmd == "ReceivedItems":
 | |
|             if args["index"] == 0:
 | |
|                 self.full_inventory.clear()
 | |
| 
 | |
|             for item in args["items"]:
 | |
|                 self.full_inventory.append(NetworkItem(*item))
 | |
| 
 | |
|             self.server_msgs.append(encode([args]))
 | |
| 
 | |
|         elif cmd == "RoomInfo":
 | |
|             self.seed_name = args["seed_name"]
 | |
|             self.room_info = encode([args])
 | |
| 
 | |
|         else:
 | |
|             if cmd != "PrintJSON":
 | |
|                 self.server_msgs.append(encode([args]))
 | |
| 
 | |
|     def run_gui(self):
 | |
|         from kvui import GameManager
 | |
| 
 | |
|         class AHITManager(GameManager):
 | |
|             logging_pairs = [
 | |
|                 ("Client", "Archipelago")
 | |
|             ]
 | |
|             base_title = "Archipelago A Hat in Time Client"
 | |
| 
 | |
|         self.ui = AHITManager(self)
 | |
|         self.ui_task = asyncio.create_task(self.ui.async_run(), name="UI")
 | |
| 
 | |
| 
 | |
| async def proxy(websocket, path: str = "/", ctx: AHITContext = None):
 | |
|     ctx.endpoint = Endpoint(websocket)
 | |
|     try:
 | |
|         await on_client_connected(ctx)
 | |
| 
 | |
|         if ctx.is_proxy_connected():
 | |
|             async for data in websocket:
 | |
|                 if DEBUG:
 | |
|                     logger.info(f"Incoming message: {data}")
 | |
| 
 | |
|                 for msg in decode(data):
 | |
|                     if msg["cmd"] == "Connect":
 | |
|                         # Proxy is connecting, make sure it is valid
 | |
|                         if msg["game"] != "A Hat in Time":
 | |
|                             logger.info("Aborting proxy connection: game is not A Hat in Time")
 | |
|                             await ctx.disconnect_proxy()
 | |
|                             break
 | |
| 
 | |
|                         if ctx.seed_name:
 | |
|                             seed_name = msg.get("seed_name", "")
 | |
|                             if seed_name != "" and seed_name != ctx.seed_name:
 | |
|                                 logger.info("Aborting proxy connection: seed mismatch from save file")
 | |
|                                 logger.info(f"Expected: {ctx.seed_name}, got: {seed_name}")
 | |
|                                 text = encode([{"cmd": "PrintJSON",
 | |
|                                                 "data": [{"text": "Connection aborted - save file to seed mismatch"}]}])
 | |
|                                 await ctx.send_msgs_proxy(text)
 | |
|                                 await ctx.disconnect_proxy()
 | |
|                                 break
 | |
| 
 | |
|                         if ctx.connected_msg and ctx.is_connected():
 | |
|                             await ctx.send_msgs_proxy(ctx.connected_msg)
 | |
|                             ctx.update_items()
 | |
|                         continue
 | |
| 
 | |
|                     if not ctx.is_proxy_connected():
 | |
|                         break
 | |
| 
 | |
|                     await ctx.send_msgs([msg])
 | |
| 
 | |
|     except Exception as e:
 | |
|         if not isinstance(e, websockets.WebSocketException):
 | |
|             logger.exception(e)
 | |
|     finally:
 | |
|         await ctx.disconnect_proxy()
 | |
| 
 | |
| 
 | |
| async def on_client_connected(ctx: AHITContext):
 | |
|     if ctx.room_info and ctx.is_connected():
 | |
|         await ctx.send_msgs_proxy(ctx.room_info)
 | |
|     else:
 | |
|         ctx.awaiting_info = True
 | |
| 
 | |
| 
 | |
| async def proxy_loop(ctx: AHITContext):
 | |
|     try:
 | |
|         while not ctx.exit_event.is_set():
 | |
|             if len(ctx.server_msgs) > 0:
 | |
|                 for msg in ctx.server_msgs:
 | |
|                     await ctx.send_msgs_proxy(msg)
 | |
| 
 | |
|                 ctx.server_msgs.clear()
 | |
|             await asyncio.sleep(0.1)
 | |
|     except Exception as e:
 | |
|         logger.exception(e)
 | |
|         logger.info("Aborting AHIT Proxy Client due to errors")
 | |
| 
 | |
| 
 | |
| def launch():
 | |
|     async def main():
 | |
|         parser = get_base_parser()
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         ctx = AHITContext(args.connect, args.password)
 | |
|         logger.info("Starting A Hat in Time proxy server")
 | |
|         ctx.proxy = websockets.serve(functools.partial(proxy, ctx=ctx),
 | |
|                                      host="localhost", port=11311, ping_timeout=999999, ping_interval=999999)
 | |
|         ctx.proxy_task = asyncio.create_task(proxy_loop(ctx), name="ProxyLoop")
 | |
| 
 | |
|         if gui_enabled:
 | |
|             ctx.run_gui()
 | |
|         ctx.run_cli()
 | |
| 
 | |
|         await ctx.proxy
 | |
|         await ctx.proxy_task
 | |
|         await ctx.exit_event.wait()
 | |
| 
 | |
|     Utils.init_logging("AHITClient")
 | |
|     # options = Utils.get_options()
 | |
| 
 | |
|     import colorama
 | |
|     colorama.init()
 | |
|     asyncio.run(main())
 | |
|     colorama.deinit()
 |