generateur_v3/backend/api/routes/room/routes.py

556 lines
22 KiB
Python

from database.room.crud import create_room_db, delete_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, disconnect_member, create_anonymous_member, create_anonymous_waiter, create_user_member, create_user_waiter, get_waiter, accept_waiter, leave_room, refuse_waiter, check_room
from pydantic.error_wrappers import ValidationError
from pydantic import validate_arguments
from sqlmodel import select
from pydantic import BaseModel
from typing import Any, Callable, Dict, List, Optional
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status, Query
from config import ALGORITHM, SECRET_KEY
from database.auth.crud import get_user_from_clientId_db
from database.auth.models import User
from database.db import get_session
from sqlmodel import Session
from database.room.models import Anonymous, MemberSerializer, MemberWithRelations, MemberRead, Room, RoomCreate, RoomRead, Member
from services.auth import get_current_user_optional
from fastapi.exceptions import HTTPException
from jose import jwt, exceptions
import inspect
router = APIRouter(tags=["room"])
class RoomAndMember(BaseModel):
room: RoomRead
member: MemberRead
class Waiter(BaseModel):
username: str
waiter_id: str
def serialize_member(member: Member) -> MemberRead | Waiter:
member_obj = member.user or member.anonymous
if member.waiting == False:
return MemberRead(username=member_obj.username, reconnect_code=getattr(member_obj, "reconnect_code", ""), isUser=member.user_id != None, isAdmin=member.is_admin, id_code=member.id_code)
if member.waiting == True:
return Waiter(username=member_obj.username, waiter_id=member.id_code)
@router.post('/room', response_model=RoomAndMember)
def create_room(room: RoomCreate, username: Optional[str] = Query(default=None, max_length=20), user: User | None = Depends(get_current_user_optional), db: Session = Depends(get_session)):
room_obj = create_room_db(room=room, user=user, username=username, db=db)
return {'room': room_obj['room'], "member": serialize_member(room_obj['member'])}
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, List[WebSocket]] = {}
def add(self, group: str, ws: WebSocket):
if group not in self.active_connections:
self.active_connections[group] = []
if ws not in self.active_connections[group]:
self.active_connections[group].append(ws)
def remove(self, group: str, ws: WebSocket):
if group in self.active_connections:
if ws in self.active_connections[group]:
self.active_connections[group].remove(ws)
async def broadcast(self, message, group: str, exclude: list[WebSocket] = []):
if group in self.active_connections:
for connection in self.active_connections[group]:
if connection not in exclude:
await connection.send_json(message)
def make_event_decorator(eventsDict):
def _(name: str | List, conditions: List[Callable | bool] = []):
def add_event(func):
model = validate_arguments(func).model
if type(name) == str:
eventsDict[name] = {"func": func,
"conditions": conditions, "model": model}
if type(name) == list:
for n in name:
eventsDict[n] = {"func": func,
"conditions": conditions, "model": model}
return func
return add_event
return _
class Event(BaseModel):
func: Callable
conditions: List[Callable | bool]
model: BaseModel
def dict_model(model: BaseModel, exclude: List[str]):
value = {}
for n, f in model:
if n not in exclude:
value[n] = f
return value
def dict_all(obj: Any):
if isinstance(obj, dict):
value = {}
for k, v in obj.items():
if isinstance(v, dict):
v = dict_all(v)
value[k] = dict(v)
elif isinstance(v, BaseModel):
value[k] = dict(v)
else:
try:
value[k] = dict(v)
except:
value[k] = v
return value
return dict(obj)
class Consumer:
events: Dict[str, Event] = {}
sendings: Dict[str, Any] = {}
event = make_event_decorator(events)
sending = make_event_decorator(sendings)
def __init__(self, ws: WebSocket):
self.ws: WebSocket = ws
#self.events: Dict[str, Callable] = {}
async def connect(self):
pass
async def validation_error_handler(self, e: ValidationError):
errors = e.errors()
await self.ws.send_json({"type": "error", "data": {"detail": [{ers['loc'][-1]: ers['msg']} for ers in errors]}})
async def send(self, payload):
type = payload.get('type', None)
print('TYPE', type, self.member)
if type is not None:
event_wrapper = self.sendings.get(type, None)
if event_wrapper is not None:
handler = event_wrapper.get('func')
conditions = event_wrapper.get('conditions')
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
if handler is not None and is_valid:
model = event_wrapper.get("model")
data = payload.get('data') or {}
try:
validated_payload = model(self=self, **data)
except ValidationError as e:
print("ERROR", e)
await self.ws.send_json({"type": "error", "data": {"msg": "Oops there was an error"}})
return
validated_payload = dict_model(validated_payload,
exclude=["v__duplicate_kwargs", "args", 'kwargs', "self"])
try:
parsed_payload = handler(
self, **validated_payload)
await self.ws.send_json({'type': type, "data": dict_all(parsed_payload)})
return
except Exception as e:
return
return
#print('pls')
await self.ws.send_json(payload)
#print('sent')
async def receive(self, data):
event = data.get('type', None)
if event is not None:
event_wrapper = self.events.get(event, None)
if event_wrapper is not None:
handler = event_wrapper.get('func')
conditions = event_wrapper.get('conditions')
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
if handler is not None and is_valid:
model = event_wrapper.get("model")
payload = data.get('data') or {}
try:
validated_payload = model(self=self, **payload)
except ValidationError as e:
await self.validation_error_handler(e)
return
await handler(**{k: v for k, v in validated_payload.dict().items() if k not in ["v__duplicate_kwargs", "args", 'kwargs']})
async def disconnect(self):
pass
async def run(self):
await self.connect()
try:
while True:
data = await self.ws.receive_json()
await self.receive(data)
except WebSocketDisconnect:
await self.disconnect()
class ConsumerManager:
def __init__(self):
self.active_connections: Dict[str, List[Consumer]] = {}
def add(self, group: str, ws: Consumer):
if group not in self.active_connections:
self.active_connections[group] = []
#print("adding", ws, self.active_connections[group])
if ws not in self.active_connections[group]:
#print('ACTUALLY ADDING')
self.active_connections[group].append(ws)
def remove(self, group: str, ws: Consumer):
if group in self.active_connections:
if ws in self.active_connections[group]:
self.active_connections[group].remove(ws)
async def broadcast(self, message, group: str, exclude: list[Consumer] = []):
if group in self.active_connections:
#print(self.active_connections[group], exclude)
for connection in list(set(self.active_connections[group])):
if connection not in exclude:
#print('SEND TO', connection, message)
await connection.send(message)
class Token(BaseModel):
token: str
def get_user_from_token(token: str, db: Session):
try:
decoded = jwt.decode(token=token, key=SECRET_KEY,
algorithms=[ALGORITHM])
except exceptions.ExpiredSignatureError:
return False
clientId = decoded.get('sub')
return get_user_from_clientId_db(clientId=clientId, db=db)
def check_same_member(member: Member, memberR: MemberRead):
return (member.user is not None and memberR.username == member.user.username) or (member.anonymous and memberR.reconnect_code == member.anonymous.reconnect_code)
class RoomConsumer(Consumer):
def __init__(self, ws: WebSocket, room: Room, manager: "RoomManager", db: Session):
self.room = room
self.ws = ws
self.manager = manager
self.db = db
self.member = None
# WS Utilities
async def connect(self):
await self.ws.accept()
async def direct_send(self, type: str, payload: Any):
await self.ws.send_json({'type': type, "data": payload})
async def send_to_admin(self, type: str, payload: Any, exclude: bool = False):
await self.manager.send_to_admin(self.room.id, {'type': type, "data": payload})
async def send_to(self, type: str, payload: Any,member_id, exclude: bool = False):
await self.manager.send_to(self.room.id, member_id,{'type': type, "data": payload})
async def broadcast(self, type, payload, exclude=False):
await self.manager.broadcast({"type": type, "data": payload}, self.room.id, exclude=[exclude == True and self])
def add_to_group(self):
self.manager.add(self.room.id, self)
async def connect_self(self):
if isinstance(self.member, Member):
connect_member(self.member, self.db)
await self.broadcast(type="connect", payload={"member": serialize_member(self.member).dict()}, exclude=True)
async def disconnect_self(self):
if isinstance(self.member, Member):
disconnect_member(self.member, self.db)
if self.member.waiting is False:
await self.broadcast(type="disconnect", payload={"member": serialize_member(self.member).dict()})
else:
await self.broadcast(type="disconnect_waiter", payload={"waiter": serialize_member(self.member).dict()})
# DB Utilities
# Received Events
@Consumer.event('login')
async def login(self, token: str | None = None, reconnect_code: str | None = None):
if token is not None:
member = get_member_from_token(token, self.room.id, self.db)
if member == False:
await self.direct_send(type="error", payload={"msg": "Token expired"})
return
if member is None:
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable dans cette salle"})
return
self.member = member
await self.connect_self()
self.add_to_group()
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()})
elif reconnect_code is not None:
member = get_member_from_reconnect_code(
reconnect_code, self.room.id, db=self.db)
if member is None:
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable dans cette salle"})
return
self.member = member
await self.connect_self()
self.add_to_group()
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()})
if reconnect_code is None and token is None:
await self.direct_send(type="error", payload={"msg": "Veuillez spécifier une méthode de connection"})
@Consumer.event('join')
async def join(self, token: str | None = None, username: str | None = None):
if self.room.public == False:
if token is not None:
user = get_user_from_token(token, self.db)
if user is None:
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"})
return
if user is False:
await self.direct_send(type="error", payload={"msg": "Token expired"})
return
waiter = create_user_waiter(user, self.room, self.db)
if waiter.waiting is False:
self.member = waiter
# await self.connect_self()
self.add_to_group()
await self.connect_self()
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()})
return
self.member = waiter
self.add_to_group()
await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member).dict()})
await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member).dict()})
elif username is not None:
waiter = create_anonymous_waiter(username, self.room, self.db)
if waiter is None:
await self.direct_send(type="error", payload={"msg": "Nom d'utilisateur invalide ou indisponible"})
return
self.member = waiter
self.add_to_group()
await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member).dict()})
await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member).dict()})
else:
if token is not None:
user = get_user_from_token(token, self.db)
if user is None:
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"})
return
if user is False:
await self.direct_send(type="error", payload={"msg": "Token expired"})
return
member = create_user_member(user, self.room, self.db)
if member is None:
return
self.member = member
self.add_to_group()
await self.broadcast(type="joined", payload={"member": serialize_member(self.member).dict()}, exclude=True)
await self.direct_send(type="accepted", payload={"member": serialize_member(self.member).dict()})
elif username is not None:
member = create_anonymous_member(username, self.room, self.db)
if member is None:
await self.direct_send(type="error", data={"msg": "Nom d'utilisateur indisponible"})
return
self.member = member
self.add_to_group()
await self.broadcast(type="joined", payload={"member": serialize_member(self.member).dict()}, exclude=True)
await self.direct_send(type="accepted", payload={"member": serialize_member(self.member).dict()})
async def isAdminReceive(self):
is_admin = self.member is not None and self.member.is_admin == True
if not is_admin:
await self.direct_send(type="error", payload={"msg": "Vous n'avez pas la permission de faire ca"})
return False
return True
def isAdmin(self):
return self.member is not None and self.member.is_admin == True
@Consumer.event('accept', conditions=[isAdminReceive])
async def accept(self, waiter_id: str):
waiter = get_waiter(waiter_id, self.db)
if waiter is None:
await self.direct_send(type="error", payload={'msg': "Utilisateur introuvable"})
return
member = accept_waiter(waiter, self.db)
await self.send_to(type="accepted", payload={"member": serialize_member(member).dict()}, member_id=waiter_id)
await self.broadcast(type="joined", payload={"member": serialize_member(member).dict()})
@Consumer.event('refuse', conditions=[isAdminReceive])
async def accept(self, waiter_id: str):
waiter = get_waiter(waiter_id, self.db)
member = refuse_waiter(waiter, self.db)
await self.send_to(type="refused", payload={'waiter_id': waiter_id}, member_id=waiter_id)
# await self.broadcast(type="joined", payload={"member": serialize_member(member).dict()})
@Consumer.event('ping_room')
async def proom(self):
await self.broadcast(type='ping', payload={}, exclude=True)
async def hasRoom(self):
if self.member is None:
await self.direct_send(type="error", payload={"msg": "Vous n'êtes connecté à aucune salle"})
return self.member
@Consumer.event('leave', conditions=[hasRoom])
async def leave(self):
if self.member.is_admin is True:
await self.direct_send(type="error", payload={"msg": "Vous ne pouvez pas quitter une salle dont vous êtes l'administrateur"})
return
member_obj = serialize_member(self.member).dict()
leave_room(self.member, self.db)
await self.direct_send(type="successfully_leaved", payload = {})
await self.broadcast(type='leaved', payload ={"member": member_obj})
self.member = None
@Consumer.event('ban', conditions=[isAdminReceive])
async def ban(self, member_id: str):
member = get_member(member_id, self.room.id, self.db)
if member == None:
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"})
return
if member.is_admin is True:
await self.direct_send(type="error", payload={"msg": "Vous ne pouvez pas bannir un administrateur"})
return
member_serialized = serialize_member(member)
leave_room(member, self.db)
await self.send_to(type="banned", payload={}, member_id=member.id_code)
await self.broadcast(type="leaved", payload = {"member": member_serialized.dict()})
def isMember(self):
return self.member is not None and self.member.waiting == False
# Sending Events
@Consumer.sending("joined", conditions=[isMember])
def joined(self, member: MemberRead):
if self.member.id_code == member.id_code:
raise ValueError("")
if self.member.is_admin == False:
member.reconnect_code = ""
return {"member": member}
@Consumer.sending(['waiter', "disconnect_waiter"], conditions=[isAdmin])
def waiter(self, waiter: Waiter):
return {"waiter": waiter}
@Consumer.sending(['connect', "disconnect", ""])
def connect_event(self, member: MemberRead):
if not self.member.is_admin:
member.reconnect_code = ""
return {"member": member}
@Consumer.sending('disconnect')
def disconnect_event(self, member: MemberRead):
if not self.member.is_admin:
member.reconnect_code = ""
return {"member": member}
@Consumer.sending('disconnect_waiter', conditions=[isAdmin])
def disconnect_event(self, waiter: Waiter):
return {"waiter": waiter}
def isWaiter(self):
return self.member is not None and self.member.waiting == True
@Consumer.sending("refused", conditions=[isWaiter])
def refused(self, waiter_id: str):
self.member = None
self.manager.remove(self.room.id, self)
return {"waiter_id": waiter_id}
@Consumer.sending("banned", conditions=[isMember])
def banned(self):
self.member = None
self.manager.remove(self.room.id, self)
return {}
@Consumer.sending('ping', conditions=[isMember])
def ping(self):
return {}
async def disconnect(self):
print("DISCONNECT", self.member)
self.manager.remove(self.room.id, self)
await self.disconnect_self()
class RoomManager:
def __init__(self):
self.active_connections: Dict[str, List[RoomConsumer]] = {}
def add(self, group: str, ws: RoomConsumer):
if group not in self.active_connections:
self.active_connections[group] = []
#print("adding", ws, self.active_connections[group])
if ws not in self.active_connections[group]:
#print('ACTUALLY ADDING')
self.active_connections[group].append(ws)
def remove(self, group: str, ws: RoomConsumer):
if group in self.active_connections:
if ws in self.active_connections[group]:
self.active_connections[group].remove(ws)
async def broadcast(self, message, group: str, exclude: list[RoomConsumer] = []):
if group in self.active_connections:
#print(self.active_connections[group], exclude)
for connection in list(set(self.active_connections[group])):
if connection not in exclude:
#print('SEND TO', connection, message)
await connection.send(message)
async def send_to(self, group, id_code, msg):
if group in self.active_connections:
members = [c for c in self.active_connections[group] if c.member.id_code == id_code]
for m in members:
await m.send(msg)
async def send_to_admin(self, group, msg):
if group in self.active_connections:
members = [c for c in self.active_connections[group] if c.member.is_admin == True]
for m in members:
await m.send(msg)
manager = RoomManager()
@router.websocket('/ws/room/{room_id}')
async def room_ws(ws: WebSocket, room: Room | None = Depends(check_room), db: Session = Depends(get_session)):
if room is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail='Room not found')
consumer = RoomConsumer(ws=ws, room=room, manager=manager, db=db)
await consumer.run()