from typing import List, Callable, Any, Dict from pydantic import validate_arguments, BaseModel from fastapi.websockets import WebSocketDisconnect, WebSocket from pydantic.error_wrappers import ValidationError import inspect def make_event_decorator(eventsDict): def _(name: str | List, conditions: List[Callable | bool] = []): def add_event(func): model = validate_arguments(func).model if type(name) == str: eventsDict[name] = {"func": func, "conditions": conditions, "model": model} if type(name) == list: for n in name: eventsDict[n] = {"func": func, "conditions": conditions, "model": model} return func return add_event return _ def dict_model(model: BaseModel, exclude: List[str]): value = {} for n, f in model: if n not in exclude: value[n] = f return value def dict_all(obj: Any): if isinstance(obj, dict): value = {} for k, v in obj.items(): if isinstance(v, dict): v = dict_all(v) value[k] = dict(v) elif isinstance(v, BaseModel): value[k] = dict(v) else: try: value[k] = dict(v) except: value[k] = v return value return dict(obj) class Event(BaseModel): func: Callable conditions: List[Callable | bool] model: BaseModel class Consumer: events: Dict[str, Event] = {} sendings: Dict[str, Any] = {} event = make_event_decorator(events) sending = make_event_decorator(sendings) def __init__(self, ws: WebSocket): self.ws: WebSocket = ws #self.events: Dict[str, Callable] = {} async def connect(self): pass async def validation_error_handler(self, e: ValidationError): errors = e.errors() await self.ws.send_json({"type": "error", "data": {"detail": [{ers['loc'][-1]: ers['msg']} for ers in errors]}}) async def send(self, payload): type = payload.get('type', None) #print('TYPE', type, self.member) if type is not None: event_wrapper = self.sendings.get(type, None) if event_wrapper is not None: handler = event_wrapper.get('func') conditions = event_wrapper.get('conditions') is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions]) if handler is not None and is_valid: model = event_wrapper.get("model") data = payload.get('data') or {} try: validated_payload = model(self=self, **data) except ValidationError as e: await self.ws.send_json({"type": "error", "data": {"msg": "Oops there was an error"}}) return validated_payload = dict_model(validated_payload, exclude=["v__duplicate_kwargs", "args", 'kwargs', "self"]) try: parsed_payload = handler( self, **validated_payload) await self.ws.send_json({'type': type, "data": dict_all(parsed_payload)}) return except Exception as e: return return await self.ws.send_json(payload) async def receive(self, data): event = data.get('type', None) if event is not None: event_wrapper = self.events.get(event, None) if event_wrapper is not None: handler = event_wrapper.get('func') conditions = event_wrapper.get('conditions') is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions]) if handler is not None and is_valid: model = event_wrapper.get("model") payload = data.get('data') or {} try: validated_payload = model(self=self, **payload) except ValidationError as e: await self.validation_error_handler(e) return await handler(**{k: v for k, v in validated_payload.dict().items() if k not in ["v__duplicate_kwargs", "args", 'kwargs']}) async def disconnect(self): pass async def run(self): await self.connect() try: while True: data = await self.ws.receive_json() await self.receive(data) except WebSocketDisconnect: await self.disconnect()