generateur_v3/backend/api/routes/room/routes.py

556 lines
22 KiB
Python
Raw Normal View History

2022-09-25 22:32:19 +02:00
from database.room.crud import create_room_db, delete_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, disconnect_member, create_anonymous_member, create_anonymous_waiter, create_user_member, create_user_waiter, get_waiter, accept_waiter, leave_room, refuse_waiter, check_room
2022-09-21 22:31:50 +02:00
from pydantic.error_wrappers import ValidationError
from pydantic import validate_arguments
2022-09-18 22:43:04 +02:00
from sqlmodel import select
from pydantic import BaseModel
from typing import Any, Callable, Dict, List, Optional
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status, Query
from config import ALGORITHM, SECRET_KEY
from database.auth.crud import get_user_from_clientId_db
from database.auth.models import User
from database.db import get_session
2022-09-25 22:32:19 +02:00
2022-09-18 22:43:04 +02:00
from sqlmodel import Session
2022-09-21 22:31:50 +02:00
from database.room.models import Anonymous, MemberSerializer, MemberWithRelations, MemberRead, Room, RoomCreate, RoomRead, Member
2022-09-18 22:43:04 +02:00
from services.auth import get_current_user_optional
from fastapi.exceptions import HTTPException
from jose import jwt, exceptions
2022-09-21 22:31:50 +02:00
import inspect
2022-09-16 21:50:55 +02:00
router = APIRouter(tags=["room"])
2022-09-18 22:43:04 +02:00
class RoomAndMember(BaseModel):
room: RoomRead
member: MemberRead
2022-09-21 22:31:50 +02:00
class Waiter(BaseModel):
username: str
waiter_id: str
def serialize_member(member: Member) -> MemberRead | Waiter:
member_obj = member.user or member.anonymous
if member.waiting == False:
2022-09-25 22:32:19 +02:00
return MemberRead(username=member_obj.username, reconnect_code=getattr(member_obj, "reconnect_code", ""), isUser=member.user_id != None, isAdmin=member.is_admin, id_code=member.id_code)
2022-09-21 22:31:50 +02:00
if member.waiting == True:
2022-09-25 22:32:19 +02:00
return Waiter(username=member_obj.username, waiter_id=member.id_code)
2022-09-21 22:31:50 +02:00
2022-09-18 22:43:04 +02:00
@router.post('/room', response_model=RoomAndMember)
def create_room(room: RoomCreate, username: Optional[str] = Query(default=None, max_length=20), user: User | None = Depends(get_current_user_optional), db: Session = Depends(get_session)):
room_obj = create_room_db(room=room, user=user, username=username, db=db)
2022-09-21 22:31:50 +02:00
return {'room': room_obj['room'], "member": serialize_member(room_obj['member'])}
2022-09-18 22:43:04 +02:00
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, List[WebSocket]] = {}
2022-09-21 22:31:50 +02:00
def add(self, group: str, ws: WebSocket):
2022-09-18 22:43:04 +02:00
if group not in self.active_connections:
self.active_connections[group] = []
if ws not in self.active_connections[group]:
self.active_connections[group].append(ws)
def remove(self, group: str, ws: WebSocket):
if group in self.active_connections:
if ws in self.active_connections[group]:
self.active_connections[group].remove(ws)
async def broadcast(self, message, group: str, exclude: list[WebSocket] = []):
if group in self.active_connections:
for connection in self.active_connections[group]:
if connection not in exclude:
await connection.send_json(message)
def make_event_decorator(eventsDict):
2022-09-25 22:32:19 +02:00
def _(name: str | List, conditions: List[Callable | bool] = []):
2022-09-18 22:43:04 +02:00
def add_event(func):
2022-09-21 22:31:50 +02:00
model = validate_arguments(func).model
2022-09-25 22:32:19 +02:00
if type(name) == str:
eventsDict[name] = {"func": func,
"conditions": conditions, "model": model}
if type(name) == list:
for n in name:
eventsDict[n] = {"func": func,
"conditions": conditions, "model": model}
2022-09-18 22:43:04 +02:00
return func
return add_event
return _
2022-09-21 22:31:50 +02:00
class Event(BaseModel):
func: Callable
conditions: List[Callable | bool]
model: BaseModel
def dict_model(model: BaseModel, exclude: List[str]):
value = {}
for n, f in model:
if n not in exclude:
value[n] = f
return value
def dict_all(obj: Any):
if isinstance(obj, dict):
value = {}
for k, v in obj.items():
if isinstance(v, dict):
v = dict_all(v)
value[k] = dict(v)
elif isinstance(v, BaseModel):
value[k] = dict(v)
else:
try:
value[k] = dict(v)
except:
value[k] = v
return value
return dict(obj)
2022-09-18 22:43:04 +02:00
class Consumer:
2022-09-21 22:31:50 +02:00
events: Dict[str, Event] = {}
sendings: Dict[str, Any] = {}
2022-09-18 22:43:04 +02:00
event = make_event_decorator(events)
2022-09-21 22:31:50 +02:00
sending = make_event_decorator(sendings)
2022-09-18 22:43:04 +02:00
def __init__(self, ws: WebSocket):
self.ws: WebSocket = ws
#self.events: Dict[str, Callable] = {}
async def connect(self):
pass
2022-09-21 22:31:50 +02:00
async def validation_error_handler(self, e: ValidationError):
errors = e.errors()
await self.ws.send_json({"type": "error", "data": {"detail": [{ers['loc'][-1]: ers['msg']} for ers in errors]}})
async def send(self, payload):
type = payload.get('type', None)
2022-09-25 22:32:19 +02:00
print('TYPE', type, self.member)
2022-09-21 22:31:50 +02:00
if type is not None:
event_wrapper = self.sendings.get(type, None)
if event_wrapper is not None:
handler = event_wrapper.get('func')
conditions = event_wrapper.get('conditions')
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
if handler is not None and is_valid:
model = event_wrapper.get("model")
data = payload.get('data') or {}
try:
validated_payload = model(self=self, **data)
except ValidationError as e:
2022-09-25 22:32:19 +02:00
print("ERROR", e)
2022-09-21 22:31:50 +02:00
await self.ws.send_json({"type": "error", "data": {"msg": "Oops there was an error"}})
return
validated_payload = dict_model(validated_payload,
exclude=["v__duplicate_kwargs", "args", 'kwargs', "self"])
try:
parsed_payload = handler(
2022-09-25 22:32:19 +02:00
self, **validated_payload)
2022-09-21 22:31:50 +02:00
await self.ws.send_json({'type': type, "data": dict_all(parsed_payload)})
return
except Exception as e:
return
return
2022-09-25 22:32:19 +02:00
#print('pls')
2022-09-21 22:31:50 +02:00
await self.ws.send_json(payload)
2022-09-25 22:32:19 +02:00
#print('sent')
2022-09-21 22:31:50 +02:00
2022-09-18 22:43:04 +02:00
async def receive(self, data):
event = data.get('type', None)
if event is not None:
2022-09-21 22:31:50 +02:00
event_wrapper = self.events.get(event, None)
if event_wrapper is not None:
handler = event_wrapper.get('func')
conditions = event_wrapper.get('conditions')
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
if handler is not None and is_valid:
model = event_wrapper.get("model")
payload = data.get('data') or {}
try:
validated_payload = model(self=self, **payload)
except ValidationError as e:
await self.validation_error_handler(e)
return
await handler(**{k: v for k, v in validated_payload.dict().items() if k not in ["v__duplicate_kwargs", "args", 'kwargs']})
2022-09-18 22:43:04 +02:00
async def disconnect(self):
pass
async def run(self):
await self.connect()
try:
while True:
data = await self.ws.receive_json()
await self.receive(data)
except WebSocketDisconnect:
await self.disconnect()
2022-09-21 22:31:50 +02:00
class ConsumerManager:
def __init__(self):
self.active_connections: Dict[str, List[Consumer]] = {}
def add(self, group: str, ws: Consumer):
if group not in self.active_connections:
self.active_connections[group] = []
2022-09-25 22:32:19 +02:00
#print("adding", ws, self.active_connections[group])
2022-09-21 22:31:50 +02:00
if ws not in self.active_connections[group]:
2022-09-25 22:32:19 +02:00
#print('ACTUALLY ADDING')
2022-09-21 22:31:50 +02:00
self.active_connections[group].append(ws)
def remove(self, group: str, ws: Consumer):
if group in self.active_connections:
if ws in self.active_connections[group]:
self.active_connections[group].remove(ws)
async def broadcast(self, message, group: str, exclude: list[Consumer] = []):
if group in self.active_connections:
2022-09-25 22:32:19 +02:00
#print(self.active_connections[group], exclude)
2022-09-21 22:31:50 +02:00
for connection in list(set(self.active_connections[group])):
if connection not in exclude:
2022-09-25 22:32:19 +02:00
#print('SEND TO', connection, message)
2022-09-21 22:31:50 +02:00
await connection.send(message)
2022-09-25 22:32:19 +02:00
2022-09-21 22:31:50 +02:00
2022-09-18 22:43:04 +02:00
class Token(BaseModel):
token: str
def get_user_from_token(token: str, db: Session):
try:
decoded = jwt.decode(token=token, key=SECRET_KEY,
algorithms=[ALGORITHM])
except exceptions.ExpiredSignatureError:
return False
clientId = decoded.get('sub')
return get_user_from_clientId_db(clientId=clientId, db=db)
2022-09-25 22:32:19 +02:00
def check_same_member(member: Member, memberR: MemberRead):
return (member.user is not None and memberR.username == member.user.username) or (member.anonymous and memberR.reconnect_code == member.anonymous.reconnect_code)
2022-09-18 22:43:04 +02:00
class RoomConsumer(Consumer):
2022-09-21 22:31:50 +02:00
2022-09-25 22:32:19 +02:00
def __init__(self, ws: WebSocket, room: Room, manager: "RoomManager", db: Session):
2022-09-18 22:43:04 +02:00
self.room = room
self.ws = ws
self.manager = manager
self.db = db
self.member = None
2022-09-21 22:31:50 +02:00
# WS Utilities
2022-09-18 22:43:04 +02:00
async def connect(self):
await self.ws.accept()
2022-09-21 22:31:50 +02:00
async def direct_send(self, type: str, payload: Any):
2022-09-18 22:43:04 +02:00
await self.ws.send_json({'type': type, "data": payload})
async def send_to_admin(self, type: str, payload: Any, exclude: bool = False):
2022-09-25 22:32:19 +02:00
await self.manager.send_to_admin(self.room.id, {'type': type, "data": payload})
async def send_to(self, type: str, payload: Any,member_id, exclude: bool = False):
await self.manager.send_to(self.room.id, member_id,{'type': type, "data": payload})
2022-09-18 22:43:04 +02:00
2022-09-25 22:32:19 +02:00
async def broadcast(self, type, payload, exclude=False):
await self.manager.broadcast({"type": type, "data": payload}, self.room.id, exclude=[exclude == True and self])
2022-09-18 22:43:04 +02:00
2022-09-21 22:31:50 +02:00
def add_to_group(self):
self.manager.add(self.room.id, self)
2022-09-18 22:43:04 +02:00
async def connect_self(self):
if isinstance(self.member, Member):
connect_member(self.member, self.db)
2022-09-25 22:32:19 +02:00
await self.broadcast(type="connect", payload={"member": serialize_member(self.member).dict()}, exclude=True)
2022-09-21 22:31:50 +02:00
2022-09-18 22:43:04 +02:00
async def disconnect_self(self):
if isinstance(self.member, Member):
disconnect_member(self.member, self.db)
2022-09-25 22:32:19 +02:00
if self.member.waiting is False:
await self.broadcast(type="disconnect", payload={"member": serialize_member(self.member).dict()})
else:
await self.broadcast(type="disconnect_waiter", payload={"waiter": serialize_member(self.member).dict()})
2022-09-21 22:31:50 +02:00
# DB Utilities
# Received Events
2022-09-18 22:43:04 +02:00
@Consumer.event('login')
2022-09-21 22:31:50 +02:00
async def login(self, token: str | None = None, reconnect_code: str | None = None):
if token is not None:
member = get_member_from_token(token, self.room.id, self.db)
if member == False:
await self.direct_send(type="error", payload={"msg": "Token expired"})
2022-09-18 22:43:04 +02:00
return
if member is None:
2022-09-21 22:31:50 +02:00
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable dans cette salle"})
2022-09-18 22:43:04 +02:00
return
self.member = member
2022-09-21 22:31:50 +02:00
2022-09-25 22:32:19 +02:00
await self.connect_self()
2022-09-21 22:31:50 +02:00
self.add_to_group()
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()})
elif reconnect_code is not None:
2022-09-25 22:32:19 +02:00
member = get_member_from_reconnect_code(
2022-09-21 22:31:50 +02:00
reconnect_code, self.room.id, db=self.db)
2022-09-18 22:43:04 +02:00
if member is None:
2022-09-21 22:31:50 +02:00
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable dans cette salle"})
2022-09-18 22:43:04 +02:00
return
self.member = member
2022-09-21 22:31:50 +02:00
2022-09-25 22:32:19 +02:00
await self.connect_self()
2022-09-21 22:31:50 +02:00
self.add_to_group()
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()})
if reconnect_code is None and token is None:
await self.direct_send(type="error", payload={"msg": "Veuillez spécifier une méthode de connection"})
2022-09-18 22:43:04 +02:00
@Consumer.event('join')
2022-09-21 22:31:50 +02:00
async def join(self, token: str | None = None, username: str | None = None):
if self.room.public == False:
if token is not None:
user = get_user_from_token(token, self.db)
if user is None:
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"})
return
if user is False:
await self.direct_send(type="error", payload={"msg": "Token expired"})
return
waiter = create_user_waiter(user, self.room, self.db)
if waiter.waiting is False:
self.member = waiter
# await self.connect_self()
self.add_to_group()
2022-09-25 22:32:19 +02:00
await self.connect_self()
2022-09-21 22:31:50 +02:00
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()})
return
self.member = waiter
self.add_to_group()
await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member).dict()})
2022-09-25 22:32:19 +02:00
await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member).dict()})
2022-09-21 22:31:50 +02:00
elif username is not None:
waiter = create_anonymous_waiter(username, self.room, self.db)
if waiter is None:
await self.direct_send(type="error", payload={"msg": "Nom d'utilisateur invalide ou indisponible"})
return
self.member = waiter
self.add_to_group()
await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member).dict()})
2022-09-25 22:32:19 +02:00
await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member).dict()})
2022-09-21 22:31:50 +02:00
else:
if token is not None:
user = get_user_from_token(token, self.db)
if user is None:
2022-09-25 22:32:19 +02:00
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"})
2022-09-21 22:31:50 +02:00
return
if user is False:
2022-09-25 22:32:19 +02:00
await self.direct_send(type="error", payload={"msg": "Token expired"})
2022-09-21 22:31:50 +02:00
return
member = create_user_member(user, self.room, self.db)
if member is None:
return
self.member = member
self.add_to_group()
2022-09-25 22:32:19 +02:00
await self.broadcast(type="joined", payload={"member": serialize_member(self.member).dict()}, exclude=True)
await self.direct_send(type="accepted", payload={"member": serialize_member(self.member).dict()})
2022-09-21 22:31:50 +02:00
elif username is not None:
member = create_anonymous_member(username, self.room, self.db)
if member is None:
await self.direct_send(type="error", data={"msg": "Nom d'utilisateur indisponible"})
return
self.member = member
self.add_to_group()
2022-09-25 22:32:19 +02:00
await self.broadcast(type="joined", payload={"member": serialize_member(self.member).dict()}, exclude=True)
await self.direct_send(type="accepted", payload={"member": serialize_member(self.member).dict()})
async def isAdminReceive(self):
is_admin = self.member is not None and self.member.is_admin == True
if not is_admin:
await self.direct_send(type="error", payload={"msg": "Vous n'avez pas la permission de faire ca"})
return False
return True
2022-09-21 22:31:50 +02:00
def isAdmin(self):
return self.member is not None and self.member.is_admin == True
2022-09-25 22:32:19 +02:00
@Consumer.event('accept', conditions=[isAdminReceive])
2022-09-21 22:31:50 +02:00
async def accept(self, waiter_id: str):
waiter = get_waiter(waiter_id, self.db)
2022-09-25 22:32:19 +02:00
if waiter is None:
await self.direct_send(type="error", payload={'msg': "Utilisateur introuvable"})
return
2022-09-21 22:31:50 +02:00
member = accept_waiter(waiter, self.db)
2022-09-25 22:32:19 +02:00
await self.send_to(type="accepted", payload={"member": serialize_member(member).dict()}, member_id=waiter_id)
2022-09-21 22:31:50 +02:00
await self.broadcast(type="joined", payload={"member": serialize_member(member).dict()})
2022-09-25 22:32:19 +02:00
@Consumer.event('refuse', conditions=[isAdminReceive])
2022-09-21 22:31:50 +02:00
async def accept(self, waiter_id: str):
waiter = get_waiter(waiter_id, self.db)
member = refuse_waiter(waiter, self.db)
2022-09-25 22:32:19 +02:00
await self.send_to(type="refused", payload={'waiter_id': waiter_id}, member_id=waiter_id)
# await self.broadcast(type="joined", payload={"member": serialize_member(member).dict()})
2022-09-21 22:31:50 +02:00
@Consumer.event('ping_room')
async def proom(self):
2022-09-25 22:32:19 +02:00
await self.broadcast(type='ping', payload={}, exclude=True)
async def hasRoom(self):
if self.member is None:
await self.direct_send(type="error", payload={"msg": "Vous n'êtes connecté à aucune salle"})
return self.member
@Consumer.event('leave', conditions=[hasRoom])
async def leave(self):
if self.member.is_admin is True:
await self.direct_send(type="error", payload={"msg": "Vous ne pouvez pas quitter une salle dont vous êtes l'administrateur"})
return
member_obj = serialize_member(self.member).dict()
leave_room(self.member, self.db)
await self.direct_send(type="successfully_leaved", payload = {})
await self.broadcast(type='leaved', payload ={"member": member_obj})
self.member = None
@Consumer.event('ban', conditions=[isAdminReceive])
async def ban(self, member_id: str):
member = get_member(member_id, self.room.id, self.db)
if member == None:
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"})
return
if member.is_admin is True:
await self.direct_send(type="error", payload={"msg": "Vous ne pouvez pas bannir un administrateur"})
return
member_serialized = serialize_member(member)
leave_room(member, self.db)
await self.send_to(type="banned", payload={}, member_id=member.id_code)
await self.broadcast(type="leaved", payload = {"member": member_serialized.dict()})
2022-09-21 22:31:50 +02:00
def isMember(self):
return self.member is not None and self.member.waiting == False
# Sending Events
@Consumer.sending("joined", conditions=[isMember])
def joined(self, member: MemberRead):
2022-09-25 22:32:19 +02:00
if self.member.id_code == member.id_code:
raise ValueError("")
2022-09-21 22:31:50 +02:00
if self.member.is_admin == False:
member.reconnect_code = ""
return {"member": member}
2022-09-25 22:32:19 +02:00
@Consumer.sending(['waiter', "disconnect_waiter"], conditions=[isAdmin])
2022-09-21 22:31:50 +02:00
def waiter(self, waiter: Waiter):
return {"waiter": waiter}
2022-09-25 22:32:19 +02:00
@Consumer.sending(['connect', "disconnect", ""])
def connect_event(self, member: MemberRead):
if not self.member.is_admin:
member.reconnect_code = ""
return {"member": member}
@Consumer.sending('disconnect')
def disconnect_event(self, member: MemberRead):
if not self.member.is_admin:
member.reconnect_code = ""
return {"member": member}
@Consumer.sending('disconnect_waiter', conditions=[isAdmin])
def disconnect_event(self, waiter: Waiter):
return {"waiter": waiter}
def isWaiter(self):
return self.member is not None and self.member.waiting == True
@Consumer.sending("refused", conditions=[isWaiter])
def refused(self, waiter_id: str):
self.member = None
self.manager.remove(self.room.id, self)
return {"waiter_id": waiter_id}
@Consumer.sending("banned", conditions=[isMember])
def banned(self):
self.member = None
self.manager.remove(self.room.id, self)
return {}
2022-09-21 22:31:50 +02:00
@Consumer.sending('ping', conditions=[isMember])
def ping(self):
return {}
2022-09-18 22:43:04 +02:00
async def disconnect(self):
2022-09-25 22:32:19 +02:00
print("DISCONNECT", self.member)
2022-09-21 22:31:50 +02:00
self.manager.remove(self.room.id, self)
2022-09-25 22:32:19 +02:00
await self.disconnect_self()
class RoomManager:
def __init__(self):
self.active_connections: Dict[str, List[RoomConsumer]] = {}
def add(self, group: str, ws: RoomConsumer):
if group not in self.active_connections:
self.active_connections[group] = []
#print("adding", ws, self.active_connections[group])
if ws not in self.active_connections[group]:
#print('ACTUALLY ADDING')
self.active_connections[group].append(ws)
def remove(self, group: str, ws: RoomConsumer):
if group in self.active_connections:
if ws in self.active_connections[group]:
self.active_connections[group].remove(ws)
2022-09-21 22:31:50 +02:00
2022-09-25 22:32:19 +02:00
async def broadcast(self, message, group: str, exclude: list[RoomConsumer] = []):
if group in self.active_connections:
#print(self.active_connections[group], exclude)
for connection in list(set(self.active_connections[group])):
if connection not in exclude:
#print('SEND TO', connection, message)
await connection.send(message)
async def send_to(self, group, id_code, msg):
if group in self.active_connections:
members = [c for c in self.active_connections[group] if c.member.id_code == id_code]
for m in members:
await m.send(msg)
async def send_to_admin(self, group, msg):
if group in self.active_connections:
members = [c for c in self.active_connections[group] if c.member.is_admin == True]
for m in members:
await m.send(msg)
manager = RoomManager()
2022-09-21 22:31:50 +02:00
2022-09-18 22:43:04 +02:00
@router.websocket('/ws/room/{room_id}')
async def room_ws(ws: WebSocket, room: Room | None = Depends(check_room), db: Session = Depends(get_session)):
if room is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail='Room not found')
consumer = RoomConsumer(ws=ws, room=room, manager=manager, db=db)
await consumer.run()