335 lines
13 KiB
Python
335 lines
13 KiB
Python
from typing import Any, TYPE_CHECKING, Callable
|
|
|
|
from fastapi.websockets import WebSocket
|
|
from sqlalchemy.orm.exc import StaleDataError
|
|
from sqlmodel import Session
|
|
|
|
from database.auth.crud import get_user_from_token
|
|
from database.room.crud import change_room_name, change_room_status, serialize_member, check_user_in_room, \
|
|
create_anonymous, create_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, \
|
|
disconnect_member, get_waiter, accept_waiter, leave_room, refuse_waiter
|
|
from database.room.models import Room, Member, MemberRead, Waiter
|
|
from services.websocket import Consumer
|
|
|
|
if TYPE_CHECKING:
|
|
from routes.room.routes import RoomManager
|
|
|
|
|
|
class RoomConsumer(Consumer):
|
|
|
|
def __init__(self, ws: WebSocket, room: Room | None, manager: "RoomManager", db: Session):
|
|
self.room = room
|
|
self.ws = ws
|
|
self.manager = manager
|
|
self.db = db
|
|
self.member = None
|
|
self.banned = False
|
|
|
|
|
|
async def connect(self):
|
|
await self.ws.accept()
|
|
if self.room is None:
|
|
await self.send_error("Salle introuvable", code=404)
|
|
await self.ws.close()
|
|
return False
|
|
return True
|
|
|
|
# WS Utilities
|
|
async def send(self, payload: Any | Callable):
|
|
if callable(payload):
|
|
payload = payload(self.member)
|
|
return await super().send(payload)
|
|
|
|
|
|
|
|
async def direct_send(self, type: str, payload: Any, code: int | None = None):
|
|
sending = {'type': type, "data": payload, }
|
|
if code != None:
|
|
sending["code"] = code
|
|
await self.ws.send_json({'type': type, "data": payload, })
|
|
|
|
async def send_to_admin(self, type: str, payload: Any, exclude: bool = False):
|
|
await self.manager.send_to_admin(self.room.id_code, {'type': type, "data": payload})
|
|
|
|
async def send_to(self, type: str, payload: Any, member_id, exclude: bool = False):
|
|
await self.manager.send_to(self.room.id_code, member_id, {'type': type, "data": payload})
|
|
|
|
async def broadcast(self, type, payload, exclude=False):
|
|
await self.manager.broadcast({"type": type, "data": payload}, self.room.id_code,
|
|
exclude=[exclude == True and self])
|
|
|
|
def add_to_group(self):
|
|
self.manager.add(self.room.id_code, self)
|
|
|
|
async def connect_self(self):
|
|
if isinstance(self.member, Member):
|
|
connect_member(self.member, self.db)
|
|
await self.manager.broadcast(lambda m: {"type": "connect", "data": {
|
|
"member": serialize_member(self.member, admin=m.is_admin, m2=m)}}, self.room.id_code, exclude=[self], conditions=[lambda m: m.member.waiting is not True])
|
|
# await self.broadcast(type="connect", payload={"member": serialize_member(self.member)}, exclude=True)
|
|
|
|
async def disconnect_self(self):
|
|
|
|
self.manager.remove(self.room.id_code, self)
|
|
|
|
if isinstance(self.member, Member):
|
|
|
|
try:
|
|
disconnect_member(self.member, self.db)
|
|
except StaleDataError:
|
|
return
|
|
if self.member.waiting is False:
|
|
await self.manager.broadcast(lambda m: {"type": "disconnect", "data": {
|
|
"member": serialize_member(self.member, admin=m.is_admin, m2=m)}}, self.room.id_code,
|
|
exclude=[self], conditions=[lambda m: m.member.waiting is not True])
|
|
# await self.broadcast(type="disconnect", payload={"member": serialize_member(self.member)})
|
|
else:
|
|
|
|
await self.send_to_admin(type="disconnect_waiter", payload={"waiter": serialize_member(self.member)})
|
|
|
|
async def loginMember(self, member: Member):
|
|
if member.room_id == self.room.id and member.waiting is False:
|
|
self.member = member
|
|
await self.connect_self()
|
|
self.add_to_group()
|
|
await self.direct_send(type="loggedIn",
|
|
payload={"member": {**serialize_member(self.member, private=True, m2=self.member)}})
|
|
|
|
async def send_error(self, msg, code: int = 400):
|
|
await self.direct_send(type="error", payload={"msg": msg, "code": code})
|
|
|
|
# Conditions
|
|
|
|
async def isAdminReceive(self):
|
|
is_admin = self.member is not None and self.member.is_admin == True
|
|
if not is_admin:
|
|
await self.direct_send(type="error", payload={"msg": "Vous n'avez pas la permission de faire ca"})
|
|
return False
|
|
return True
|
|
|
|
def isAdmin(self):
|
|
return self.member is not None and self.member.is_admin == True
|
|
|
|
async def isMember(self):
|
|
|
|
if self.member is None:
|
|
await self.send_error("Vous n'êtes connecté à aucune salle")
|
|
return self.member is not None and self.member.waiting == False
|
|
|
|
def isWaiter(self):
|
|
return self.member is not None and self.member.waiting == True
|
|
|
|
# Received Events
|
|
|
|
@Consumer.event('login')
|
|
async def login(self, token: str | None = None, reconnect_code: str | None = None):
|
|
if reconnect_code is None and token is None:
|
|
await self.direct_send(type="error", payload={"msg": "Veuillez spécifier une méthode de connection"})
|
|
return
|
|
|
|
if token is not None:
|
|
member = get_member_from_token(token, self.room.id, self.db)
|
|
|
|
if member == False:
|
|
await self.send_error("Token expired", code=422)
|
|
return
|
|
if member is None:
|
|
await self.send_error("Utilisateur introuvable dans cette salle", code=401)
|
|
return
|
|
|
|
elif reconnect_code is not None:
|
|
member = get_member_from_reconnect_code(
|
|
reconnect_code, self.room.id, db=self.db)
|
|
if member is None:
|
|
await self.send_error("Utilisateur introuvable dans cette salle", code=401)
|
|
return
|
|
await self.loginMember(member)
|
|
|
|
@Consumer.event('join')
|
|
async def join(self, token: str | None = None, username: str | None = None):
|
|
waiter = None
|
|
|
|
if token is not None:
|
|
user = get_user_from_token(token, self.db)
|
|
|
|
if user is None:
|
|
await self.send_error("Utilisateur introuvable")
|
|
return
|
|
if user is False:
|
|
await self.send_error("Token expired")
|
|
return
|
|
|
|
userInRoom = check_user_in_room(user.id, self.room.id, self.db)
|
|
|
|
if userInRoom is not None:
|
|
await self.loginMember(userInRoom)
|
|
return
|
|
|
|
waiter = create_member(
|
|
user=user, room=self.room, waiting=self.room.public is False, db=self.db)
|
|
|
|
elif username is not None:
|
|
username = username.strip()
|
|
if not (4 <= len(username) <= 15):
|
|
await self.send_error("Nom invalide (4-15 caractères)")
|
|
return
|
|
|
|
anonymous = create_anonymous(username, self.room, self.db)
|
|
if anonymous is None:
|
|
await self.send_error("Nom d'utilisateur invalide ou indisponible")
|
|
return
|
|
|
|
waiter = create_member(
|
|
anonymous=anonymous, room=self.room, waiting=self.room.public is False, db=self.db)
|
|
|
|
if waiter is None:
|
|
await self.send_error("Une erreur est survenue")
|
|
return
|
|
|
|
self.member = waiter
|
|
self.add_to_group()
|
|
|
|
if self.room.public is False:
|
|
await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member), "room": {
|
|
"name": self.room.name, "id_code": self.room.id_code}})
|
|
await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member)})
|
|
else:
|
|
await self.manager.broadcast(
|
|
lambda m: {"type": "joined", "data": {"member": serialize_member(self.member, admin=m.is_admin, m2=m)}},
|
|
self.room.id_code)
|
|
# await self.broadcast(type="joined", payload={"member": serialize_member(self.member)}, exclude=True)
|
|
await self.direct_send(type="accepted",
|
|
payload={"member": serialize_member(self.member, private=True, m2=self.member)})
|
|
|
|
@Consumer.event('accept', conditions=[isAdminReceive])
|
|
async def accept(self, waiter_id: str):
|
|
waiter = get_waiter(waiter_id, self.db)
|
|
if waiter is None:
|
|
await self.send_error("Utilisateur en liste d'attente introuvable")
|
|
return
|
|
member = accept_waiter(waiter, self.db)
|
|
await self.send_to(type="accepted", payload={"member": serialize_member(member, private=True, m2=member)},
|
|
member_id=waiter_id)
|
|
await self.manager.broadcast(
|
|
lambda m: {"type": "joined",
|
|
"data": {"member": serialize_member(member, admin=m.is_admin, m2=m)}},
|
|
self.room.id_code)
|
|
# await self.broadcast(type="joined", payload={"member": serialize_member(member)})
|
|
|
|
@Consumer.event('refuse', conditions=[isAdminReceive])
|
|
async def refuse(self, waiter_id: str):
|
|
waiter = get_waiter(waiter_id, self.db)
|
|
refuse_waiter(waiter, self.db)
|
|
await self.send_to(type="refused", payload={'waiter_id': waiter_id}, member_id=waiter_id)
|
|
await self.direct_send(type="successfullyRefused", payload={"waiter_id": waiter_id})
|
|
|
|
@Consumer.event('ping_room')
|
|
async def proom(self):
|
|
await self.broadcast(type='ping', payload={}, exclude=True)
|
|
|
|
@Consumer.event('sub_parcours')
|
|
async def sub_parcours(self, parcours_id: str):
|
|
if isinstance(self.member, Member) and self.member.waiting == False:
|
|
self.manager.add(parcours_id, self)
|
|
|
|
@Consumer.event('unsub_parcours')
|
|
async def unsub_parcours(self, parcours_id: str):
|
|
if isinstance(self.member, Member) and self.member.waiting == False:
|
|
self.manager.remove(parcours_id, self)
|
|
|
|
@Consumer.event('set_name', conditions=[isAdminReceive])
|
|
async def change_name(self, name: str):
|
|
if len(name) < 20:
|
|
self.room = change_room_name(self.room, name, self.db)
|
|
|
|
await self.broadcast(type="new_name", payload={"name": name})
|
|
|
|
return
|
|
await self.send_error('Nom trop long (max 20 character)')
|
|
|
|
@Consumer.event('set_visibility', conditions=[isAdminReceive])
|
|
async def change_visibility(self, public: bool):
|
|
self.room = change_room_status(self.room, public, self.db)
|
|
await self.broadcast(type="new_visibility", payload={"public": public})
|
|
|
|
async def isConnected(self):
|
|
if self.member is None:
|
|
await self.direct_send(type="error", payload={"msg": "Vous n'êtes connecté à aucune salle"})
|
|
return self.member is not None
|
|
|
|
@Consumer.event('leave', conditions=[isMember])
|
|
async def leave(self):
|
|
|
|
if self.member.is_admin is True:
|
|
await self.send_error("Vous ne pouvez pas quitter une salle dont vous êtes l'administrateur")
|
|
return
|
|
member_obj = serialize_member(self.member)
|
|
leave_room(self.member, self.db)
|
|
self.member = None
|
|
await self.direct_send(type="successfully_leaved", payload={})
|
|
await self.broadcast(type='leaved', payload={"member": member_obj})
|
|
self.member = None
|
|
|
|
@Consumer.event('ban', conditions=[isAdminReceive])
|
|
async def ban(self, member_id: str):
|
|
member = get_member(member_id, self.room.id, self.db)
|
|
if member == None:
|
|
await self.send_error("Utilisateur introuvable")
|
|
return
|
|
if member.is_admin is True:
|
|
await self.send_error("Vous ne pouvez pas bannir un administrateur")
|
|
return
|
|
|
|
member_serialized = serialize_member(member)
|
|
leave_room(member, self.db)
|
|
await self.send_to(type="banned", payload={}, member_id=member.id_code)
|
|
await self.broadcast(type="leaved", payload={"member": member_serialized})
|
|
|
|
# Sending Events
|
|
|
|
@Consumer.sending(["joined"], conditions=[isMember])
|
|
def joined(self, member: MemberRead):
|
|
|
|
if self.member.id_code == member.id_code:
|
|
raise ValueError("") # Prevent from sending event
|
|
if self.member.is_admin == False:
|
|
member.reconnect_code = ""
|
|
return {"member": member}
|
|
|
|
@Consumer.sending(['waiter', "disconnect_waiter"], conditions=[isAdmin])
|
|
def waiter(self, waiter: Waiter):
|
|
return {"waiter": waiter}
|
|
|
|
@Consumer.sending('accepted')
|
|
def accepted(self, member: MemberRead):
|
|
self.db.refresh(self.member)
|
|
return {"member": member}
|
|
|
|
@Consumer.sending("refused", conditions=[isWaiter])
|
|
def refused(self, waiter_id: str):
|
|
self.member = None
|
|
self.manager.remove(self.room.id, self)
|
|
return {"waiter_id": waiter_id}
|
|
|
|
# @Consumer.sending("banned", conditions=[isMember])
|
|
# def banned(self):
|
|
# self.member = None
|
|
# self.manager.remove(self.room.id_code, self)
|
|
# self.banned = True
|
|
# #await self.ws.close()
|
|
# return {}
|
|
|
|
@Consumer.sending('ping', conditions=[isMember])
|
|
def ping(self):
|
|
return {}
|
|
|
|
async def disconnect(self):
|
|
|
|
|
|
self.manager.remove(self.room.id_code, self)
|
|
|
|
for p in self.room.parcours:
|
|
self.manager.remove(p.id_code, self)
|
|
|
|
await self.disconnect_self()
|