2023-02-28 10:21:08 +01:00

147 lines
5.1 KiB
Python

import inspect
from typing import List, Callable, Any, Dict
from fastapi.websockets import WebSocketDisconnect, WebSocket
from pydantic import validate_arguments, BaseModel
from pydantic.error_wrappers import ValidationError
def make_event_decorator(eventsDict):
def _(name: str | List, conditions: List[Callable | bool] = []):
def add_event(func):
model = validate_arguments(func).model
if type(name) == str:
eventsDict[name] = {"func": func,
"conditions": conditions, "model": model}
if type(name) == list:
for n in name:
eventsDict[n] = {"func": func,
"conditions": conditions, "model": model}
return func
return add_event
return _
def dict_model(model: BaseModel, exclude: List[str]):
value = {}
for n, f in model:
if n not in exclude:
value[n] = f
return value
def dict_all(obj: Any):
if isinstance(obj, dict):
value = {}
for k, v in obj.items():
if isinstance(v, dict):
v = dict_all(v)
value[k] = dict(v)
elif isinstance(v, BaseModel):
value[k] = dict(v)
else:
try:
value[k] = dict(v)
except:
value[k] = v
return value
return dict(obj)
class Event(BaseModel):
func: Callable
conditions: List[Callable | bool]
model: BaseModel
class Consumer:
events: Dict[str, Event] = {}
sendings: Dict[str, Any] = {}
event = make_event_decorator(events)
sending = make_event_decorator(sendings)
def __init__(self, ws: WebSocket):
self.ws: WebSocket = ws
#self.events: Dict[str, Callable] = {}
async def connect(self):
await self.ws.accept()
return True
async def validation_error_handler(self, e: ValidationError):
errors = e.errors()
await self.ws.send_json({"type": "error", "data": {"detail": [{ers['loc'][-1]: ers['msg']} for ers in errors]}})
async def send(self, payload):
''' if self.ws.state == WebSocketState.DISCONNECTED:
return '''
type = payload.get('type', None)
if type is not None:
event_wrapper = self.sendings.get(type, None)
if event_wrapper is not None:
handler = event_wrapper.get('func')
conditions = event_wrapper.get('conditions')
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
if handler is not None and is_valid:
model = event_wrapper.get("model")
data = payload.get('data') or {}
try:
validated_payload = model(self=self, **data)
except ValidationError as e:
await self.ws.send_json({"type": "error", "data": {"msg": "Oops there was an error"}})
return
validated_payload = dict_model(validated_payload,
exclude=["v__duplicate_kwargs", "args", 'kwargs', "self"])
try:
parsed_payload = handler(
self, **validated_payload)
await self.ws.send_json({'type': type, "data": dict_all(parsed_payload)})
return
except Exception as e:
return
return
await self.ws.send_json(payload)
async def receive(self, data):
event = data.get('type', None)
if event is not None:
event_wrapper = self.events.get(event, None)
if event_wrapper is not None:
handler = event_wrapper.get('func')
conditions = event_wrapper.get('conditions')
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
if handler is not None and is_valid:
model = event_wrapper.get("model")
payload = data.get('data') or {}
try:
validated_payload = model(self=self, **payload)
except ValidationError as e:
await self.validation_error_handler(e)
return
await handler(**{k: v for k, v in validated_payload.dict().items() if k not in ["v__duplicate_kwargs", "args", 'kwargs']})
async def disconnect(self):
pass
async def run(self):
accepted = await self.connect()
if accepted is False:
return
try:
while True:
data = await self.ws.receive_json()
await self.receive(data)
except WebSocketDisconnect:
await self.disconnect()