disconnect on send failure

This commit is contained in:
Fabian Dill 2020-04-19 14:05:58 +02:00
parent b676d4131f
commit 3840832f05
1 changed files with 35 additions and 22 deletions

View File

@ -9,6 +9,7 @@ import zlib
import collections import collections
import typing import typing
import inspect import inspect
import weakref
import ModuleUpdate import ModuleUpdate
@ -41,6 +42,13 @@ class Client:
self.tags = [] self.tags = []
self.version = [0, 0, 0] self.version = [0, 0, 0]
self.messageprocessor = ClientMessageProcessor(ctx, self) self.messageprocessor = ClientMessageProcessor(ctx, self)
self.ctx = weakref.ref(ctx)
async def disconnect(self):
ctx = self.ctx()
if ctx:
await on_client_disconnected(ctx, self)
ctx.clients.remove(self)
@property @property
def wants_item_notification(self): def wants_item_notification(self):
@ -96,20 +104,25 @@ class Context:
f'for {len(received_items)} players') f'for {len(received_items)} players')
async def send_msgs(websocket, msgs): async def send_msgs(client: Client, msgs):
websocket = client.socket
if not websocket or not websocket.open or websocket.closed: if not websocket or not websocket.open or websocket.closed:
return return
await websocket.send(json.dumps(msgs)) try:
await websocket.send(json.dumps(msgs))
except websockets.ConnectionClosed:
logging.exception("Exception during send_msgs")
await client.disconnect()
def broadcast_all(ctx : Context, msgs): def broadcast_all(ctx : Context, msgs):
for client in ctx.clients: for client in ctx.clients:
if client.auth: if client.auth:
asyncio.create_task(send_msgs(client.socket, msgs)) asyncio.create_task(send_msgs(client, msgs))
def broadcast_team(ctx : Context, team, msgs): def broadcast_team(ctx : Context, team, msgs):
for client in ctx.clients: for client in ctx.clients:
if client.auth and client.team == team: if client.auth and client.team == team:
asyncio.create_task(send_msgs(client.socket, msgs)) asyncio.create_task(send_msgs(client, msgs))
def notify_all(ctx : Context, text): def notify_all(ctx : Context, text):
logging.info("Notice (all): %s" % text) logging.info("Notice (all): %s" % text)
@ -125,7 +138,7 @@ def notify_client(client: Client, text: str):
if not client.auth: if not client.auth:
return return
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text)) logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
asyncio.create_task(send_msgs(client.socket, [['Print', text]])) asyncio.create_task(send_msgs(client, [['Print', text]]))
# separated out, due to compatibilty between clients # separated out, due to compatibilty between clients
@ -140,7 +153,7 @@ def notify_hints(ctx: Context, team: int, hints: typing.List[Utils.Hint]):
payload = cmd payload = cmd
else: else:
payload = texts payload = texts
asyncio.create_task(send_msgs(client.socket, payload)) asyncio.create_task(send_msgs(client, payload))
async def server(websocket, path, ctx: Context): async def server(websocket, path, ctx: Context):
client = Client(websocket, ctx) client = Client(websocket, ctx)
@ -161,11 +174,10 @@ async def server(websocket, path, ctx: Context):
if not isinstance(e, websockets.WebSocketException): if not isinstance(e, websockets.WebSocketException):
logging.exception(e) logging.exception(e)
finally: finally:
await on_client_disconnected(ctx, client) await client.disconnect()
ctx.clients.remove(client)
async def on_client_connected(ctx: Context, client: Client): async def on_client_connected(ctx: Context, client: Client):
await send_msgs(client.socket, [['RoomInfo', { await send_msgs(client, [['RoomInfo', {
'password': ctx.password is not None, 'password': ctx.password is not None,
'players': [(client.team, client.slot, client.name) for client in ctx.clients if client.auth], 'players': [(client.team, client.slot, client.name) for client in ctx.clients if client.auth],
# tags are for additional features in the communication. # tags are for additional features in the communication.
@ -230,7 +242,8 @@ def send_new_items(ctx: Context):
continue continue
items = get_received_items(ctx, client.team, client.slot) items = get_received_items(ctx, client.team, client.slot)
if len(items) > client.send_index: if len(items) > client.send_index:
asyncio.create_task(send_msgs(client.socket, [['ReceivedItems', (client.send_index, tuplize_received_items(items)[client.send_index:])]])) asyncio.create_task(send_msgs(client, [
['ReceivedItems', (client.send_index, tuplize_received_items(items)[client.send_index:])]]))
client.send_index = len(items) client.send_index = len(items)
@ -267,7 +280,7 @@ def register_location_checks(ctx: Context, team: int, slot: int, locations):
for client in ctx.clients: for client in ctx.clients:
if client.team == team and client.wants_item_notification: if client.team == team and client.wants_item_notification:
asyncio.create_task( asyncio.create_task(
send_msgs(client.socket, [['ItemFound', (target_item, location, slot)]])) send_msgs(client, [['ItemFound', (target_item, location, slot)]]))
ctx.location_checks[team, slot] |= set(locations) ctx.location_checks[team, slot] |= set(locations)
send_new_items(ctx) send_new_items(ctx)
@ -522,14 +535,14 @@ class ClientMessageProcessor(CommandProcessor):
async def process_client_cmd(ctx: Context, client: Client, cmd, args): async def process_client_cmd(ctx: Context, client: Client, cmd, args):
if type(cmd) is not str: if type(cmd) is not str:
await send_msgs(client.socket, [['InvalidCmd']]) await send_msgs(client, [['InvalidCmd']])
return return
if cmd == 'Connect': if cmd == 'Connect':
if not args or type(args) is not dict or \ if not args or type(args) is not dict or \
'password' not in args or type(args['password']) not in [str, type(None)] or \ 'password' not in args or type(args['password']) not in [str, type(None)] or \
'rom' not in args or type(args['rom']) is not list: 'rom' not in args or type(args['rom']) is not list:
await send_msgs(client.socket, [['InvalidArguments', 'Connect']]) await send_msgs(client, [['InvalidArguments', 'Connect']])
return return
errors = set() errors = set()
@ -548,7 +561,7 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
client.slot = slot client.slot = slot
if errors: if errors:
await send_msgs(client.socket, [['ConnectionRefused', list(errors)]]) await send_msgs(client, [['ConnectionRefused', list(errors)]])
else: else:
client.auth = True client.auth = True
client.version = args.get('version', Client.version) client.version = args.get('version', Client.version)
@ -559,7 +572,7 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
if items: if items:
reply.append(['ReceivedItems', (0, tuplize_received_items(items))]) reply.append(['ReceivedItems', (0, tuplize_received_items(items))])
client.send_index = len(items) client.send_index = len(items)
await send_msgs(client.socket, reply) await send_msgs(client, reply)
await on_client_joined(ctx, client) await on_client_joined(ctx, client)
if client.auth: if client.auth:
@ -567,22 +580,22 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
items = get_received_items(ctx, client.team, client.slot) items = get_received_items(ctx, client.team, client.slot)
if items: if items:
client.send_index = len(items) client.send_index = len(items)
await send_msgs(client.socket, [['ReceivedItems', (0, tuplize_received_items(items))]]) await send_msgs(client, [['ReceivedItems', (0, tuplize_received_items(items))]])
elif cmd == 'LocationChecks': elif cmd == 'LocationChecks':
if type(args) is not list: if type(args) is not list:
await send_msgs(client.socket, [['InvalidArguments', 'LocationChecks']]) await send_msgs(client, [['InvalidArguments', 'LocationChecks']])
return return
register_location_checks(ctx, client.team, client.slot, args) register_location_checks(ctx, client.team, client.slot, args)
elif cmd == 'LocationScouts': elif cmd == 'LocationScouts':
if type(args) is not list: if type(args) is not list:
await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']]) await send_msgs(client, [['InvalidArguments', 'LocationScouts']])
return return
locs = [] locs = []
for location in args: for location in args:
if type(location) is not int or 0 >= location > len(Regions.location_table): if type(location) is not int or 0 >= location > len(Regions.location_table):
await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']]) await send_msgs(client, [['InvalidArguments', 'LocationScouts']])
return return
loc_name = list(Regions.location_table.keys())[location - 1] loc_name = list(Regions.location_table.keys())[location - 1]
target_item, target_player = ctx.locations[(Regions.location_table[loc_name][0], client.slot)] target_item, target_player = ctx.locations[(Regions.location_table[loc_name][0], client.slot)]
@ -595,17 +608,17 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
locs.append([loc_name, location, target_item, target_player]) locs.append([loc_name, location, target_item, target_player])
# logging.info(f"{client.name} in team {client.team+1} scouted {', '.join([l[0] for l in locs])}") # logging.info(f"{client.name} in team {client.team+1} scouted {', '.join([l[0] for l in locs])}")
await send_msgs(client.socket, [['LocationInfo', [l[1:] for l in locs]]]) await send_msgs(client, [['LocationInfo', [l[1:] for l in locs]]])
elif cmd == 'UpdateTags': elif cmd == 'UpdateTags':
if not args or type(args) is not list: if not args or type(args) is not list:
await send_msgs(client.socket, [['InvalidArguments', 'UpdateTags']]) await send_msgs(client, [['InvalidArguments', 'UpdateTags']])
return return
client.tags = args client.tags = args
if cmd == 'Say': if cmd == 'Say':
if type(args) is not str or not args.isprintable(): if type(args) is not str or not args.isprintable():
await send_msgs(client.socket, [['InvalidArguments', 'Say']]) await send_msgs(client, [['InvalidArguments', 'Say']])
return return
notify_all(ctx, client.name + ': ' + args) notify_all(ctx, client.name + ': ' + args)