diff --git a/backend/api/database/room/crud.py b/backend/api/database/room/crud.py index 8b5f68a..79708d7 100644 --- a/backend/api/database/room/crud.py +++ b/backend/api/database/room/crud.py @@ -1,37 +1,40 @@ from fastapi import Depends from sqlmodel import Session, select, col from database.db import get_session -from database.room.models import Anonymous, Member, Room, RoomCreate +from database.room.models import Anonymous, Member, Room, RoomCreate, Waiter, MemberRead from database.auth.models import User from services.database import generate_unique_code from database.auth.crud import get_user_from_token def check_room(room_id: str, db: Session = Depends(get_session)): - room = db.exec(select(Room).where(Room.id_code == room_id)).first() - return room + room = db.exec(select(Room).where(Room.id_code == room_id)).first() + return room -def create_room_db(*,room: RoomCreate, user: User | None = None, username: str | None = None, db: Session): - id_code = generate_unique_code(Room,s=db) - member_id = generate_unique_code(Member, s=db) - room_obj = Room(**room.dict(exclude_unset=True), id_code=id_code) - if user is not None: - member = Member(user_id=user.id, room=room_obj, is_admin=True, id_code=member_id) - db.add(member) - db.commit() - db.refresh(member) - if username is not None: - reconnect_code = generate_unique_code(Anonymous, s=db, field_name='reconnect_code') - anonymous = Anonymous(username=username, reconnect_code=reconnect_code) - member = Member(anonymous=anonymous, room=room_obj, is_admin=True, id_code=member_id) - db.add(member) - db.commit() - db.refresh(member) - if username is None and user is None: - raise ValueError('Username or user required') - - return {"room": room_obj, "member": member} +def create_room_db(*, room: RoomCreate, user: User | None = None, username: str | None = None, db: Session): + id_code = generate_unique_code(Room, s=db) + member_id = generate_unique_code(Member, s=db) + room_obj = Room(**room.dict(exclude_unset=True), id_code=id_code) + if user is not None: + member = Member(user_id=user.id, room=room_obj, + is_admin=True, id_code=member_id) + db.add(member) + db.commit() + db.refresh(member) + if username is not None: + reconnect_code = generate_unique_code( + Anonymous, s=db, field_name='reconnect_code') + anonymous = Anonymous(username=username, reconnect_code=reconnect_code) + member = Member(anonymous=anonymous, room=room_obj, + is_admin=True, id_code=member_id) + db.add(member) + db.commit() + db.refresh(member) + if username is None and user is None: + raise ValueError('Username or user required') + + return {"room": room_obj, "member": member} def get_member_from_user(user_id: int, room_id: int, db: Session): @@ -70,6 +73,24 @@ def get_anonymous_from_code(reconnect_code: str, db: Session): return anonymous +def create_member(*, room: Room, user: User | None = None, anonymous: Anonymous | None = None, waiting: bool = False, db: Session): + member_id = generate_unique_code(Member, s=db) + member = Member(room=room, user=user, anonymous=anonymous, waiting=waiting, + id_code=member_id) + db.add(member) + db.commit() + db.refresh(member) + return member + + +def get_or_create_member(*, room: Room, user: User | None = None, anonymous: Anonymous | None = None, waiting: bool = False, db: Session): + member = user is not None and get_member_from_user(user.id, room.id, db) + if member is not None and member is not False: + return member + member= create_member(room=room, user=user, anonymous=anonymous, waiting=waiting, db=db) + + + def connect_member(member: Member, db: Session): member.online = True db.add(member) @@ -92,7 +113,6 @@ def disconnect_member(member: Member, db: Session): def validate_username(username: str, room: Room, db: Session = Depends(get_session)): - print('VALIDATE', username) if len(username) > 20: return None members = select(Member.anonymous_id).where( @@ -116,7 +136,21 @@ def create_anonymous_member(username: str, room: Room, db: Session): db.commit() db.refresh(member) return member +def create_anonymous(username: str, room: Room, db: Session): + username = validate_username(username, room, db) + if username is None: + return None + reconnect_code = generate_unique_code( + Anonymous, s=db, field_name="reconnect_code") + anonymous = Anonymous(username=username, reconnect_code=reconnect_code) + db.add(anonymous) + db.commit() + db.refresh(anonymous) + return anonymous +def check_user_in_room(user_id: int, room_id: int, db: Session): + user = db.exec(select(Member).where(Member.user_id==user_id, Member.room_id == room_id)).first() + return user def create_user_member(user: User, room: Room, db: Session): member = get_member_from_user(user.id, room.id, db) @@ -147,30 +181,32 @@ def create_anonymous_waiter(username: str, room: Room, db: Session): return member + + def create_user_waiter(user: User, room: Room, db: Session): member = get_member_from_user(user.id, room.id, db) if member is not None: return member - member_id = generate_unique_code(Member, s=db) - member = Member(room=room, user=user, waiting=True, - id_code=member_id) - db.add(member) - db.commit() - db.refresh(member) + + member = create_member(room=room, user=user, waiting=True, + db=db) return member def get_waiter(waiter_code: str, db: Session): - return db.exec(select(Member).where(Member.id_code == waiter_code)).first() + return db.exec(select(Member).where(Member.id_code == waiter_code, Member.waiting == True)).first() + + +def get_member(id_code: str, room_id: str, db: Session): + return db.exec(select(Member).where(Member.id_code == id_code, Member.room_id == room_id)).first() -def get_member(id_code: str,room_id:str, db: Session): - return db.exec(select(Member).where(Member.id_code == id_code, Member.room_id== room_id)).first() def delete_member(member: Member, db: Session): db.delete(member) db.commit() return None + def accept_waiter(member: Member, db: Session): member.waiting = False member.waiter_code = None @@ -185,7 +221,16 @@ def refuse_waiter(member: Member, db: Session): db.commit() return None + def leave_room(member: Member, db: Session): - db.delete(member) - db.commit() - return None + db.delete(member) + db.commit() + return None + + +def serialize_member(member: Member) -> MemberRead | Waiter: + member_obj = member.user or member.anonymous + if member.waiting == False: + return MemberRead(username=member_obj.username, reconnect_code=getattr(member_obj, "reconnect_code", ""), isUser=member.user_id != None, isAdmin=member.is_admin, id_code=member.id_code).dict() + if member.waiting == True: + return Waiter(username=member_obj.username, waiter_id=member.id_code).dict() diff --git a/backend/api/database/room/models.py b/backend/api/database/room/models.py index 71e2965..0db71af 100644 --- a/backend/api/database/room/models.py +++ b/backend/api/database/room/models.py @@ -1,4 +1,4 @@ -from pydantic import root_validator +from pydantic import root_validator, BaseModel from typing import List, Optional, TYPE_CHECKING from sqlmodel import SQLModel, Field, Relationship @@ -76,6 +76,7 @@ class MemberRead(SQLModel): isUser: bool isAdmin: bool id_code: str + class MemberSerializer(MemberRead): member: MemberWithRelations @@ -90,3 +91,11 @@ class MemberSerializer(MemberRead): return {"username": member_obj.username, "reconnect_code": getattr(member_obj, "reconnect_code", ""), "isAdmin": member.is_admin, "isUser": member.user != None} +class RoomAndMember(BaseModel): + room: RoomRead + member: MemberRead + + +class Waiter(BaseModel): + username: str + waiter_id: str diff --git a/backend/api/main.py b/backend/api/main.py index 1c0317f..b367c29 100644 --- a/backend/api/main.py +++ b/backend/api/main.py @@ -53,8 +53,7 @@ admin = Admin(app, engine) class UserAdmin(ModelView, model=User): column_list = [User.id, User.username] - - + admin.add_view(UserAdmin) @app.on_event("startup") diff --git a/backend/api/routes/room/consumer.py b/backend/api/routes/room/consumer.py new file mode 100644 index 0000000..998222b --- /dev/null +++ b/backend/api/routes/room/consumer.py @@ -0,0 +1,234 @@ +from fastapi.websockets import WebSocket +from services.websocket import Consumer +from typing import Any, TYPE_CHECKING +from database.room.models import Room, Member, MemberRead, Waiter +from sqlmodel import Session +from database.room.crud import serialize_member,check_user_in_room, create_anonymous, create_member, create_room_db, delete_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, disconnect_member, create_anonymous_member, create_anonymous_waiter, create_user_member, create_user_waiter, get_or_create_member, get_waiter, accept_waiter, leave_room, refuse_waiter, check_room +from database.auth.crud import get_user_from_token +if TYPE_CHECKING: + from routes.room.routes import RoomManager + +class RoomConsumer(Consumer): + + def __init__(self, ws: WebSocket, room: Room, manager: "RoomManager", db: Session): + self.room = room + self.ws = ws + self.manager = manager + self.db = db + self.member = None + + # WS Utilities + + async def connect(self): + await self.ws.accept() + + async def direct_send(self, type: str, payload: Any): + await self.ws.send_json({'type': type, "data": payload}) + + async def send_to_admin(self, type: str, payload: Any, exclude: bool = False): + await self.manager.send_to_admin(self.room.id, {'type': type, "data": payload}) + + async def send_to(self, type: str, payload: Any, member_id, exclude: bool = False): + await self.manager.send_to(self.room.id, member_id, {'type': type, "data": payload}) + + async def broadcast(self, type, payload, exclude=False): + await self.manager.broadcast({"type": type, "data": payload}, self.room.id, exclude=[exclude == True and self]) + + def add_to_group(self): + self.manager.add(self.room.id, self) + + async def connect_self(self): + if isinstance(self.member, Member): + connect_member(self.member, self.db) + await self.broadcast(type="connect", payload={"member": serialize_member(self.member)}, exclude=True) + + async def disconnect_self(self): + if isinstance(self.member, Member): + disconnect_member(self.member, self.db) + if self.member.waiting is False: + await self.broadcast(type="disconnect", payload={"member": serialize_member(self.member)}) + else: + await self.send_to_admin(type="disconnect_waiter", payload={"waiter": serialize_member(self.member)}) + + async def loginMember(self, member: Member): + if member.room_id == self.room.id and member.waiting == False: + self.member = member + await self.connect_self() + self.add_to_group() + await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member)}) + + async def send_error(self, msg): + await self.direct_send(type="error", payload={"msg": msg}) + # Conditions + + async def isAdminReceive(self): + is_admin = self.member is not None and self.member.is_admin == True + if not is_admin: + await self.direct_send(type="error", payload={"msg": "Vous n'avez pas la permission de faire ca"}) + return False + return True + + def isAdmin(self): + return self.member is not None and self.member.is_admin == True + + async def isMember(self): + if self.member is None: + await self.send_error("Vous n'êtes connecté à aucune salle") + return self.member is not None and self.member.waiting == False + + def isWaiter(self): + return self.member is not None and self.member.waiting == True + # Received Events + @Consumer.event('login') + async def login(self, token: str | None = None, reconnect_code: str | None = None): + if reconnect_code is None and token is None: + await self.direct_send(type="error", payload={"msg": "Veuillez spécifier une méthode de connection"}) + return + + if token is not None: + member = get_member_from_token(token, self.room.id, self.db) + if member == False: + await self.send_error("Token expired") + return + if member is None: + await self.send_error("Utilisateur introuvable dans cette salle") + return + + elif reconnect_code is not None: + member = get_member_from_reconnect_code( + reconnect_code, self.room.id, db=self.db) + if member is None: + await self.send_error("Utilisateur introuvable dans cette salle") + return + + await self.loginMember(member) + + @Consumer.event('join') + async def join(self, token: str | None = None, username: str | None = None): + if token is not None: + user = get_user_from_token(token, self.db) + + if user is None: + await self.send_error("Utilisateur introuvable") + return + if user is False: + await self.send_error("Token expired") + return + + userInRoom = check_user_in_room(user.id, self.room.id, self.db) + + if userInRoom is not None: + await self.loginMember(userInRoom) + return + + waiter = create_member( + user=user, room=self.room, waiting=self.room.public is False, db=self.db) + + elif username is not None: + anonymous = create_anonymous(username, self.room, self.db) + if anonymous is None: + await self.send_error("Nom d'utilisateur invalide ou indisponible") + return + + waiter = create_member( + anonymous=anonymous, room=self.room, waiting=self.room.public is False, db=self.db) + + self.member = waiter + self.add_to_group() + + if self.room.public is False: + await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member)}) + await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member)}) + else: + await self.broadcast(type="joined", payload={"member": serialize_member(self.member)}, exclude=True) + await self.direct_send(type="accepted", payload={"member": serialize_member(self.member)}) + + + @Consumer.event('accept', conditions=[isAdminReceive]) + async def accept(self, waiter_id: str): + waiter = get_waiter(waiter_id, self.db) + if waiter is None: + await self.send_error("Utilisateur en list d'attente introuvable") + return + member = accept_waiter(waiter, self.db) + await self.send_to(type="accepted", payload={"member": serialize_member(member)}, member_id=waiter_id) + await self.broadcast(type="joined", payload={"member": serialize_member(member)}) + + @Consumer.event('refuse', conditions=[isAdminReceive]) + async def accept(self, waiter_id: str): + waiter = get_waiter(waiter_id, self.db) + member = refuse_waiter(waiter, self.db) + await self.send_to(type="refused", payload={'waiter_id': waiter_id}, member_id=waiter_id) + await self.direct_send(type="successfullyRefused", payload= {"waiter_id": waiter_id}) + + @Consumer.event('ping_room') + async def proom(self): + await self.broadcast(type='ping', payload={}, exclude=True) + + async def isConnected(self): + if self.member is None: + await self.direct_send(type="error", payload={"msg": "Vous n'êtes connecté à aucune salle"}) + return self.member is not None + + + @Consumer.event('leave', conditions=[isMember]) + async def leave(self): + if self.member.is_admin is True: + await self.send_error("Vous ne pouvez pas quitter une salle dont vous êtes l'administrateur") + return + member_obj = serialize_member(self.member) + leave_room(self.member, self.db) + + await self.direct_send(type="successfully_leaved", payload={}) + await self.broadcast(type='leaved', payload={"member": member_obj}) + self.member = None + + @Consumer.event('ban', conditions=[isAdminReceive]) + async def ban(self, member_id: str): + member = get_member(member_id, self.room.id, self.db) + if member == None: + await self.send_error("Utilisateur introuvable") + return + if member.is_admin is True: + await self.send_error("Vous ne pouvez pas bannir un administrateur") + return + + member_serialized = serialize_member(member) + leave_room(member, self.db) + await self.send_to(type="banned", payload={}, member_id=member.id_code) + await self.broadcast(type="leaved", payload={"member": member_serialized}) + + # Sending Events + + @Consumer.sending(['connect', "disconnect", "joined"], conditions=[isMember]) + def joined(self, member: MemberRead): + if self.member.id_code == member.id_code: + raise ValueError("") # Prevent from sending event + if self.member.is_admin == False: + member.reconnect_code = "" + return {"member": member} + + @Consumer.sending(['waiter', "disconnect_waiter"], conditions=[isAdmin]) + def waiter(self, waiter: Waiter): + return {"waiter": waiter} + + @Consumer.sending("refused", conditions=[isWaiter]) + def refused(self, waiter_id: str): + self.member = None + self.manager.remove(self.room.id, self) + return {"waiter_id": waiter_id} + + @Consumer.sending("banned", conditions=[isMember]) + def banned(self): + self.member = None + self.manager.remove(self.room.id, self) + return {} + + @Consumer.sending('ping', conditions=[isMember]) + def ping(self): + return {} + + async def disconnect(self): + self.manager.remove(self.room.id, self) + await self.disconnect_self() + diff --git a/backend/api/routes/room/manager.py b/backend/api/routes/room/manager.py new file mode 100644 index 0000000..5bdf883 --- /dev/null +++ b/backend/api/routes/room/manager.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING, Dict, List +if TYPE_CHECKING: + from routes.room.consumer import RoomConsumer + +class RoomManager: + def __init__(self): + self.active_connections: Dict[str, List["RoomConsumer"]] = {} + + def add(self, group: str, member: "RoomConsumer"): + + if group not in self.active_connections: + self.active_connections[group] = [] + if member not in self.active_connections[group]: + self.active_connections[group].append(member) + + def remove(self, group: str, member: "RoomConsumer"): + if group in self.active_connections: + if member in self.active_connections[group]: + self.active_connections[group].remove(member) + + async def broadcast(self, message, group: str, exclude: list["RoomConsumer"] = []): + if group in self.active_connections: + for connection in list(set(self.active_connections[group])): + if connection not in exclude: + await connection.send(message) + + async def send_to(self, group, id_code, msg): + if group in self.active_connections: + members = [c for c in self.active_connections[group] + if c.member.id_code == id_code] + for m in members: + await m.send(msg) + + async def send_to_admin(self, group, msg): + if group in self.active_connections: + members = [c for c in self.active_connections[group] + if c.member.is_admin == True] + for m in members: + await m.send(msg) diff --git a/backend/api/routes/room/routes.py b/backend/api/routes/room/routes.py index 79c6f75..f2dfc15 100644 --- a/backend/api/routes/room/routes.py +++ b/backend/api/routes/room/routes.py @@ -1,551 +1,33 @@ -from database.room.crud import create_room_db, delete_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, disconnect_member, create_anonymous_member, create_anonymous_waiter, create_user_member, create_user_waiter, get_waiter, accept_waiter, leave_room, refuse_waiter, check_room -from pydantic.error_wrappers import ValidationError -from pydantic import validate_arguments -from sqlmodel import select +from database.room.crud import serialize_member,check_user_in_room, create_anonymous, create_member, create_room_db, delete_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, disconnect_member, create_anonymous_member, create_anonymous_waiter, create_user_member, create_user_waiter, get_or_create_member, get_waiter, accept_waiter, leave_room, refuse_waiter, check_room + from pydantic import BaseModel from typing import Any, Callable, Dict, List, Optional -from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status, Query +from fastapi import APIRouter, Depends, WebSocket, status, Query from config import ALGORITHM, SECRET_KEY from database.auth.crud import get_user_from_clientId_db from database.auth.models import User from database.db import get_session from sqlmodel import Session -from database.room.models import Anonymous, MemberSerializer, MemberWithRelations, MemberRead, Room, RoomCreate, RoomRead, Member +from database.room.models import Room, RoomCreate, RoomAndMember +from routes.room.consumer import RoomConsumer +from routes.room.manager import RoomManager from services.auth import get_current_user_optional from fastapi.exceptions import HTTPException -from jose import jwt, exceptions -import inspect +from database.auth.crud import get_user_from_token +from services.websocket import Consumer router = APIRouter(tags=["room"]) -class RoomAndMember(BaseModel): - room: RoomRead - member: MemberRead - - -class Waiter(BaseModel): - username: str - waiter_id: str - - -def serialize_member(member: Member) -> MemberRead | Waiter: - member_obj = member.user or member.anonymous - if member.waiting == False: - return MemberRead(username=member_obj.username, reconnect_code=getattr(member_obj, "reconnect_code", ""), isUser=member.user_id != None, isAdmin=member.is_admin, id_code=member.id_code) - if member.waiting == True: - return Waiter(username=member_obj.username, waiter_id=member.id_code) - - @router.post('/room', response_model=RoomAndMember) def create_room(room: RoomCreate, username: Optional[str] = Query(default=None, max_length=20), user: User | None = Depends(get_current_user_optional), db: Session = Depends(get_session)): room_obj = create_room_db(room=room, user=user, username=username, db=db) return {'room': room_obj['room'], "member": serialize_member(room_obj['member'])} -class ConnectionManager: - def __init__(self): - self.active_connections: Dict[str, List[WebSocket]] = {} - - def add(self, group: str, ws: WebSocket): - if group not in self.active_connections: - self.active_connections[group] = [] - - if ws not in self.active_connections[group]: - self.active_connections[group].append(ws) - - def remove(self, group: str, ws: WebSocket): - if group in self.active_connections: - if ws in self.active_connections[group]: - self.active_connections[group].remove(ws) - - async def broadcast(self, message, group: str, exclude: list[WebSocket] = []): - if group in self.active_connections: - for connection in self.active_connections[group]: - if connection not in exclude: - await connection.send_json(message) - - -def make_event_decorator(eventsDict): - def _(name: str | List, conditions: List[Callable | bool] = []): - def add_event(func): - model = validate_arguments(func).model - if type(name) == str: - eventsDict[name] = {"func": func, - "conditions": conditions, "model": model} - if type(name) == list: - for n in name: - eventsDict[n] = {"func": func, - "conditions": conditions, "model": model} - return func - return add_event - return _ - - -class Event(BaseModel): - func: Callable - conditions: List[Callable | bool] - model: BaseModel - - -def dict_model(model: BaseModel, exclude: List[str]): - value = {} - for n, f in model: - if n not in exclude: - value[n] = f - return value - - -def dict_all(obj: Any): - if isinstance(obj, dict): - value = {} - for k, v in obj.items(): - if isinstance(v, dict): - v = dict_all(v) - value[k] = dict(v) - elif isinstance(v, BaseModel): - value[k] = dict(v) - else: - try: - value[k] = dict(v) - except: - value[k] = v - return value - return dict(obj) - - -class Consumer: - events: Dict[str, Event] = {} - sendings: Dict[str, Any] = {} - event = make_event_decorator(events) - sending = make_event_decorator(sendings) - - def __init__(self, ws: WebSocket): - self.ws: WebSocket = ws - #self.events: Dict[str, Callable] = {} - - async def connect(self): - pass - - async def validation_error_handler(self, e: ValidationError): - errors = e.errors() - await self.ws.send_json({"type": "error", "data": {"detail": [{ers['loc'][-1]: ers['msg']} for ers in errors]}}) - - async def send(self, payload): - type = payload.get('type', None) - print('TYPE', type, self.member) - if type is not None: - event_wrapper = self.sendings.get(type, None) - if event_wrapper is not None: - handler = event_wrapper.get('func') - conditions = event_wrapper.get('conditions') - - is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions]) - - if handler is not None and is_valid: - model = event_wrapper.get("model") - - data = payload.get('data') or {} - try: - validated_payload = model(self=self, **data) - except ValidationError as e: - print("ERROR", e) - await self.ws.send_json({"type": "error", "data": {"msg": "Oops there was an error"}}) - return - - validated_payload = dict_model(validated_payload, - exclude=["v__duplicate_kwargs", "args", 'kwargs', "self"]) - try: - parsed_payload = handler( - self, **validated_payload) - - await self.ws.send_json({'type': type, "data": dict_all(parsed_payload)}) - return - except Exception as e: - return - return - #print('pls') - await self.ws.send_json(payload) - #print('sent') - - async def receive(self, data): - event = data.get('type', None) - if event is not None: - event_wrapper = self.events.get(event, None) - if event_wrapper is not None: - handler = event_wrapper.get('func') - conditions = event_wrapper.get('conditions') - - is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions]) - - if handler is not None and is_valid: - model = event_wrapper.get("model") - - payload = data.get('data') or {} - try: - validated_payload = model(self=self, **payload) - except ValidationError as e: - await self.validation_error_handler(e) - return - - await handler(**{k: v for k, v in validated_payload.dict().items() if k not in ["v__duplicate_kwargs", "args", 'kwargs']}) - - async def disconnect(self): - pass - - async def run(self): - await self.connect() - try: - while True: - data = await self.ws.receive_json() - await self.receive(data) - except WebSocketDisconnect: - await self.disconnect() - - -class ConsumerManager: - def __init__(self): - self.active_connections: Dict[str, List[Consumer]] = {} - - def add(self, group: str, ws: Consumer): - - if group not in self.active_connections: - self.active_connections[group] = [] - #print("adding", ws, self.active_connections[group]) - if ws not in self.active_connections[group]: - #print('ACTUALLY ADDING') - self.active_connections[group].append(ws) - - def remove(self, group: str, ws: Consumer): - if group in self.active_connections: - if ws in self.active_connections[group]: - self.active_connections[group].remove(ws) - - async def broadcast(self, message, group: str, exclude: list[Consumer] = []): - if group in self.active_connections: - #print(self.active_connections[group], exclude) - for connection in list(set(self.active_connections[group])): - if connection not in exclude: - #print('SEND TO', connection, message) - await connection.send(message) - - - - -class Token(BaseModel): - token: str - - -def get_user_from_token(token: str, db: Session): - try: - decoded = jwt.decode(token=token, key=SECRET_KEY, - algorithms=[ALGORITHM]) - except exceptions.ExpiredSignatureError: - return False - - clientId = decoded.get('sub') - return get_user_from_clientId_db(clientId=clientId, db=db) - - -def check_same_member(member: Member, memberR: MemberRead): - return (member.user is not None and memberR.username == member.user.username) or (member.anonymous and memberR.reconnect_code == member.anonymous.reconnect_code) - -class RoomConsumer(Consumer): - - def __init__(self, ws: WebSocket, room: Room, manager: "RoomManager", db: Session): - self.room = room - self.ws = ws - self.manager = manager - self.db = db - self.member = None - # WS Utilities - - async def connect(self): - await self.ws.accept() - - async def direct_send(self, type: str, payload: Any): - await self.ws.send_json({'type': type, "data": payload}) - - async def send_to_admin(self, type: str, payload: Any, exclude: bool = False): - await self.manager.send_to_admin(self.room.id, {'type': type, "data": payload}) - - async def send_to(self, type: str, payload: Any,member_id, exclude: bool = False): - await self.manager.send_to(self.room.id, member_id,{'type': type, "data": payload}) - - async def broadcast(self, type, payload, exclude=False): - await self.manager.broadcast({"type": type, "data": payload}, self.room.id, exclude=[exclude == True and self]) - - def add_to_group(self): - self.manager.add(self.room.id, self) - - async def connect_self(self): - if isinstance(self.member, Member): - connect_member(self.member, self.db) - await self.broadcast(type="connect", payload={"member": serialize_member(self.member).dict()}, exclude=True) - - async def disconnect_self(self): - if isinstance(self.member, Member): - disconnect_member(self.member, self.db) - if self.member.waiting is False: - await self.broadcast(type="disconnect", payload={"member": serialize_member(self.member).dict()}) - else: - await self.broadcast(type="disconnect_waiter", payload={"waiter": serialize_member(self.member).dict()}) - - # DB Utilities - - # Received Events - @Consumer.event('login') - async def login(self, token: str | None = None, reconnect_code: str | None = None): - if token is not None: - member = get_member_from_token(token, self.room.id, self.db) - if member == False: - await self.direct_send(type="error", payload={"msg": "Token expired"}) - return - if member is None: - await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable dans cette salle"}) - return - self.member = member - - await self.connect_self() - self.add_to_group() - await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()}) - - elif reconnect_code is not None: - member = get_member_from_reconnect_code( - reconnect_code, self.room.id, db=self.db) - if member is None: - await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable dans cette salle"}) - return - - self.member = member - - await self.connect_self() - self.add_to_group() - await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()}) - if reconnect_code is None and token is None: - await self.direct_send(type="error", payload={"msg": "Veuillez spécifier une méthode de connection"}) - - @Consumer.event('join') - async def join(self, token: str | None = None, username: str | None = None): - if self.room.public == False: - if token is not None: - user = get_user_from_token(token, self.db) - - if user is None: - await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"}) - return - if user is False: - await self.direct_send(type="error", payload={"msg": "Token expired"}) - return - - waiter = create_user_waiter(user, self.room, self.db) - - if waiter.waiting is False: - self.member = waiter - # await self.connect_self() - self.add_to_group() - await self.connect_self() - await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()}) - return - - self.member = waiter - self.add_to_group() - await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member).dict()}) - await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member).dict()}) - - elif username is not None: - waiter = create_anonymous_waiter(username, self.room, self.db) - if waiter is None: - await self.direct_send(type="error", payload={"msg": "Nom d'utilisateur invalide ou indisponible"}) - return - self.member = waiter - self.add_to_group() - await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member).dict()}) - await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member).dict()}) - else: - if token is not None: - user = get_user_from_token(token, self.db) - if user is None: - await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"}) - return - if user is False: - await self.direct_send(type="error", payload={"msg": "Token expired"}) - return - - member = create_user_member(user, self.room, self.db) - if member is None: - return - self.member = member - self.add_to_group() - await self.broadcast(type="joined", payload={"member": serialize_member(self.member).dict()}, exclude=True) - await self.direct_send(type="accepted", payload={"member": serialize_member(self.member).dict()}) - elif username is not None: - member = create_anonymous_member(username, self.room, self.db) - if member is None: - await self.direct_send(type="error", data={"msg": "Nom d'utilisateur indisponible"}) - return - self.member = member - self.add_to_group() - - await self.broadcast(type="joined", payload={"member": serialize_member(self.member).dict()}, exclude=True) - await self.direct_send(type="accepted", payload={"member": serialize_member(self.member).dict()}) - - async def isAdminReceive(self): - is_admin = self.member is not None and self.member.is_admin == True - if not is_admin: - await self.direct_send(type="error", payload={"msg": "Vous n'avez pas la permission de faire ca"}) - return False - return True - - def isAdmin(self): - return self.member is not None and self.member.is_admin == True - - @Consumer.event('accept', conditions=[isAdminReceive]) - async def accept(self, waiter_id: str): - waiter = get_waiter(waiter_id, self.db) - if waiter is None: - await self.direct_send(type="error", payload={'msg': "Utilisateur introuvable"}) - return - member = accept_waiter(waiter, self.db) - await self.send_to(type="accepted", payload={"member": serialize_member(member).dict()}, member_id=waiter_id) - await self.broadcast(type="joined", payload={"member": serialize_member(member).dict()}) - - @Consumer.event('refuse', conditions=[isAdminReceive]) - async def accept(self, waiter_id: str): - waiter = get_waiter(waiter_id, self.db) - member = refuse_waiter(waiter, self.db) - await self.send_to(type="refused", payload={'waiter_id': waiter_id}, member_id=waiter_id) - # await self.broadcast(type="joined", payload={"member": serialize_member(member).dict()}) - - @Consumer.event('ping_room') - async def proom(self): - await self.broadcast(type='ping', payload={}, exclude=True) - async def hasRoom(self): - if self.member is None: - await self.direct_send(type="error", payload={"msg": "Vous n'êtes connecté à aucune salle"}) - return self.member - @Consumer.event('leave', conditions=[hasRoom]) - async def leave(self): - if self.member.is_admin is True: - await self.direct_send(type="error", payload={"msg": "Vous ne pouvez pas quitter une salle dont vous êtes l'administrateur"}) - return - member_obj = serialize_member(self.member).dict() - leave_room(self.member, self.db) - await self.direct_send(type="successfully_leaved", payload = {}) - await self.broadcast(type='leaved', payload ={"member": member_obj}) - self.member = None - - @Consumer.event('ban', conditions=[isAdminReceive]) - async def ban(self, member_id: str): - member = get_member(member_id, self.room.id, self.db) - if member == None: - await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"}) - return - if member.is_admin is True: - await self.direct_send(type="error", payload={"msg": "Vous ne pouvez pas bannir un administrateur"}) - return - member_serialized = serialize_member(member) - - leave_room(member, self.db) - await self.send_to(type="banned", payload={}, member_id=member.id_code) - await self.broadcast(type="leaved", payload = {"member": member_serialized.dict()}) - - def isMember(self): - return self.member is not None and self.member.waiting == False - - # Sending Events - @Consumer.sending("joined", conditions=[isMember]) - def joined(self, member: MemberRead): - if self.member.id_code == member.id_code: - raise ValueError("") - if self.member.is_admin == False: - member.reconnect_code = "" - return {"member": member} - - @Consumer.sending(['waiter', "disconnect_waiter"], conditions=[isAdmin]) - def waiter(self, waiter: Waiter): - return {"waiter": waiter} - - @Consumer.sending(['connect', "disconnect", ""]) - def connect_event(self, member: MemberRead): - if not self.member.is_admin: - member.reconnect_code = "" - return {"member": member} - - @Consumer.sending('disconnect') - def disconnect_event(self, member: MemberRead): - if not self.member.is_admin: - member.reconnect_code = "" - return {"member": member} - - @Consumer.sending('disconnect_waiter', conditions=[isAdmin]) - def disconnect_event(self, waiter: Waiter): - return {"waiter": waiter} - def isWaiter(self): - return self.member is not None and self.member.waiting == True - - @Consumer.sending("refused", conditions=[isWaiter]) - def refused(self, waiter_id: str): - self.member = None - self.manager.remove(self.room.id, self) - return {"waiter_id": waiter_id} - - @Consumer.sending("banned", conditions=[isMember]) - def banned(self): - self.member = None - self.manager.remove(self.room.id, self) - return {} - - @Consumer.sending('ping', conditions=[isMember]) - def ping(self): - return {} - async def disconnect(self): - print("DISCONNECT", self.member) - self.manager.remove(self.room.id, self) - await self.disconnect_self() - - - -class RoomManager: - def __init__(self): - self.active_connections: Dict[str, List[RoomConsumer]] = {} - - def add(self, group: str, ws: RoomConsumer): - - if group not in self.active_connections: - self.active_connections[group] = [] - #print("adding", ws, self.active_connections[group]) - if ws not in self.active_connections[group]: - #print('ACTUALLY ADDING') - self.active_connections[group].append(ws) - - def remove(self, group: str, ws: RoomConsumer): - if group in self.active_connections: - if ws in self.active_connections[group]: - self.active_connections[group].remove(ws) - - async def broadcast(self, message, group: str, exclude: list[RoomConsumer] = []): - if group in self.active_connections: - #print(self.active_connections[group], exclude) - for connection in list(set(self.active_connections[group])): - if connection not in exclude: - #print('SEND TO', connection, message) - await connection.send(message) - - async def send_to(self, group, id_code, msg): - if group in self.active_connections: - members = [c for c in self.active_connections[group] if c.member.id_code == id_code] - for m in members: - await m.send(msg) - - async def send_to_admin(self, group, msg): - if group in self.active_connections: - members = [c for c in self.active_connections[group] if c.member.is_admin == True] - for m in members: - await m.send(msg) - manager = RoomManager() + @router.websocket('/ws/room/{room_id}') async def room_ws(ws: WebSocket, room: Room | None = Depends(check_room), db: Session = Depends(get_session)): if room is None: diff --git a/backend/api/services/websocket.py b/backend/api/services/websocket.py new file mode 100644 index 0000000..8987ffc --- /dev/null +++ b/backend/api/services/websocket.py @@ -0,0 +1,137 @@ +from typing import List, Callable, Any, Dict +from pydantic import validate_arguments, BaseModel +from fastapi.websockets import WebSocketDisconnect, WebSocket +from pydantic.error_wrappers import ValidationError +import inspect + +def make_event_decorator(eventsDict): + def _(name: str | List, conditions: List[Callable | bool] = []): + def add_event(func): + model = validate_arguments(func).model + if type(name) == str: + eventsDict[name] = {"func": func, + "conditions": conditions, "model": model} + if type(name) == list: + for n in name: + eventsDict[n] = {"func": func, + "conditions": conditions, "model": model} + return func + return add_event + return _ + + + +def dict_model(model: BaseModel, exclude: List[str]): + value = {} + for n, f in model: + if n not in exclude: + value[n] = f + return value +def dict_all(obj: Any): + if isinstance(obj, dict): + value = {} + for k, v in obj.items(): + if isinstance(v, dict): + v = dict_all(v) + value[k] = dict(v) + elif isinstance(v, BaseModel): + value[k] = dict(v) + else: + try: + value[k] = dict(v) + except: + value[k] = v + return value + return dict(obj) + + +class Event(BaseModel): + func: Callable + conditions: List[Callable | bool] + model: BaseModel + +class Consumer: + events: Dict[str, Event] = {} + sendings: Dict[str, Any] = {} + event = make_event_decorator(events) + sending = make_event_decorator(sendings) + + def __init__(self, ws: WebSocket): + self.ws: WebSocket = ws + #self.events: Dict[str, Callable] = {} + + async def connect(self): + pass + + async def validation_error_handler(self, e: ValidationError): + errors = e.errors() + await self.ws.send_json({"type": "error", "data": {"detail": [{ers['loc'][-1]: ers['msg']} for ers in errors]}}) + + async def send(self, payload): + type = payload.get('type', None) + print('TYPE', type, self.member) + if type is not None: + event_wrapper = self.sendings.get(type, None) + if event_wrapper is not None: + handler = event_wrapper.get('func') + conditions = event_wrapper.get('conditions') + + is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions]) + + if handler is not None and is_valid: + model = event_wrapper.get("model") + + data = payload.get('data') or {} + try: + validated_payload = model(self=self, **data) + except ValidationError as e: + print("ERROR", e) + await self.ws.send_json({"type": "error", "data": {"msg": "Oops there was an error"}}) + return + + validated_payload = dict_model(validated_payload, + exclude=["v__duplicate_kwargs", "args", 'kwargs', "self"]) + try: + parsed_payload = handler( + self, **validated_payload) + + await self.ws.send_json({'type': type, "data": dict_all(parsed_payload)}) + return + except Exception as e: + return + return + await self.ws.send_json(payload) + + async def receive(self, data): + event = data.get('type', None) + if event is not None: + event_wrapper = self.events.get(event, None) + if event_wrapper is not None: + handler = event_wrapper.get('func') + conditions = event_wrapper.get('conditions') + + is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions]) + + if handler is not None and is_valid: + model = event_wrapper.get("model") + + payload = data.get('data') or {} + try: + validated_payload = model(self=self, **payload) + except ValidationError as e: + await self.validation_error_handler(e) + return + + await handler(**{k: v for k, v in validated_payload.dict().items() if k not in ["v__duplicate_kwargs", "args", 'kwargs']}) + + async def disconnect(self): + pass + + async def run(self): + await self.connect() + try: + while True: + data = await self.ws.receive_json() + await self.receive(data) + except WebSocketDisconnect: + await self.disconnect() diff --git a/backend/api/tests/test_room.py b/backend/api/tests/test_room.py index 93a6781..866d189 100644 --- a/backend/api/tests/test_room.py +++ b/backend/api/tests/test_room.py @@ -138,7 +138,8 @@ def test_join_waiter_not_found(client: TestClient): admin.receive_json() admin.send_json({"type": "accept", "data": {"waiter_id": "OOOO"}}) data = admin.receive_json() - assert data == {"type": "error", "data": {"msg": "Utilisateur introuvable"}} + assert data == {"type": "error", "data": { + "msg": "Utilisateur en list d'attente introuvable"}} def test_join_no_auth(client: TestClient): room = test_create_room_no_auth(client=client) @@ -219,7 +220,8 @@ def test_join_auth_refused(client: TestClient): "waiter": {"waiter_id": waiter_id, "username": "lilian2"}}} admin.send_json({"type": "refuse", "data": { "waiter_id": waiter_id}}) - + adata = admin.receive_json() + assert adata == {"type": "successfullyRefused", "data": {"waiter_id": waiter_id}} mdata = member.receive_json() assert mdata == {"type": "refused", "data": { "waiter_id": waiter_id}}