This commit is contained in:
Kilton937342 2022-09-18 22:43:04 +02:00
parent 562a160a3d
commit 84b38cc12d
15 changed files with 508 additions and 169 deletions

92
backend/api/api.json Normal file
View File

@ -0,0 +1,92 @@
{
"routes": {},
"events": {
"login": {
"input": { "token": null, "reconnect_code": null },
"output": {
"type": "loggedIn",
"member": { "username": "str", "reconnect_code": "str ou null" }
},
"broadcast": {
"all": {
"type": "connect",
"member": { "id": "int (member id!)" }
}
},
"errors": [
{
"type": "error",
"error": { "status": "401", "msg": "Membre introuvable" }
}
]
},
"join": {
"input": { "username": "str", "user": null },
"output": {
"public?": {
"type": "accepted",
"member": {
"username": "str",
"reconnect_code": "str ou null"
}
},
"private?": {
"type": "waiting",
"waiter": { "id_code": "str", "username": "str" }
}
},
"broadcast": {
"public?": {
"all": {
"type": "joined",
"member": {
"id": "int",
"username": "str",
"reconnect_code": "str uniquement pour l'admin"
}
}
},
"private?": {
"admin": {
"type": "waiter",
"waiter": { "id_code": "str", "username": "str" }
}
}
},
"errors": [
{
"status": "400",
"msg": "User input (trop long ou déjà pris, etc)"
}
]
},
"accept": {
"input": { "waiter_id": "str" },
"broadcast": {
"waiter": {
"type": "accepted",
"member": {
"username": "str",
"reconnect_code": "str ou null"
}
},
"all": {
"type": "joined",
"member": {
"id": "int",
"username": "str",
"reconnect_code": "str uniquement pour l'admin"
}
}
}
},
"reject": {
"input": { "waiter_id": "str" },
"output": {
"type": "successfullyRejected",
"waiter": { "id": "int" }
},
"broadcast": { "waiter": { "type": "rejected" } }
}
}
}

View File

@ -1,6 +1,7 @@
from datetime import timedelta from datetime import timedelta
from redis import Redis from redis import Redis
from pydantic import BaseModel from pydantic import BaseModel
SECRET_KEY = "6323081020d8939e6385dd688a26cbca0bb34ed91997959167637319ba4f6f3e" SECRET_KEY = "6323081020d8939e6385dd688a26cbca0bb34ed91997959167637319ba4f6f3e"
ALGORITHM = "HS256" ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 ACCESS_TOKEN_EXPIRE_MINUTES = 30

View File

@ -4,7 +4,6 @@ from uuid import UUID
import uuid import uuid
from sqlmodel import Field, SQLModel, Relationship from sqlmodel import Field, SQLModel, Relationship
from pydantic import validator, BaseModel from pydantic import validator, BaseModel
from database.db import get_session, get_session
from services.password import validate_password from services.password import validate_password
from services.schema import as_form from services.schema import as_form
@ -25,7 +24,7 @@ class User(UserBase, table=True):
exercices: List['Exercice'] = Relationship(back_populates='author') exercices: List['Exercice'] = Relationship(back_populates='author')
tags: List['Tag'] = Relationship(back_populates='author') tags: List['Tag'] = Relationship(back_populates='author')
members: List['Member'] = Relationship(back_populates='user') members: List["Member"] = Relationship(back_populates='user')
@as_form @as_form
class UserEdit(UserBase): class UserEdit(UserBase):
@ -41,8 +40,6 @@ class UserRegister(BaseModel):
password: str password: str
password_confirm: str password_confirm: str
@validator('password') @validator('password')
def password_validation(cls, v): def password_validation(cls, v):
is_valid = validate_password(v) is_valid = validate_password(v)

View File

@ -1,5 +1,3 @@
import random
import string
from sqlmodel import SQLModel, create_engine, Session, select from sqlmodel import SQLModel, create_engine, Session, select
sqlite_file_name = "database.db" sqlite_file_name = "database.db"

View File

@ -1,6 +1,33 @@
from sqlmodel import Session from fastapi import Depends
from database.room.models import RoomCreate from sqlmodel import Session, select
from database.db import get_session
from database.room.models import Anonymous, Member, Room, RoomCreate
from database.auth.models import User from database.auth.models import User
from services.database import generate_unique_code
def create_room_db(*,room: RoomCreate, user: User | None = None,username: str, db: Session): def create_room_db(*,room: RoomCreate, user: User | None = None, username: str | None = None, db: Session):
return id_code = generate_unique_code(Room,s=db)
room_obj = Room(**room.dict(exclude_unset=True), id_code=id_code)
if user is not None:
member = Member(user_id=user.id, room=room_obj)
db.add(member)
db.commit()
db.refresh(member)
if username is not None:
reconnect_code = generate_unique_code(Anonymous, s=db, field_name='reconnect_code')
anonymous = Anonymous(username=username, reconnect_code=reconnect_code)
member = Member(anonymous=anonymous, room=room_obj)
db.add(member)
db.commit()
db.refresh(member)
if username is None and user is None:
raise ValueError('Username or user required')
return {"room": room_obj, "member": member}
def check_room(room_id: str, db: Session = Depends(get_session)):
room = db.exec(select(Room).where(Room.id_code==room_id)).first()
return room
def userInRoom(room: Room, user: User, db: Session):
member = db.exec(select(Member).where(Member.room_id == room.id, Member.user_id == user.id)).first()
return member

View File

@ -1,5 +1,7 @@
from typing import List, Optional, TYPE_CHECKING from typing import List, Optional, TYPE_CHECKING
from sqlmodel import SQLModel, Field, Relationship from sqlmodel import SQLModel, Field, Relationship
from database.auth.models import UserRead
if TYPE_CHECKING: if TYPE_CHECKING:
from database.auth.models import User from database.auth.models import User
@ -13,13 +15,11 @@ class RoomCreate(RoomBase):
class Room(RoomBase, table=True): class Room(RoomBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
id_code: str id_code: str = Field(index=True)
members: List['Member'] = Relationship(back_populates="room") members: List['Member'] = Relationship(back_populates="room")
class RoomRead(RoomBase):
id_code: str
#members: List[]
class AnonymousBase(SQLModel): class AnonymousBase(SQLModel):
username: str = Field(max_length=20) username: str = Field(max_length=20)
@ -29,7 +29,7 @@ class AnonymousCreate(AnonymousBase):
class Anonymous(AnonymousBase, table=True): class Anonymous(AnonymousBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
reconnect_code: str reconnect_code: str = Field(index=True)
member: 'Member' = Relationship(back_populates="anonymous") member: 'Member' = Relationship(back_populates="anonymous")
@ -38,12 +38,28 @@ class Member(SQLModel, table = True):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
user_id: Optional[int] = Field(foreign_key="user.id", default=None) user_id: Optional[int] = Field(foreign_key="user.id", default=None)
anonymous_id: Optional[int] = Field(foreign_key="anonymous.id", default=None) user: Optional["User"] = Relationship(back_populates='members')
anonymous_id: Optional[int] = Field(foreign_key="anonymous.id", default=None)
anonymous: Optional[Anonymous] = Relationship(back_populates="member") anonymous: Optional[Anonymous] = Relationship(back_populates="member")
user: Optional['User'] = Relationship(back_populates='members')
room_id: int = Field(foreign_key="room.id") room_id: int = Field(foreign_key="room.id")
room: Room = Relationship(back_populates='members') room: Room = Relationship(back_populates='members')
is_admin: bool = False
waiting: bool = True
online: bool = False
class RoomRead(RoomBase):
id_code: str
#members: List['Member']
class AnonymousRead(AnonymousBase):
reconnect_code: str
class Username(SQLModel):
username: str
class MemberRead(SQLModel):
anonymous: AnonymousRead = None
user: Username = None

View File

@ -1,11 +1,13 @@
#import schemas.base #import schemas.base
from sqlmodel import SQLModel, Field
from services.password import get_password_hash from services.password import get_password_hash
from sqlmodel import Session, select from sqlmodel import Session, select
from database.auth.crud import create_user_db from database.auth.crud import create_user_db
from services.auth import get_current_user_optional, jwt_required from services.auth import get_current_user_optional, jwt_required
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from database.auth.models import User, UserBase, UserRead from database.auth.models import User, UserRead
from database.exercices.models import Exercice, ExerciceRead from database.exercices.models import Exercice, ExerciceRead
from database.room.models import Room, Anonymous, Member
import database.db import database.db
from fastapi_pagination import add_pagination from fastapi_pagination import add_pagination
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
@ -14,7 +16,7 @@ from fastapi import FastAPI, HTTPException, Depends, Request, status, Header
from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth import AuthJWT
from fastapi_jwt_auth.exceptions import AuthJWTException from fastapi_jwt_auth.exceptions import AuthJWTException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from typing import List, Sequence from typing import List, Optional, Sequence
from tortoise.contrib.pydantic import pydantic_model_creator from tortoise.contrib.pydantic import pydantic_model_creator
from fastapi import FastAPI, HTTPException, params from fastapi import FastAPI, HTTPException, params
from tortoise import Tortoise from tortoise import Tortoise
@ -30,7 +32,6 @@ import config
from sqladmin import Admin, ModelView from sqladmin import Admin, ModelView
from database.db import engine from database.db import engine
from fastapi.security import OAuth2PasswordBearer, HTTPBearer from fastapi.security import OAuth2PasswordBearer, HTTPBearer
from pydantic import Field
app = FastAPI(title="API Generateur d'exercices") app = FastAPI(title="API Generateur d'exercices")
origins = [ origins = [
"http://localhost:8000", "http://localhost:8000",
@ -78,8 +79,6 @@ async def validation_exception_handler(request, exc: RequestValidationError|Vali
#JWT AUTH #JWT AUTH
@AuthJWT.load_config @AuthJWT.load_config
def get_config(): def get_config():
return config.settings return config.settings

View File

@ -18,7 +18,6 @@ class Token(BaseModel):
token_type: str token_type: str
refresh_token: str refresh_token: str
@router.post("/login", response_model=Token) @router.post("/login", response_model=Token)
def login_for_access_token(user: User = Depends(authenticate_user)): def login_for_access_token(user: User = Depends(authenticate_user)):
Authorize = AuthJWT() Authorize = AuthJWT()

View File

@ -1,7 +1,257 @@
from fastapi import APIRouter from sqlmodel import select
from pydantic import BaseModel
from typing import Any, Callable, Dict, List, Optional
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status, Query
from config import ALGORITHM, SECRET_KEY
from database.auth.crud import get_user_from_clientId_db
from database.auth.models import User
from database.db import get_session
from database.room.crud import check_room, create_room_db, userInRoom
from sqlmodel import Session
from database.room.models import Anonymous, Member, MemberRead, Room, RoomCreate, RoomRead
from services.auth import get_current_user_optional
from fastapi.exceptions import HTTPException
from jose import jwt, exceptions
router = APIRouter(tags=["room"]) router = APIRouter(tags=["room"])
@router.post('/room')
def create_room(): class RoomAndMember(BaseModel):
room: RoomRead
member: MemberRead
@router.post('/room', response_model=RoomAndMember)
def create_room(room: RoomCreate, username: Optional[str] = Query(default=None, max_length=20), user: User | None = Depends(get_current_user_optional), db: Session = Depends(get_session)):
room_obj = create_room_db(room=room, user=user, username=username, db=db)
return room_obj
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, List[WebSocket]] = {}
async def add(self, group: str, ws: WebSocket):
if group not in self.active_connections:
self.active_connections[group] = []
if ws not in self.active_connections[group]:
self.active_connections[group].append(ws)
def remove(self, group: str, ws: WebSocket):
if group in self.active_connections:
if ws in self.active_connections[group]:
self.active_connections[group].remove(ws)
async def broadcast(self, message, group: str, exclude: list[WebSocket] = []):
if group in self.active_connections:
for connection in self.active_connections[group]:
if connection not in exclude:
await connection.send_json(message)
manager = ConnectionManager()
def make_event_decorator(eventsDict):
def _(name: str):
def add_event(func):
eventsDict[name] = func
return func
return add_event
return _
class Consumer:
events: Dict[str, Callable] = {}
event = make_event_decorator(events)
def __init__(self, ws: WebSocket):
self.ws: WebSocket = ws
#self.events: Dict[str, Callable] = {}
async def connect(self):
pass
async def receive(self, data):
event = data.get('type', None)
if event is not None:
handler = self.events.get(event, None)
if handler is not None:
payload = data.get('data')
await handler(self, payload)
async def disconnect(self):
pass
async def run(self):
await self.connect()
try:
while True:
data = await self.ws.receive_json()
await self.receive(data)
except WebSocketDisconnect:
await self.disconnect()
class Token(BaseModel):
token: str
def get_user_from_token(token: str, db: Session):
try:
decoded = jwt.decode(token=token, key=SECRET_KEY,
algorithms=[ALGORITHM])
except exceptions.ExpiredSignatureError:
return False
clientId = decoded.get('sub')
return get_user_from_clientId_db(clientId=clientId, db=db)
def get_member_from_user(user_id: int, room_id: int, db: Session):
member = db.exec(select(Member).where(Member.room_id ==
room_id, Member.user_id == user_id)).first()
return member
def get_member_from_anonymous(anonymous_id: int, room_id: int, db: Session):
member = db.exec(select(Member).where(Member.room_id ==
room_id, Member.anonymous_id == anonymous_id)).first()
return member
def get_anonymous_from_code(reconnect_code: str, db: Session):
anonymous = db.exec(select(Anonymous).where(
Anonymous.reconnect_code == reconnect_code)).first()
return anonymous
def connect_member(member: Member, db: Session):
member.online = True
db.add(member)
db.commit()
db.refresh(member)
return member
def disconnect_member(member: Member, db: Session):
member.online = False
db.add(member)
db.commit()
db.refresh(member)
return member
def validate_username(username: str, room: Room, db: Session):
members = select(Member).where(Member.room_id == room.id, Member.anonymous_id != None)
def create_anonymous_member(username: str, room: Room, db: Session):
pass
class RoomConsumer(Consumer):
def __init__(self, ws: WebSocket, room: Room, manager: ConnectionManager, db: Session):
self.room = room
self.ws = ws
self.manager = manager
self.db = db
self.member = None
#WS Utilities
async def connect(self):
await self.ws.accept()
async def send(self, type: str, payload: Any):
await self.ws.send_json({'type': type, "data": payload})
async def send_to_all_room(self, type: str, payload: Any, exclude: bool = False):
await self.manager.broadcast({'type': type, "data": payload}, f'{self.room.id}__member', [exclude == True and self.ws])
await self.manager.broadcast({'type': type, "data": payload}, f'{self.room.id}__admin', [exclude == True and self.ws])
async def send_to_admin(self, type: str, payload: Any, exclude: bool = False):
await self.manager.broadcast({'type': type, "data": payload}, f'{self.room.id}__admin', [exclude == True and self.ws])
async def send_to_members(self, type: str, payload: Any, exclude: bool = False):
await self.manager.broadcast({'type': type, "data": payload}, f'{self.room.id}__member', [exclude == True and self.ws])
def add_to_admin(self):
self.manager.add(f'{self.room.id}__admin', self.ws)
def add_to_members(self):
self.manager.add(f'{self.room.id}__members', self.ws)
def add_to_groups(self):
if isinstance(self.member, Member):
if self.member.is_admin == True:
self.add_to_admin()
if self.member.is_admin == False:
self.add_to_members()
async def connect_self(self):
if isinstance(self.member, Member):
connect_member(self.member, self.db)
await self.send_to_all_room(type="connect", payload={}, exclude=True)
async def disconnect_self(self):
if isinstance(self.member, Member):
disconnect_member(self.member, self.db)
await self.send_to_all_room(type="disconnect", payload={}, exclude=True)
#DB Utilities
#Events
@Consumer.event('login')
async def login(self, data):
if 'token' in data:
token = data.get('token')
user = get_user_from_token(token, db=self.db)
if user == False:
await self.send()
return return
if user is None:
return
member = get_member_from_user(
user_id=user.id, room_id=self.room.id, db=self.db)
if member is None:
return
self.member = member
self.add_to_groups()
self.connect_self()
await self.send()
elif "reconnect_code" in data:
reconnect_code = data.get('reconnect_code')
anonymous = get_anonymous_from_code(
reconnect_code=reconnect_code, db=self.db)
if anonymous is None:
return
member = get_member_from_anonymous(
anonymous_id=anonymous.id, room_id=self.room.id, db=self.db)
if member is None:
return
self.member = member
self.add_to_groups()
self.connect_self()
await self.send(type="accepted")
@Consumer.event('join')
async def join(self, data):
if "token" in data:
return
else:
return
return
async def disconnect(self):
await self.disconnect_self()
@router.websocket('/ws/room/{room_id}')
async def room_ws(ws: WebSocket, room: Room | None = Depends(check_room), db: Session = Depends(get_session)):
if room is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail='Room not found')
consumer = RoomConsumer(ws=ws, room=room, manager=manager, db=db)
await consumer.run()

View File

@ -29,17 +29,14 @@ def jwt_required(Authorize: AuthJWT = Depends(), token: str = Depends(bearer)):
Authorize.jwt_required() Authorize.jwt_required()
return Authorize return Authorize
def jwt_optional(Authorize: AuthJWT = Depends()): def jwt_optional(Authorize: AuthJWT = Depends()):
Authorize.jwt_optional() Authorize.jwt_optional()
return Authorize return Authorize
def jwt_refresh_required(Authorize: AuthJWT = Depends(), token: str = Depends(bearer)): def jwt_refresh_required(Authorize: AuthJWT = Depends(), token: str = Depends(bearer)):
Authorize.jwt_refresh_token_required() Authorize.jwt_refresh_token_required()
return Authorize return Authorize
def fresh_jwt_required(Authorize: AuthJWT = Depends(), token: str = Depends(bearer)): def fresh_jwt_required(Authorize: AuthJWT = Depends(), token: str = Depends(bearer)):
Authorize.fresh_jwt_required() Authorize.fresh_jwt_required()
return Authorize return Authorize
@ -47,7 +44,6 @@ def fresh_jwt_required(Authorize: AuthJWT = Depends(), token: str = Depends(bear
def get_current_clientId(Authorize: AuthJWT = Depends(jwt_required)): def get_current_clientId(Authorize: AuthJWT = Depends(jwt_required)):
return Authorize.get_jwt_subject() return Authorize.get_jwt_subject()
def get_current_user(clientId: str = Depends(get_current_clientId), db: Session = Depends(get_session)): def get_current_user(clientId: str = Depends(get_current_clientId), db: Session = Depends(get_session)):
user = get_user_from_clientId_db(clientId, db) user = get_user_from_clientId_db(clientId, db)
if not user: if not user:

View File

@ -4,11 +4,13 @@ from sqlmodel import select, Session
from sqlmodel import SQLModel from sqlmodel import SQLModel
def generate_unique_code(model: SQLModel, s: Session, length: int = 6): def generate_unique_code(model: SQLModel, s: Session, field_name='id_code', length: int = 6):
if getattr(model, field_name, None) is None:
raise KeyError("Invalid field name")
while True: while True:
code = ''.join(random.choices(string.ascii_uppercase, k=length)) code = ''.join(random.choices(string.ascii_uppercase, k=length))
is_unique = s.exec(select(model).where( is_unique = s.exec(select(model).where(
model.id_code == code)).first() == None getattr(model, field_name) == code)).first() == None
if is_unique: if is_unique:
break break
return code return code

View File

@ -1,108 +0,0 @@
import uuid
from pydantic import validator
import io
import os
from fastapi import UploadFile, Form
from pathlib import Path
from typing import IO, Any, List, Optional, Type
from fastapi import FastAPI
from sqlmodel import Field, Session, SQLModel, create_engine, select
from services.exoValidation import get_support_from_data
from services.io import add_fast_api_root, get_filename, get_or_create_dir, remove_fastapi_root
class FileFieldMeta(type):
def __getitem__(self, upload_root: str,) -> Type['FileField']:
return type('MyFieldValue', (FileField,), {'upload_root': upload_root})
class FileField(str, metaclass=FileFieldMeta):
upload_root: str
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, value: str | IO, values):
print(cls.upload_root, cls.default_naming_field)
upload_root = get_or_create_dir(
add_fast_api_root(cls.upload_root))
if not isinstance(value, str):
value.seek(0)
is_binary = isinstance(value, io.BytesIO)
name = get_filename(value, 'exo_source.py')
parent = get_or_create_dir(os.path.join(
upload_root, values['id_code']))
mode = 'w+' if not is_binary else 'wb+'
path = os.path.join(parent, name)
with open(path, mode) as f:
f.write(value.read())
return remove_fastapi_root(path)
else:
if not os.path.exists(value):
raise ValueError('File does not exist')
return value
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
id_code : str
path: FileField['/testing', 'id_code']
class HeroCreate(SQLModel):
path: FileField[42]
class HeroRead(SQLModel):
id: int
id_code: str
path: str
sqlite_file_name = "testing.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
connect_args = {"check_same_thread": False}
engine = create_engine(sqlite_url, echo=True, connect_args=connect_args)
def create_db_and_tables():
SQLModel.metadata.create_all(engine)
app = FastAPI()
@app.on_event("startup")
def on_startup():
create_db_and_tables()
@app.get("/heroes/", response_model=List[HeroRead])
def read_heroes():
with Session(engine) as session:
heroes = session.exec(select(Hero)).all()
return heroes
@app.post("/heroes/", )
def create_hero(file: UploadFile, name: str = Form()):
with Session(engine) as session:
file_obj = file.file._file
db_hero = Hero(path=file_obj, id_code=name)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return "db_hero"

View File

@ -4,7 +4,6 @@ VALID_USERNAME = 'lilian'
VALID_PASSWORD = 'Test12345' VALID_PASSWORD = 'Test12345'
def test_register(client: TestClient, username = VALID_USERNAME): def test_register(client: TestClient, username = VALID_USERNAME):
print('usernae')
r = client.post('/register', data={"username": username, 'password': VALID_PASSWORD, 'password_confirm': VALID_PASSWORD}) r = client.post('/register', data={"username": username, 'password': VALID_PASSWORD, 'password_confirm': VALID_PASSWORD})
data = r.json() data = r.json()
print(data) print(data)

View File

@ -0,0 +1,79 @@
from fastapi import HTTPException
from fastapi.testclient import TestClient
from tests.test_auth import test_register
def test_create_room_no_auth(client: TestClient):
r = client.post('/room', json={"name": "test_room",
"public": False}, params={'username': "lilian"})
print(r.json())
assert "id_code" in r.json()['room']
assert "reconnect_code" in r.json()['member']['anonymous']
assert {"room": {**r.json()['room'], 'id_code': None}, "member": {**r.json()['member'], "anonymous": {**r.json()['member']['anonymous'], "reconnect_code": None}}} == {"room": {"id_code": None, "name": "test_room",
"public": False}, 'member': {"anonymous": {"username": "lilian", "reconnect_code": None}, "user": None}}
return r.json()
def test_create_room_no_auth_invalid(client: TestClient):
r = client.post('/room', json={"name": "test_room"*21,
"public": False}, params={'username': "lilian"*21})
print(r.json())
assert r.json() == {'detail': {'username_error': 'ensure this value has at most 20 characters',
'name_error': 'ensure this value has at most 20 characters'}}
def test_create_room_auth(client: TestClient, token = None):
if token is None:
token = test_register(client=client)['access']
r = client.post('/room', json={"name": "test_room",
"public": False}, headers={"Authorization": "Bearer " + token})
print(r.json())
assert "id_code" in r.json()['room']
assert {**r.json(), "room": {**r.json()['room'], 'id_code': None}} == {"room": {"id_code": None, "name": "test_room",
"public": False}, 'member': {"user": {"username": "lilian"}, "anonymous": None}}
return r.json()
def test_room_not_found(client: TestClient):
try:
with client.websocket_connect('/ws/room/eee') as r:
pass
except HTTPException as e :
assert True
except Exception:
assert False
def test_login_no_auth(client: TestClient):
room = test_create_room_no_auth(client=client)
member = room['member']['anonymous']
with client.websocket_connect(f"/ws/room/" + room['room']['id_code']) as ws:
ws.send_json({"type": "login", "data": {"reconnect_code": member['reconnect_code']}})
data = ws.receive_json()
print(data)
assert data == {'type': "loggedIn", "data": {"member": {"username": member['username'], "reconnect_code": member['reconnect_code'], "isAdmin": True}}}
def test_login_auth(client: TestClient):
token = test_register(client=client)['access']
room = test_create_room_auth(client=client, token=token)
member = room['member']['user']
with client.websocket_connect(f"/ws/room/" + room['room']['id_code']) as ws:
ws.send_json({"type": "login", "data": {"token": token}})
data = ws.receive_json()
print(data)
assert data == {'type': "loggedIn", "data": {"member": {"username": member['username'], "isAdmin": True}}}
def test_join_no_auth(client: TestClient):
room = test_create_room_no_auth(client=client)
member = room['member']['anonymous']
with client.websocket_connect(f"/ws/room/" + room['room']['id_code']) as admin:
admin.send_json({"type": "login", "data": {
"reconnect_code": member['reconnect_code']}})
with client.websocket_connect(f"/ws/room/" + room['room']['id_code']) as member:
member.send_json({"type":"join", "data": {"username": "member"}})
mdata = member.receive_json()
assert "id_code" in mdata['data']['waiter']
assert mdata == {"type": "waiting", "data": {"waiter": {
"username": "member", "id_code": mdata['data']['waiter']}}}
adata = admin.receive_json()
assert adata == {'type': "waiter", 'data': {
"waiter": {"id_code": mdata['data']['waiter'], "username": "member"}}}
admin.send({"type": "accept", "data": {"waiter_id": mdata['data']['waiter']}})

View File

@ -22,33 +22,25 @@ class ConnectionManager:
def __init__(self): def __init__(self):
self.active_connections: Dict[str,List[WebSocket]] = {} self.active_connections: Dict[str,List[WebSocket]] = {}
async def connect(self, websocket: WebSocket, room_id): async def add(self, group, ws):
await websocket.accept()
if room_id not in self.active_connections: if group not in self.active_connections:
self.active_connections[room_id] = [] self.active_connections[group] = []
self.active_connections[room_id].append(websocket) if ws not in self.active_connections[group]:
self.active_connections[group].append(ws)
async def add(self, room_id, ws): def remove(self, ws: WebSocket, group):
if room_id not in self.active_connections: if group in self.active_connections:
self.active_connections[room_id] = [] if ws in self.active_connections[group]:
self.active_connections[group].remove(ws)
self.active_connections[room_id].append(ws)
def remove(self, websocket: WebSocket, room_id):
if room_id in self.active_connections:
try:
self.active_connections[room_id].remove(websocket)
except:
pass
async def send_personal_message(self, message: str, websocket: WebSocket): async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message) await websocket.send_text(message)
async def broadcast(self, message: str, room_id): async def broadcast(self, message: str, group):
if room_id in self.active_connections: if group in self.active_connections:
for connection in self.active_connections[room_id]: for connection in self.active_connections[group]:
await connection.send_json(message) await connection.send_json(message)