137 lines
4.9 KiB
Python
137 lines
4.9 KiB
Python
from typing import List, Callable, Any, Dict
|
|
from pydantic import validate_arguments, BaseModel
|
|
from fastapi.websockets import WebSocketDisconnect, WebSocket
|
|
from pydantic.error_wrappers import ValidationError
|
|
import inspect
|
|
|
|
def make_event_decorator(eventsDict):
|
|
def _(name: str | List, conditions: List[Callable | bool] = []):
|
|
def add_event(func):
|
|
model = validate_arguments(func).model
|
|
if type(name) == str:
|
|
eventsDict[name] = {"func": func,
|
|
"conditions": conditions, "model": model}
|
|
if type(name) == list:
|
|
for n in name:
|
|
eventsDict[n] = {"func": func,
|
|
"conditions": conditions, "model": model}
|
|
return func
|
|
return add_event
|
|
return _
|
|
|
|
|
|
|
|
def dict_model(model: BaseModel, exclude: List[str]):
|
|
value = {}
|
|
for n, f in model:
|
|
if n not in exclude:
|
|
value[n] = f
|
|
return value
|
|
def dict_all(obj: Any):
|
|
if isinstance(obj, dict):
|
|
value = {}
|
|
for k, v in obj.items():
|
|
if isinstance(v, dict):
|
|
v = dict_all(v)
|
|
value[k] = dict(v)
|
|
elif isinstance(v, BaseModel):
|
|
value[k] = dict(v)
|
|
else:
|
|
try:
|
|
value[k] = dict(v)
|
|
except:
|
|
value[k] = v
|
|
return value
|
|
return dict(obj)
|
|
|
|
|
|
class Event(BaseModel):
|
|
func: Callable
|
|
conditions: List[Callable | bool]
|
|
model: BaseModel
|
|
|
|
class Consumer:
|
|
events: Dict[str, Event] = {}
|
|
sendings: Dict[str, Any] = {}
|
|
event = make_event_decorator(events)
|
|
sending = make_event_decorator(sendings)
|
|
|
|
def __init__(self, ws: WebSocket):
|
|
self.ws: WebSocket = ws
|
|
#self.events: Dict[str, Callable] = {}
|
|
|
|
async def connect(self):
|
|
pass
|
|
|
|
async def validation_error_handler(self, e: ValidationError):
|
|
errors = e.errors()
|
|
await self.ws.send_json({"type": "error", "data": {"detail": [{ers['loc'][-1]: ers['msg']} for ers in errors]}})
|
|
|
|
async def send(self, payload):
|
|
type = payload.get('type', None)
|
|
#print('TYPE', type, self.member)
|
|
if type is not None:
|
|
event_wrapper = self.sendings.get(type, None)
|
|
if event_wrapper is not None:
|
|
handler = event_wrapper.get('func')
|
|
conditions = event_wrapper.get('conditions')
|
|
|
|
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
|
|
|
|
if handler is not None and is_valid:
|
|
model = event_wrapper.get("model")
|
|
|
|
data = payload.get('data') or {}
|
|
try:
|
|
validated_payload = model(self=self, **data)
|
|
except ValidationError as e:
|
|
await self.ws.send_json({"type": "error", "data": {"msg": "Oops there was an error"}})
|
|
return
|
|
|
|
validated_payload = dict_model(validated_payload,
|
|
exclude=["v__duplicate_kwargs", "args", 'kwargs', "self"])
|
|
try:
|
|
parsed_payload = handler(
|
|
self, **validated_payload)
|
|
|
|
await self.ws.send_json({'type': type, "data": dict_all(parsed_payload)})
|
|
return
|
|
except Exception as e:
|
|
return
|
|
return
|
|
await self.ws.send_json(payload)
|
|
|
|
async def receive(self, data):
|
|
event = data.get('type', None)
|
|
if event is not None:
|
|
event_wrapper = self.events.get(event, None)
|
|
if event_wrapper is not None:
|
|
handler = event_wrapper.get('func')
|
|
conditions = event_wrapper.get('conditions')
|
|
|
|
is_valid = all([(await c(self)) if inspect.iscoroutinefunction(c) else c(self) if inspect.isfunction(c) else c == True if isinstance(c, bool) else True for c in conditions])
|
|
|
|
if handler is not None and is_valid:
|
|
model = event_wrapper.get("model")
|
|
|
|
payload = data.get('data') or {}
|
|
try:
|
|
validated_payload = model(self=self, **payload)
|
|
except ValidationError as e:
|
|
await self.validation_error_handler(e)
|
|
return
|
|
|
|
await handler(**{k: v for k, v in validated_payload.dict().items() if k not in ["v__duplicate_kwargs", "args", 'kwargs']})
|
|
|
|
async def disconnect(self):
|
|
pass
|
|
|
|
async def run(self):
|
|
await self.connect()
|
|
try:
|
|
while True:
|
|
data = await self.ws.receive_json()
|
|
await self.receive(data)
|
|
except WebSocketDisconnect:
|
|
await self.disconnect()
|