Generateurv2/backend/env/lib/python3.10/site-packages/channels_redis/pubsub.py
2022-06-24 17:14:37 +02:00

474 lines
18 KiB
Python

import asyncio
import functools
import logging
import types
import uuid
import aioredis
import msgpack
from .utils import _consistent_hash
logger = logging.getLogger(__name__)
def _wrap_close(proxy, loop):
original_impl = loop.close
def _wrapper(self, *args, **kwargs):
if loop in proxy._layers:
layer = proxy._layers[loop]
del proxy._layers[loop]
loop.run_until_complete(layer.flush())
self.close = original_impl
return self.close(*args, **kwargs)
loop.close = types.MethodType(_wrapper, loop)
async def _async_proxy(obj, name, *args, **kwargs):
# Must be defined as a function and not a method due to
# https://bugs.python.org/issue38364
layer = obj._get_layer()
return await getattr(layer, name)(*args, **kwargs)
class RedisPubSubChannelLayer:
def __init__(self, *args, **kwargs) -> None:
self._args = args
self._kwargs = kwargs
self._layers = {}
def __getattr__(self, name):
if name in (
"new_channel",
"send",
"receive",
"group_add",
"group_discard",
"group_send",
"flush",
):
return functools.partial(_async_proxy, self, name)
else:
return getattr(self._get_layer(), name)
def serialize(self, message):
"""
Serializes message to a byte string.
"""
return msgpack.packb(message)
def deserialize(self, message):
"""
Deserializes from a byte string.
"""
return msgpack.unpackb(message)
def _get_layer(self):
loop = asyncio.get_running_loop()
try:
layer = self._layers[loop]
except KeyError:
layer = RedisPubSubLoopLayer(
*self._args,
**self._kwargs,
channel_layer=self,
)
self._layers[loop] = layer
_wrap_close(self, loop)
return layer
class RedisPubSubLoopLayer:
"""
Channel Layer that uses Redis's pub/sub functionality.
"""
def __init__(
self,
hosts=None,
prefix="asgi",
on_disconnect=None,
on_reconnect=None,
channel_layer=None,
**kwargs,
):
if hosts is None:
hosts = [("localhost", 6379)]
assert (
isinstance(hosts, list) and len(hosts) > 0
), "`hosts` must be a list with at least one Redis server"
self.prefix = prefix
self.on_disconnect = on_disconnect
self.on_reconnect = on_reconnect
self.channel_layer = channel_layer
# Each consumer gets its own *specific* channel, created with the `new_channel()` method.
# This dict maps `channel_name` to a queue of messages for that channel.
self.channels = {}
# A channel can subscribe to zero or more groups.
# This dict maps `group_name` to set of channel names who are subscribed to that group.
self.groups = {}
# For each host, we create a `RedisSingleShardConnection` to manage the connection to that host.
self._shards = [RedisSingleShardConnection(host, self) for host in hosts]
def _get_shard(self, channel_or_group_name):
"""
Return the shard that is used exclusively for this channel or group.
"""
return self._shards[_consistent_hash(channel_or_group_name, len(self._shards))]
def _get_group_channel_name(self, group):
"""
Return the channel name used by a group.
Includes '__group__' in the returned
string so that these names are distinguished
from those returned by `new_channel()`.
Technically collisions are possible, but it
takes what I believe is intentional abuse in
order to have colliding names.
"""
return f"{self.prefix}__group__{group}"
async def _subscribe_to_channel(self, channel):
self.channels[channel] = asyncio.Queue()
shard = self._get_shard(channel)
await shard.subscribe(channel)
extensions = ["groups", "flush"]
################################################################################
# Channel layer API
################################################################################
async def send(self, channel, message):
"""
Send a message onto a (general or specific) channel.
"""
shard = self._get_shard(channel)
await shard.publish(channel, self.channel_layer.serialize(message))
async def new_channel(self, prefix="specific."):
"""
Returns a new channel name that can be used by a consumer in our
process as a specific channel.
"""
channel = f"{self.prefix}{prefix}{uuid.uuid4().hex}"
await self._subscribe_to_channel(channel)
return channel
async def receive(self, channel):
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""
if channel not in self.channels:
await self._subscribe_to_channel(channel)
q = self.channels[channel]
try:
message = await q.get()
except asyncio.CancelledError:
# We assume here that the reason we are cancelled is because the consumer
# is exiting, therefore we need to cleanup by unsubscribe below. Indeed,
# currently the way that Django Channels works, this is a safe assumption.
# In the future, Dajngo Channels could change to call a *new* method that
# would serve as the antithesis of `new_channel()`; this new method might
# be named `delete_channel()`. If that were the case, we would do the
# following cleanup from that new `delete_channel()` method, but, since
# that's not how Django Channels works (yet), we do the cleanup below:
if channel in self.channels:
del self.channels[channel]
try:
shard = self._get_shard(channel)
await shard.unsubscribe(channel)
except BaseException:
logger.exception("Unexpected exception while cleaning-up channel:")
# We don't re-raise here because we want the CancelledError to be the one re-raised.
raise
return self.channel_layer.deserialize(message)
################################################################################
# Groups extension
################################################################################
async def group_add(self, group, channel):
"""
Adds the channel name to a group.
"""
if channel not in self.channels:
raise RuntimeError(
"You can only call group_add() on channels that exist in-process.\n"
"Consumers are encouraged to use the common pattern:\n"
f" self.channel_layer.group_add({repr(group)}, self.channel_name)"
)
group_channel = self._get_group_channel_name(group)
if group_channel not in self.groups:
self.groups[group_channel] = set()
group_channels = self.groups[group_channel]
if channel not in group_channels:
group_channels.add(channel)
shard = self._get_shard(group_channel)
await shard.subscribe(group_channel)
async def group_discard(self, group, channel):
"""
Removes the channel from a group.
"""
group_channel = self._get_group_channel_name(group)
assert group_channel in self.groups
group_channels = self.groups[group_channel]
assert channel in group_channels
group_channels.remove(channel)
if len(group_channels) == 0:
del self.groups[group_channel]
shard = self._get_shard(group_channel)
await shard.unsubscribe(group_channel)
async def group_send(self, group, message):
"""
Send the message to all subscribers of the group.
"""
group_channel = self._get_group_channel_name(group)
shard = self._get_shard(group_channel)
await shard.publish(group_channel, self.channel_layer.serialize(message))
################################################################################
# Flush extension
################################################################################
async def flush(self):
"""
Flush the layer, making it like new. It can continue to be used as if it
was just created. This also closes connections, serving as a clean-up
method; connections will be re-opened if you continue using this layer.
"""
self.channels = {}
self.groups = {}
for shard in self._shards:
await shard.flush()
def on_close_noop(sender, exc=None):
"""
If you don't pass an `on_close` function to the `Receiver`, then it
defaults to one that closes the Receiver whenever the last subscriber
unsubscribes. That is not what we want; instead, we want the Receiver
to continue even if no one is subscribed, because soon someone *will*
subscribe and we want things to continue from there. Passing this
empty function solves it.
"""
pass
class RedisSingleShardConnection:
def __init__(self, host, channel_layer):
self.host = host.copy() if type(host) is dict else {"address": host}
self.master_name = self.host.pop("master_name", None)
self.channel_layer = channel_layer
self._subscribed_to = set()
self._lock = None
self._redis = None
self._pub_conn = None
self._sub_conn = None
self._receiver = None
self._receive_task = None
self._keepalive_task = None
async def publish(self, channel, message):
conn = await self._get_pub_conn()
await conn.publish(channel, message)
async def subscribe(self, channel):
if channel not in self._subscribed_to:
self._subscribed_to.add(channel)
conn = await self._get_sub_conn()
await conn.subscribe(self._receiver.channel(channel))
async def unsubscribe(self, channel):
if channel in self._subscribed_to:
self._subscribed_to.remove(channel)
conn = await self._get_sub_conn()
await conn.unsubscribe(channel)
async def flush(self):
for task in [self._keepalive_task, self._receive_task]:
if task is not None:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self._keepalive_task = None
self._receive_task = None
self._receiver = None
if self._sub_conn is not None:
self._sub_conn.close()
await self._sub_conn.wait_closed()
self._put_redis_conn(self._sub_conn)
self._sub_conn = None
if self._pub_conn is not None:
self._pub_conn.close()
await self._pub_conn.wait_closed()
self._put_redis_conn(self._pub_conn)
self._pub_conn = None
self._subscribed_to = set()
async def _get_pub_conn(self):
"""
Return the connection to this shard that is used for *publishing* messages.
If the connection is dead, automatically reconnect.
"""
if self._lock is None:
self._lock = asyncio.Lock()
async with self._lock:
if self._pub_conn is not None and self._pub_conn.closed:
self._put_redis_conn(self._pub_conn)
self._pub_conn = None
while self._pub_conn is None:
try:
self._pub_conn = await self._get_redis_conn()
except BaseException:
self._put_redis_conn(self._pub_conn)
logger.warning(
f"Failed to connect to Redis publish host: {self.host}; will try again in 1 second..."
)
await asyncio.sleep(1)
return self._pub_conn
async def _get_sub_conn(self):
"""
Return the connection to this shard that is used for *subscribing* to channels.
If the connection is dead, automatically reconnect and resubscribe to all our channels!
"""
if self._keepalive_task is None:
self._keepalive_task = asyncio.ensure_future(self._do_keepalive())
if self._lock is None:
self._lock = asyncio.Lock()
async with self._lock:
if self._sub_conn is not None and self._sub_conn.closed:
self._put_redis_conn(self._sub_conn)
self._sub_conn = None
self._notify_consumers(self.channel_layer.on_disconnect)
if self._sub_conn is None:
if self._receive_task is not None:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
# This is the normal case, that `asyncio.CancelledError` is throw. All good.
pass
except BaseException:
logger.exception(
"Unexpected exception while canceling the receiver task:"
)
# Don't re-raise here. We don't actually care why `_receive_task` didn't exit cleanly.
self._receive_task = None
while self._sub_conn is None:
try:
self._sub_conn = await self._get_redis_conn()
except BaseException:
self._put_redis_conn(self._sub_conn)
logger.warning(
f"Failed to connect to Redis subscribe host: {self.host}; will try again in 1 second..."
)
await asyncio.sleep(1)
self._receiver = aioredis.pubsub.Receiver(on_close=on_close_noop)
self._receive_task = asyncio.ensure_future(self._do_receiving())
if len(self._subscribed_to) > 0:
# Do our best to recover by resubscribing to the channels that we were previously subscribed to.
resubscribe_to = [
self._receiver.channel(name) for name in self._subscribed_to
]
await self._sub_conn.subscribe(*resubscribe_to)
self._notify_consumers(self.channel_layer.on_reconnect)
return self._sub_conn
async def _do_receiving(self):
async for ch, message in self._receiver.iter():
name = ch.name
if isinstance(name, bytes):
# Reversing what happens here:
# https://github.com/aio-libs/aioredis-py/blob/8a207609b7f8a33e74c7c8130d97186e78cc0052/aioredis/util.py#L17
name = name.decode()
if name in self.channel_layer.channels:
self.channel_layer.channels[name].put_nowait(message)
elif name in self.channel_layer.groups:
for channel_name in self.channel_layer.groups[name]:
if channel_name in self.channel_layer.channels:
self.channel_layer.channels[channel_name].put_nowait(message)
def _notify_consumers(self, mtype):
if mtype is not None:
for channel in self.channel_layer.channels.values():
channel.put_nowait(
self.channel_layer.channel_layer.serialize({"type": mtype})
)
async def _ensure_redis(self):
if self._redis is None:
if self.master_name is None:
self._redis = await aioredis.create_redis_pool(**self.host)
else:
# aioredis default timeout is way too low
self._redis = await aioredis.sentinel.create_sentinel(
timeout=2, **self.host
)
def _get_aioredis_pool(self):
if self.master_name is None:
return self._redis._pool_or_conn
else:
return self._redis.master_for(self.master_name)._pool_or_conn
async def _get_redis_conn(self):
await self._ensure_redis()
conn = await self._get_aioredis_pool().acquire()
return aioredis.Redis(conn)
def _put_redis_conn(self, conn):
if conn:
self._get_aioredis_pool().release(conn._pool_or_conn)
async def _do_keepalive(self):
"""
This task's simple job is just to call `self._get_sub_conn()` periodically.
Why? Well, calling `self._get_sub_conn()` has the nice side-effect that if
that connection has died (because Redis was restarted, or there was a networking
hiccup, for example), then calling `self._get_sub_conn()` will reconnect and
restore our old subscriptions. Thus, we want to do this on a predictable schedule.
This is kinda a sub-optimal way to achieve this, but I can't find a way in aioredis
to get a notification when the connection dies. I find this (sub-optimal) method
of checking the connection state works fine for my app; if Redis restarts, we reconnect
and resubscribe *quickly enough*; I mean, Redis restarting is already bad because it
will cause messages to get lost, and this periodic check at least minimizes the
damage *enough*.
Note you wouldn't need this if you were *sure* that there would be a lot of subscribe/
unsubscribe events on your site, because such events each call `self._get_sub_conn()`.
Thus, on a site with heavy traffic this task may not be necessary, but also maybe it is.
Why? Well, in a heavy traffic site you probably have more than one Django server replicas,
so it might be the case that one of your replicas is under-utilized and this periodic
connection check will be beneficial in the same way as it is for a low-traffic site.
"""
while True:
await asyncio.sleep(1)
try:
await self._get_sub_conn()
except Exception:
logger.exception("Unexpected exception in keepalive task:")