Generateurv2/backend/env/lib/python3.10/site-packages/aioredis/pubsub.py
2022-06-24 17:14:37 +02:00

454 lines
14 KiB
Python

import asyncio
import json
import types
import collections
import warnings
import sys
from .abc import AbcChannel
from .util import _converters # , _set_result
from .errors import ChannelClosedError
from .log import logger
__all__ = [
"Channel",
"EndOfStream",
"Receiver",
]
# End of pubsub messages stream marker.
EndOfStream = object()
class Channel(AbcChannel):
"""Wrapper around asyncio.Queue."""
def __init__(self, name, is_pattern, loop=None):
if loop is not None and sys.version_info >= (3, 8):
warnings.warn("The loop argument is deprecated",
DeprecationWarning)
self._queue = ClosableQueue()
self._name = _converters[type(name)](name)
self._is_pattern = is_pattern
def __repr__(self):
return "<{} name:{!r}, is_pattern:{}, qsize:{}>".format(
self.__class__.__name__,
self._name, self._is_pattern, self._queue.qsize())
@property
def name(self):
"""Encoded channel name/pattern."""
return self._name
@property
def is_pattern(self):
"""Set to True if channel is subscribed to pattern."""
return self._is_pattern
@property
def is_active(self):
"""Returns True until there are messages in channel or
connection is subscribed to it.
Can be used with ``while``:
>>> ch = conn.pubsub_channels['chan:1']
>>> while ch.is_active:
... msg = await ch.get() # may stuck for a long time
"""
return not self._queue.exhausted
async def get(self, *, encoding=None, decoder=None):
"""Coroutine that waits for and returns a message.
:raises aioredis.ChannelClosedError: If channel is unsubscribed
and has no messages.
"""
assert decoder is None or callable(decoder), decoder
if self._queue.exhausted:
raise ChannelClosedError()
msg = await self._queue.get()
if msg is EndOfStream:
# TODO: maybe we need an explicit marker for "end of stream"
# currently, returning None may overlap with
# possible return value from `decoder`
# so the user would have to check `ch.is_active`
# to determine if its EoS or payload
return
if self._is_pattern:
dest_channel, msg = msg
if encoding is not None:
msg = msg.decode(encoding)
if decoder is not None:
msg = decoder(msg)
if self._is_pattern:
return dest_channel, msg
return msg
async def get_json(self, encoding='utf-8'):
"""Shortcut to get JSON messages."""
return (await self.get(encoding=encoding, decoder=json.loads))
def iter(self, *, encoding=None, decoder=None):
"""Same as get method but its native coroutine.
Usage example:
>>> async for msg in ch.iter():
... print(msg)
"""
return _IterHelper(self,
is_active=lambda ch: ch.is_active,
encoding=encoding,
decoder=decoder)
async def wait_message(self):
"""Waits for message to become available in channel
or channel is closed (unsubscribed).
Possible usage:
>>> while (await ch.wait_message()):
... msg = await ch.get()
"""
if not self.is_active:
return False
if not self._queue.empty():
return True
await self._queue.wait()
return self.is_active
# internal methods
def put_nowait(self, data):
self._queue.put(data)
def close(self, exc=None):
"""Marks channel as inactive.
Internal method, will be called from connection
on `unsubscribe` command.
"""
if not self._queue.closed:
self._queue.close()
class _IterHelper:
__slots__ = ('_ch', '_is_active', '_args', '_kw')
def __init__(self, ch, is_active, *args, **kw):
self._ch = ch
self._is_active = is_active
self._args = args
self._kw = kw
def __aiter__(self):
return self
async def __anext__(self):
if not self._is_active(self._ch):
raise StopAsyncIteration # noqa
msg = await self._ch.get(*self._args, **self._kw)
if msg is None:
raise StopAsyncIteration # noqa
return msg
class Receiver:
"""Multi-producers, single-consumer Pub/Sub queue.
Can be used in cases where a single consumer task
must read messages from several different channels
(where pattern subscriptions may not work well
or channels can be added/removed dynamically).
Example use case:
>>> from aioredis.pubsub import Receiver
>>> from aioredis.abc import AbcChannel
>>> mpsc = Receiver()
>>> async def reader(mpsc):
... async for channel, msg in mpsc.iter():
... assert isinstance(channel, AbcChannel)
... print("Got {!r} in channel {!r}".format(msg, channel))
>>> asyncio.ensure_future(reader(mpsc))
>>> await redis.subscribe(mpsc.channel('channel:1'),
... mpsc.channel('channel:3'))
... mpsc.channel('channel:5'))
>>> await redis.psubscribe(mpsc.pattern('hello'))
>>> # publishing 'Hello world' into 'hello-channel'
>>> # will print this message:
Got b'Hello world' in channel b'hello-channel'
>>> # when all is done:
>>> await redis.unsubscribe('channel:1', 'channel:3', 'channel:5')
>>> await redis.punsubscribe('hello')
>>> mpsc.stop()
>>> # any message received after stop() will be ignored.
"""
def __init__(self, loop=None, on_close=None):
assert on_close is None or callable(on_close), (
"on_close must be None or callable", on_close)
if loop is not None:
warnings.warn("The loop argument is deprecated",
DeprecationWarning)
if on_close is None:
on_close = self.check_stop
self._queue = ClosableQueue()
self._refs = {}
self._on_close = on_close
def __repr__(self):
return ('<Receiver is_active:{}, senders:{}, qsize:{}>'
.format(self.is_active, len(self._refs), self._queue.qsize()))
def channel(self, name):
"""Create a channel.
Returns ``_Sender`` object implementing
:class:`~aioredis.abc.AbcChannel`.
"""
enc_name = _converters[type(name)](name)
if (enc_name, False) not in self._refs:
ch = _Sender(self, enc_name,
is_pattern=False)
self._refs[(enc_name, False)] = ch
return ch
return self._refs[(enc_name, False)]
def pattern(self, pattern):
"""Create a pattern channel.
Returns ``_Sender`` object implementing
:class:`~aioredis.abc.AbcChannel`.
"""
enc_pattern = _converters[type(pattern)](pattern)
if (enc_pattern, True) not in self._refs:
ch = _Sender(self, enc_pattern,
is_pattern=True)
self._refs[(enc_pattern, True)] = ch
return self._refs[(enc_pattern, True)]
@property
def channels(self):
"""Read-only channels dict."""
return types.MappingProxyType({
ch.name: ch for ch in self._refs.values()
if not ch.is_pattern})
@property
def patterns(self):
"""Read-only patterns dict."""
return types.MappingProxyType({
ch.name: ch for ch in self._refs.values()
if ch.is_pattern})
async def get(self, *, encoding=None, decoder=None):
"""Wait for and return pub/sub message from one of channels.
Return value is either:
* tuple of two elements: channel & message;
* tuple of three elements: pattern channel, (target channel & message);
* or None in case Receiver is not active or has just been stopped.
:raises aioredis.ChannelClosedError: If listener is stopped
and all messages have been received.
"""
# TODO: add note about raised exception and end marker.
# Flow before ClosableQueue:
# - ch.get() -> message
# - ch.close() -> ch.put(None)
# - ch.get() -> None
# - ch.get() -> ChannelClosedError
# Current flow:
# - ch.get() -> message
# - ch.close() -> ch._closed = True
# - ch.get() -> ChannelClosedError
assert decoder is None or callable(decoder), decoder
if self._queue.exhausted:
raise ChannelClosedError()
obj = await self._queue.get()
if obj is EndOfStream:
return
ch, msg = obj
if ch.is_pattern:
dest_ch, msg = msg
if encoding is not None:
msg = msg.decode(encoding)
if decoder is not None:
msg = decoder(msg)
if ch.is_pattern:
return ch, (dest_ch, msg)
return ch, msg
async def wait_message(self):
"""Blocks until new message appear."""
if not self._queue.empty():
return True
if self._queue.closed:
return False
await self._queue.wait()
return self.is_active
@property
def is_active(self):
"""Returns True if listener has any active subscription."""
if self._queue.exhausted:
return False
return any(ch.is_active for ch in self._refs.values())
def stop(self):
"""Stop receiving messages.
All new messages after this call will be ignored,
so you must call unsubscribe before stopping this listener.
"""
self._queue.close()
# TODO: discard all senders as they might still be active.
# Channels storage in Connection should be refactored:
# if we drop _Senders here they will still be subscribed
# and will reside in memory although messages will be discarded.
def iter(self, *, encoding=None, decoder=None):
"""Returns async iterator.
Usage example:
>>> async for ch, msg in mpsc.iter():
... print(ch, msg)
"""
return _IterHelper(self,
is_active=lambda r: not r._queue.exhausted,
encoding=encoding,
decoder=decoder)
def check_stop(self, channel, exc=None):
"""TBD"""
# NOTE: this is a fast-path implementation,
# if overridden, implementation should use public API:
#
# if self.is_active and not (self.channels or self.patterns):
if not self._refs:
self.stop()
# internal methods
def _put_nowait(self, data, *, sender):
if self._queue.closed:
logger.warning("Pub/Sub listener message after stop:"
" sender: %r, data: %r",
sender, data)
return
self._queue.put((sender, data))
def _close(self, sender, exc=None):
self._refs.pop((sender.name, sender.is_pattern))
self._on_close(sender, exc=exc)
class _Sender(AbcChannel):
"""Write-Only Channel.
Does not allow direct ``.get()`` calls.
"""
def __init__(self, receiver, name, is_pattern):
self._receiver = receiver
self._name = _converters[type(name)](name)
self._is_pattern = is_pattern
self._closed = False
def __repr__(self):
return "<{} name:{!r}, is_pattern:{}, receiver:{!r}>".format(
self.__class__.__name__,
self._name, self._is_pattern, self._receiver)
@property
def name(self):
"""Encoded channel name or pattern."""
return self._name
@property
def is_pattern(self):
"""Set to True if channel is subscribed to pattern."""
return self._is_pattern
@property
def is_active(self):
return not self._closed
async def get(self, *, encoding=None, decoder=None):
raise RuntimeError("MPSC channel does not allow direct get() calls")
def put_nowait(self, data):
self._receiver._put_nowait(data, sender=self)
def close(self, exc=None):
# TODO: close() is exclusive so we can not share same _Sender
# between different connections.
# This needs to be fixed.
if self._closed:
return
self._closed = True
self._receiver._close(self, exc=exc)
class ClosableQueue:
def __init__(self):
self._queue = collections.deque()
self._event = asyncio.Event()
self._closed = False
async def wait(self):
while not (self._queue or self._closed):
await self._event.wait()
return True
async def get(self):
await self.wait()
assert self._queue or self._closed, (
"Unexpected queue state", self._queue, self._closed)
if not self._queue and self._closed:
return EndOfStream
item = self._queue.popleft()
if not self._queue:
self._event.clear()
return item
def put(self, item):
if self._closed:
return
self._queue.append(item)
self._event.set()
def close(self):
"""Mark queue as closed and notify all waiters."""
self._closed = True
self._event.set()
@property
def closed(self):
return self._closed
@property
def exhausted(self):
return self._closed and not self._queue
def empty(self):
return not self._queue
def qsize(self):
return len(self._queue)
def __repr__(self):
closed = 'closed' if self._closed else 'open'
return '<Queue {} size:{}>'.format(closed, len(self._queue))