327 lines
11 KiB
Python
327 lines
11 KiB
Python
import logging
|
|
import time
|
|
import traceback
|
|
from urllib.parse import unquote
|
|
|
|
from autobahn.twisted.websocket import (
|
|
ConnectionDeny,
|
|
WebSocketServerFactory,
|
|
WebSocketServerProtocol,
|
|
)
|
|
from twisted.internet import defer
|
|
|
|
from .utils import parse_x_forwarded_for
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class WebSocketProtocol(WebSocketServerProtocol):
|
|
"""
|
|
Protocol which supports WebSockets and forwards incoming messages to
|
|
the websocket channels.
|
|
"""
|
|
|
|
application_type = "websocket"
|
|
|
|
# If we should send no more messages (e.g. we error-closed the socket)
|
|
muted = False
|
|
|
|
def onConnect(self, request):
|
|
self.server = self.factory.server_class
|
|
self.server.protocol_connected(self)
|
|
self.request = request
|
|
self.protocol_to_accept = None
|
|
self.socket_opened = time.time()
|
|
self.last_ping = time.time()
|
|
try:
|
|
# Sanitize and decode headers
|
|
self.clean_headers = []
|
|
for name, value in request.headers.items():
|
|
name = name.encode("ascii")
|
|
# Prevent CVE-2015-0219
|
|
if b"_" in name:
|
|
continue
|
|
self.clean_headers.append((name.lower(), value.encode("latin1")))
|
|
# Get client address if possible
|
|
peer = self.transport.getPeer()
|
|
host = self.transport.getHost()
|
|
if hasattr(peer, "host") and hasattr(peer, "port"):
|
|
self.client_addr = [str(peer.host), peer.port]
|
|
self.server_addr = [str(host.host), host.port]
|
|
else:
|
|
self.client_addr = None
|
|
self.server_addr = None
|
|
|
|
if self.server.proxy_forwarded_address_header:
|
|
self.client_addr, self.client_scheme = parse_x_forwarded_for(
|
|
dict(self.clean_headers),
|
|
self.server.proxy_forwarded_address_header,
|
|
self.server.proxy_forwarded_port_header,
|
|
self.server.proxy_forwarded_proto_header,
|
|
self.client_addr,
|
|
)
|
|
# Decode websocket subprotocol options
|
|
subprotocols = []
|
|
for header, value in self.clean_headers:
|
|
if header == b"sec-websocket-protocol":
|
|
subprotocols = [
|
|
x.strip() for x in unquote(value.decode("ascii")).split(",")
|
|
]
|
|
# Make new application instance with scope
|
|
self.path = request.path.encode("ascii")
|
|
self.application_deferred = defer.maybeDeferred(
|
|
self.server.create_application,
|
|
self,
|
|
{
|
|
"type": "websocket",
|
|
"path": unquote(self.path.decode("ascii")),
|
|
"raw_path": self.path,
|
|
"headers": self.clean_headers,
|
|
"query_string": self._raw_query_string, # Passed by HTTP protocol
|
|
"client": self.client_addr,
|
|
"server": self.server_addr,
|
|
"subprotocols": subprotocols,
|
|
},
|
|
)
|
|
if self.application_deferred is not None:
|
|
self.application_deferred.addCallback(self.applicationCreateWorked)
|
|
self.application_deferred.addErrback(self.applicationCreateFailed)
|
|
except Exception:
|
|
# Exceptions here are not displayed right, just 500.
|
|
# Turn them into an ERROR log.
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
# Make a deferred and return it - we'll either call it or err it later on
|
|
self.handshake_deferred = defer.Deferred()
|
|
return self.handshake_deferred
|
|
|
|
def applicationCreateWorked(self, application_queue):
|
|
"""
|
|
Called when the background thread has successfully made the application
|
|
instance.
|
|
"""
|
|
# Store the application's queue
|
|
self.application_queue = application_queue
|
|
# Send over the connect message
|
|
self.application_queue.put_nowait({"type": "websocket.connect"})
|
|
self.server.log_action(
|
|
"websocket",
|
|
"connecting",
|
|
{
|
|
"path": self.request.path,
|
|
"client": "%s:%s" % tuple(self.client_addr)
|
|
if self.client_addr
|
|
else None,
|
|
},
|
|
)
|
|
|
|
def applicationCreateFailed(self, failure):
|
|
"""
|
|
Called when application creation fails.
|
|
"""
|
|
logger.error(failure)
|
|
return failure
|
|
|
|
### Twisted event handling
|
|
|
|
def onOpen(self):
|
|
# Send news that this channel is open
|
|
logger.debug("WebSocket %s open and established", self.client_addr)
|
|
self.server.log_action(
|
|
"websocket",
|
|
"connected",
|
|
{
|
|
"path": self.request.path,
|
|
"client": "%s:%s" % tuple(self.client_addr)
|
|
if self.client_addr
|
|
else None,
|
|
},
|
|
)
|
|
|
|
def onMessage(self, payload, isBinary):
|
|
# If we're muted, do nothing.
|
|
if self.muted:
|
|
logger.debug("Muting incoming frame on %s", self.client_addr)
|
|
return
|
|
logger.debug("WebSocket incoming frame on %s", self.client_addr)
|
|
self.last_ping = time.time()
|
|
if isBinary:
|
|
self.application_queue.put_nowait(
|
|
{"type": "websocket.receive", "bytes": payload}
|
|
)
|
|
else:
|
|
self.application_queue.put_nowait(
|
|
{"type": "websocket.receive", "text": payload.decode("utf8")}
|
|
)
|
|
|
|
def onClose(self, wasClean, code, reason):
|
|
"""
|
|
Called when Twisted closes the socket.
|
|
"""
|
|
self.server.protocol_disconnected(self)
|
|
logger.debug("WebSocket closed for %s", self.client_addr)
|
|
if not self.muted and hasattr(self, "application_queue"):
|
|
self.application_queue.put_nowait(
|
|
{"type": "websocket.disconnect", "code": code}
|
|
)
|
|
self.server.log_action(
|
|
"websocket",
|
|
"disconnected",
|
|
{
|
|
"path": self.request.path,
|
|
"client": "%s:%s" % tuple(self.client_addr)
|
|
if self.client_addr
|
|
else None,
|
|
},
|
|
)
|
|
|
|
### Internal event handling
|
|
|
|
def handle_reply(self, message):
|
|
if "type" not in message:
|
|
raise ValueError("Message has no type defined")
|
|
if message["type"] == "websocket.accept":
|
|
self.serverAccept(message.get("subprotocol", None))
|
|
elif message["type"] == "websocket.close":
|
|
if self.state == self.STATE_CONNECTING:
|
|
self.serverReject()
|
|
else:
|
|
self.serverClose(code=message.get("code", None))
|
|
elif message["type"] == "websocket.send":
|
|
if self.state == self.STATE_CONNECTING:
|
|
raise ValueError("Socket has not been accepted, so cannot send over it")
|
|
if message.get("bytes", None) and message.get("text", None):
|
|
raise ValueError(
|
|
"Got invalid WebSocket reply message on %s - contains both bytes and text keys"
|
|
% (message,)
|
|
)
|
|
if message.get("bytes", None):
|
|
self.serverSend(message["bytes"], True)
|
|
if message.get("text", None):
|
|
self.serverSend(message["text"], False)
|
|
|
|
def handle_exception(self, exception):
|
|
"""
|
|
Called by the server when our application tracebacks
|
|
"""
|
|
if hasattr(self, "handshake_deferred"):
|
|
# If the handshake is still ongoing, we need to emit a HTTP error
|
|
# code rather than a WebSocket one.
|
|
self.handshake_deferred.errback(
|
|
ConnectionDeny(code=500, reason="Internal server error")
|
|
)
|
|
else:
|
|
self.sendCloseFrame(code=1011)
|
|
|
|
def serverAccept(self, subprotocol=None):
|
|
"""
|
|
Called when we get a message saying to accept the connection.
|
|
"""
|
|
self.handshake_deferred.callback(subprotocol)
|
|
del self.handshake_deferred
|
|
logger.debug("WebSocket %s accepted by application", self.client_addr)
|
|
|
|
def serverReject(self):
|
|
"""
|
|
Called when we get a message saying to reject the connection.
|
|
"""
|
|
self.handshake_deferred.errback(
|
|
ConnectionDeny(code=403, reason="Access denied")
|
|
)
|
|
del self.handshake_deferred
|
|
self.server.protocol_disconnected(self)
|
|
logger.debug("WebSocket %s rejected by application", self.client_addr)
|
|
self.server.log_action(
|
|
"websocket",
|
|
"rejected",
|
|
{
|
|
"path": self.request.path,
|
|
"client": "%s:%s" % tuple(self.client_addr)
|
|
if self.client_addr
|
|
else None,
|
|
},
|
|
)
|
|
|
|
def serverSend(self, content, binary=False):
|
|
"""
|
|
Server-side channel message to send a message.
|
|
"""
|
|
if self.state == self.STATE_CONNECTING:
|
|
self.serverAccept()
|
|
logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
|
|
if binary:
|
|
self.sendMessage(content, binary)
|
|
else:
|
|
self.sendMessage(content.encode("utf8"), binary)
|
|
|
|
def serverClose(self, code=None):
|
|
"""
|
|
Server-side channel message to close the socket
|
|
"""
|
|
code = 1000 if code is None else code
|
|
self.sendClose(code=code)
|
|
|
|
### Utils
|
|
|
|
def duration(self):
|
|
"""
|
|
Returns the time since the socket was opened
|
|
"""
|
|
return time.time() - self.socket_opened
|
|
|
|
def check_timeouts(self):
|
|
"""
|
|
Called periodically to see if we should timeout something
|
|
"""
|
|
# Web timeout checking
|
|
if (
|
|
self.duration() > self.server.websocket_timeout
|
|
and self.server.websocket_timeout >= 0
|
|
):
|
|
self.serverClose()
|
|
# Ping check
|
|
# If we're still connecting, deny the connection
|
|
if self.state == self.STATE_CONNECTING:
|
|
if self.duration() > self.server.websocket_connect_timeout:
|
|
self.serverReject()
|
|
elif self.state == self.STATE_OPEN:
|
|
if (time.time() - self.last_ping) > self.server.ping_interval:
|
|
self._sendAutoPing()
|
|
self.last_ping = time.time()
|
|
|
|
def __hash__(self):
|
|
return hash(id(self))
|
|
|
|
def __eq__(self, other):
|
|
return id(self) == id(other)
|
|
|
|
def __repr__(self):
|
|
return f"<WebSocketProtocol client={self.client_addr!r} path={self.path!r}>"
|
|
|
|
|
|
class WebSocketFactory(WebSocketServerFactory):
|
|
"""
|
|
Factory subclass that remembers what the "main"
|
|
factory is, so WebSocket protocols can access it
|
|
to get reply ID info.
|
|
"""
|
|
|
|
protocol = WebSocketProtocol
|
|
|
|
def __init__(self, server_class, *args, **kwargs):
|
|
self.server_class = server_class
|
|
WebSocketServerFactory.__init__(self, *args, **kwargs)
|
|
|
|
def buildProtocol(self, addr):
|
|
"""
|
|
Builds protocol instances. We use this to inject the factory object into the protocol.
|
|
"""
|
|
try:
|
|
protocol = super().buildProtocol(addr)
|
|
protocol.factory = self
|
|
return protocol
|
|
except Exception:
|
|
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
|
raise
|