from __future__ import annotations

import asyncio
import collections
import datetime
import functools
import logging
import multiprocessing
import pickle
import random
import socket
import threading
import time
import typing
import sys

import websockets
from pony.orm import commit, db_session, select

import Utils

from MultiServer import Context, server, auto_shutdown, ServerCommandProcessor, ClientMessageProcessor, load_server_cert
from Utils import restricted_loads, cache_argsless
from .locker import Locker
from .models import Command, GameDataPackage, Room, db


class CustomClientMessageProcessor(ClientMessageProcessor):
    ctx: WebHostContext

    def _cmd_video(self, platform: str, user: str):
        """Set a link for your name in the WebHostLib tracker pointing to a video stream.
        Currently, only YouTube and Twitch platforms are supported.
        """
        if platform.lower().startswith("t"):  # twitch
            self.ctx.video[self.client.team, self.client.slot] = "Twitch", user
            self.ctx.save()
            self.output(f"Registered Twitch Stream https://www.twitch.tv/{user}")
            return True
        elif platform.lower().startswith("y"):  # youtube
            self.ctx.video[self.client.team, self.client.slot] = "Youtube", user
            self.ctx.save()
            self.output(f"Registered Youtube Stream for {user}")
            return True
        return False


# inject
import MultiServer

MultiServer.client_message_processor = CustomClientMessageProcessor
del MultiServer


class DBCommandProcessor(ServerCommandProcessor):
    def output(self, text: str):
        self.ctx.logger.info(text)


class WebHostContext(Context):
    room_id: int

    def __init__(self, static_server_data: dict, logger: logging.Logger):
        # static server data is used during _load_game_data to load required data,
        # without needing to import worlds system, which takes quite a bit of memory
        self.static_server_data = static_server_data
        super(WebHostContext, self).__init__("", 0, "", "", 1,
                                             40, True, "enabled", "enabled",
                                             "enabled", 0, 2, logger=logger)
        del self.static_server_data
        self.main_loop = asyncio.get_running_loop()
        self.video = {}
        self.tags = ["AP", "WebHost"]

    def _load_game_data(self):
        for key, value in self.static_server_data.items():
            # NOTE: attributes are mutable and shared, so they will have to be copied before being modified
            setattr(self, key, value)
        self.non_hintable_names = collections.defaultdict(frozenset, self.non_hintable_names)

    def listen_to_db_commands(self):
        cmdprocessor = DBCommandProcessor(self)

        while not self.exit_event.is_set():
            with db_session:
                commands = select(command for command in Command if command.room.id == self.room_id)
                if commands:
                    for command in commands:
                        self.main_loop.call_soon_threadsafe(cmdprocessor, command.commandtext)
                        command.delete()
                    commit()
            time.sleep(5)

    @db_session
    def load(self, room_id: int):
        self.room_id = room_id
        room = Room.get(id=room_id)
        if room.last_port:
            self.port = room.last_port
        else:
            self.port = get_random_port()

        multidata = self.decompress(room.seed.multidata)
        game_data_packages = {}

        static_gamespackage = self.gamespackage  # this is shared across all rooms
        static_item_name_groups = self.item_name_groups
        static_location_name_groups = self.location_name_groups
        self.gamespackage = {"Archipelago": static_gamespackage.get("Archipelago", {})}  # this may be modified by _load
        self.item_name_groups = {"Archipelago": static_item_name_groups.get("Archipelago", {})}
        self.location_name_groups = {"Archipelago": static_location_name_groups.get("Archipelago", {})}

        for game in list(multidata.get("datapackage", {})):
            game_data = multidata["datapackage"][game]
            if "checksum" in game_data:
                if static_gamespackage.get(game, {}).get("checksum") == game_data["checksum"]:
                    # non-custom. remove from multidata and use static data
                    # games package could be dropped from static data once all rooms embed data package
                    del multidata["datapackage"][game]
                else:
                    row = GameDataPackage.get(checksum=game_data["checksum"])
                    if row:  # None if rolled on >= 0.3.9 but uploaded to <= 0.3.8. multidata should be complete
                        game_data_packages[game] = Utils.restricted_loads(row.data)
                        continue
                    else:
                        self.logger.warning(f"Did not find game_data_package for {game}: {game_data['checksum']}")
            self.gamespackage[game] = static_gamespackage.get(game, {})
            self.item_name_groups[game] = static_item_name_groups.get(game, {})
            self.location_name_groups[game] = static_location_name_groups.get(game, {})

        if not game_data_packages:
            # all static -> use the static dicts directly
            self.gamespackage = static_gamespackage
            self.item_name_groups = static_item_name_groups
            self.location_name_groups = static_location_name_groups
        return self._load(multidata, game_data_packages, True)

    @db_session
    def init_save(self, enabled: bool = True):
        self.saving = enabled
        if self.saving:
            savegame_data = Room.get(id=self.room_id).multisave
            if savegame_data:
                self.set_save(restricted_loads(Room.get(id=self.room_id).multisave))
            self._start_async_saving(atexit_save=False)
        threading.Thread(target=self.listen_to_db_commands, daemon=True).start()

    @db_session
    def _save(self, exit_save: bool = False) -> bool:
        room = Room.get(id=self.room_id)
        room.multisave = pickle.dumps(self.get_save())
        # saving only occurs on activity, so we can "abuse" this information to mark this as last_activity
        if not exit_save:  # we don't want to count a shutdown as activity, which would restart the server again
            room.last_activity = datetime.datetime.utcnow()
        return True

    def get_save(self) -> dict:
        d = super(WebHostContext, self).get_save()
        d["video"] = [(tuple(playerslot), videodata) for playerslot, videodata in self.video.items()]
        return d


def get_random_port():
    return random.randint(49152, 65535)


@cache_argsless
def get_static_server_data() -> dict:
    import worlds
    data = {
        "non_hintable_names": {
            world_name: world.hint_blacklist
            for world_name, world in worlds.AutoWorldRegister.world_types.items()
        },
        "gamespackage": {
            world_name: {
                key: value
                for key, value in game_package.items()
                if key not in ("item_name_groups", "location_name_groups")
            }
            for world_name, game_package in worlds.network_data_package["games"].items()
        },
        "item_name_groups": {
            world_name: world.item_name_groups
            for world_name, world in worlds.AutoWorldRegister.world_types.items()
        },
        "location_name_groups": {
            world_name: world.location_name_groups
            for world_name, world in worlds.AutoWorldRegister.world_types.items()
        },
    }

    return data


def set_up_logging(room_id) -> logging.Logger:
    import os
    # logger setup
    logger = logging.getLogger(f"RoomLogger {room_id}")

    # this *should* be empty, but just in case.
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)
        handler.close()

    file_handler = logging.FileHandler(
        os.path.join(Utils.user_path("logs"), f"{room_id}.txt"),
        "a",
        encoding="utf-8-sig")
    file_handler.setFormatter(logging.Formatter("[%(asctime)s]: %(message)s"))
    logger.setLevel(logging.INFO)
    logger.addHandler(file_handler)
    return logger


def run_server_process(name: str, ponyconfig: dict, static_server_data: dict,
                       cert_file: typing.Optional[str], cert_key_file: typing.Optional[str],
                       host: str, rooms_to_run: multiprocessing.Queue, rooms_shutting_down: multiprocessing.Queue):
    Utils.init_logging(name)
    try:
        import resource
    except ModuleNotFoundError:
        pass  # unix only module
    else:
        # Each Server is another file handle, so request as many as we can from the system
        file_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
        # set soft limit to hard limit
        resource.setrlimit(resource.RLIMIT_NOFILE, (file_limit, file_limit))
        del resource, file_limit

    # establish DB connection for multidata and multisave
    db.bind(**ponyconfig)
    db.generate_mapping(check_tables=False)

    if "worlds" in sys.modules:
        raise Exception("Worlds system should not be loaded in the custom server.")

    import gc
    ssl_context = load_server_cert(cert_file, cert_key_file) if cert_file else None
    del cert_file, cert_key_file, ponyconfig
    gc.collect()  # free intermediate objects used during setup

    loop = asyncio.get_event_loop()

    async def start_room(room_id):
        with Locker(f"RoomLocker {room_id}"):
            try:
                logger = set_up_logging(room_id)
                ctx = WebHostContext(static_server_data, logger)
                ctx.load(room_id)
                ctx.init_save()
                try:
                    ctx.server = websockets.serve(
                        functools.partial(server, ctx=ctx), ctx.host, ctx.port, ssl=ssl_context)

                    await ctx.server
                except OSError:  # likely port in use
                    ctx.server = websockets.serve(
                        functools.partial(server, ctx=ctx), ctx.host, 0, ssl=ssl_context)

                    await ctx.server
                port = 0
                for wssocket in ctx.server.ws_server.sockets:
                    socketname = wssocket.getsockname()
                    if wssocket.family == socket.AF_INET6:
                        # Prefer IPv4, as most users seem to not have working ipv6 support
                        if not port:
                            port = socketname[1]
                    elif wssocket.family == socket.AF_INET:
                        port = socketname[1]
                if port:
                    ctx.logger.info(f'Hosting game at {host}:{port}')
                    with db_session:
                        room = Room.get(id=ctx.room_id)
                        room.last_port = port
                else:
                    ctx.logger.exception("Could not determine port. Likely hosting failure.")
                with db_session:
                    ctx.auto_shutdown = Room.get(id=room_id).timeout
                if ctx.saving:
                    setattr(asyncio.current_task(), "save", lambda: ctx._save(True))
                ctx.shutdown_task = asyncio.create_task(auto_shutdown(ctx, []))
                await ctx.shutdown_task

            except (KeyboardInterrupt, SystemExit):
                if ctx.saving:
                    ctx._save()
                    setattr(asyncio.current_task(), "save", None)
            except Exception as e:
                with db_session:
                    room = Room.get(id=room_id)
                    room.last_port = -1
                logger.exception(e)
                raise
            else:
                if ctx.saving:
                    ctx._save()
                    setattr(asyncio.current_task(), "save", None)
            finally:
                try:
                    ctx.save_dirty = False  # make sure the saving thread does not write to DB after final wakeup
                    ctx.exit_event.set()  # make sure the saving thread stops at some point
                    # NOTE: async saving should probably be an async task and could be merged with shutdown_task
                    with (db_session):
                        # ensure the Room does not spin up again on its own, minute of safety buffer
                        room = Room.get(id=room_id)
                        room.last_activity = datetime.datetime.utcnow() - \
                                             datetime.timedelta(minutes=1, seconds=room.timeout)
                    logging.info(f"Shutting down room {room_id} on {name}.")
                finally:
                    await asyncio.sleep(5)
                    rooms_shutting_down.put(room_id)

    class Starter(threading.Thread):
        _tasks: typing.List[asyncio.Future]

        def __init__(self):
            super().__init__()
            self._tasks = []

        def _done(self, task: asyncio.Future):
            self._tasks.remove(task)
            task.result()

        def run(self):
            while 1:
                next_room = rooms_to_run.get(block=True,  timeout=None)
                gc.collect(0)
                task = asyncio.run_coroutine_threadsafe(start_room(next_room), loop)
                self._tasks.append(task)
                task.add_done_callback(self._done)
                logging.info(f"Starting room {next_room} on {name}.")
                del task  # delete reference to task object

    starter = Starter()
    starter.daemon = True
    starter.start()
    try:
        loop.run_forever()
    finally:
        # save all tasks that want to be saved during shutdown
        for task in asyncio.all_tasks(loop):
            save: typing.Optional[typing.Callable[[], typing.Any]] = getattr(task, "save", None)
            if save:
                save()