disconnect on send failure
This commit is contained in:
parent
b676d4131f
commit
3840832f05
|
@ -9,6 +9,7 @@ import zlib
|
||||||
import collections
|
import collections
|
||||||
import typing
|
import typing
|
||||||
import inspect
|
import inspect
|
||||||
|
import weakref
|
||||||
|
|
||||||
import ModuleUpdate
|
import ModuleUpdate
|
||||||
|
|
||||||
|
@ -41,6 +42,13 @@ class Client:
|
||||||
self.tags = []
|
self.tags = []
|
||||||
self.version = [0, 0, 0]
|
self.version = [0, 0, 0]
|
||||||
self.messageprocessor = ClientMessageProcessor(ctx, self)
|
self.messageprocessor = ClientMessageProcessor(ctx, self)
|
||||||
|
self.ctx = weakref.ref(ctx)
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
ctx = self.ctx()
|
||||||
|
if ctx:
|
||||||
|
await on_client_disconnected(ctx, self)
|
||||||
|
ctx.clients.remove(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def wants_item_notification(self):
|
def wants_item_notification(self):
|
||||||
|
@ -96,20 +104,25 @@ class Context:
|
||||||
f'for {len(received_items)} players')
|
f'for {len(received_items)} players')
|
||||||
|
|
||||||
|
|
||||||
async def send_msgs(websocket, msgs):
|
async def send_msgs(client: Client, msgs):
|
||||||
|
websocket = client.socket
|
||||||
if not websocket or not websocket.open or websocket.closed:
|
if not websocket or not websocket.open or websocket.closed:
|
||||||
return
|
return
|
||||||
await websocket.send(json.dumps(msgs))
|
try:
|
||||||
|
await websocket.send(json.dumps(msgs))
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
logging.exception("Exception during send_msgs")
|
||||||
|
await client.disconnect()
|
||||||
|
|
||||||
def broadcast_all(ctx : Context, msgs):
|
def broadcast_all(ctx : Context, msgs):
|
||||||
for client in ctx.clients:
|
for client in ctx.clients:
|
||||||
if client.auth:
|
if client.auth:
|
||||||
asyncio.create_task(send_msgs(client.socket, msgs))
|
asyncio.create_task(send_msgs(client, msgs))
|
||||||
|
|
||||||
def broadcast_team(ctx : Context, team, msgs):
|
def broadcast_team(ctx : Context, team, msgs):
|
||||||
for client in ctx.clients:
|
for client in ctx.clients:
|
||||||
if client.auth and client.team == team:
|
if client.auth and client.team == team:
|
||||||
asyncio.create_task(send_msgs(client.socket, msgs))
|
asyncio.create_task(send_msgs(client, msgs))
|
||||||
|
|
||||||
def notify_all(ctx : Context, text):
|
def notify_all(ctx : Context, text):
|
||||||
logging.info("Notice (all): %s" % text)
|
logging.info("Notice (all): %s" % text)
|
||||||
|
@ -125,7 +138,7 @@ def notify_client(client: Client, text: str):
|
||||||
if not client.auth:
|
if not client.auth:
|
||||||
return
|
return
|
||||||
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
|
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
|
||||||
asyncio.create_task(send_msgs(client.socket, [['Print', text]]))
|
asyncio.create_task(send_msgs(client, [['Print', text]]))
|
||||||
|
|
||||||
|
|
||||||
# separated out, due to compatibilty between clients
|
# separated out, due to compatibilty between clients
|
||||||
|
@ -140,7 +153,7 @@ def notify_hints(ctx: Context, team: int, hints: typing.List[Utils.Hint]):
|
||||||
payload = cmd
|
payload = cmd
|
||||||
else:
|
else:
|
||||||
payload = texts
|
payload = texts
|
||||||
asyncio.create_task(send_msgs(client.socket, payload))
|
asyncio.create_task(send_msgs(client, payload))
|
||||||
|
|
||||||
async def server(websocket, path, ctx: Context):
|
async def server(websocket, path, ctx: Context):
|
||||||
client = Client(websocket, ctx)
|
client = Client(websocket, ctx)
|
||||||
|
@ -161,11 +174,10 @@ async def server(websocket, path, ctx: Context):
|
||||||
if not isinstance(e, websockets.WebSocketException):
|
if not isinstance(e, websockets.WebSocketException):
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
finally:
|
finally:
|
||||||
await on_client_disconnected(ctx, client)
|
await client.disconnect()
|
||||||
ctx.clients.remove(client)
|
|
||||||
|
|
||||||
async def on_client_connected(ctx: Context, client: Client):
|
async def on_client_connected(ctx: Context, client: Client):
|
||||||
await send_msgs(client.socket, [['RoomInfo', {
|
await send_msgs(client, [['RoomInfo', {
|
||||||
'password': ctx.password is not None,
|
'password': ctx.password is not None,
|
||||||
'players': [(client.team, client.slot, client.name) for client in ctx.clients if client.auth],
|
'players': [(client.team, client.slot, client.name) for client in ctx.clients if client.auth],
|
||||||
# tags are for additional features in the communication.
|
# tags are for additional features in the communication.
|
||||||
|
@ -230,7 +242,8 @@ def send_new_items(ctx: Context):
|
||||||
continue
|
continue
|
||||||
items = get_received_items(ctx, client.team, client.slot)
|
items = get_received_items(ctx, client.team, client.slot)
|
||||||
if len(items) > client.send_index:
|
if len(items) > client.send_index:
|
||||||
asyncio.create_task(send_msgs(client.socket, [['ReceivedItems', (client.send_index, tuplize_received_items(items)[client.send_index:])]]))
|
asyncio.create_task(send_msgs(client, [
|
||||||
|
['ReceivedItems', (client.send_index, tuplize_received_items(items)[client.send_index:])]]))
|
||||||
client.send_index = len(items)
|
client.send_index = len(items)
|
||||||
|
|
||||||
|
|
||||||
|
@ -267,7 +280,7 @@ def register_location_checks(ctx: Context, team: int, slot: int, locations):
|
||||||
for client in ctx.clients:
|
for client in ctx.clients:
|
||||||
if client.team == team and client.wants_item_notification:
|
if client.team == team and client.wants_item_notification:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
send_msgs(client.socket, [['ItemFound', (target_item, location, slot)]]))
|
send_msgs(client, [['ItemFound', (target_item, location, slot)]]))
|
||||||
ctx.location_checks[team, slot] |= set(locations)
|
ctx.location_checks[team, slot] |= set(locations)
|
||||||
send_new_items(ctx)
|
send_new_items(ctx)
|
||||||
|
|
||||||
|
@ -522,14 +535,14 @@ class ClientMessageProcessor(CommandProcessor):
|
||||||
|
|
||||||
async def process_client_cmd(ctx: Context, client: Client, cmd, args):
|
async def process_client_cmd(ctx: Context, client: Client, cmd, args):
|
||||||
if type(cmd) is not str:
|
if type(cmd) is not str:
|
||||||
await send_msgs(client.socket, [['InvalidCmd']])
|
await send_msgs(client, [['InvalidCmd']])
|
||||||
return
|
return
|
||||||
|
|
||||||
if cmd == 'Connect':
|
if cmd == 'Connect':
|
||||||
if not args or type(args) is not dict or \
|
if not args or type(args) is not dict or \
|
||||||
'password' not in args or type(args['password']) not in [str, type(None)] or \
|
'password' not in args or type(args['password']) not in [str, type(None)] or \
|
||||||
'rom' not in args or type(args['rom']) is not list:
|
'rom' not in args or type(args['rom']) is not list:
|
||||||
await send_msgs(client.socket, [['InvalidArguments', 'Connect']])
|
await send_msgs(client, [['InvalidArguments', 'Connect']])
|
||||||
return
|
return
|
||||||
|
|
||||||
errors = set()
|
errors = set()
|
||||||
|
@ -548,7 +561,7 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
|
||||||
client.slot = slot
|
client.slot = slot
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
await send_msgs(client.socket, [['ConnectionRefused', list(errors)]])
|
await send_msgs(client, [['ConnectionRefused', list(errors)]])
|
||||||
else:
|
else:
|
||||||
client.auth = True
|
client.auth = True
|
||||||
client.version = args.get('version', Client.version)
|
client.version = args.get('version', Client.version)
|
||||||
|
@ -559,7 +572,7 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
|
||||||
if items:
|
if items:
|
||||||
reply.append(['ReceivedItems', (0, tuplize_received_items(items))])
|
reply.append(['ReceivedItems', (0, tuplize_received_items(items))])
|
||||||
client.send_index = len(items)
|
client.send_index = len(items)
|
||||||
await send_msgs(client.socket, reply)
|
await send_msgs(client, reply)
|
||||||
await on_client_joined(ctx, client)
|
await on_client_joined(ctx, client)
|
||||||
|
|
||||||
if client.auth:
|
if client.auth:
|
||||||
|
@ -567,22 +580,22 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
|
||||||
items = get_received_items(ctx, client.team, client.slot)
|
items = get_received_items(ctx, client.team, client.slot)
|
||||||
if items:
|
if items:
|
||||||
client.send_index = len(items)
|
client.send_index = len(items)
|
||||||
await send_msgs(client.socket, [['ReceivedItems', (0, tuplize_received_items(items))]])
|
await send_msgs(client, [['ReceivedItems', (0, tuplize_received_items(items))]])
|
||||||
|
|
||||||
elif cmd == 'LocationChecks':
|
elif cmd == 'LocationChecks':
|
||||||
if type(args) is not list:
|
if type(args) is not list:
|
||||||
await send_msgs(client.socket, [['InvalidArguments', 'LocationChecks']])
|
await send_msgs(client, [['InvalidArguments', 'LocationChecks']])
|
||||||
return
|
return
|
||||||
register_location_checks(ctx, client.team, client.slot, args)
|
register_location_checks(ctx, client.team, client.slot, args)
|
||||||
|
|
||||||
elif cmd == 'LocationScouts':
|
elif cmd == 'LocationScouts':
|
||||||
if type(args) is not list:
|
if type(args) is not list:
|
||||||
await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']])
|
await send_msgs(client, [['InvalidArguments', 'LocationScouts']])
|
||||||
return
|
return
|
||||||
locs = []
|
locs = []
|
||||||
for location in args:
|
for location in args:
|
||||||
if type(location) is not int or 0 >= location > len(Regions.location_table):
|
if type(location) is not int or 0 >= location > len(Regions.location_table):
|
||||||
await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']])
|
await send_msgs(client, [['InvalidArguments', 'LocationScouts']])
|
||||||
return
|
return
|
||||||
loc_name = list(Regions.location_table.keys())[location - 1]
|
loc_name = list(Regions.location_table.keys())[location - 1]
|
||||||
target_item, target_player = ctx.locations[(Regions.location_table[loc_name][0], client.slot)]
|
target_item, target_player = ctx.locations[(Regions.location_table[loc_name][0], client.slot)]
|
||||||
|
@ -595,17 +608,17 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args):
|
||||||
locs.append([loc_name, location, target_item, target_player])
|
locs.append([loc_name, location, target_item, target_player])
|
||||||
|
|
||||||
# logging.info(f"{client.name} in team {client.team+1} scouted {', '.join([l[0] for l in locs])}")
|
# logging.info(f"{client.name} in team {client.team+1} scouted {', '.join([l[0] for l in locs])}")
|
||||||
await send_msgs(client.socket, [['LocationInfo', [l[1:] for l in locs]]])
|
await send_msgs(client, [['LocationInfo', [l[1:] for l in locs]]])
|
||||||
|
|
||||||
elif cmd == 'UpdateTags':
|
elif cmd == 'UpdateTags':
|
||||||
if not args or type(args) is not list:
|
if not args or type(args) is not list:
|
||||||
await send_msgs(client.socket, [['InvalidArguments', 'UpdateTags']])
|
await send_msgs(client, [['InvalidArguments', 'UpdateTags']])
|
||||||
return
|
return
|
||||||
client.tags = args
|
client.tags = args
|
||||||
|
|
||||||
if cmd == 'Say':
|
if cmd == 'Say':
|
||||||
if type(args) is not str or not args.isprintable():
|
if type(args) is not str or not args.isprintable():
|
||||||
await send_msgs(client.socket, [['InvalidArguments', 'Say']])
|
await send_msgs(client, [['InvalidArguments', 'Say']])
|
||||||
return
|
return
|
||||||
|
|
||||||
notify_all(ctx, client.name + ': ' + args)
|
notify_all(ctx, client.name + ': ' + args)
|
||||||
|
|
Loading…
Reference in New Issue