Archipelago/NetUtils.py

53 lines
1.5 KiB
Python

from __future__ import annotations
import asyncio
import logging
import typing
from json import loads, dumps
import websockets
class Node:
endpoints: typing.List
dumper = staticmethod(dumps)
loader = staticmethod(loads)
def __init__(self):
self.endpoints = []
def broadcast_all(self, msgs):
msgs = self.dumper(msgs)
for endpoint in self.endpoints:
asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
async def send_msgs(self, endpoint: Endpoint, msgs):
if not endpoint.socket or not endpoint.socket.open or endpoint.socket.closed:
return
try:
await endpoint.socket.send(self.dumper(msgs))
except websockets.ConnectionClosed:
logging.exception("Exception during send_msgs")
await self.disconnect(endpoint)
async def send_encoded_msgs(self, endpoint: Endpoint, msg: str):
if not endpoint.socket or not endpoint.socket.open or endpoint.socket.closed:
return
try:
await endpoint.socket.send(msg)
except websockets.ConnectionClosed:
logging.exception("Exception during send_msgs")
await self.disconnect(endpoint)
async def disconnect(self, endpoint):
if endpoint in self.endpoints:
self.endpoints.remove(endpoint)
class Endpoint:
socket: websockets.WebSocketServerProtocol
def __init__(self, socket):
self.socket = socket
async def disconnect(self):
raise NotImplementedError