disconnect on send failure

This commit is contained in:
Fabian Dill 2020-04-19 14:05:58 +02:00
parent b676d4131f
commit 3840832f05
1 changed files with 35 additions and 22 deletions

View File

@ -9,6 +9,7 @@ import zlib
import collections
import typing
import inspect
import weakref
import ModuleUpdate
@ -41,6 +42,13 @@ class Client:
self.tags = []
self.version = [0, 0, 0]
self.messageprocessor = ClientMessageProcessor(ctx, self)
self.ctx = weakref.ref(ctx)
async def disconnect(self):
ctx = self.ctx()
if ctx:
await on_client_disconnected(ctx, self)
ctx.clients.remove(self)
@property
def wants_item_notification(self):
@ -96,20 +104,25 @@ class Context:
f'for {len(received_items)} players')
async def send_msgs(websocket, msgs):
async def send_msgs(client: Client, msgs):
websocket = client.socket
if not websocket or not websocket.open or websocket.closed:
return
await websocket.send(json.dumps(msgs))
try:
await websocket.send(json.dumps(msgs))
except websockets.ConnectionClosed:
logging.exception("Exception during send_msgs")
await client.disconnect()
def broadcast_all(ctx : Context, msgs):
for client in ctx.clients:
if client.auth:
asyncio.create_task(send_msgs(client.socket, msgs))
asyncio.create_task(send_msgs(client, msgs))
def broadcast_team(ctx : Context, team, msgs):
for client in ctx.clients:
if client.auth and client.team == team:
asyncio.create_task(send_msgs(client.socket, msgs))
asyncio.create_task(send_msgs(client, msgs))
def notify_all(ctx : Context, text):
logging.info("Notice (all): %s" % text)
@ -125,7 +138,7 @@ def notify_client(client: Client, text: str):
if not client.auth:
return
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
asyncio.create_task(send_msgs(client.socket, [['Print', text]]))
asyncio.create_task(send_msgs(client, [['Print', text]]))
# separated out, due to compatibilty between clients
@ -140,7 +153,7 @@ def notify_hints(ctx: Context, team: int, hints: typing.List[Utils.Hint]):
payload = cmd
else:
payload = texts
asyncio.create_task(send_msgs(client.socket, payload))
asyncio.create_task(send_msgs(client, payload))
async def server(websocket, path, ctx: Context):
client = Client(websocket, ctx)
@ -161,11 +174,10 @@ async def server(websocket, path, ctx: Context):
if not isinstance(e, websockets.WebSocketException):
logging.exception(e)
finally:
await on_client_disconnected(ctx, client)
ctx.clients.remove(client)
await client.disconnect()
async def on_client_connected(ctx: Context, client: Client):
await send_msgs(client.socket, [['RoomInfo', {
await send_msgs(client, [['RoomInfo', {
'password': ctx.password is not None,
'players': [(client.team, client.slot, client.name) for client in ctx.clients if client.auth],
# tags are for additional features in the communication.
@ -230,7 +242,8 @@ def send_new_items(ctx: Context):
continue
items = get_received_items(ctx, client.team, client.slot)
if len(items) > client.send_index:
asyncio.create_task(send_msgs(client.socket, [['ReceivedItems', (client.send_index, tuplize_received_items(items)[client.send_index:])]]))
asyncio.create_task(send_msgs(client, [
['ReceivedItems', (client.send_index, tuplize_received_items(items)[client.send_index:])]]))
client.send_index = len(items)
@ -267,7 +280,7 @@ def register_location_checks(ctx: Context, team: int, slot: int, locations):
for client in ctx.clients:
if client.team == team and client.wants_item_notification:
asyncio.create_task(
send_msgs(client.socket, [['ItemFound', (target_item, location, slot)]]))
send_msgs(client, [['ItemFound', (target_item, location, slot)]]))
ctx.location_checks[team, slot] |= set(locations)
send_new_items(ctx)
@ -522,14 +535,14 @@ class ClientMessageProcessor(CommandProcessor):
async def process_client_cmd(ctx: Context, client: Client, cmd, args):
if type(cmd) is not str:
await send_msgs(client.socket, [['InvalidCmd']])
await send_msgs(client, [['InvalidCmd']])
return
if cmd == 'Connect':
if not args or type(args) is not dict or \
'password' not in args or type(args['password']) not in [str, type(None)] or \
'rom' not in args or type(args['rom']) is not list:
await send_msgs(client.socket, [['InvalidArguments', 'Connect']])
await send_msgs(client, [['InvalidArguments', 'Connect']])
return
errors = set()
@ -548,7 +561,7 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
client.slot = slot
if errors:
await send_msgs(client.socket, [['ConnectionRefused', list(errors)]])
await send_msgs(client, [['ConnectionRefused', list(errors)]])
else:
client.auth = True
client.version = args.get('version', Client.version)
@ -559,7 +572,7 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
if items:
reply.append(['ReceivedItems', (0, tuplize_received_items(items))])
client.send_index = len(items)
await send_msgs(client.socket, reply)
await send_msgs(client, reply)
await on_client_joined(ctx, client)
if client.auth:
@ -567,22 +580,22 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
items = get_received_items(ctx, client.team, client.slot)
if items:
client.send_index = len(items)
await send_msgs(client.socket, [['ReceivedItems', (0, tuplize_received_items(items))]])
await send_msgs(client, [['ReceivedItems', (0, tuplize_received_items(items))]])
elif cmd == 'LocationChecks':
if type(args) is not list:
await send_msgs(client.socket, [['InvalidArguments', 'LocationChecks']])
await send_msgs(client, [['InvalidArguments', 'LocationChecks']])
return
register_location_checks(ctx, client.team, client.slot, args)
elif cmd == 'LocationScouts':
if type(args) is not list:
await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']])
await send_msgs(client, [['InvalidArguments', 'LocationScouts']])
return
locs = []
for location in args:
if type(location) is not int or 0 >= location > len(Regions.location_table):
await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']])
await send_msgs(client, [['InvalidArguments', 'LocationScouts']])
return
loc_name = list(Regions.location_table.keys())[location - 1]
target_item, target_player = ctx.locations[(Regions.location_table[loc_name][0], client.slot)]
@ -595,17 +608,17 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
locs.append([loc_name, location, target_item, target_player])
# logging.info(f"{client.name} in team {client.team+1} scouted {', '.join([l[0] for l in locs])}")
await send_msgs(client.socket, [['LocationInfo', [l[1:] for l in locs]]])
await send_msgs(client, [['LocationInfo', [l[1:] for l in locs]]])
elif cmd == 'UpdateTags':
if not args or type(args) is not list:
await send_msgs(client.socket, [['InvalidArguments', 'UpdateTags']])
await send_msgs(client, [['InvalidArguments', 'UpdateTags']])
return
client.tags = args
if cmd == 'Say':
if type(args) is not str or not args.isprintable():
await send_msgs(client.socket, [['InvalidArguments', 'Say']])
await send_msgs(client, [['InvalidArguments', 'Say']])
return
notify_all(ctx, client.name + ': ' + args)