room first part

This commit is contained in:
Kilton937342 2022-09-26 10:04:02 +02:00
parent 8391c2e64a
commit b639c5a88e
8 changed files with 515 additions and 568 deletions

View File

@ -1,7 +1,7 @@
from fastapi import Depends from fastapi import Depends
from sqlmodel import Session, select, col from sqlmodel import Session, select, col
from database.db import get_session from database.db import get_session
from database.room.models import Anonymous, Member, Room, RoomCreate from database.room.models import Anonymous, Member, Room, RoomCreate, Waiter, MemberRead
from database.auth.models import User from database.auth.models import User
from services.database import generate_unique_code from services.database import generate_unique_code
from database.auth.crud import get_user_from_token from database.auth.crud import get_user_from_token
@ -17,14 +17,17 @@ def create_room_db(*,room: RoomCreate, user: User | None = None, username: str |
member_id = generate_unique_code(Member, s=db) member_id = generate_unique_code(Member, s=db)
room_obj = Room(**room.dict(exclude_unset=True), id_code=id_code) room_obj = Room(**room.dict(exclude_unset=True), id_code=id_code)
if user is not None: if user is not None:
member = Member(user_id=user.id, room=room_obj, is_admin=True, id_code=member_id) member = Member(user_id=user.id, room=room_obj,
is_admin=True, id_code=member_id)
db.add(member) db.add(member)
db.commit() db.commit()
db.refresh(member) db.refresh(member)
if username is not None: if username is not None:
reconnect_code = generate_unique_code(Anonymous, s=db, field_name='reconnect_code') reconnect_code = generate_unique_code(
Anonymous, s=db, field_name='reconnect_code')
anonymous = Anonymous(username=username, reconnect_code=reconnect_code) anonymous = Anonymous(username=username, reconnect_code=reconnect_code)
member = Member(anonymous=anonymous, room=room_obj, is_admin=True, id_code=member_id) member = Member(anonymous=anonymous, room=room_obj,
is_admin=True, id_code=member_id)
db.add(member) db.add(member)
db.commit() db.commit()
db.refresh(member) db.refresh(member)
@ -70,6 +73,24 @@ def get_anonymous_from_code(reconnect_code: str, db: Session):
return anonymous return anonymous
def create_member(*, room: Room, user: User | None = None, anonymous: Anonymous | None = None, waiting: bool = False, db: Session):
member_id = generate_unique_code(Member, s=db)
member = Member(room=room, user=user, anonymous=anonymous, waiting=waiting,
id_code=member_id)
db.add(member)
db.commit()
db.refresh(member)
return member
def get_or_create_member(*, room: Room, user: User | None = None, anonymous: Anonymous | None = None, waiting: bool = False, db: Session):
member = user is not None and get_member_from_user(user.id, room.id, db)
if member is not None and member is not False:
return member
member= create_member(room=room, user=user, anonymous=anonymous, waiting=waiting, db=db)
def connect_member(member: Member, db: Session): def connect_member(member: Member, db: Session):
member.online = True member.online = True
db.add(member) db.add(member)
@ -92,7 +113,6 @@ def disconnect_member(member: Member, db: Session):
def validate_username(username: str, room: Room, db: Session = Depends(get_session)): def validate_username(username: str, room: Room, db: Session = Depends(get_session)):
print('VALIDATE', username)
if len(username) > 20: if len(username) > 20:
return None return None
members = select(Member.anonymous_id).where( members = select(Member.anonymous_id).where(
@ -116,7 +136,21 @@ def create_anonymous_member(username: str, room: Room, db: Session):
db.commit() db.commit()
db.refresh(member) db.refresh(member)
return member return member
def create_anonymous(username: str, room: Room, db: Session):
username = validate_username(username, room, db)
if username is None:
return None
reconnect_code = generate_unique_code(
Anonymous, s=db, field_name="reconnect_code")
anonymous = Anonymous(username=username, reconnect_code=reconnect_code)
db.add(anonymous)
db.commit()
db.refresh(anonymous)
return anonymous
def check_user_in_room(user_id: int, room_id: int, db: Session):
user = db.exec(select(Member).where(Member.user_id==user_id, Member.room_id == room_id)).first()
return user
def create_user_member(user: User, room: Room, db: Session): def create_user_member(user: User, room: Room, db: Session):
member = get_member_from_user(user.id, room.id, db) member = get_member_from_user(user.id, room.id, db)
@ -147,30 +181,32 @@ def create_anonymous_waiter(username: str, room: Room, db: Session):
return member return member
def create_user_waiter(user: User, room: Room, db: Session): def create_user_waiter(user: User, room: Room, db: Session):
member = get_member_from_user(user.id, room.id, db) member = get_member_from_user(user.id, room.id, db)
if member is not None: if member is not None:
return member return member
member_id = generate_unique_code(Member, s=db)
member = Member(room=room, user=user, waiting=True, member = create_member(room=room, user=user, waiting=True,
id_code=member_id) db=db)
db.add(member)
db.commit()
db.refresh(member)
return member return member
def get_waiter(waiter_code: str, db: Session): def get_waiter(waiter_code: str, db: Session):
return db.exec(select(Member).where(Member.id_code == waiter_code)).first() return db.exec(select(Member).where(Member.id_code == waiter_code, Member.waiting == True)).first()
def get_member(id_code: str, room_id: str, db: Session): def get_member(id_code: str, room_id: str, db: Session):
return db.exec(select(Member).where(Member.id_code == id_code, Member.room_id == room_id)).first() return db.exec(select(Member).where(Member.id_code == id_code, Member.room_id == room_id)).first()
def delete_member(member: Member, db: Session): def delete_member(member: Member, db: Session):
db.delete(member) db.delete(member)
db.commit() db.commit()
return None return None
def accept_waiter(member: Member, db: Session): def accept_waiter(member: Member, db: Session):
member.waiting = False member.waiting = False
member.waiter_code = None member.waiter_code = None
@ -185,7 +221,16 @@ def refuse_waiter(member: Member, db: Session):
db.commit() db.commit()
return None return None
def leave_room(member: Member, db: Session): def leave_room(member: Member, db: Session):
db.delete(member) db.delete(member)
db.commit() db.commit()
return None return None
def serialize_member(member: Member) -> MemberRead | Waiter:
member_obj = member.user or member.anonymous
if member.waiting == False:
return MemberRead(username=member_obj.username, reconnect_code=getattr(member_obj, "reconnect_code", ""), isUser=member.user_id != None, isAdmin=member.is_admin, id_code=member.id_code).dict()
if member.waiting == True:
return Waiter(username=member_obj.username, waiter_id=member.id_code).dict()

View File

@ -1,4 +1,4 @@
from pydantic import root_validator from pydantic import root_validator, BaseModel
from typing import List, Optional, TYPE_CHECKING from typing import List, Optional, TYPE_CHECKING
from sqlmodel import SQLModel, Field, Relationship from sqlmodel import SQLModel, Field, Relationship
@ -76,6 +76,7 @@ class MemberRead(SQLModel):
isUser: bool isUser: bool
isAdmin: bool isAdmin: bool
id_code: str id_code: str
class MemberSerializer(MemberRead): class MemberSerializer(MemberRead):
member: MemberWithRelations member: MemberWithRelations
@ -90,3 +91,11 @@ class MemberSerializer(MemberRead):
return {"username": member_obj.username, "reconnect_code": getattr(member_obj, "reconnect_code", ""), "isAdmin": member.is_admin, "isUser": member.user != None} return {"username": member_obj.username, "reconnect_code": getattr(member_obj, "reconnect_code", ""), "isAdmin": member.is_admin, "isUser": member.user != None}
class RoomAndMember(BaseModel):
room: RoomRead
member: MemberRead
class Waiter(BaseModel):
username: str
waiter_id: str

View File

@ -54,7 +54,6 @@ admin = Admin(app, engine)
class UserAdmin(ModelView, model=User): class UserAdmin(ModelView, model=User):
column_list = [User.id, User.username] column_list = [User.id, User.username]
admin.add_view(UserAdmin) admin.add_view(UserAdmin)
@app.on_event("startup") @app.on_event("startup")

View File

@ -0,0 +1,234 @@
from fastapi.websockets import WebSocket
from services.websocket import Consumer
from typing import Any, TYPE_CHECKING
from database.room.models import Room, Member, MemberRead, Waiter
from sqlmodel import Session
from database.room.crud import serialize_member,check_user_in_room, create_anonymous, create_member, create_room_db, delete_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, disconnect_member, create_anonymous_member, create_anonymous_waiter, create_user_member, create_user_waiter, get_or_create_member, get_waiter, accept_waiter, leave_room, refuse_waiter, check_room
from database.auth.crud import get_user_from_token
if TYPE_CHECKING:
from routes.room.routes import RoomManager
class RoomConsumer(Consumer):
def __init__(self, ws: WebSocket, room: Room, manager: "RoomManager", db: Session):
self.room = room
self.ws = ws
self.manager = manager
self.db = db
self.member = None
# WS Utilities
async def connect(self):
await self.ws.accept()
async def direct_send(self, type: str, payload: Any):
await self.ws.send_json({'type': type, "data": payload})
async def send_to_admin(self, type: str, payload: Any, exclude: bool = False):
await self.manager.send_to_admin(self.room.id, {'type': type, "data": payload})
async def send_to(self, type: str, payload: Any, member_id, exclude: bool = False):
await self.manager.send_to(self.room.id, member_id, {'type': type, "data": payload})
async def broadcast(self, type, payload, exclude=False):
await self.manager.broadcast({"type": type, "data": payload}, self.room.id, exclude=[exclude == True and self])
def add_to_group(self):
self.manager.add(self.room.id, self)
async def connect_self(self):
if isinstance(self.member, Member):
connect_member(self.member, self.db)
await self.broadcast(type="connect", payload={"member": serialize_member(self.member)}, exclude=True)
async def disconnect_self(self):
if isinstance(self.member, Member):
disconnect_member(self.member, self.db)
if self.member.waiting is False:
await self.broadcast(type="disconnect", payload={"member": serialize_member(self.member)})
else:
await self.send_to_admin(type="disconnect_waiter", payload={"waiter": serialize_member(self.member)})
async def loginMember(self, member: Member):
if member.room_id == self.room.id and member.waiting == False:
self.member = member
await self.connect_self()
self.add_to_group()
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member)})
async def send_error(self, msg):
await self.direct_send(type="error", payload={"msg": msg})
# Conditions
async def isAdminReceive(self):
is_admin = self.member is not None and self.member.is_admin == True
if not is_admin:
await self.direct_send(type="error", payload={"msg": "Vous n'avez pas la permission de faire ca"})
return False
return True
def isAdmin(self):
return self.member is not None and self.member.is_admin == True
async def isMember(self):
if self.member is None:
await self.send_error("Vous n'êtes connecté à aucune salle")
return self.member is not None and self.member.waiting == False
def isWaiter(self):
return self.member is not None and self.member.waiting == True
# Received Events
@Consumer.event('login')
async def login(self, token: str | None = None, reconnect_code: str | None = None):
if reconnect_code is None and token is None:
await self.direct_send(type="error", payload={"msg": "Veuillez spécifier une méthode de connection"})
return
if token is not None:
member = get_member_from_token(token, self.room.id, self.db)
if member == False:
await self.send_error("Token expired")
return
if member is None:
await self.send_error("Utilisateur introuvable dans cette salle")
return
elif reconnect_code is not None:
member = get_member_from_reconnect_code(
reconnect_code, self.room.id, db=self.db)
if member is None:
await self.send_error("Utilisateur introuvable dans cette salle")
return
await self.loginMember(member)
@Consumer.event('join')
async def join(self, token: str | None = None, username: str | None = None):
if token is not None:
user = get_user_from_token(token, self.db)
if user is None:
await self.send_error("Utilisateur introuvable")
return
if user is False:
await self.send_error("Token expired")
return
userInRoom = check_user_in_room(user.id, self.room.id, self.db)
if userInRoom is not None:
await self.loginMember(userInRoom)
return
waiter = create_member(
user=user, room=self.room, waiting=self.room.public is False, db=self.db)
elif username is not None:
anonymous = create_anonymous(username, self.room, self.db)
if anonymous is None:
await self.send_error("Nom d'utilisateur invalide ou indisponible")
return
waiter = create_member(
anonymous=anonymous, room=self.room, waiting=self.room.public is False, db=self.db)
self.member = waiter
self.add_to_group()
if self.room.public is False:
await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member)})
await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member)})
else:
await self.broadcast(type="joined", payload={"member": serialize_member(self.member)}, exclude=True)
await self.direct_send(type="accepted", payload={"member": serialize_member(self.member)})
@Consumer.event('accept', conditions=[isAdminReceive])
async def accept(self, waiter_id: str):
waiter = get_waiter(waiter_id, self.db)
if waiter is None:
await self.send_error("Utilisateur en list d'attente introuvable")
return
member = accept_waiter(waiter, self.db)
await self.send_to(type="accepted", payload={"member": serialize_member(member)}, member_id=waiter_id)
await self.broadcast(type="joined", payload={"member": serialize_member(member)})
@Consumer.event('refuse', conditions=[isAdminReceive])
async def accept(self, waiter_id: str):
waiter = get_waiter(waiter_id, self.db)
member = refuse_waiter(waiter, self.db)
await self.send_to(type="refused", payload={'waiter_id': waiter_id}, member_id=waiter_id)
await self.direct_send(type="successfullyRefused", payload= {"waiter_id": waiter_id})
@Consumer.event('ping_room')
async def proom(self):
await self.broadcast(type='ping', payload={}, exclude=True)
async def isConnected(self):
if self.member is None:
await self.direct_send(type="error", payload={"msg": "Vous n'êtes connecté à aucune salle"})
return self.member is not None
@Consumer.event('leave', conditions=[isMember])
async def leave(self):
if self.member.is_admin is True:
await self.send_error("Vous ne pouvez pas quitter une salle dont vous êtes l'administrateur")
return
member_obj = serialize_member(self.member)
leave_room(self.member, self.db)
await self.direct_send(type="successfully_leaved", payload={})
await self.broadcast(type='leaved', payload={"member": member_obj})
self.member = None
@Consumer.event('ban', conditions=[isAdminReceive])
async def ban(self, member_id: str):
member = get_member(member_id, self.room.id, self.db)
if member == None:
await self.send_error("Utilisateur introuvable")
return
if member.is_admin is True:
await self.send_error("Vous ne pouvez pas bannir un administrateur")
return
member_serialized = serialize_member(member)
leave_room(member, self.db)
await self.send_to(type="banned", payload={}, member_id=member.id_code)
await self.broadcast(type="leaved", payload={"member": member_serialized})
# Sending Events
@Consumer.sending(['connect', "disconnect", "joined"], conditions=[isMember])
def joined(self, member: MemberRead):
if self.member.id_code == member.id_code:
raise ValueError("") # Prevent from sending event
if self.member.is_admin == False:
member.reconnect_code = ""
return {"member": member}
@Consumer.sending(['waiter', "disconnect_waiter"], conditions=[isAdmin])
def waiter(self, waiter: Waiter):
return {"waiter": waiter}
@Consumer.sending("refused", conditions=[isWaiter])
def refused(self, waiter_id: str):
self.member = None
self.manager.remove(self.room.id, self)
return {"waiter_id": waiter_id}
@Consumer.sending("banned", conditions=[isMember])
def banned(self):
self.member = None
self.manager.remove(self.room.id, self)
return {}
@Consumer.sending('ping', conditions=[isMember])
def ping(self):
return {}
async def disconnect(self):
self.manager.remove(self.room.id, self)
await self.disconnect_self()

View File

@ -0,0 +1,39 @@
from typing import TYPE_CHECKING, Dict, List
if TYPE_CHECKING:
from routes.room.consumer import RoomConsumer
class RoomManager:
def __init__(self):
self.active_connections: Dict[str, List["RoomConsumer"]] = {}
def add(self, group: str, member: "RoomConsumer"):
if group not in self.active_connections:
self.active_connections[group] = []
if member not in self.active_connections[group]:
self.active_connections[group].append(member)
def remove(self, group: str, member: "RoomConsumer"):
if group in self.active_connections:
if member in self.active_connections[group]:
self.active_connections[group].remove(member)
async def broadcast(self, message, group: str, exclude: list["RoomConsumer"] = []):
if group in self.active_connections:
for connection in list(set(self.active_connections[group])):
if connection not in exclude:
await connection.send(message)
async def send_to(self, group, id_code, msg):
if group in self.active_connections:
members = [c for c in self.active_connections[group]
if c.member.id_code == id_code]
for m in members:
await m.send(msg)
async def send_to_admin(self, group, msg):
if group in self.active_connections:
members = [c for c in self.active_connections[group]
if c.member.is_admin == True]
for m in members:
await m.send(msg)

View File

@ -1,551 +1,33 @@
from database.room.crud import create_room_db, delete_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, disconnect_member, create_anonymous_member, create_anonymous_waiter, create_user_member, create_user_waiter, get_waiter, accept_waiter, leave_room, refuse_waiter, check_room from database.room.crud import serialize_member,check_user_in_room, create_anonymous, create_member, create_room_db, delete_member, get_member, get_member_from_token, get_member_from_reconnect_code, connect_member, disconnect_member, create_anonymous_member, create_anonymous_waiter, create_user_member, create_user_waiter, get_or_create_member, get_waiter, accept_waiter, leave_room, refuse_waiter, check_room
from pydantic.error_wrappers import ValidationError
from pydantic import validate_arguments
from sqlmodel import select
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status, Query from fastapi import APIRouter, Depends, WebSocket, status, Query
from config import ALGORITHM, SECRET_KEY from config import ALGORITHM, SECRET_KEY
from database.auth.crud import get_user_from_clientId_db from database.auth.crud import get_user_from_clientId_db
from database.auth.models import User from database.auth.models import User
from database.db import get_session from database.db import get_session
from sqlmodel import Session from sqlmodel import Session
from database.room.models import Anonymous, MemberSerializer, MemberWithRelations, MemberRead, Room, RoomCreate, RoomRead, Member from database.room.models import Room, RoomCreate, RoomAndMember
from routes.room.consumer import RoomConsumer
from routes.room.manager import RoomManager
from services.auth import get_current_user_optional from services.auth import get_current_user_optional
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from jose import jwt, exceptions from database.auth.crud import get_user_from_token
import inspect from services.websocket import Consumer
router = APIRouter(tags=["room"]) router = APIRouter(tags=["room"])
class RoomAndMember(BaseModel):
room: RoomRead
member: MemberRead
class Waiter(BaseModel):
username: str
waiter_id: str
def serialize_member(member: Member) -> MemberRead | Waiter:
member_obj = member.user or member.anonymous
if member.waiting == False:
return MemberRead(username=member_obj.username, reconnect_code=getattr(member_obj, "reconnect_code", ""), isUser=member.user_id != None, isAdmin=member.is_admin, id_code=member.id_code)
if member.waiting == True:
return Waiter(username=member_obj.username, waiter_id=member.id_code)
@router.post('/room', response_model=RoomAndMember) @router.post('/room', response_model=RoomAndMember)
def create_room(room: RoomCreate, username: Optional[str] = Query(default=None, max_length=20), user: User | None = Depends(get_current_user_optional), db: Session = Depends(get_session)): def create_room(room: RoomCreate, username: Optional[str] = Query(default=None, max_length=20), user: User | None = Depends(get_current_user_optional), db: Session = Depends(get_session)):
room_obj = create_room_db(room=room, user=user, username=username, db=db) room_obj = create_room_db(room=room, user=user, username=username, db=db)
return {'room': room_obj['room'], "member": serialize_member(room_obj['member'])} return {'room': room_obj['room'], "member": serialize_member(room_obj['member'])}
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, List[WebSocket]] = {}
def add(self, group: str, ws: WebSocket):
if group not in self.active_connections:
self.active_connections[group] = []
if ws not in self.active_connections[group]:
self.active_connections[group].append(ws)
def remove(self, group: str, ws: WebSocket):
if group in self.active_connections:
if ws in self.active_connections[group]:
self.active_connections[group].remove(ws)
async def broadcast(self, message, group: str, exclude: list[WebSocket] = []):
if group in self.active_connections:
for connection in self.active_connections[group]:
if connection not in exclude:
await connection.send_json(message)
def make_event_decorator(eventsDict):
def _(name: str | List, conditions: List[Callable | bool] = []):
def add_event(func):
model = validate_arguments(func).model
if type(name) == str:
eventsDict[name] = {"func": func,
"conditions": conditions, "model": model}
if type(name) == list:
for n in name:
eventsDict[n] = {"func": func,
"conditions": conditions, "model": model}
return func
return add_event
return _
class Event(BaseModel):
func: Callable
conditions: List[Callable | bool]
model: BaseModel
def dict_model(model: BaseModel, exclude: List[str]):
value = {}
for n, f in model:
if n not in exclude:
value[n] = f
return value
def dict_all(obj: Any):
if isinstance(obj, dict):
value = {}
for k, v in obj.items():
if isinstance(v, dict):
v = dict_all(v)
value[k] = dict(v)
elif isinstance(v, BaseModel):
value[k] = dict(v)
else:
try:
value[k] = dict(v)
except:
value[k] = v
return value
return dict(obj)
class Consumer:
events: Dict[str, Event] = {}
sendings: Dict[str, Any] = {}
event = make_event_decorator(events)
sending = make_event_decorator(sendings)
def __init__(self, ws: WebSocket):
self.ws: WebSocket = ws
#self.events: Dict[str, Callable] = {}
async def connect(self):
pass
async def validation_error_handler(self, e: ValidationError):
errors = e.errors()
await self.ws.send_json({"type": "error", "data": {"detail": [{ers['loc'][-1]: ers['msg']} for ers in errors]}})
async def send(self, payload):
type = payload.get('type', None)
print('TYPE', type, self.member)
if type is not None:
event_wrapper = self.sendings.get(type, None)
if event_wrapper is not None:
handler = event_wrapper.get('func')
conditions = event_wrapper.get('conditions')
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
if handler is not None and is_valid:
model = event_wrapper.get("model")
data = payload.get('data') or {}
try:
validated_payload = model(self=self, **data)
except ValidationError as e:
print("ERROR", e)
await self.ws.send_json({"type": "error", "data": {"msg": "Oops there was an error"}})
return
validated_payload = dict_model(validated_payload,
exclude=["v__duplicate_kwargs", "args", 'kwargs', "self"])
try:
parsed_payload = handler(
self, **validated_payload)
await self.ws.send_json({'type': type, "data": dict_all(parsed_payload)})
return
except Exception as e:
return
return
#print('pls')
await self.ws.send_json(payload)
#print('sent')
async def receive(self, data):
event = data.get('type', None)
if event is not None:
event_wrapper = self.events.get(event, None)
if event_wrapper is not None:
handler = event_wrapper.get('func')
conditions = event_wrapper.get('conditions')
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
if handler is not None and is_valid:
model = event_wrapper.get("model")
payload = data.get('data') or {}
try:
validated_payload = model(self=self, **payload)
except ValidationError as e:
await self.validation_error_handler(e)
return
await handler(**{k: v for k, v in validated_payload.dict().items() if k not in ["v__duplicate_kwargs", "args", 'kwargs']})
async def disconnect(self):
pass
async def run(self):
await self.connect()
try:
while True:
data = await self.ws.receive_json()
await self.receive(data)
except WebSocketDisconnect:
await self.disconnect()
class ConsumerManager:
def __init__(self):
self.active_connections: Dict[str, List[Consumer]] = {}
def add(self, group: str, ws: Consumer):
if group not in self.active_connections:
self.active_connections[group] = []
#print("adding", ws, self.active_connections[group])
if ws not in self.active_connections[group]:
#print('ACTUALLY ADDING')
self.active_connections[group].append(ws)
def remove(self, group: str, ws: Consumer):
if group in self.active_connections:
if ws in self.active_connections[group]:
self.active_connections[group].remove(ws)
async def broadcast(self, message, group: str, exclude: list[Consumer] = []):
if group in self.active_connections:
#print(self.active_connections[group], exclude)
for connection in list(set(self.active_connections[group])):
if connection not in exclude:
#print('SEND TO', connection, message)
await connection.send(message)
class Token(BaseModel):
token: str
def get_user_from_token(token: str, db: Session):
try:
decoded = jwt.decode(token=token, key=SECRET_KEY,
algorithms=[ALGORITHM])
except exceptions.ExpiredSignatureError:
return False
clientId = decoded.get('sub')
return get_user_from_clientId_db(clientId=clientId, db=db)
def check_same_member(member: Member, memberR: MemberRead):
return (member.user is not None and memberR.username == member.user.username) or (member.anonymous and memberR.reconnect_code == member.anonymous.reconnect_code)
class RoomConsumer(Consumer):
def __init__(self, ws: WebSocket, room: Room, manager: "RoomManager", db: Session):
self.room = room
self.ws = ws
self.manager = manager
self.db = db
self.member = None
# WS Utilities
async def connect(self):
await self.ws.accept()
async def direct_send(self, type: str, payload: Any):
await self.ws.send_json({'type': type, "data": payload})
async def send_to_admin(self, type: str, payload: Any, exclude: bool = False):
await self.manager.send_to_admin(self.room.id, {'type': type, "data": payload})
async def send_to(self, type: str, payload: Any,member_id, exclude: bool = False):
await self.manager.send_to(self.room.id, member_id,{'type': type, "data": payload})
async def broadcast(self, type, payload, exclude=False):
await self.manager.broadcast({"type": type, "data": payload}, self.room.id, exclude=[exclude == True and self])
def add_to_group(self):
self.manager.add(self.room.id, self)
async def connect_self(self):
if isinstance(self.member, Member):
connect_member(self.member, self.db)
await self.broadcast(type="connect", payload={"member": serialize_member(self.member).dict()}, exclude=True)
async def disconnect_self(self):
if isinstance(self.member, Member):
disconnect_member(self.member, self.db)
if self.member.waiting is False:
await self.broadcast(type="disconnect", payload={"member": serialize_member(self.member).dict()})
else:
await self.broadcast(type="disconnect_waiter", payload={"waiter": serialize_member(self.member).dict()})
# DB Utilities
# Received Events
@Consumer.event('login')
async def login(self, token: str | None = None, reconnect_code: str | None = None):
if token is not None:
member = get_member_from_token(token, self.room.id, self.db)
if member == False:
await self.direct_send(type="error", payload={"msg": "Token expired"})
return
if member is None:
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable dans cette salle"})
return
self.member = member
await self.connect_self()
self.add_to_group()
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()})
elif reconnect_code is not None:
member = get_member_from_reconnect_code(
reconnect_code, self.room.id, db=self.db)
if member is None:
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable dans cette salle"})
return
self.member = member
await self.connect_self()
self.add_to_group()
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()})
if reconnect_code is None and token is None:
await self.direct_send(type="error", payload={"msg": "Veuillez spécifier une méthode de connection"})
@Consumer.event('join')
async def join(self, token: str | None = None, username: str | None = None):
if self.room.public == False:
if token is not None:
user = get_user_from_token(token, self.db)
if user is None:
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"})
return
if user is False:
await self.direct_send(type="error", payload={"msg": "Token expired"})
return
waiter = create_user_waiter(user, self.room, self.db)
if waiter.waiting is False:
self.member = waiter
# await self.connect_self()
self.add_to_group()
await self.connect_self()
await self.direct_send(type="loggedIn", payload={"member": serialize_member(self.member).dict()})
return
self.member = waiter
self.add_to_group()
await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member).dict()})
await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member).dict()})
elif username is not None:
waiter = create_anonymous_waiter(username, self.room, self.db)
if waiter is None:
await self.direct_send(type="error", payload={"msg": "Nom d'utilisateur invalide ou indisponible"})
return
self.member = waiter
self.add_to_group()
await self.direct_send(type="waiting", payload={"waiter": serialize_member(self.member).dict()})
await self.send_to_admin(type="waiter", payload={"waiter": serialize_member(self.member).dict()})
else:
if token is not None:
user = get_user_from_token(token, self.db)
if user is None:
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"})
return
if user is False:
await self.direct_send(type="error", payload={"msg": "Token expired"})
return
member = create_user_member(user, self.room, self.db)
if member is None:
return
self.member = member
self.add_to_group()
await self.broadcast(type="joined", payload={"member": serialize_member(self.member).dict()}, exclude=True)
await self.direct_send(type="accepted", payload={"member": serialize_member(self.member).dict()})
elif username is not None:
member = create_anonymous_member(username, self.room, self.db)
if member is None:
await self.direct_send(type="error", data={"msg": "Nom d'utilisateur indisponible"})
return
self.member = member
self.add_to_group()
await self.broadcast(type="joined", payload={"member": serialize_member(self.member).dict()}, exclude=True)
await self.direct_send(type="accepted", payload={"member": serialize_member(self.member).dict()})
async def isAdminReceive(self):
is_admin = self.member is not None and self.member.is_admin == True
if not is_admin:
await self.direct_send(type="error", payload={"msg": "Vous n'avez pas la permission de faire ca"})
return False
return True
def isAdmin(self):
return self.member is not None and self.member.is_admin == True
@Consumer.event('accept', conditions=[isAdminReceive])
async def accept(self, waiter_id: str):
waiter = get_waiter(waiter_id, self.db)
if waiter is None:
await self.direct_send(type="error", payload={'msg': "Utilisateur introuvable"})
return
member = accept_waiter(waiter, self.db)
await self.send_to(type="accepted", payload={"member": serialize_member(member).dict()}, member_id=waiter_id)
await self.broadcast(type="joined", payload={"member": serialize_member(member).dict()})
@Consumer.event('refuse', conditions=[isAdminReceive])
async def accept(self, waiter_id: str):
waiter = get_waiter(waiter_id, self.db)
member = refuse_waiter(waiter, self.db)
await self.send_to(type="refused", payload={'waiter_id': waiter_id}, member_id=waiter_id)
# await self.broadcast(type="joined", payload={"member": serialize_member(member).dict()})
@Consumer.event('ping_room')
async def proom(self):
await self.broadcast(type='ping', payload={}, exclude=True)
async def hasRoom(self):
if self.member is None:
await self.direct_send(type="error", payload={"msg": "Vous n'êtes connecté à aucune salle"})
return self.member
@Consumer.event('leave', conditions=[hasRoom])
async def leave(self):
if self.member.is_admin is True:
await self.direct_send(type="error", payload={"msg": "Vous ne pouvez pas quitter une salle dont vous êtes l'administrateur"})
return
member_obj = serialize_member(self.member).dict()
leave_room(self.member, self.db)
await self.direct_send(type="successfully_leaved", payload = {})
await self.broadcast(type='leaved', payload ={"member": member_obj})
self.member = None
@Consumer.event('ban', conditions=[isAdminReceive])
async def ban(self, member_id: str):
member = get_member(member_id, self.room.id, self.db)
if member == None:
await self.direct_send(type="error", payload={"msg": "Utilisateur introuvable"})
return
if member.is_admin is True:
await self.direct_send(type="error", payload={"msg": "Vous ne pouvez pas bannir un administrateur"})
return
member_serialized = serialize_member(member)
leave_room(member, self.db)
await self.send_to(type="banned", payload={}, member_id=member.id_code)
await self.broadcast(type="leaved", payload = {"member": member_serialized.dict()})
def isMember(self):
return self.member is not None and self.member.waiting == False
# Sending Events
@Consumer.sending("joined", conditions=[isMember])
def joined(self, member: MemberRead):
if self.member.id_code == member.id_code:
raise ValueError("")
if self.member.is_admin == False:
member.reconnect_code = ""
return {"member": member}
@Consumer.sending(['waiter', "disconnect_waiter"], conditions=[isAdmin])
def waiter(self, waiter: Waiter):
return {"waiter": waiter}
@Consumer.sending(['connect', "disconnect", ""])
def connect_event(self, member: MemberRead):
if not self.member.is_admin:
member.reconnect_code = ""
return {"member": member}
@Consumer.sending('disconnect')
def disconnect_event(self, member: MemberRead):
if not self.member.is_admin:
member.reconnect_code = ""
return {"member": member}
@Consumer.sending('disconnect_waiter', conditions=[isAdmin])
def disconnect_event(self, waiter: Waiter):
return {"waiter": waiter}
def isWaiter(self):
return self.member is not None and self.member.waiting == True
@Consumer.sending("refused", conditions=[isWaiter])
def refused(self, waiter_id: str):
self.member = None
self.manager.remove(self.room.id, self)
return {"waiter_id": waiter_id}
@Consumer.sending("banned", conditions=[isMember])
def banned(self):
self.member = None
self.manager.remove(self.room.id, self)
return {}
@Consumer.sending('ping', conditions=[isMember])
def ping(self):
return {}
async def disconnect(self):
print("DISCONNECT", self.member)
self.manager.remove(self.room.id, self)
await self.disconnect_self()
class RoomManager:
def __init__(self):
self.active_connections: Dict[str, List[RoomConsumer]] = {}
def add(self, group: str, ws: RoomConsumer):
if group not in self.active_connections:
self.active_connections[group] = []
#print("adding", ws, self.active_connections[group])
if ws not in self.active_connections[group]:
#print('ACTUALLY ADDING')
self.active_connections[group].append(ws)
def remove(self, group: str, ws: RoomConsumer):
if group in self.active_connections:
if ws in self.active_connections[group]:
self.active_connections[group].remove(ws)
async def broadcast(self, message, group: str, exclude: list[RoomConsumer] = []):
if group in self.active_connections:
#print(self.active_connections[group], exclude)
for connection in list(set(self.active_connections[group])):
if connection not in exclude:
#print('SEND TO', connection, message)
await connection.send(message)
async def send_to(self, group, id_code, msg):
if group in self.active_connections:
members = [c for c in self.active_connections[group] if c.member.id_code == id_code]
for m in members:
await m.send(msg)
async def send_to_admin(self, group, msg):
if group in self.active_connections:
members = [c for c in self.active_connections[group] if c.member.is_admin == True]
for m in members:
await m.send(msg)
manager = RoomManager() manager = RoomManager()
@router.websocket('/ws/room/{room_id}') @router.websocket('/ws/room/{room_id}')
async def room_ws(ws: WebSocket, room: Room | None = Depends(check_room), db: Session = Depends(get_session)): async def room_ws(ws: WebSocket, room: Room | None = Depends(check_room), db: Session = Depends(get_session)):
if room is None: if room is None:

View File

@ -0,0 +1,137 @@
from typing import List, Callable, Any, Dict
from pydantic import validate_arguments, BaseModel
from fastapi.websockets import WebSocketDisconnect, WebSocket
from pydantic.error_wrappers import ValidationError
import inspect
def make_event_decorator(eventsDict):
def _(name: str | List, conditions: List[Callable | bool] = []):
def add_event(func):
model = validate_arguments(func).model
if type(name) == str:
eventsDict[name] = {"func": func,
"conditions": conditions, "model": model}
if type(name) == list:
for n in name:
eventsDict[n] = {"func": func,
"conditions": conditions, "model": model}
return func
return add_event
return _
def dict_model(model: BaseModel, exclude: List[str]):
value = {}
for n, f in model:
if n not in exclude:
value[n] = f
return value
def dict_all(obj: Any):
if isinstance(obj, dict):
value = {}
for k, v in obj.items():
if isinstance(v, dict):
v = dict_all(v)
value[k] = dict(v)
elif isinstance(v, BaseModel):
value[k] = dict(v)
else:
try:
value[k] = dict(v)
except:
value[k] = v
return value
return dict(obj)
class Event(BaseModel):
func: Callable
conditions: List[Callable | bool]
model: BaseModel
class Consumer:
events: Dict[str, Event] = {}
sendings: Dict[str, Any] = {}
event = make_event_decorator(events)
sending = make_event_decorator(sendings)
def __init__(self, ws: WebSocket):
self.ws: WebSocket = ws
#self.events: Dict[str, Callable] = {}
async def connect(self):
pass
async def validation_error_handler(self, e: ValidationError):
errors = e.errors()
await self.ws.send_json({"type": "error", "data": {"detail": [{ers['loc'][-1]: ers['msg']} for ers in errors]}})
async def send(self, payload):
type = payload.get('type', None)
print('TYPE', type, self.member)
if type is not None:
event_wrapper = self.sendings.get(type, None)
if event_wrapper is not None:
handler = event_wrapper.get('func')
conditions = event_wrapper.get('conditions')
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
if handler is not None and is_valid:
model = event_wrapper.get("model")
data = payload.get('data') or {}
try:
validated_payload = model(self=self, **data)
except ValidationError as e:
print("ERROR", e)
await self.ws.send_json({"type": "error", "data": {"msg": "Oops there was an error"}})
return
validated_payload = dict_model(validated_payload,
exclude=["v__duplicate_kwargs", "args", 'kwargs', "self"])
try:
parsed_payload = handler(
self, **validated_payload)
await self.ws.send_json({'type': type, "data": dict_all(parsed_payload)})
return
except Exception as e:
return
return
await self.ws.send_json(payload)
async def receive(self, data):
event = data.get('type', None)
if event is not None:
event_wrapper = self.events.get(event, None)
if event_wrapper is not None:
handler = event_wrapper.get('func')
conditions = event_wrapper.get('conditions')
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
if handler is not None and is_valid:
model = event_wrapper.get("model")
payload = data.get('data') or {}
try:
validated_payload = model(self=self, **payload)
except ValidationError as e:
await self.validation_error_handler(e)
return
await handler(**{k: v for k, v in validated_payload.dict().items() if k not in ["v__duplicate_kwargs", "args", 'kwargs']})
async def disconnect(self):
pass
async def run(self):
await self.connect()
try:
while True:
data = await self.ws.receive_json()
await self.receive(data)
except WebSocketDisconnect:
await self.disconnect()

View File

@ -138,7 +138,8 @@ def test_join_waiter_not_found(client: TestClient):
admin.receive_json() admin.receive_json()
admin.send_json({"type": "accept", "data": {"waiter_id": "OOOO"}}) admin.send_json({"type": "accept", "data": {"waiter_id": "OOOO"}})
data = admin.receive_json() data = admin.receive_json()
assert data == {"type": "error", "data": {"msg": "Utilisateur introuvable"}} assert data == {"type": "error", "data": {
"msg": "Utilisateur en list d'attente introuvable"}}
def test_join_no_auth(client: TestClient): def test_join_no_auth(client: TestClient):
room = test_create_room_no_auth(client=client) room = test_create_room_no_auth(client=client)
@ -219,7 +220,8 @@ def test_join_auth_refused(client: TestClient):
"waiter": {"waiter_id": waiter_id, "username": "lilian2"}}} "waiter": {"waiter_id": waiter_id, "username": "lilian2"}}}
admin.send_json({"type": "refuse", "data": { admin.send_json({"type": "refuse", "data": {
"waiter_id": waiter_id}}) "waiter_id": waiter_id}})
adata = admin.receive_json()
assert adata == {"type": "successfullyRefused", "data": {"waiter_id": waiter_id}}
mdata = member.receive_json() mdata = member.receive_json()
assert mdata == {"type": "refused", "data": { assert mdata == {"type": "refused", "data": {
"waiter_id": waiter_id}} "waiter_id": waiter_id}}