from __future__ import annotations import asyncio import json import logging import typing import websockets class Node: endpoints: typing.List def __init__(self): self.endpoints = [] super(Node, self).__init__() def broadcast_all(self, msgs): msgs = json.dumps(msgs) for endpoint in self.endpoints: asyncio.create_task(self.send_json_msgs(endpoint, msgs)) async def send_msgs(self, endpoint: Endpoint, msgs): if not endpoint.socket or not endpoint.socket.open or endpoint.socket.closed: return try: await endpoint.socket.send(json.dumps(msgs)) except websockets.ConnectionClosed: logging.exception("Exception during send_msgs") await self.disconnect(endpoint) async def send_json_msgs(self, endpoint: Endpoint, msg: str): if not endpoint.socket or not endpoint.socket.open or endpoint.socket.closed: return try: await endpoint.socket.send(msg) except websockets.ConnectionClosed: logging.exception("Exception during send_msgs") await self.disconnect(endpoint) async def disconnect(self, endpoint): if endpoint in self.endpoints: self.endpoints.remove(endpoint) class Endpoint: socket: websockets.WebSocketServerProtocol def __init__(self, socket): self.socket = socket async def disconnect(self): raise NotImplementedError