disconnect on send failure
This commit is contained in:
parent
b676d4131f
commit
3840832f05
|
@ -9,6 +9,7 @@ import zlib
|
|||
import collections
|
||||
import typing
|
||||
import inspect
|
||||
import weakref
|
||||
|
||||
import ModuleUpdate
|
||||
|
||||
|
@ -41,6 +42,13 @@ class Client:
|
|||
self.tags = []
|
||||
self.version = [0, 0, 0]
|
||||
self.messageprocessor = ClientMessageProcessor(ctx, self)
|
||||
self.ctx = weakref.ref(ctx)
|
||||
|
||||
async def disconnect(self):
|
||||
ctx = self.ctx()
|
||||
if ctx:
|
||||
await on_client_disconnected(ctx, self)
|
||||
ctx.clients.remove(self)
|
||||
|
||||
@property
|
||||
def wants_item_notification(self):
|
||||
|
@ -96,20 +104,25 @@ class Context:
|
|||
f'for {len(received_items)} players')
|
||||
|
||||
|
||||
async def send_msgs(websocket, msgs):
|
||||
async def send_msgs(client: Client, msgs):
|
||||
websocket = client.socket
|
||||
if not websocket or not websocket.open or websocket.closed:
|
||||
return
|
||||
await websocket.send(json.dumps(msgs))
|
||||
try:
|
||||
await websocket.send(json.dumps(msgs))
|
||||
except websockets.ConnectionClosed:
|
||||
logging.exception("Exception during send_msgs")
|
||||
await client.disconnect()
|
||||
|
||||
def broadcast_all(ctx : Context, msgs):
|
||||
for client in ctx.clients:
|
||||
if client.auth:
|
||||
asyncio.create_task(send_msgs(client.socket, msgs))
|
||||
asyncio.create_task(send_msgs(client, msgs))
|
||||
|
||||
def broadcast_team(ctx : Context, team, msgs):
|
||||
for client in ctx.clients:
|
||||
if client.auth and client.team == team:
|
||||
asyncio.create_task(send_msgs(client.socket, msgs))
|
||||
asyncio.create_task(send_msgs(client, msgs))
|
||||
|
||||
def notify_all(ctx : Context, text):
|
||||
logging.info("Notice (all): %s" % text)
|
||||
|
@ -125,7 +138,7 @@ def notify_client(client: Client, text: str):
|
|||
if not client.auth:
|
||||
return
|
||||
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
|
||||
asyncio.create_task(send_msgs(client.socket, [['Print', text]]))
|
||||
asyncio.create_task(send_msgs(client, [['Print', text]]))
|
||||
|
||||
|
||||
# separated out, due to compatibilty between clients
|
||||
|
@ -140,7 +153,7 @@ def notify_hints(ctx: Context, team: int, hints: typing.List[Utils.Hint]):
|
|||
payload = cmd
|
||||
else:
|
||||
payload = texts
|
||||
asyncio.create_task(send_msgs(client.socket, payload))
|
||||
asyncio.create_task(send_msgs(client, payload))
|
||||
|
||||
async def server(websocket, path, ctx: Context):
|
||||
client = Client(websocket, ctx)
|
||||
|
@ -161,11 +174,10 @@ async def server(websocket, path, ctx: Context):
|
|||
if not isinstance(e, websockets.WebSocketException):
|
||||
logging.exception(e)
|
||||
finally:
|
||||
await on_client_disconnected(ctx, client)
|
||||
ctx.clients.remove(client)
|
||||
await client.disconnect()
|
||||
|
||||
async def on_client_connected(ctx: Context, client: Client):
|
||||
await send_msgs(client.socket, [['RoomInfo', {
|
||||
await send_msgs(client, [['RoomInfo', {
|
||||
'password': ctx.password is not None,
|
||||
'players': [(client.team, client.slot, client.name) for client in ctx.clients if client.auth],
|
||||
# tags are for additional features in the communication.
|
||||
|
@ -230,7 +242,8 @@ def send_new_items(ctx: Context):
|
|||
continue
|
||||
items = get_received_items(ctx, client.team, client.slot)
|
||||
if len(items) > client.send_index:
|
||||
asyncio.create_task(send_msgs(client.socket, [['ReceivedItems', (client.send_index, tuplize_received_items(items)[client.send_index:])]]))
|
||||
asyncio.create_task(send_msgs(client, [
|
||||
['ReceivedItems', (client.send_index, tuplize_received_items(items)[client.send_index:])]]))
|
||||
client.send_index = len(items)
|
||||
|
||||
|
||||
|
@ -267,7 +280,7 @@ def register_location_checks(ctx: Context, team: int, slot: int, locations):
|
|||
for client in ctx.clients:
|
||||
if client.team == team and client.wants_item_notification:
|
||||
asyncio.create_task(
|
||||
send_msgs(client.socket, [['ItemFound', (target_item, location, slot)]]))
|
||||
send_msgs(client, [['ItemFound', (target_item, location, slot)]]))
|
||||
ctx.location_checks[team, slot] |= set(locations)
|
||||
send_new_items(ctx)
|
||||
|
||||
|
@ -522,14 +535,14 @@ class ClientMessageProcessor(CommandProcessor):
|
|||
|
||||
async def process_client_cmd(ctx: Context, client: Client, cmd, args):
|
||||
if type(cmd) is not str:
|
||||
await send_msgs(client.socket, [['InvalidCmd']])
|
||||
await send_msgs(client, [['InvalidCmd']])
|
||||
return
|
||||
|
||||
if cmd == 'Connect':
|
||||
if not args or type(args) is not dict or \
|
||||
'password' not in args or type(args['password']) not in [str, type(None)] or \
|
||||
'rom' not in args or type(args['rom']) is not list:
|
||||
await send_msgs(client.socket, [['InvalidArguments', 'Connect']])
|
||||
await send_msgs(client, [['InvalidArguments', 'Connect']])
|
||||
return
|
||||
|
||||
errors = set()
|
||||
|
@ -548,7 +561,7 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
|
|||
client.slot = slot
|
||||
|
||||
if errors:
|
||||
await send_msgs(client.socket, [['ConnectionRefused', list(errors)]])
|
||||
await send_msgs(client, [['ConnectionRefused', list(errors)]])
|
||||
else:
|
||||
client.auth = True
|
||||
client.version = args.get('version', Client.version)
|
||||
|
@ -559,7 +572,7 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
|
|||
if items:
|
||||
reply.append(['ReceivedItems', (0, tuplize_received_items(items))])
|
||||
client.send_index = len(items)
|
||||
await send_msgs(client.socket, reply)
|
||||
await send_msgs(client, reply)
|
||||
await on_client_joined(ctx, client)
|
||||
|
||||
if client.auth:
|
||||
|
@ -567,22 +580,22 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
|
|||
items = get_received_items(ctx, client.team, client.slot)
|
||||
if items:
|
||||
client.send_index = len(items)
|
||||
await send_msgs(client.socket, [['ReceivedItems', (0, tuplize_received_items(items))]])
|
||||
await send_msgs(client, [['ReceivedItems', (0, tuplize_received_items(items))]])
|
||||
|
||||
elif cmd == 'LocationChecks':
|
||||
if type(args) is not list:
|
||||
await send_msgs(client.socket, [['InvalidArguments', 'LocationChecks']])
|
||||
await send_msgs(client, [['InvalidArguments', 'LocationChecks']])
|
||||
return
|
||||
register_location_checks(ctx, client.team, client.slot, args)
|
||||
|
||||
elif cmd == 'LocationScouts':
|
||||
if type(args) is not list:
|
||||
await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']])
|
||||
await send_msgs(client, [['InvalidArguments', 'LocationScouts']])
|
||||
return
|
||||
locs = []
|
||||
for location in args:
|
||||
if type(location) is not int or 0 >= location > len(Regions.location_table):
|
||||
await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']])
|
||||
await send_msgs(client, [['InvalidArguments', 'LocationScouts']])
|
||||
return
|
||||
loc_name = list(Regions.location_table.keys())[location - 1]
|
||||
target_item, target_player = ctx.locations[(Regions.location_table[loc_name][0], client.slot)]
|
||||
|
@ -595,17 +608,17 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
|
|||
locs.append([loc_name, location, target_item, target_player])
|
||||
|
||||
# logging.info(f"{client.name} in team {client.team+1} scouted {', '.join([l[0] for l in locs])}")
|
||||
await send_msgs(client.socket, [['LocationInfo', [l[1:] for l in locs]]])
|
||||
await send_msgs(client, [['LocationInfo', [l[1:] for l in locs]]])
|
||||
|
||||
elif cmd == 'UpdateTags':
|
||||
if not args or type(args) is not list:
|
||||
await send_msgs(client.socket, [['InvalidArguments', 'UpdateTags']])
|
||||
await send_msgs(client, [['InvalidArguments', 'UpdateTags']])
|
||||
return
|
||||
client.tags = args
|
||||
|
||||
if cmd == 'Say':
|
||||
if type(args) is not str or not args.isprintable():
|
||||
await send_msgs(client.socket, [['InvalidArguments', 'Say']])
|
||||
await send_msgs(client, [['InvalidArguments', 'Say']])
|
||||
return
|
||||
|
||||
notify_all(ctx, client.name + ': ' + args)
|
||||
|
|
Loading…
Reference in New Issue