room first part
This commit is contained in:
parent
8391c2e64a
commit
b639c5a88e
@ -1,37 +1,40 @@
|
|||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlmodel import Session, select, col
|
from sqlmodel import Session, select, col
|
||||||
from database.db import get_session
|
from database.db import get_session
|
||||||
from database.room.models import Anonymous, Member, Room, RoomCreate
|
from database.room.models import Anonymous, Member, Room, RoomCreate, Waiter, MemberRead
|
||||||
from database.auth.models import User
|
from database.auth.models import User
|
||||||
from services.database import generate_unique_code
|
from services.database import generate_unique_code
|
||||||
from database.auth.crud import get_user_from_token
|
from database.auth.crud import get_user_from_token
|
||||||
|
|
||||||
|
|
||||||
def check_room(room_id: str, db: Session = Depends(get_session)):
|
def check_room(room_id: str, db: Session = Depends(get_session)):
|
||||||
room = db.exec(select(Room).where(Room.id_code == room_id)).first()
|
room = db.exec(select(Room).where(Room.id_code == room_id)).first()
|
||||||
return room
|
return room
|
||||||
|
|
||||||
|
|
||||||
def create_room_db(*,room: RoomCreate, user: User | None = None, username: str | None = None, db: Session):
|
def create_room_db(*, room: RoomCreate, user: User | None = None, username: str | None = None, db: Session):
|
||||||
id_code = generate_unique_code(Room,s=db)
|
id_code = generate_unique_code(Room, s=db)
|
||||||
member_id = generate_unique_code(Member, s=db)
|
member_id = generate_unique_code(Member, s=db)
|
||||||
room_obj = Room(**room.dict(exclude_unset=True), id_code=id_code)
|
room_obj = Room(**room.dict(exclude_unset=True), id_code=id_code)
|
||||||
if user is not None:
|
if user is not None:
|
||||||
member = Member(user_id=user.id, room=room_obj, is_admin=True, id_code=member_id)
|
member = Member(user_id=user.id, room=room_obj,
|
||||||
db.add(member)
|
is_admin=True, id_code=member_id)
|
||||||
db.commit()
|
db.add(member)
|
||||||
db.refresh(member)
|
db.commit()
|
||||||
if username is not None:
|
db.refresh(member)
|
||||||
reconnect_code = generate_unique_code(Anonymous, s=db, field_name='reconnect_code')
|
if username is not None:
|
||||||
anonymous = Anonymous(username=username, reconnect_code=reconnect_code)
|
reconnect_code = generate_unique_code(
|
||||||
member = Member(anonymous=anonymous, room=room_obj, is_admin=True, id_code=member_id)
|
Anonymous, s=db, field_name='reconnect_code')
|
||||||
db.add(member)
|
anonymous = Anonymous(username=username, reconnect_code=reconnect_code)
|
||||||
db.commit()
|
member = Member(anonymous=anonymous, room=room_obj,
|
||||||
db.refresh(member)
|
is_admin=True, id_code=member_id)
|
||||||
if username is None and user is None:
|
db.add(member)
|
||||||
raise ValueError('Username or user required')
|
db.commit()
|
||||||
|
db.refresh(member)
|
||||||
return {"room": room_obj, "member": member}
|
if username is None and user is None:
|
||||||
|
raise ValueError('Username or user required')
|
||||||
|
|
||||||
|
return {"room": room_obj, "member": member}
|
||||||
|
|
||||||
|
|
||||||
def get_member_from_user(user_id: int, room_id: int, db: Session):
|
def get_member_from_user(user_id: int, room_id: int, db: Session):
|
||||||
@ -70,6 +73,24 @@ def get_anonymous_from_code(reconnect_code: str, db: Session):
|
|||||||
return anonymous
|
return anonymous
|
||||||
|
|
||||||
|
|
||||||
|
def create_member(*, room: Room, user: User | None = None, anonymous: Anonymous | None = None, waiting: bool = False, db: Session):
|
||||||
|
member_id = generate_unique_code(Member, s=db)
|
||||||
|
member = Member(room=room, user=user, anonymous=anonymous, waiting=waiting,
|
||||||
|
id_code=member_id)
|
||||||
|
db.add(member)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(member)
|
||||||
|
return member
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_create_member(*, room: Room, user: User | None = None, anonymous: Anonymous | None = None, waiting: bool = False, db: Session):
|
||||||
|
member = user is not None and get_member_from_user(user.id, room.id, db)
|
||||||
|
if member is not None and member is not False:
|
||||||
|
return member
|
||||||
|
member= create_member(room=room, user=user, anonymous=anonymous, waiting=waiting, db=db)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def connect_member(member: Member, db: Session):
|
def connect_member(member: Member, db: Session):
|
||||||
member.online = True
|
member.online = True
|
||||||
db.add(member)
|
db.add(member)
|
||||||
@ -92,7 +113,6 @@ def disconnect_member(member: Member, db: Session):
|
|||||||
|
|
||||||
|
|
||||||
def validate_username(username: str, room: Room, db: Session = Depends(get_session)):
|
def validate_username(username: str, room: Room, db: Session = Depends(get_session)):
|
||||||
print('VALIDATE', username)
|
|
||||||
if len(username) > 20:
|
if len(username) > 20:
|
||||||
return None
|
return None
|
||||||
members = select(Member.anonymous_id).where(
|
members = select(Member.anonymous_id).where(
|
||||||
@ -116,7 +136,21 @@ def create_anonymous_member(username: str, room: Room, db: Session):
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(member)
|
db.refresh(member)
|
||||||
return member
|
return member
|
||||||
|
def create_anonymous(username: str, room: Room, db: Session):
|
||||||
|
username = validate_username(username, room, db)
|
||||||
|
if username is None:
|
||||||
|
return None
|
||||||
|
reconnect_code = generate_unique_code(
|
||||||
|
Anonymous, s=db, field_name="reconnect_code")
|
||||||
|
anonymous = Anonymous(username=username, reconnect_code=reconnect_code)
|
||||||
|
db.add(anonymous)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(anonymous)
|
||||||
|
return anonymous
|
||||||
|
|
||||||
|
def check_user_in_room(user_id: int, room_id: int, db: Session):
|
||||||
|
user = db.exec(select(Member).where(Member.user_id==user_id, Member.room_id == room_id)).first()
|
||||||
|
return user
|
||||||
|
|
||||||
def create_user_member(user: User, room: Room, db: Session):
|
def create_user_member(user: User, room: Room, db: Session):
|
||||||
member = get_member_from_user(user.id, room.id, db)
|
member = get_member_from_user(user.id, room.id, db)
|
||||||
@ -147,30 +181,32 @@ def create_anonymous_waiter(username: str, room: Room, db: Session):
|
|||||||
return member
|
return member
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_user_waiter(user: User, room: Room, db: Session):
|
def create_user_waiter(user: User, room: Room, db: Session):
|
||||||
member = get_member_from_user(user.id, room.id, db)
|
member = get_member_from_user(user.id, room.id, db)
|
||||||
if member is not None:
|
if member is not None:
|
||||||
return member
|
return member
|
||||||
member_id = generate_unique_code(Member, s=db)
|
|
||||||
member = Member(room=room, user=user, waiting=True,
|
member = create_member(room=room, user=user, waiting=True,
|
||||||
id_code=member_id)
|
db=db)
|
||||||
db.add(member)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(member)
|
|
||||||
return member
|
return member
|
||||||
|
|
||||||
|
|
||||||
def get_waiter(waiter_code: str, db: Session):
|
def get_waiter(waiter_code: str, db: Session):
|
||||||
return db.exec(select(Member).where(Member.id_code == waiter_code)).first()
|
return db.exec(select(Member).where(Member.id_code == waiter_code, Member.waiting == True)).first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_member(id_code: str, room_id: str, db: Session):
|
||||||
|
return db.exec(select(Member).where(Member.id_code == id_code, Member.room_id == room_id)).first()
|
||||||
|
|
||||||
|
|
||||||
def get_member(id_code: str,room_id:str, db: Session):
|
|
||||||
return db.exec(select(Member).where(Member.id_code == id_code, Member.room_id== room_id)).first()
|
|
||||||
def delete_member(member: Member, db: Session):
|
def delete_member(member: Member, db: Session):
|
||||||
db.delete(member)
|
db.delete(member)
|
||||||
db.commit()
|
db.commit()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def accept_waiter(member: Member, db: Session):
|
def accept_waiter(member: Member, db: Session):
|
||||||
member.waiting = False
|
member.waiting = False
|
||||||
member.waiter_code = None
|
member.waiter_code = None
|
||||||
@ -185,7 +221,16 @@ def refuse_waiter(member: Member, db: Session):
|
|||||||
db.commit()
|
db.commit()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def leave_room(member: Member, db: Session):
|
def leave_room(member: Member, db: Session):
|
||||||
db.delete(member)
|
db.delete(member)
|
||||||
db.commit()
|
db.commit()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_member(member: Member) -> MemberRead | Waiter:
|
||||||
|
member_obj = member.user or member.anonymous
|
||||||
|
if member.waiting == False:
|
||||||
|
return MemberRead(username=member_obj.username, reconnect_code=getattr(member_obj, "reconnect_code", ""), isUser=member.user_id != None, isAdmin=member.is_admin, id_code=member.id_code).dict()
|
||||||
|
if member.waiting == True:
|
||||||
|
return Waiter(username=member_obj.username, waiter_id=member.id_code).dict()
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from pydantic import root_validator
|
from pydantic import root_validator, BaseModel
|
||||||
from typing import List, Optional, TYPE_CHECKING
|
from typing import List, Optional, TYPE_CHECKING
|
||||||
from sqlmodel import SQLModel, Field, Relationship
|
from sqlmodel import SQLModel, Field, Relationship
|
||||||
|
|
||||||
@ -76,6 +76,7 @@ class MemberRead(SQLModel):
|
|||||||
isUser: bool
|
isUser: bool
|
||||||
isAdmin: bool
|
isAdmin: bool
|
||||||
id_code: str
|
id_code: str
|
||||||
|
|
||||||
class MemberSerializer(MemberRead):
|
class MemberSerializer(MemberRead):
|
||||||
member: MemberWithRelations
|
member: MemberWithRelations
|
||||||
|
|
||||||
@ -90,3 +91,11 @@ class MemberSerializer(MemberRead):
|
|||||||
return {"username": member_obj.username, "reconnect_code": getattr(member_obj, "reconnect_code", ""), "isAdmin": member.is_admin, "isUser": member.user != None}
|
return {"username": member_obj.username, "reconnect_code": getattr(member_obj, "reconnect_code", ""), "isAdmin": member.is_admin, "isUser": member.user != None}
|
||||||
|
|
||||||
|
|
||||||
|
class RoomAndMember(BaseModel):
|
||||||
|
room: RoomRead
|
||||||
|
member: MemberRead
|
||||||
|
|
||||||
|
|
||||||
|
class Waiter(BaseModel):
|
||||||
|
username: str
|
||||||
|
waiter_id: str
|
||||||
|
@ -53,8 +53,7 @@ admin = Admin(app, engine)
|
|||||||
|
|
||||||
class UserAdmin(ModelView, model=User):
|
class UserAdmin(ModelView, model=User):
|
||||||
column_list = [User.id, User.username]
|
column_list = [User.id, User.username]
|
||||||
|
|
||||||
|
|
||||||
admin.add_view(UserAdmin)
|
admin.add_view(UserAdmin)
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
|
234
backend/api/routes/room/consumer.py
Normal file
234
backend/api/routes/room/consumer.py
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
from fastapi.websockets import WebSocket
|
||||||
|
from services.websocket import Consumer
|
||||||
|
from typing import Any, TYPE_CHECKING
|
||||||
|
from database.room.models import Room, Member, MemberRead, Waiter
|
||||||
|
from sqlmodel import Session
|
||||||
|
from database.room.crud import serialize_member,check_user_in_room, create_anonymous, create_member, create_room_db, delete_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, disconnect_member, create_anonymous_member, create_anonymous_waiter, create_user_member, create_user_waiter, get_or_create_member, get_waiter, accept_waiter, leave_room, refuse_waiter, check_room
|
||||||
|
from database.auth.crud import get_user_from_token
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from routes.room.routes import RoomManager
|
||||||
|
|
||||||
|
class RoomConsumer(Consumer):
|
||||||
|
|
||||||
|
def __init__(self, ws: WebSocket, room: Room, manager: "RoomManager", db: Session):
|
||||||
|
self.room = room
|
||||||
|
self.ws = ws
|
||||||
|
self.manager = manager
|
||||||
|
self.db = db
|
||||||
|
self.member = None
|
||||||
|
|
||||||
|
# WS Utilities
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
await self.ws.accept()
|
||||||
|
|
||||||
|
async def direct_send(self, type: str, payload: Any):
|
||||||
|
await self.ws.send_json({'type': type, "data": payload})
|
||||||
|
|
||||||
|
async def send_to_admin(self, type: str, payload: Any, exclude: bool = False):
|
||||||
|
await self.manager.send_to_admin(self.room.id, {'type': type, "data": payload})
|
||||||
|
|
||||||
|
async def send_to(self, type: str, payload: Any, member_id, exclude: bool = False):
|
||||||
|
await self.manager.send_to(self.room.id, member_id, {'type': type, "data": payload})
|
||||||
|
|
||||||
|
async def broadcast(self, type, payload, exclude=False):
|
||||||
|
await self.manager.broadcast({"type": type, "data": payload}, self.room.id, exclude=[exclude == True and self])
|
||||||
|
|
||||||
|
def add_to_group(self):
|
||||||
|
self.manager.add(self.room.id, self)
|
||||||
|
|
||||||
|
async def connect_self(self):
|
||||||
|
if isinstance(self.member, Member):
|
||||||
|
connect_member(self.member, self.db)
|
||||||
|
await self.broadcast(type="connect", payload={"member": serialize_member(self.member)}, exclude=True)
|
||||||
|
|
||||||
|
async def disconnect_self(self):
|
||||||
|
if isinstance(self.member, Member):
|
||||||
|
disconnect_member(self.member, self.db)
|
||||||
|
if self.member.waiting is False:
|
||||||
|
await self.broadcast(type="disconnect", payload={"member": serialize_member(self.member)})
|
||||||
|
else:
|
||||||
|
await self.send_to_admin(type="disconnect_waiter", payload={"waiter": serialize_member(self.member)})
|
||||||
|
|
||||||
|
async def loginMember(self, member: Member):
|
||||||
|
if member.room_id == self.room.id and member.waiting == False:
|
||||||
|
self.member = member
|
||||||
|
await self.connect_self()
|
||||||
|
self.add_to_group()
|
||||||
|
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member)})
|
||||||
|
|
||||||
|
async def send_error(self, msg):
|
||||||
|
await self.direct_send(type="error", payload={"msg": msg})
|
||||||
|
# Conditions
|
||||||
|
|
||||||
|
async def isAdminReceive(self):
|
||||||
|
is_admin = self.member is not None and self.member.is_admin == True
|
||||||
|
if not is_admin:
|
||||||
|
await self.direct_send(type="error", payload={"msg": "Vous n'avez pas la permission de faire ca"})
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def isAdmin(self):
|
||||||
|
return self.member is not None and self.member.is_admin == True
|
||||||
|
|
||||||
|
async def isMember(self):
|
||||||
|
if self.member is None:
|
||||||
|
await self.send_error("Vous n'êtes connecté à aucune salle")
|
||||||
|
return self.member is not None and self.member.waiting == False
|
||||||
|
|
||||||
|
def isWaiter(self):
|
||||||
|
return self.member is not None and self.member.waiting == True
|
||||||
|
# Received Events
|
||||||
|
@Consumer.event('login')
|
||||||
|
async def login(self, token: str | None = None, reconnect_code: str | None = None):
|
||||||
|
if reconnect_code is None and token is None:
|
||||||
|
await self.direct_send(type="error", payload={"msg": "Veuillez spécifier une méthode de connection"})
|
||||||
|
return
|
||||||
|
|
||||||
|
if token is not None:
|
||||||
|
member = get_member_from_token(token, self.room.id, self.db)
|
||||||
|
if member == False:
|
||||||
|
await self.send_error("Token expired")
|
||||||
|
return
|
||||||
|
if member is None:
|
||||||
|
await self.send_error("Utilisateur introuvable dans cette salle")
|
||||||
|
return
|
||||||
|
|
||||||
|
elif reconnect_code is not None:
|
||||||
|
member = get_member_from_reconnect_code(
|
||||||
|
reconnect_code, self.room.id, db=self.db)
|
||||||
|
if member is None:
|
||||||
|
await self.send_error("Utilisateur introuvable dans cette salle")
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.loginMember(member)
|
||||||
|
|
||||||
|
@Consumer.event('join')
|
||||||
|
async def join(self, token: str | None = None, username: str | None = None):
|
||||||
|
if token is not None:
|
||||||
|
user = get_user_from_token(token, self.db)
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
await self.send_error("Utilisateur introuvable")
|
||||||
|
return
|
||||||
|
if user is False:
|
||||||
|
await self.send_error("Token expired")
|
||||||
|
return
|
||||||
|
|
||||||
|
userInRoom = check_user_in_room(user.id, self.room.id, self.db)
|
||||||
|
|
||||||
|
if userInRoom is not None:
|
||||||
|
await self.loginMember(userInRoom)
|
||||||
|
return
|
||||||
|
|
||||||
|
waiter = create_member(
|
||||||
|
user=user, room=self.room, waiting=self.room.public is False, db=self.db)
|
||||||
|
|
||||||
|
elif username is not None:
|
||||||
|
anonymous = create_anonymous(username, self.room, self.db)
|
||||||
|
if anonymous is None:
|
||||||
|
await self.send_error("Nom d'utilisateur invalide ou indisponible")
|
||||||
|
return
|
||||||
|
|
||||||
|
waiter = create_member(
|
||||||
|
anonymous=anonymous, room=self.room, waiting=self.room.public is False, db=self.db)
|
||||||
|
|
||||||
|
self.member = waiter
|
||||||
|
self.add_to_group()
|
||||||
|
|
||||||
|
if self.room.public is False:
|
||||||
|
await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member)})
|
||||||
|
await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member)})
|
||||||
|
else:
|
||||||
|
await self.broadcast(type="joined", payload={"member": serialize_member(self.member)}, exclude=True)
|
||||||
|
await self.direct_send(type="accepted", payload={"member": serialize_member(self.member)})
|
||||||
|
|
||||||
|
|
||||||
|
@Consumer.event('accept', conditions=[isAdminReceive])
|
||||||
|
async def accept(self, waiter_id: str):
|
||||||
|
waiter = get_waiter(waiter_id, self.db)
|
||||||
|
if waiter is None:
|
||||||
|
await self.send_error("Utilisateur en list d'attente introuvable")
|
||||||
|
return
|
||||||
|
member = accept_waiter(waiter, self.db)
|
||||||
|
await self.send_to(type="accepted", payload={"member": serialize_member(member)}, member_id=waiter_id)
|
||||||
|
await self.broadcast(type="joined", payload={"member": serialize_member(member)})
|
||||||
|
|
||||||
|
@Consumer.event('refuse', conditions=[isAdminReceive])
|
||||||
|
async def accept(self, waiter_id: str):
|
||||||
|
waiter = get_waiter(waiter_id, self.db)
|
||||||
|
member = refuse_waiter(waiter, self.db)
|
||||||
|
await self.send_to(type="refused", payload={'waiter_id': waiter_id}, member_id=waiter_id)
|
||||||
|
await self.direct_send(type="successfullyRefused", payload= {"waiter_id": waiter_id})
|
||||||
|
|
||||||
|
@Consumer.event('ping_room')
|
||||||
|
async def proom(self):
|
||||||
|
await self.broadcast(type='ping', payload={}, exclude=True)
|
||||||
|
|
||||||
|
async def isConnected(self):
|
||||||
|
if self.member is None:
|
||||||
|
await self.direct_send(type="error", payload={"msg": "Vous n'êtes connecté à aucune salle"})
|
||||||
|
return self.member is not None
|
||||||
|
|
||||||
|
|
||||||
|
@Consumer.event('leave', conditions=[isMember])
|
||||||
|
async def leave(self):
|
||||||
|
if self.member.is_admin is True:
|
||||||
|
await self.send_error("Vous ne pouvez pas quitter une salle dont vous êtes l'administrateur")
|
||||||
|
return
|
||||||
|
member_obj = serialize_member(self.member)
|
||||||
|
leave_room(self.member, self.db)
|
||||||
|
|
||||||
|
await self.direct_send(type="successfully_leaved", payload={})
|
||||||
|
await self.broadcast(type='leaved', payload={"member": member_obj})
|
||||||
|
self.member = None
|
||||||
|
|
||||||
|
@Consumer.event('ban', conditions=[isAdminReceive])
|
||||||
|
async def ban(self, member_id: str):
|
||||||
|
member = get_member(member_id, self.room.id, self.db)
|
||||||
|
if member == None:
|
||||||
|
await self.send_error("Utilisateur introuvable")
|
||||||
|
return
|
||||||
|
if member.is_admin is True:
|
||||||
|
await self.send_error("Vous ne pouvez pas bannir un administrateur")
|
||||||
|
return
|
||||||
|
|
||||||
|
member_serialized = serialize_member(member)
|
||||||
|
leave_room(member, self.db)
|
||||||
|
await self.send_to(type="banned", payload={}, member_id=member.id_code)
|
||||||
|
await self.broadcast(type="leaved", payload={"member": member_serialized})
|
||||||
|
|
||||||
|
# Sending Events
|
||||||
|
|
||||||
|
@Consumer.sending(['connect', "disconnect", "joined"], conditions=[isMember])
|
||||||
|
def joined(self, member: MemberRead):
|
||||||
|
if self.member.id_code == member.id_code:
|
||||||
|
raise ValueError("") # Prevent from sending event
|
||||||
|
if self.member.is_admin == False:
|
||||||
|
member.reconnect_code = ""
|
||||||
|
return {"member": member}
|
||||||
|
|
||||||
|
@Consumer.sending(['waiter', "disconnect_waiter"], conditions=[isAdmin])
|
||||||
|
def waiter(self, waiter: Waiter):
|
||||||
|
return {"waiter": waiter}
|
||||||
|
|
||||||
|
@Consumer.sending("refused", conditions=[isWaiter])
|
||||||
|
def refused(self, waiter_id: str):
|
||||||
|
self.member = None
|
||||||
|
self.manager.remove(self.room.id, self)
|
||||||
|
return {"waiter_id": waiter_id}
|
||||||
|
|
||||||
|
@Consumer.sending("banned", conditions=[isMember])
|
||||||
|
def banned(self):
|
||||||
|
self.member = None
|
||||||
|
self.manager.remove(self.room.id, self)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@Consumer.sending('ping', conditions=[isMember])
|
||||||
|
def ping(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
self.manager.remove(self.room.id, self)
|
||||||
|
await self.disconnect_self()
|
||||||
|
|
39
backend/api/routes/room/manager.py
Normal file
39
backend/api/routes/room/manager.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from routes.room.consumer import RoomConsumer
|
||||||
|
|
||||||
|
class RoomManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.active_connections: Dict[str, List["RoomConsumer"]] = {}
|
||||||
|
|
||||||
|
def add(self, group: str, member: "RoomConsumer"):
|
||||||
|
|
||||||
|
if group not in self.active_connections:
|
||||||
|
self.active_connections[group] = []
|
||||||
|
if member not in self.active_connections[group]:
|
||||||
|
self.active_connections[group].append(member)
|
||||||
|
|
||||||
|
def remove(self, group: str, member: "RoomConsumer"):
|
||||||
|
if group in self.active_connections:
|
||||||
|
if member in self.active_connections[group]:
|
||||||
|
self.active_connections[group].remove(member)
|
||||||
|
|
||||||
|
async def broadcast(self, message, group: str, exclude: list["RoomConsumer"] = []):
|
||||||
|
if group in self.active_connections:
|
||||||
|
for connection in list(set(self.active_connections[group])):
|
||||||
|
if connection not in exclude:
|
||||||
|
await connection.send(message)
|
||||||
|
|
||||||
|
async def send_to(self, group, id_code, msg):
|
||||||
|
if group in self.active_connections:
|
||||||
|
members = [c for c in self.active_connections[group]
|
||||||
|
if c.member.id_code == id_code]
|
||||||
|
for m in members:
|
||||||
|
await m.send(msg)
|
||||||
|
|
||||||
|
async def send_to_admin(self, group, msg):
|
||||||
|
if group in self.active_connections:
|
||||||
|
members = [c for c in self.active_connections[group]
|
||||||
|
if c.member.is_admin == True]
|
||||||
|
for m in members:
|
||||||
|
await m.send(msg)
|
@ -1,551 +1,33 @@
|
|||||||
from database.room.crud import create_room_db, delete_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, disconnect_member, create_anonymous_member, create_anonymous_waiter, create_user_member, create_user_waiter, get_waiter, accept_waiter, leave_room, refuse_waiter, check_room
|
from database.room.crud import serialize_member,check_user_in_room, create_anonymous, create_member, create_room_db, delete_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, disconnect_member, create_anonymous_member, create_anonymous_waiter, create_user_member, create_user_waiter, get_or_create_member, get_waiter, accept_waiter, leave_room, refuse_waiter, check_room
|
||||||
from pydantic.error_wrappers import ValidationError
|
|
||||||
from pydantic import validate_arguments
|
|
||||||
from sqlmodel import select
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status, Query
|
from fastapi import APIRouter, Depends, WebSocket, status, Query
|
||||||
from config import ALGORITHM, SECRET_KEY
|
from config import ALGORITHM, SECRET_KEY
|
||||||
from database.auth.crud import get_user_from_clientId_db
|
from database.auth.crud import get_user_from_clientId_db
|
||||||
from database.auth.models import User
|
from database.auth.models import User
|
||||||
from database.db import get_session
|
from database.db import get_session
|
||||||
|
|
||||||
from sqlmodel import Session
|
from sqlmodel import Session
|
||||||
from database.room.models import Anonymous, MemberSerializer, MemberWithRelations, MemberRead, Room, RoomCreate, RoomRead, Member
|
from database.room.models import Room, RoomCreate, RoomAndMember
|
||||||
|
from routes.room.consumer import RoomConsumer
|
||||||
|
from routes.room.manager import RoomManager
|
||||||
from services.auth import get_current_user_optional
|
from services.auth import get_current_user_optional
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
from jose import jwt, exceptions
|
from database.auth.crud import get_user_from_token
|
||||||
import inspect
|
from services.websocket import Consumer
|
||||||
router = APIRouter(tags=["room"])
|
router = APIRouter(tags=["room"])
|
||||||
|
|
||||||
|
|
||||||
class RoomAndMember(BaseModel):
|
|
||||||
room: RoomRead
|
|
||||||
member: MemberRead
|
|
||||||
|
|
||||||
|
|
||||||
class Waiter(BaseModel):
|
|
||||||
username: str
|
|
||||||
waiter_id: str
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_member(member: Member) -> MemberRead | Waiter:
|
|
||||||
member_obj = member.user or member.anonymous
|
|
||||||
if member.waiting == False:
|
|
||||||
return MemberRead(username=member_obj.username, reconnect_code=getattr(member_obj, "reconnect_code", ""), isUser=member.user_id != None, isAdmin=member.is_admin, id_code=member.id_code)
|
|
||||||
if member.waiting == True:
|
|
||||||
return Waiter(username=member_obj.username, waiter_id=member.id_code)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post('/room', response_model=RoomAndMember)
|
@router.post('/room', response_model=RoomAndMember)
|
||||||
def create_room(room: RoomCreate, username: Optional[str] = Query(default=None, max_length=20), user: User | None = Depends(get_current_user_optional), db: Session = Depends(get_session)):
|
def create_room(room: RoomCreate, username: Optional[str] = Query(default=None, max_length=20), user: User | None = Depends(get_current_user_optional), db: Session = Depends(get_session)):
|
||||||
room_obj = create_room_db(room=room, user=user, username=username, db=db)
|
room_obj = create_room_db(room=room, user=user, username=username, db=db)
|
||||||
return {'room': room_obj['room'], "member": serialize_member(room_obj['member'])}
|
return {'room': room_obj['room'], "member": serialize_member(room_obj['member'])}
|
||||||
|
|
||||||
|
|
||||||
class ConnectionManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.active_connections: Dict[str, List[WebSocket]] = {}
|
|
||||||
|
|
||||||
def add(self, group: str, ws: WebSocket):
|
|
||||||
if group not in self.active_connections:
|
|
||||||
self.active_connections[group] = []
|
|
||||||
|
|
||||||
if ws not in self.active_connections[group]:
|
|
||||||
self.active_connections[group].append(ws)
|
|
||||||
|
|
||||||
def remove(self, group: str, ws: WebSocket):
|
|
||||||
if group in self.active_connections:
|
|
||||||
if ws in self.active_connections[group]:
|
|
||||||
self.active_connections[group].remove(ws)
|
|
||||||
|
|
||||||
async def broadcast(self, message, group: str, exclude: list[WebSocket] = []):
|
|
||||||
if group in self.active_connections:
|
|
||||||
for connection in self.active_connections[group]:
|
|
||||||
if connection not in exclude:
|
|
||||||
await connection.send_json(message)
|
|
||||||
|
|
||||||
|
|
||||||
def make_event_decorator(eventsDict):
|
|
||||||
def _(name: str | List, conditions: List[Callable | bool] = []):
|
|
||||||
def add_event(func):
|
|
||||||
model = validate_arguments(func).model
|
|
||||||
if type(name) == str:
|
|
||||||
eventsDict[name] = {"func": func,
|
|
||||||
"conditions": conditions, "model": model}
|
|
||||||
if type(name) == list:
|
|
||||||
for n in name:
|
|
||||||
eventsDict[n] = {"func": func,
|
|
||||||
"conditions": conditions, "model": model}
|
|
||||||
return func
|
|
||||||
return add_event
|
|
||||||
return _
|
|
||||||
|
|
||||||
|
|
||||||
class Event(BaseModel):
|
|
||||||
func: Callable
|
|
||||||
conditions: List[Callable | bool]
|
|
||||||
model: BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
def dict_model(model: BaseModel, exclude: List[str]):
|
|
||||||
value = {}
|
|
||||||
for n, f in model:
|
|
||||||
if n not in exclude:
|
|
||||||
value[n] = f
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def dict_all(obj: Any):
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
value = {}
|
|
||||||
for k, v in obj.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
v = dict_all(v)
|
|
||||||
value[k] = dict(v)
|
|
||||||
elif isinstance(v, BaseModel):
|
|
||||||
value[k] = dict(v)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
value[k] = dict(v)
|
|
||||||
except:
|
|
||||||
value[k] = v
|
|
||||||
return value
|
|
||||||
return dict(obj)
|
|
||||||
|
|
||||||
|
|
||||||
class Consumer:
|
|
||||||
events: Dict[str, Event] = {}
|
|
||||||
sendings: Dict[str, Any] = {}
|
|
||||||
event = make_event_decorator(events)
|
|
||||||
sending = make_event_decorator(sendings)
|
|
||||||
|
|
||||||
def __init__(self, ws: WebSocket):
|
|
||||||
self.ws: WebSocket = ws
|
|
||||||
#self.events: Dict[str, Callable] = {}
|
|
||||||
|
|
||||||
async def connect(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def validation_error_handler(self, e: ValidationError):
|
|
||||||
errors = e.errors()
|
|
||||||
await self.ws.send_json({"type": "error", "data": {"detail": [{ers['loc'][-1]: ers['msg']} for ers in errors]}})
|
|
||||||
|
|
||||||
async def send(self, payload):
|
|
||||||
type = payload.get('type', None)
|
|
||||||
print('TYPE', type, self.member)
|
|
||||||
if type is not None:
|
|
||||||
event_wrapper = self.sendings.get(type, None)
|
|
||||||
if event_wrapper is not None:
|
|
||||||
handler = event_wrapper.get('func')
|
|
||||||
conditions = event_wrapper.get('conditions')
|
|
||||||
|
|
||||||
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
|
|
||||||
|
|
||||||
if handler is not None and is_valid:
|
|
||||||
model = event_wrapper.get("model")
|
|
||||||
|
|
||||||
data = payload.get('data') or {}
|
|
||||||
try:
|
|
||||||
validated_payload = model(self=self, **data)
|
|
||||||
except ValidationError as e:
|
|
||||||
print("ERROR", e)
|
|
||||||
await self.ws.send_json({"type": "error", "data": {"msg": "Oops there was an error"}})
|
|
||||||
return
|
|
||||||
|
|
||||||
validated_payload = dict_model(validated_payload,
|
|
||||||
exclude=["v__duplicate_kwargs", "args", 'kwargs', "self"])
|
|
||||||
try:
|
|
||||||
parsed_payload = handler(
|
|
||||||
self, **validated_payload)
|
|
||||||
|
|
||||||
await self.ws.send_json({'type': type, "data": dict_all(parsed_payload)})
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
return
|
|
||||||
return
|
|
||||||
#print('pls')
|
|
||||||
await self.ws.send_json(payload)
|
|
||||||
#print('sent')
|
|
||||||
|
|
||||||
async def receive(self, data):
|
|
||||||
event = data.get('type', None)
|
|
||||||
if event is not None:
|
|
||||||
event_wrapper = self.events.get(event, None)
|
|
||||||
if event_wrapper is not None:
|
|
||||||
handler = event_wrapper.get('func')
|
|
||||||
conditions = event_wrapper.get('conditions')
|
|
||||||
|
|
||||||
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
|
|
||||||
|
|
||||||
if handler is not None and is_valid:
|
|
||||||
model = event_wrapper.get("model")
|
|
||||||
|
|
||||||
payload = data.get('data') or {}
|
|
||||||
try:
|
|
||||||
validated_payload = model(self=self, **payload)
|
|
||||||
except ValidationError as e:
|
|
||||||
await self.validation_error_handler(e)
|
|
||||||
return
|
|
||||||
|
|
||||||
await handler(**{k: v for k, v in validated_payload.dict().items() if k not in ["v__duplicate_kwargs", "args", 'kwargs']})
|
|
||||||
|
|
||||||
async def disconnect(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
await self.connect()
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
data = await self.ws.receive_json()
|
|
||||||
await self.receive(data)
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
await self.disconnect()
|
|
||||||
|
|
||||||
|
|
||||||
class ConsumerManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.active_connections: Dict[str, List[Consumer]] = {}
|
|
||||||
|
|
||||||
def add(self, group: str, ws: Consumer):
|
|
||||||
|
|
||||||
if group not in self.active_connections:
|
|
||||||
self.active_connections[group] = []
|
|
||||||
#print("adding", ws, self.active_connections[group])
|
|
||||||
if ws not in self.active_connections[group]:
|
|
||||||
#print('ACTUALLY ADDING')
|
|
||||||
self.active_connections[group].append(ws)
|
|
||||||
|
|
||||||
def remove(self, group: str, ws: Consumer):
|
|
||||||
if group in self.active_connections:
|
|
||||||
if ws in self.active_connections[group]:
|
|
||||||
self.active_connections[group].remove(ws)
|
|
||||||
|
|
||||||
async def broadcast(self, message, group: str, exclude: list[Consumer] = []):
|
|
||||||
if group in self.active_connections:
|
|
||||||
#print(self.active_connections[group], exclude)
|
|
||||||
for connection in list(set(self.active_connections[group])):
|
|
||||||
if connection not in exclude:
|
|
||||||
#print('SEND TO', connection, message)
|
|
||||||
await connection.send(message)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
|
||||||
token: str
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_from_token(token: str, db: Session):
|
|
||||||
try:
|
|
||||||
decoded = jwt.decode(token=token, key=SECRET_KEY,
|
|
||||||
algorithms=[ALGORITHM])
|
|
||||||
except exceptions.ExpiredSignatureError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
clientId = decoded.get('sub')
|
|
||||||
return get_user_from_clientId_db(clientId=clientId, db=db)
|
|
||||||
|
|
||||||
|
|
||||||
def check_same_member(member: Member, memberR: MemberRead):
|
|
||||||
return (member.user is not None and memberR.username == member.user.username) or (member.anonymous and memberR.reconnect_code == member.anonymous.reconnect_code)
|
|
||||||
|
|
||||||
class RoomConsumer(Consumer):
|
|
||||||
|
|
||||||
def __init__(self, ws: WebSocket, room: Room, manager: "RoomManager", db: Session):
|
|
||||||
self.room = room
|
|
||||||
self.ws = ws
|
|
||||||
self.manager = manager
|
|
||||||
self.db = db
|
|
||||||
self.member = None
|
|
||||||
# WS Utilities
|
|
||||||
|
|
||||||
async def connect(self):
|
|
||||||
await self.ws.accept()
|
|
||||||
|
|
||||||
async def direct_send(self, type: str, payload: Any):
|
|
||||||
await self.ws.send_json({'type': type, "data": payload})
|
|
||||||
|
|
||||||
async def send_to_admin(self, type: str, payload: Any, exclude: bool = False):
|
|
||||||
await self.manager.send_to_admin(self.room.id, {'type': type, "data": payload})
|
|
||||||
|
|
||||||
async def send_to(self, type: str, payload: Any,member_id, exclude: bool = False):
|
|
||||||
await self.manager.send_to(self.room.id, member_id,{'type': type, "data": payload})
|
|
||||||
|
|
||||||
async def broadcast(self, type, payload, exclude=False):
|
|
||||||
await self.manager.broadcast({"type": type, "data": payload}, self.room.id, exclude=[exclude == True and self])
|
|
||||||
|
|
||||||
def add_to_group(self):
|
|
||||||
self.manager.add(self.room.id, self)
|
|
||||||
|
|
||||||
async def connect_self(self):
|
|
||||||
if isinstance(self.member, Member):
|
|
||||||
connect_member(self.member, self.db)
|
|
||||||
await self.broadcast(type="connect", payload={"member": serialize_member(self.member).dict()}, exclude=True)
|
|
||||||
|
|
||||||
async def disconnect_self(self):
|
|
||||||
if isinstance(self.member, Member):
|
|
||||||
disconnect_member(self.member, self.db)
|
|
||||||
if self.member.waiting is False:
|
|
||||||
await self.broadcast(type="disconnect", payload={"member": serialize_member(self.member).dict()})
|
|
||||||
else:
|
|
||||||
await self.broadcast(type="disconnect_waiter", payload={"waiter": serialize_member(self.member).dict()})
|
|
||||||
|
|
||||||
# DB Utilities
|
|
||||||
|
|
||||||
# Received Events
|
|
||||||
@Consumer.event('login')
|
|
||||||
async def login(self, token: str | None = None, reconnect_code: str | None = None):
|
|
||||||
if token is not None:
|
|
||||||
member = get_member_from_token(token, self.room.id, self.db)
|
|
||||||
if member == False:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Token expired"})
|
|
||||||
return
|
|
||||||
if member is None:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable dans cette salle"})
|
|
||||||
return
|
|
||||||
self.member = member
|
|
||||||
|
|
||||||
await self.connect_self()
|
|
||||||
self.add_to_group()
|
|
||||||
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()})
|
|
||||||
|
|
||||||
elif reconnect_code is not None:
|
|
||||||
member = get_member_from_reconnect_code(
|
|
||||||
reconnect_code, self.room.id, db=self.db)
|
|
||||||
if member is None:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable dans cette salle"})
|
|
||||||
return
|
|
||||||
|
|
||||||
self.member = member
|
|
||||||
|
|
||||||
await self.connect_self()
|
|
||||||
self.add_to_group()
|
|
||||||
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()})
|
|
||||||
if reconnect_code is None and token is None:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Veuillez spécifier une méthode de connection"})
|
|
||||||
|
|
||||||
@Consumer.event('join')
|
|
||||||
async def join(self, token: str | None = None, username: str | None = None):
|
|
||||||
if self.room.public == False:
|
|
||||||
if token is not None:
|
|
||||||
user = get_user_from_token(token, self.db)
|
|
||||||
|
|
||||||
if user is None:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"})
|
|
||||||
return
|
|
||||||
if user is False:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Token expired"})
|
|
||||||
return
|
|
||||||
|
|
||||||
waiter = create_user_waiter(user, self.room, self.db)
|
|
||||||
|
|
||||||
if waiter.waiting is False:
|
|
||||||
self.member = waiter
|
|
||||||
# await self.connect_self()
|
|
||||||
self.add_to_group()
|
|
||||||
await self.connect_self()
|
|
||||||
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()})
|
|
||||||
return
|
|
||||||
|
|
||||||
self.member = waiter
|
|
||||||
self.add_to_group()
|
|
||||||
await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member).dict()})
|
|
||||||
await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member).dict()})
|
|
||||||
|
|
||||||
elif username is not None:
|
|
||||||
waiter = create_anonymous_waiter(username, self.room, self.db)
|
|
||||||
if waiter is None:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Nom d'utilisateur invalide ou indisponible"})
|
|
||||||
return
|
|
||||||
self.member = waiter
|
|
||||||
self.add_to_group()
|
|
||||||
await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member).dict()})
|
|
||||||
await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member).dict()})
|
|
||||||
else:
|
|
||||||
if token is not None:
|
|
||||||
user = get_user_from_token(token, self.db)
|
|
||||||
if user is None:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"})
|
|
||||||
return
|
|
||||||
if user is False:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Token expired"})
|
|
||||||
return
|
|
||||||
|
|
||||||
member = create_user_member(user, self.room, self.db)
|
|
||||||
if member is None:
|
|
||||||
return
|
|
||||||
self.member = member
|
|
||||||
self.add_to_group()
|
|
||||||
await self.broadcast(type="joined", payload={"member": serialize_member(self.member).dict()}, exclude=True)
|
|
||||||
await self.direct_send(type="accepted", payload={"member": serialize_member(self.member).dict()})
|
|
||||||
elif username is not None:
|
|
||||||
member = create_anonymous_member(username, self.room, self.db)
|
|
||||||
if member is None:
|
|
||||||
await self.direct_send(type="error", data={"msg": "Nom d'utilisateur indisponible"})
|
|
||||||
return
|
|
||||||
self.member = member
|
|
||||||
self.add_to_group()
|
|
||||||
|
|
||||||
await self.broadcast(type="joined", payload={"member": serialize_member(self.member).dict()}, exclude=True)
|
|
||||||
await self.direct_send(type="accepted", payload={"member": serialize_member(self.member).dict()})
|
|
||||||
|
|
||||||
async def isAdminReceive(self):
|
|
||||||
is_admin = self.member is not None and self.member.is_admin == True
|
|
||||||
if not is_admin:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Vous n'avez pas la permission de faire ca"})
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def isAdmin(self):
|
|
||||||
return self.member is not None and self.member.is_admin == True
|
|
||||||
|
|
||||||
@Consumer.event('accept', conditions=[isAdminReceive])
|
|
||||||
async def accept(self, waiter_id: str):
|
|
||||||
waiter = get_waiter(waiter_id, self.db)
|
|
||||||
if waiter is None:
|
|
||||||
await self.direct_send(type="error", payload={'msg': "Utilisateur introuvable"})
|
|
||||||
return
|
|
||||||
member = accept_waiter(waiter, self.db)
|
|
||||||
await self.send_to(type="accepted", payload={"member": serialize_member(member).dict()}, member_id=waiter_id)
|
|
||||||
await self.broadcast(type="joined", payload={"member": serialize_member(member).dict()})
|
|
||||||
|
|
||||||
@Consumer.event('refuse', conditions=[isAdminReceive])
|
|
||||||
async def accept(self, waiter_id: str):
|
|
||||||
waiter = get_waiter(waiter_id, self.db)
|
|
||||||
member = refuse_waiter(waiter, self.db)
|
|
||||||
await self.send_to(type="refused", payload={'waiter_id': waiter_id}, member_id=waiter_id)
|
|
||||||
# await self.broadcast(type="joined", payload={"member": serialize_member(member).dict()})
|
|
||||||
|
|
||||||
@Consumer.event('ping_room')
|
|
||||||
async def proom(self):
|
|
||||||
await self.broadcast(type='ping', payload={}, exclude=True)
|
|
||||||
async def hasRoom(self):
|
|
||||||
if self.member is None:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Vous n'êtes connecté à aucune salle"})
|
|
||||||
return self.member
|
|
||||||
@Consumer.event('leave', conditions=[hasRoom])
|
|
||||||
async def leave(self):
|
|
||||||
if self.member.is_admin is True:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Vous ne pouvez pas quitter une salle dont vous êtes l'administrateur"})
|
|
||||||
return
|
|
||||||
member_obj = serialize_member(self.member).dict()
|
|
||||||
leave_room(self.member, self.db)
|
|
||||||
await self.direct_send(type="successfully_leaved", payload = {})
|
|
||||||
await self.broadcast(type='leaved', payload ={"member": member_obj})
|
|
||||||
self.member = None
|
|
||||||
|
|
||||||
@Consumer.event('ban', conditions=[isAdminReceive])
|
|
||||||
async def ban(self, member_id: str):
|
|
||||||
member = get_member(member_id, self.room.id, self.db)
|
|
||||||
if member == None:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"})
|
|
||||||
return
|
|
||||||
if member.is_admin is True:
|
|
||||||
await self.direct_send(type="error", payload={"msg": "Vous ne pouvez pas bannir un administrateur"})
|
|
||||||
return
|
|
||||||
member_serialized = serialize_member(member)
|
|
||||||
|
|
||||||
leave_room(member, self.db)
|
|
||||||
await self.send_to(type="banned", payload={}, member_id=member.id_code)
|
|
||||||
await self.broadcast(type="leaved", payload = {"member": member_serialized.dict()})
|
|
||||||
|
|
||||||
def isMember(self):
|
|
||||||
return self.member is not None and self.member.waiting == False
|
|
||||||
|
|
||||||
# Sending Events
|
|
||||||
@Consumer.sending("joined", conditions=[isMember])
|
|
||||||
def joined(self, member: MemberRead):
|
|
||||||
if self.member.id_code == member.id_code:
|
|
||||||
raise ValueError("")
|
|
||||||
if self.member.is_admin == False:
|
|
||||||
member.reconnect_code = ""
|
|
||||||
return {"member": member}
|
|
||||||
|
|
||||||
@Consumer.sending(['waiter', "disconnect_waiter"], conditions=[isAdmin])
|
|
||||||
def waiter(self, waiter: Waiter):
|
|
||||||
return {"waiter": waiter}
|
|
||||||
|
|
||||||
@Consumer.sending(['connect', "disconnect", ""])
|
|
||||||
def connect_event(self, member: MemberRead):
|
|
||||||
if not self.member.is_admin:
|
|
||||||
member.reconnect_code = ""
|
|
||||||
return {"member": member}
|
|
||||||
|
|
||||||
@Consumer.sending('disconnect')
|
|
||||||
def disconnect_event(self, member: MemberRead):
|
|
||||||
if not self.member.is_admin:
|
|
||||||
member.reconnect_code = ""
|
|
||||||
return {"member": member}
|
|
||||||
|
|
||||||
@Consumer.sending('disconnect_waiter', conditions=[isAdmin])
|
|
||||||
def disconnect_event(self, waiter: Waiter):
|
|
||||||
return {"waiter": waiter}
|
|
||||||
def isWaiter(self):
|
|
||||||
return self.member is not None and self.member.waiting == True
|
|
||||||
|
|
||||||
@Consumer.sending("refused", conditions=[isWaiter])
|
|
||||||
def refused(self, waiter_id: str):
|
|
||||||
self.member = None
|
|
||||||
self.manager.remove(self.room.id, self)
|
|
||||||
return {"waiter_id": waiter_id}
|
|
||||||
|
|
||||||
@Consumer.sending("banned", conditions=[isMember])
|
|
||||||
def banned(self):
|
|
||||||
self.member = None
|
|
||||||
self.manager.remove(self.room.id, self)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@Consumer.sending('ping', conditions=[isMember])
|
|
||||||
def ping(self):
|
|
||||||
return {}
|
|
||||||
async def disconnect(self):
|
|
||||||
print("DISCONNECT", self.member)
|
|
||||||
self.manager.remove(self.room.id, self)
|
|
||||||
await self.disconnect_self()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RoomManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.active_connections: Dict[str, List[RoomConsumer]] = {}
|
|
||||||
|
|
||||||
def add(self, group: str, ws: RoomConsumer):
|
|
||||||
|
|
||||||
if group not in self.active_connections:
|
|
||||||
self.active_connections[group] = []
|
|
||||||
#print("adding", ws, self.active_connections[group])
|
|
||||||
if ws not in self.active_connections[group]:
|
|
||||||
#print('ACTUALLY ADDING')
|
|
||||||
self.active_connections[group].append(ws)
|
|
||||||
|
|
||||||
def remove(self, group: str, ws: RoomConsumer):
|
|
||||||
if group in self.active_connections:
|
|
||||||
if ws in self.active_connections[group]:
|
|
||||||
self.active_connections[group].remove(ws)
|
|
||||||
|
|
||||||
async def broadcast(self, message, group: str, exclude: list[RoomConsumer] = []):
|
|
||||||
if group in self.active_connections:
|
|
||||||
#print(self.active_connections[group], exclude)
|
|
||||||
for connection in list(set(self.active_connections[group])):
|
|
||||||
if connection not in exclude:
|
|
||||||
#print('SEND TO', connection, message)
|
|
||||||
await connection.send(message)
|
|
||||||
|
|
||||||
async def send_to(self, group, id_code, msg):
|
|
||||||
if group in self.active_connections:
|
|
||||||
members = [c for c in self.active_connections[group] if c.member.id_code == id_code]
|
|
||||||
for m in members:
|
|
||||||
await m.send(msg)
|
|
||||||
|
|
||||||
async def send_to_admin(self, group, msg):
|
|
||||||
if group in self.active_connections:
|
|
||||||
members = [c for c in self.active_connections[group] if c.member.is_admin == True]
|
|
||||||
for m in members:
|
|
||||||
await m.send(msg)
|
|
||||||
|
|
||||||
manager = RoomManager()
|
manager = RoomManager()
|
||||||
|
|
||||||
|
|
||||||
@router.websocket('/ws/room/{room_id}')
|
@router.websocket('/ws/room/{room_id}')
|
||||||
async def room_ws(ws: WebSocket, room: Room | None = Depends(check_room), db: Session = Depends(get_session)):
|
async def room_ws(ws: WebSocket, room: Room | None = Depends(check_room), db: Session = Depends(get_session)):
|
||||||
if room is None:
|
if room is None:
|
||||||
|
137
backend/api/services/websocket.py
Normal file
137
backend/api/services/websocket.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from typing import List, Callable, Any, Dict
|
||||||
|
from pydantic import validate_arguments, BaseModel
|
||||||
|
from fastapi.websockets import WebSocketDisconnect, WebSocket
|
||||||
|
from pydantic.error_wrappers import ValidationError
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
def make_event_decorator(eventsDict):
|
||||||
|
def _(name: str | List, conditions: List[Callable | bool] = []):
|
||||||
|
def add_event(func):
|
||||||
|
model = validate_arguments(func).model
|
||||||
|
if type(name) == str:
|
||||||
|
eventsDict[name] = {"func": func,
|
||||||
|
"conditions": conditions, "model": model}
|
||||||
|
if type(name) == list:
|
||||||
|
for n in name:
|
||||||
|
eventsDict[n] = {"func": func,
|
||||||
|
"conditions": conditions, "model": model}
|
||||||
|
return func
|
||||||
|
return add_event
|
||||||
|
return _
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def dict_model(model: BaseModel, exclude: List[str]):
|
||||||
|
value = {}
|
||||||
|
for n, f in model:
|
||||||
|
if n not in exclude:
|
||||||
|
value[n] = f
|
||||||
|
return value
|
||||||
|
def dict_all(obj: Any):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
value = {}
|
||||||
|
for k, v in obj.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
v = dict_all(v)
|
||||||
|
value[k] = dict(v)
|
||||||
|
elif isinstance(v, BaseModel):
|
||||||
|
value[k] = dict(v)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
value[k] = dict(v)
|
||||||
|
except:
|
||||||
|
value[k] = v
|
||||||
|
return value
|
||||||
|
return dict(obj)
|
||||||
|
|
||||||
|
|
||||||
|
class Event(BaseModel):
|
||||||
|
func: Callable
|
||||||
|
conditions: List[Callable | bool]
|
||||||
|
model: BaseModel
|
||||||
|
|
||||||
|
class Consumer:
|
||||||
|
events: Dict[str, Event] = {}
|
||||||
|
sendings: Dict[str, Any] = {}
|
||||||
|
event = make_event_decorator(events)
|
||||||
|
sending = make_event_decorator(sendings)
|
||||||
|
|
||||||
|
def __init__(self, ws: WebSocket):
|
||||||
|
self.ws: WebSocket = ws
|
||||||
|
#self.events: Dict[str, Callable] = {}
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def validation_error_handler(self, e: ValidationError):
|
||||||
|
errors = e.errors()
|
||||||
|
await self.ws.send_json({"type": "error", "data": {"detail": [{ers['loc'][-1]: ers['msg']} for ers in errors]}})
|
||||||
|
|
||||||
|
async def send(self, payload):
|
||||||
|
type = payload.get('type', None)
|
||||||
|
print('TYPE', type, self.member)
|
||||||
|
if type is not None:
|
||||||
|
event_wrapper = self.sendings.get(type, None)
|
||||||
|
if event_wrapper is not None:
|
||||||
|
handler = event_wrapper.get('func')
|
||||||
|
conditions = event_wrapper.get('conditions')
|
||||||
|
|
||||||
|
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
|
||||||
|
|
||||||
|
if handler is not None and is_valid:
|
||||||
|
model = event_wrapper.get("model")
|
||||||
|
|
||||||
|
data = payload.get('data') or {}
|
||||||
|
try:
|
||||||
|
validated_payload = model(self=self, **data)
|
||||||
|
except ValidationError as e:
|
||||||
|
print("ERROR", e)
|
||||||
|
await self.ws.send_json({"type": "error", "data": {"msg": "Oops there was an error"}})
|
||||||
|
return
|
||||||
|
|
||||||
|
validated_payload = dict_model(validated_payload,
|
||||||
|
exclude=["v__duplicate_kwargs", "args", 'kwargs', "self"])
|
||||||
|
try:
|
||||||
|
parsed_payload = handler(
|
||||||
|
self, **validated_payload)
|
||||||
|
|
||||||
|
await self.ws.send_json({'type': type, "data": dict_all(parsed_payload)})
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
return
|
||||||
|
return
|
||||||
|
await self.ws.send_json(payload)
|
||||||
|
|
||||||
|
async def receive(self, data):
|
||||||
|
event = data.get('type', None)
|
||||||
|
if event is not None:
|
||||||
|
event_wrapper = self.events.get(event, None)
|
||||||
|
if event_wrapper is not None:
|
||||||
|
handler = event_wrapper.get('func')
|
||||||
|
conditions = event_wrapper.get('conditions')
|
||||||
|
|
||||||
|
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
|
||||||
|
|
||||||
|
if handler is not None and is_valid:
|
||||||
|
model = event_wrapper.get("model")
|
||||||
|
|
||||||
|
payload = data.get('data') or {}
|
||||||
|
try:
|
||||||
|
validated_payload = model(self=self, **payload)
|
||||||
|
except ValidationError as e:
|
||||||
|
await self.validation_error_handler(e)
|
||||||
|
return
|
||||||
|
|
||||||
|
await handler(**{k: v for k, v in validated_payload.dict().items() if k not in ["v__duplicate_kwargs", "args", 'kwargs']})
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
await self.connect()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
data = await self.ws.receive_json()
|
||||||
|
await self.receive(data)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
await self.disconnect()
|
@ -138,7 +138,8 @@ def test_join_waiter_not_found(client: TestClient):
|
|||||||
admin.receive_json()
|
admin.receive_json()
|
||||||
admin.send_json({"type": "accept", "data": {"waiter_id": "OOOO"}})
|
admin.send_json({"type": "accept", "data": {"waiter_id": "OOOO"}})
|
||||||
data = admin.receive_json()
|
data = admin.receive_json()
|
||||||
assert data == {"type": "error", "data": {"msg": "Utilisateur introuvable"}}
|
assert data == {"type": "error", "data": {
|
||||||
|
"msg": "Utilisateur en list d'attente introuvable"}}
|
||||||
|
|
||||||
def test_join_no_auth(client: TestClient):
|
def test_join_no_auth(client: TestClient):
|
||||||
room = test_create_room_no_auth(client=client)
|
room = test_create_room_no_auth(client=client)
|
||||||
@ -219,7 +220,8 @@ def test_join_auth_refused(client: TestClient):
|
|||||||
"waiter": {"waiter_id": waiter_id, "username": "lilian2"}}}
|
"waiter": {"waiter_id": waiter_id, "username": "lilian2"}}}
|
||||||
admin.send_json({"type": "refuse", "data": {
|
admin.send_json({"type": "refuse", "data": {
|
||||||
"waiter_id": waiter_id}})
|
"waiter_id": waiter_id}})
|
||||||
|
adata = admin.receive_json()
|
||||||
|
assert adata == {"type": "successfullyRefused", "data": {"waiter_id": waiter_id}}
|
||||||
mdata = member.receive_json()
|
mdata = member.receive_json()
|
||||||
assert mdata == {"type": "refused", "data": {
|
assert mdata == {"type": "refused", "data": {
|
||||||
"waiter_id": waiter_id}}
|
"waiter_id": waiter_id}}
|
||||||
|
Loading…
Reference in New Issue
Block a user