diff --git a/backend/api/api.json b/backend/api/api.json new file mode 100644 index 0000000..779c6c7 --- /dev/null +++ b/backend/api/api.json @@ -0,0 +1,92 @@ +{ + "routes": {}, + "events": { + "login": { + "input": { "token": null, "reconnect_code": null }, + "output": { + "type": "loggedIn", + "member": { "username": "str", "reconnect_code": "str ou null" } + }, + "broadcast": { + "all": { + "type": "connect", + "member": { "id": "int (member id!)" } + } + }, + "errors": [ + { + "type": "error", + "error": { "status": "401", "msg": "Membre introuvable" } + } + ] + }, + "join": { + "input": { "username": "str", "user": null }, + "output": { + "public?": { + "type": "accepted", + "member": { + "username": "str", + "reconnect_code": "str ou null" + } + }, + "private?": { + "type": "waiting", + "waiter": { "id_code": "str", "username": "str" } + } + }, + "broadcast": { + "public?": { + "all": { + "type": "joined", + "member": { + "id": "int", + "username": "str", + "reconnect_code": "str uniquement pour l'admin" + } + } + }, + "private?": { + "admin": { + "type": "waiter", + "waiter": { "id_code": "str", "username": "str" } + } + } + }, + "errors": [ + { + "status": "400", + "msg": "User input (trop long ou déjà pris, etc)" + } + ] + }, + "accept": { + "input": { "waiter_id": "str" }, + "broadcast": { + "waiter": { + "type": "accepted", + "member": { + "username": "str", + "reconnect_code": "str ou null" + } + }, + "all": { + "type": "joined", + "member": { + "id": "int", + "username": "str", + "reconnect_code": "str uniquement pour l'admin" + } + } + } + }, + "reject": { + "input": { "waiter_id": "str" }, + "output": { + "type": "successfullyRejected", + "waiter": { "id": "int" } + }, + "broadcast": { "waiter": { "type": "rejected" } } + } + } +} diff --git a/backend/api/config.py b/backend/api/config.py index 31b0686..04d07da 100644 --- a/backend/api/config.py +++ b/backend/api/config.py @@ -1,6 +1,7 @@ from datetime import timedelta from redis import Redis from pydantic import BaseModel + SECRET_KEY = "6323081020d8939e6385dd688a26cbca0bb34ed91997959167637319ba4f6f3e" ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 diff --git a/backend/api/database/auth/models.py b/backend/api/database/auth/models.py index b4a245f..5a37d66 100644 --- a/backend/api/database/auth/models.py +++ b/backend/api/database/auth/models.py @@ -4,7 +4,6 @@ from uuid import UUID import uuid from sqlmodel import Field, SQLModel, Relationship from pydantic import validator, BaseModel -from database.db import get_session, get_session from services.password import validate_password from services.schema import as_form @@ -25,7 +24,7 @@ class User(UserBase, table=True): exercices: List['Exercice'] = Relationship(back_populates='author') tags: List['Tag'] = Relationship(back_populates='author') - members: List['Member'] = Relationship(back_populates='user') + members: List["Member"] = Relationship(back_populates='user') @as_form class UserEdit(UserBase): @@ -41,8 +40,6 @@ class UserRegister(BaseModel): password: str password_confirm: str - - @validator('password') def password_validation(cls, v): is_valid = validate_password(v) diff --git a/backend/api/database/db.py b/backend/api/database/db.py index e2005e7..aa073d0 100644 --- a/backend/api/database/db.py +++ b/backend/api/database/db.py @@ -1,5 +1,3 @@ -import random -import string from sqlmodel import SQLModel, create_engine, Session, select sqlite_file_name = "database.db" diff --git a/backend/api/database/room/crud.py b/backend/api/database/room/crud.py index a69f02d..ea79cd2 100644 --- a/backend/api/database/room/crud.py +++ b/backend/api/database/room/crud.py @@ -1,6 +1,33 @@ -from sqlmodel import Session -from database.room.models import RoomCreate +from fastapi import Depends +from sqlmodel import Session, select +from database.db import get_session +from database.room.models import Anonymous, Member, Room, RoomCreate from database.auth.models import User +from services.database import generate_unique_code -def create_room_db(*,room: RoomCreate, user: User | None = None,username: str, db: Session): - return \ No newline at end of file +def create_room_db(*,room: RoomCreate, user: User | None = None, username: str | None = None, db: Session): + id_code = generate_unique_code(Room,s=db) + room_obj = Room(**room.dict(exclude_unset=True), id_code=id_code) + if user is not None: + member = Member(user_id=user.id, room=room_obj) + db.add(member) + db.commit() + db.refresh(member) + if username is not None: + reconnect_code = generate_unique_code(Anonymous, s=db, field_name='reconnect_code') + anonymous = Anonymous(username=username, reconnect_code=reconnect_code) + member = Member(anonymous=anonymous, room=room_obj) + db.add(member) + db.commit() + db.refresh(member) + if username is None and user is None: + raise ValueError('Username or user required') + return {"room": room_obj, "member": member} + +def check_room(room_id: str, db: Session = Depends(get_session)): + room = db.exec(select(Room).where(Room.id_code==room_id)).first() + return room + +def userInRoom(room: Room, user: User, db: Session): + member = db.exec(select(Member).where(Member.room_id == room.id, Member.user_id == user.id)).first() + return member diff --git a/backend/api/database/room/models.py b/backend/api/database/room/models.py index 065543b..857f758 100644 --- a/backend/api/database/room/models.py +++ b/backend/api/database/room/models.py @@ -1,5 +1,7 @@ from typing import List, Optional, TYPE_CHECKING from sqlmodel import SQLModel, Field, Relationship + +from database.auth.models import UserRead if TYPE_CHECKING: from database.auth.models import User @@ -7,19 +9,17 @@ if TYPE_CHECKING: class RoomBase(SQLModel): name: str = Field(max_length=20) public: bool = Field(default=False) - + class RoomCreate(RoomBase): pass class Room(RoomBase, table=True): id: Optional[int] = Field(default=None, primary_key=True) - id_code: str + id_code: str = Field(index=True) members: List['Member'] = Relationship(back_populates="room") -class RoomRead(RoomBase): - id_code: str - #members: List[] + class AnonymousBase(SQLModel): username: str = Field(max_length=20) @@ -29,8 +29,8 @@ class AnonymousCreate(AnonymousBase): class Anonymous(AnonymousBase, table=True): id: Optional[int] = Field(default=None, primary_key=True) - reconnect_code: str - + reconnect_code: str = Field(index=True) + member: 'Member' = Relationship(back_populates="anonymous") @@ -38,12 +38,28 @@ class Member(SQLModel, table = True): id: Optional[int] = Field(default=None, primary_key=True) user_id: Optional[int] = Field(foreign_key="user.id", default=None) + user: Optional["User"] = Relationship(back_populates='members') + anonymous_id: Optional[int] = Field(foreign_key="anonymous.id", default=None) - anonymous: Optional[Anonymous] = Relationship(back_populates="member") - user: Optional['User'] = Relationship(back_populates='members') - + room_id: int = Field(foreign_key="room.id") room: Room = Relationship(back_populates='members') - \ No newline at end of file + is_admin: bool = False + + waiting: bool = True + + online: bool = False + +class RoomRead(RoomBase): + id_code: str + #members: List['Member'] + +class AnonymousRead(AnonymousBase): + reconnect_code: str +class Username(SQLModel): + username: str +class MemberRead(SQLModel): + anonymous: AnonymousRead = None + user: Username = None \ No newline at end of file diff --git a/backend/api/main.py b/backend/api/main.py index 0a84a16..d711c41 100644 --- a/backend/api/main.py +++ b/backend/api/main.py @@ -1,11 +1,13 @@ #import schemas.base +from sqlmodel import SQLModel, Field from services.password import get_password_hash from sqlmodel import Session, select from database.auth.crud import create_user_db from services.auth import get_current_user_optional, jwt_required from fastapi.openapi.utils import get_openapi -from database.auth.models import User, UserBase, UserRead +from database.auth.models import User, UserRead from database.exercices.models import Exercice, ExerciceRead +from database.room.models import Room, Anonymous, Member import database.db from fastapi_pagination import add_pagination from fastapi.responses import PlainTextResponse @@ -14,7 +16,7 @@ from fastapi import FastAPI, HTTPException, Depends, Request, status, Header from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException from fastapi.responses import JSONResponse -from typing import List, Sequence +from typing import List, Optional, Sequence from tortoise.contrib.pydantic import pydantic_model_creator from fastapi import FastAPI, HTTPException, params from tortoise import Tortoise @@ -30,7 +32,6 @@ import config from sqladmin import Admin, ModelView from database.db import engine from fastapi.security import OAuth2PasswordBearer, HTTPBearer -from pydantic import Field app = FastAPI(title="API Generateur d'exercices") origins = [ "http://localhost:8000", @@ -78,8 +79,6 @@ async def validation_exception_handler(request, exc: RequestValidationError|Vali #JWT AUTH - - @AuthJWT.load_config def get_config(): return config.settings diff --git a/backend/api/routes/auth/routes.py b/backend/api/routes/auth/routes.py index 9f1a6ac..7e55be9 100644 --- a/backend/api/routes/auth/routes.py +++ b/backend/api/routes/auth/routes.py @@ -18,7 +18,6 @@ class Token(BaseModel): token_type: str refresh_token: str - @router.post("/login", response_model=Token) def login_for_access_token(user: User = Depends(authenticate_user)): Authorize = AuthJWT() diff --git a/backend/api/routes/room/routes.py b/backend/api/routes/room/routes.py index 881643a..985ee77 100644 --- a/backend/api/routes/room/routes.py +++ b/backend/api/routes/room/routes.py @@ -1,7 +1,257 @@ -from fastapi import APIRouter - +from sqlmodel import select +from pydantic import BaseModel +from typing import Any, Callable, Dict, List, Optional +from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status, Query +from config import ALGORITHM, SECRET_KEY +from database.auth.crud import get_user_from_clientId_db +from database.auth.models import User +from database.db import get_session +from database.room.crud import check_room, create_room_db, userInRoom +from sqlmodel import Session +from database.room.models import Anonymous, Member, MemberRead, Room, RoomCreate, RoomRead +from services.auth import get_current_user_optional +from fastapi.exceptions import HTTPException +from jose import jwt, exceptions router = APIRouter(tags=["room"]) -@router.post('/room') -def create_room(): - return \ No newline at end of file + +class RoomAndMember(BaseModel): + room: RoomRead + member: MemberRead + + +@router.post('/room', response_model=RoomAndMember) +def create_room(room: RoomCreate, username: Optional[str] = Query(default=None, max_length=20), user: User | None = Depends(get_current_user_optional), db: Session = Depends(get_session)): + room_obj = create_room_db(room=room, user=user, username=username, db=db) + return room_obj + + +class ConnectionManager: + def __init__(self): + self.active_connections: Dict[str, List[WebSocket]] = {} + + async def add(self, group: str, ws: WebSocket): + if group not in self.active_connections: + self.active_connections[group] = [] + + if ws not in self.active_connections[group]: + self.active_connections[group].append(ws) + + def remove(self, group: str, ws: WebSocket): + if group in self.active_connections: + if ws in self.active_connections[group]: + self.active_connections[group].remove(ws) + + async def broadcast(self, message, group: str, exclude: list[WebSocket] = []): + if group in self.active_connections: + for connection in self.active_connections[group]: + if connection not in exclude: + await connection.send_json(message) + + +manager = ConnectionManager() + + +def make_event_decorator(eventsDict): + def _(name: str): + def add_event(func): + eventsDict[name] = func + return func + + return add_event + return _ + + +class Consumer: + events: Dict[str, Callable] = {} + event = make_event_decorator(events) + + def __init__(self, ws: WebSocket): + self.ws: WebSocket = ws + #self.events: Dict[str, Callable] = {} + + async def connect(self): + pass + + async def receive(self, data): + event = data.get('type', None) + if event is not None: + handler = self.events.get(event, None) + if handler is not None: + payload = data.get('data') + await handler(self, payload) + + async def disconnect(self): + pass + + async def run(self): + await self.connect() + try: + while True: + data = await self.ws.receive_json() + await self.receive(data) + except WebSocketDisconnect: + await self.disconnect() + + +class Token(BaseModel): + token: str + + +def get_user_from_token(token: str, db: Session): + try: + decoded = jwt.decode(token=token, key=SECRET_KEY, + algorithms=[ALGORITHM]) + except exceptions.ExpiredSignatureError: + return False + + clientId = decoded.get('sub') + return get_user_from_clientId_db(clientId=clientId, db=db) + + +def get_member_from_user(user_id: int, room_id: int, db: Session): + member = db.exec(select(Member).where(Member.room_id == + room_id, Member.user_id == user_id)).first() + return member + + +def get_member_from_anonymous(anonymous_id: int, room_id: int, db: Session): + member = db.exec(select(Member).where(Member.room_id == + room_id, Member.anonymous_id == anonymous_id)).first() + return member + + +def get_anonymous_from_code(reconnect_code: str, db: Session): + anonymous = db.exec(select(Anonymous).where( + Anonymous.reconnect_code == reconnect_code)).first() + return anonymous + + +def connect_member(member: Member, db: Session): + member.online = True + db.add(member) + db.commit() + db.refresh(member) + return member + + +def disconnect_member(member: Member, db: Session): + member.online = False + db.add(member) + db.commit() + db.refresh(member) + return member + + +def validate_username(username: str, room: Room, db: Session): + members = select(Member).where(Member.room_id == room.id, Member.anonymous_id != None) + +def create_anonymous_member(username: str, room: Room, db: Session): + pass + +class RoomConsumer(Consumer): + def __init__(self, ws: WebSocket, room: Room, manager: ConnectionManager, db: Session): + self.room = room + self.ws = ws + self.manager = manager + self.db = db + self.member = None + #WS Utilities + async def connect(self): + await self.ws.accept() + + async def send(self, type: str, payload: Any): + await self.ws.send_json({'type': type, "data": payload}) + + async def send_to_all_room(self, type: str, payload: Any, exclude: bool = False): + await self.manager.broadcast({'type': type, "data": payload}, f'{self.room.id}__member', [exclude == True and self.ws]) + await self.manager.broadcast({'type': type, "data": payload}, f'{self.room.id}__admin', [exclude == True and self.ws]) + + async def send_to_admin(self, type: str, payload: Any, exclude: bool = False): + await self.manager.broadcast({'type': type, "data": payload}, f'{self.room.id}__admin', [exclude == True and self.ws]) + + async def send_to_members(self, type: str, payload: Any, exclude: bool = False): + await self.manager.broadcast({'type': type, "data": payload}, f'{self.room.id}__member', [exclude == True and self.ws]) + + def add_to_admin(self): + self.manager.add(f'{self.room.id}__admin', self.ws) + + def add_to_members(self): + self.manager.add(f'{self.room.id}__members', self.ws) + + def add_to_groups(self): + if isinstance(self.member, Member): + if self.member.is_admin == True: + self.add_to_admin() + if self.member.is_admin == False: + self.add_to_members() + + async def connect_self(self): + if isinstance(self.member, Member): + connect_member(self.member, self.db) + await self.send_to_all_room(type="connect", payload={}, exclude=True) + + async def disconnect_self(self): + if isinstance(self.member, Member): + disconnect_member(self.member, self.db) + await self.send_to_all_room(type="disconnect", payload={}, exclude=True) + + #DB Utilities + + #Events + @Consumer.event('login') + async def login(self, data): + if 'token' in data: + token = data.get('token') + user = get_user_from_token(token, db=self.db) + if user == False: + await self.send() + return + if user is None: + return + + member = get_member_from_user( + user_id=user.id, room_id=self.room.id, db=self.db) + if member is None: + return + + self.member = member + self.add_to_groups() + self.connect_self() + await self.send() + + elif "reconnect_code" in data: + reconnect_code = data.get('reconnect_code') + anonymous = get_anonymous_from_code( + reconnect_code=reconnect_code, db=self.db) + if anonymous is None: + return + + member = get_member_from_anonymous( + anonymous_id=anonymous.id, room_id=self.room.id, db=self.db) + if member is None: + return + + self.member = member + self.add_to_groups() + self.connect_self() + await self.send(type="accepted") + + @Consumer.event('join') + async def join(self, data): + if "token" in data: + return + else: + return + return + + async def disconnect(self): + await self.disconnect_self() + +@router.websocket('/ws/room/{room_id}') +async def room_ws(ws: WebSocket, room: Room | None = Depends(check_room), db: Session = Depends(get_session)): + if room is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail='Room not found') + consumer = RoomConsumer(ws=ws, room=room, manager=manager, db=db) + await consumer.run() diff --git a/backend/api/services/auth.py b/backend/api/services/auth.py index 524babd..b70c07a 100644 --- a/backend/api/services/auth.py +++ b/backend/api/services/auth.py @@ -29,17 +29,14 @@ def jwt_required(Authorize: AuthJWT = Depends(), token: str = Depends(bearer)): Authorize.jwt_required() return Authorize - def jwt_optional(Authorize: AuthJWT = Depends()): Authorize.jwt_optional() return Authorize - def jwt_refresh_required(Authorize: AuthJWT = Depends(), token: str = Depends(bearer)): Authorize.jwt_refresh_token_required() return Authorize - def fresh_jwt_required(Authorize: AuthJWT = Depends(), token: str = Depends(bearer)): Authorize.fresh_jwt_required() return Authorize @@ -47,7 +44,6 @@ def fresh_jwt_required(Authorize: AuthJWT = Depends(), token: str = Depends(bear def get_current_clientId(Authorize: AuthJWT = Depends(jwt_required)): return Authorize.get_jwt_subject() - def get_current_user(clientId: str = Depends(get_current_clientId), db: Session = Depends(get_session)): user = get_user_from_clientId_db(clientId, db) if not user: diff --git a/backend/api/services/database.py b/backend/api/services/database.py index 2877e10..c27ee48 100644 --- a/backend/api/services/database.py +++ b/backend/api/services/database.py @@ -4,11 +4,13 @@ from sqlmodel import select, Session from sqlmodel import SQLModel -def generate_unique_code(model: SQLModel, s: Session, length: int = 6): +def generate_unique_code(model: SQLModel, s: Session, field_name='id_code', length: int = 6): + if getattr(model, field_name, None) is None: + raise KeyError("Invalid field name") while True: code = ''.join(random.choices(string.ascii_uppercase, k=length)) is_unique = s.exec(select(model).where( - model.id_code == code)).first() == None + getattr(model, field_name) == code)).first() == None if is_unique: break return code diff --git a/backend/api/testing.py b/backend/api/testing.py index 16b10ac..e69de29 100644 --- a/backend/api/testing.py +++ b/backend/api/testing.py @@ -1,108 +0,0 @@ -import uuid -from pydantic import validator -import io -import os -from fastapi import UploadFile, Form -from pathlib import Path -from typing import IO, Any, List, Optional, Type - -from fastapi import FastAPI -from sqlmodel import Field, Session, SQLModel, create_engine, select -from services.exoValidation import get_support_from_data - -from services.io import add_fast_api_root, get_filename, get_or_create_dir, remove_fastapi_root - - -class FileFieldMeta(type): - def __getitem__(self, upload_root: str,) -> Type['FileField']: - return type('MyFieldValue', (FileField,), {'upload_root': upload_root}) - - -class FileField(str, metaclass=FileFieldMeta): - upload_root: str - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, value: str | IO, values): - print(cls.upload_root, cls.default_naming_field) - upload_root = get_or_create_dir( - add_fast_api_root(cls.upload_root)) - - if not isinstance(value, str): - value.seek(0) - - is_binary = isinstance(value, io.BytesIO) - - name = get_filename(value, 'exo_source.py') - - parent = get_or_create_dir(os.path.join( - upload_root, values['id_code'])) - - mode = 'w+' if not is_binary else 'wb+' - - path = os.path.join(parent, name) - with open(path, mode) as f: - f.write(value.read()) - - return remove_fastapi_root(path) - - else: - if not os.path.exists(value): - raise ValueError('File does not exist') - return value - -class Hero(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - id_code : str - path: FileField['/testing', 'id_code'] - - - - -class HeroCreate(SQLModel): - path: FileField[42] - -class HeroRead(SQLModel): - id: int - id_code: str - path: str - - -sqlite_file_name = "testing.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" - -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, echo=True, connect_args=connect_args) - - -def create_db_and_tables(): - SQLModel.metadata.create_all(engine) - - -app = FastAPI() - - -@app.on_event("startup") -def on_startup(): - create_db_and_tables() - - -@app.get("/heroes/", response_model=List[HeroRead]) -def read_heroes(): - with Session(engine) as session: - heroes = session.exec(select(Hero)).all() - return heroes - -@app.post("/heroes/", ) -def create_hero(file: UploadFile, name: str = Form()): - with Session(engine) as session: - file_obj = file.file._file - db_hero = Hero(path=file_obj, id_code=name) - session.add(db_hero) - session.commit() - session.refresh(db_hero) - return "db_hero" - - diff --git a/backend/api/tests/test_auth.py b/backend/api/tests/test_auth.py index 4dfa23b..25084e4 100644 --- a/backend/api/tests/test_auth.py +++ b/backend/api/tests/test_auth.py @@ -4,7 +4,6 @@ VALID_USERNAME = 'lilian' VALID_PASSWORD = 'Test12345' def test_register(client: TestClient, username = VALID_USERNAME): - print('usernae') r = client.post('/register', data={"username": username, 'password': VALID_PASSWORD, 'password_confirm': VALID_PASSWORD}) data = r.json() print(data) diff --git a/backend/api/tests/test_room.py b/backend/api/tests/test_room.py new file mode 100644 index 0000000..4ba237b --- /dev/null +++ b/backend/api/tests/test_room.py @@ -0,0 +1,79 @@ +from fastapi import HTTPException +from fastapi.testclient import TestClient +from tests.test_auth import test_register + + +def test_create_room_no_auth(client: TestClient): + r = client.post('/room', json={"name": "test_room", + "public": False}, params={'username': "lilian"}) + print(r.json()) + assert "id_code" in r.json()['room'] + assert "reconnect_code" in r.json()['member']['anonymous'] + assert {"room": {**r.json()['room'], 'id_code': None}, "member": {**r.json()['member'], "anonymous": {**r.json()['member']['anonymous'], "reconnect_code": None}}} == {"room": {"id_code": None, "name": "test_room", + "public": False}, 'member': {"anonymous": {"username": "lilian", "reconnect_code": None}, "user": None}} + return r.json() + +def test_create_room_no_auth_invalid(client: TestClient): + r = client.post('/room', json={"name": "test_room"*21, + "public": False}, params={'username': "lilian"*21}) + print(r.json()) + assert r.json() == {'detail': {'username_error': 'ensure this value has at most 20 characters', + 'name_error': 'ensure this value has at most 20 characters'}} + +def test_create_room_auth(client: TestClient, token = None): + if token is None: + token = test_register(client=client)['access'] + r = client.post('/room', json={"name": "test_room", + "public": False}, headers={"Authorization": "Bearer " + token}) + print(r.json()) + assert "id_code" in r.json()['room'] + assert {**r.json(), "room": {**r.json()['room'], 'id_code': None}} == {"room": {"id_code": None, "name": "test_room", + "public": False}, 'member': {"user": {"username": "lilian"}, "anonymous": None}} + return r.json() + +def test_room_not_found(client: TestClient): + try: + with client.websocket_connect('/ws/room/eee') as r: + pass + except HTTPException as e : + assert True + except Exception: + assert False + +def test_login_no_auth(client: TestClient): + room = test_create_room_no_auth(client=client) + member = room['member']['anonymous'] + with client.websocket_connect(f"/ws/room/" + room['room']['id_code']) as ws: + ws.send_json({"type": "login", "data": {"reconnect_code": member['reconnect_code']}}) + data = ws.receive_json() + print(data) + assert data == {'type': "loggedIn", "data": {"member": {"username": member['username'], "reconnect_code": member['reconnect_code'], "isAdmin": True}}} + +def test_login_auth(client: TestClient): + token = test_register(client=client)['access'] + room = test_create_room_auth(client=client, token=token) + member = room['member']['user'] + with client.websocket_connect(f"/ws/room/" + room['room']['id_code']) as ws: + ws.send_json({"type": "login", "data": {"token": token}}) + data = ws.receive_json() + print(data) + assert data == {'type': "loggedIn", "data": {"member": {"username": member['username'], "isAdmin": True}}} + +def test_join_no_auth(client: TestClient): + room = test_create_room_no_auth(client=client) + member = room['member']['anonymous'] + with client.websocket_connect(f"/ws/room/" + room['room']['id_code']) as admin: + admin.send_json({"type": "login", "data": { + "reconnect_code": member['reconnect_code']}}) + with client.websocket_connect(f"/ws/room/" + room['room']['id_code']) as member: + member.send_json({"type":"join", "data": {"username": "member"}}) + mdata = member.receive_json() + assert "id_code" in mdata['data']['waiter'] + assert mdata == {"type": "waiting", "data": {"waiter": { + "username": "member", "id_code": mdata['data']['waiter']}}} + + adata = admin.receive_json() + assert adata == {'type': "waiter", 'data': { + "waiter": {"id_code": mdata['data']['waiter'], "username": "member"}}} + + admin.send({"type": "accept", "data": {"waiter_id": mdata['data']['waiter']}}) \ No newline at end of file diff --git a/backend/api_old/apis/room/websocket.py b/backend/api_old/apis/room/websocket.py index e45104d..69370a5 100644 --- a/backend/api_old/apis/room/websocket.py +++ b/backend/api_old/apis/room/websocket.py @@ -22,33 +22,25 @@ class ConnectionManager: def __init__(self): self.active_connections: Dict[str,List[WebSocket]] = {} - async def connect(self, websocket: WebSocket, room_id): - await websocket.accept() - - if room_id not in self.active_connections: - self.active_connections[room_id] = [] + async def add(self, group, ws): + + if group not in self.active_connections: + self.active_connections[group] = [] - self.active_connections[room_id].append(websocket) - - async def add(self, room_id, ws): - if room_id not in self.active_connections: - self.active_connections[room_id] = [] + if ws not in self.active_connections[group]: + self.active_connections[group].append(ws) - self.active_connections[room_id].append(ws) - - def remove(self, websocket: WebSocket, room_id): - if room_id in self.active_connections: - try: - self.active_connections[room_id].remove(websocket) - except: - pass - + def remove(self, ws: WebSocket, group): + if group in self.active_connections: + if ws in self.active_connections[group]: + self.active_connections[group].remove(ws) + async def send_personal_message(self, message: str, websocket: WebSocket): await websocket.send_text(message) - async def broadcast(self, message: str, room_id): - if room_id in self.active_connections: - for connection in self.active_connections[room_id]: + async def broadcast(self, message: str, group): + if group in self.active_connections: + for connection in self.active_connections[group]: await connection.send_json(message)