from typing import Any, TYPE_CHECKING, Callable from fastapi.websockets import WebSocket from sqlalchemy.orm.exc import StaleDataError from sqlmodel import Session from database.auth.crud import get_user_from_token from database.room.crud import change_room_name, change_room_status, serialize_member, check_user_in_room, \ create_anonymous, create_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, \ disconnect_member, get_waiter, accept_waiter, leave_room, refuse_waiter from database.room.models import Room, Member, MemberRead, Waiter, Challenger from services.websocket import Consumer if TYPE_CHECKING: from routes.room.routes import RoomManager class RoomConsumer(Consumer): def __init__(self, ws: WebSocket, room: Room | None, manager: "RoomManager", db: Session): self.room = room self.ws = ws self.manager = manager self.db = db self.member = None self.banned = False async def connect(self): await self.ws.accept() if self.room is None: await self.send_error("Salle introuvable", code=404) await self.ws.close() return False return True # WS Utilities async def send(self, payload: Any | Callable): if callable(payload): payload = payload(self.member) return await super().send(payload) async def direct_send(self, type: str, payload: Any, code: int | None = None): sending = {'type': type, "data": payload, } if code != None: sending["code"] = code await self.ws.send_json({'type': type, "data": payload, }) async def send_to_admin(self, type: str, payload: Any, exclude: bool = False): await self.manager.send_to_admin(self.room.id_code, {'type': type, "data": payload}) async def send_to(self, type: str, payload: Any, member_id, exclude: bool = False): await self.manager.send_to(self.room.id_code, member_id, {'type': type, "data": payload}) async def broadcast(self, type, payload, exclude=False): await self.manager.broadcast({"type": type, "data": payload}, self.room.id_code, exclude=[exclude == True and self]) def add_to_group(self): self.manager.add(self.room.id_code, self) async def connect_self(self): if isinstance(self.member, Member): connect_member(self.member, self.db) await self.manager.broadcast(lambda m: {"type": "connect", "data": { "member": serialize_member(self.member, admin=m.is_admin, m2=m)}}, self.room.id_code, exclude=[self], conditions=[lambda m: m.member.waiting is not True]) # await self.broadcast(type="connect", payload={"member": serialize_member(self.member)}, exclude=True) async def disconnect_self(self): print(self.manager.active_connections[self.room.id_code]) self.manager.remove(self.room.id_code, self) print(self.manager.active_connections[self.room.id_code]) if isinstance(self.member, Member): print('MEMBER', self.member) try: disconnect_member(self.member, self.db) except StaleDataError: return if self.member.waiting is False: await self.manager.broadcast(lambda m: {"type": "disconnect", "data": { "member": serialize_member(self.member, admin=m.is_admin, m2=m)}}, self.room.id_code, exclude=[self], conditions=[lambda m: m.member.waiting is not True]) # await self.broadcast(type="disconnect", payload={"member": serialize_member(self.member)}) else: await self.send_to_admin(type="disconnect_waiter", payload={"waiter": serialize_member(self.member)}) async def loginMember(self, member: Member): if member.room_id == self.room.id and member.waiting == False: self.member = member await self.connect_self() self.add_to_group() await self.direct_send(type="loggedIn", payload={"member": {**serialize_member(self.member, private=True, m2=self.member)}}) async def send_error(self, msg, code: int = 400): await self.direct_send(type="error", payload={"msg": msg, "code": code}) # Conditions async def isAdminReceive(self): is_admin = self.member is not None and self.member.is_admin == True if not is_admin: await self.direct_send(type="error", payload={"msg": "Vous n'avez pas la permission de faire ca"}) return False return True def isAdmin(self): return self.member is not None and self.member.is_admin == True async def isMember(self): print('S', self.member, self.ws, self.ws.state, self.ws.application_state.__str__()) if self.member is None: await self.send_error("Vous n'êtes connecté à aucune salle") return self.member is not None and self.member.waiting == False def isWaiter(self): return self.member is not None and self.member.waiting == True # Received Events @Consumer.event('login') async def login(self, token: str | None = None, reconnect_code: str | None = None): if reconnect_code is None and token is None: await self.direct_send(type="error", payload={"msg": "Veuillez spécifier une méthode de connection"}) return print("login", token) if token is not None: member = get_member_from_token(token, self.room.id, self.db) print('MEMBER', member) if member == False: await self.send_error("Token expired", code=422) return if member is None: await self.send_error("Utilisateur introuvable dans cette salle", code=401) return elif reconnect_code is not None: member = get_member_from_reconnect_code( reconnect_code, self.room.id, db=self.db) if member is None: await self.send_error("Utilisateur introuvable dans cette salle", code=401) return await self.loginMember(member) @Consumer.event('join') async def join(self, token: str | None = None, username: str | None = None): if token is not None: user = get_user_from_token(token, self.db) if user is None: await self.send_error("Utilisateur introuvable") return if user is False: await self.send_error("Token expired") return userInRoom = check_user_in_room(user.id, self.room.id, self.db) if userInRoom is not None: await self.loginMember(userInRoom) return waiter = create_member( user=user, room=self.room, waiting=self.room.public is False, db=self.db) elif username is not None: if len(username) < 4 or len(username) > 15: await self.send_error("Nom d'utilisateur invalide ou indisponible") return anonymous = create_anonymous(username, self.room, self.db) if anonymous is None: await self.send_error("Nom d'utilisateur invalide ou indisponible") return waiter = create_member( anonymous=anonymous, room=self.room, waiting=self.room.public is False, db=self.db) self.member = waiter self.add_to_group() if self.room.public is False: await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member), "room": { "name": self.room.name, "id_code": self.room.id_code}}) await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member)}) else: await self.manager.broadcast( lambda m: {"type": "joined", "data": {"member": serialize_member(self.member, admin=m.is_admin, m2=m)}}, self.room.id_code) # await self.broadcast(type="joined", payload={"member": serialize_member(self.member)}, exclude=True) await self.direct_send(type="accepted", payload={"member": serialize_member(self.member, private=True, m2=self.member)}) @Consumer.event('accept', conditions=[isAdminReceive]) async def accept(self, waiter_id: str): waiter = get_waiter(waiter_id, self.db) if waiter is None: await self.send_error("Utilisateur en liste d'attente introuvable") return member = accept_waiter(waiter, self.db) await self.send_to(type="accepted", payload={"member": serialize_member(member, private=True, m2=member)}, member_id=waiter_id) await self.manager.broadcast( lambda m: {"type": "joined", "data": {"member": serialize_member(member, admin=m.is_admin, m2=m)}}, self.room.id_code) # await self.broadcast(type="joined", payload={"member": serialize_member(member)}) @Consumer.event('refuse', conditions=[isAdminReceive]) async def refuse(self, waiter_id: str): waiter = get_waiter(waiter_id, self.db) refuse_waiter(waiter, self.db) await self.send_to(type="refused", payload={'waiter_id': waiter_id}, member_id=waiter_id) await self.direct_send(type="successfullyRefused", payload={"waiter_id": waiter_id}) @Consumer.event('ping_room') async def proom(self): await self.broadcast(type='ping', payload={}, exclude=True) @Consumer.event('sub_parcours') async def sub_parcours(self, parcours_id: str): if isinstance(self.member, Member) and self.member.waiting == False: self.manager.add(parcours_id, self) @Consumer.event('unsub_parcours') async def unsub_parcours(self, parcours_id: str): if isinstance(self.member, Member) and self.member.waiting == False: self.manager.remove(parcours_id, self) @Consumer.event('set_name', conditions=[isAdminReceive]) async def change_name(self, name: str): if len(name) < 20: self.room = change_room_name(self.room, name, self.db) print('SENDING') await self.broadcast(type="new_name", payload={"name": name}) return await self.send_error('Nom trop long (max 20 character)') @Consumer.event('set_visibility', conditions=[isAdminReceive]) async def change_visibility(self, public: bool): self.room = change_room_status(self.room, public, self.db) await self.broadcast(type="new_visibility", payload={"public": public}) async def isConnected(self): if self.member is None: await self.direct_send(type="error", payload={"msg": "Vous n'êtes connecté à aucune salle"}) return self.member is not None @Consumer.event('leave', conditions=[isMember]) async def leave(self): print('LEAVED', self.member, isinstance(self.member, Member), isinstance(self.member, Challenger)) if self.member.is_admin is True: await self.send_error("Vous ne pouvez pas quitter une salle dont vous êtes l'administrateur") return member_obj = serialize_member(self.member) leave_room(self.member, self.db) self.member = None await self.direct_send(type="successfully_leaved", payload={}) await self.broadcast(type='leaved', payload={"member": member_obj}) self.member = None @Consumer.event('ban', conditions=[isAdminReceive]) async def ban(self, member_id: str): member = get_member(member_id, self.room.id, self.db) if member == None: await self.send_error("Utilisateur introuvable") return if member.is_admin is True: await self.send_error("Vous ne pouvez pas bannir un administrateur") return member_serialized = serialize_member(member) leave_room(member, self.db) await self.send_to(type="banned", payload={}, member_id=member.id_code) await self.broadcast(type="leaved", payload={"member": member_serialized}) # Sending Events @Consumer.sending(["joined"], conditions=[isMember]) def joined(self, member: MemberRead): if self.member.id_code == member.id_code: raise ValueError("") # Prevent from sending event if self.member.is_admin == False: member.reconnect_code = "" return {"member": member} @Consumer.sending(['waiter', "disconnect_waiter"], conditions=[isAdmin]) def waiter(self, waiter: Waiter): return {"waiter": waiter} @Consumer.sending('accepted') def accepted(self, member: MemberRead): self.db.refresh(self.member) return {"member": member} @Consumer.sending("refused", conditions=[isWaiter]) def refused(self, waiter_id: str): self.member = None self.manager.remove(self.room.id, self) return {"waiter_id": waiter_id} # @Consumer.sending("banned", conditions=[isMember]) # def banned(self): # self.member = None # self.manager.remove(self.room.id_code, self) # self.banned = True # #await self.ws.close() # return {} @Consumer.sending('ping', conditions=[isMember]) def ping(self): return {} async def disconnect(self): print('DISCONNECTED', self.member) print(self.manager.active_connections[self.room.id_code]) self.manager.remove(self.room.id_code, self) for p in self.room.parcours: self.manager.remove(p.id_code, self) await self.disconnect_self()