Generateurv2/backend/env/lib/python3.10/site-packages/aioredis/connection.py
2022-06-24 17:14:37 +02:00

550 lines
20 KiB
Python

import types
import asyncio
import socket
import warnings
import sys
from functools import partial
from collections import deque
from contextlib import contextmanager
from .util import (
encode_command,
wait_ok,
_NOTSET,
_set_result,
_set_exception,
coerced_keys_dict,
decode,
parse_url,
get_event_loop,
)
from .parser import Reader
from .stream import open_connection, open_unix_connection
from .errors import (
ConnectionClosedError,
ConnectionForcedCloseError,
RedisError,
ProtocolError,
ReplyError,
WatchVariableError,
ReadOnlyError,
MaxClientsError
)
from .pubsub import Channel
from .abc import AbcChannel
from .abc import AbcConnection
from .log import logger
__all__ = ['create_connection', 'RedisConnection']
MAX_CHUNK_SIZE = 65536
_PUBSUB_COMMANDS = (
'SUBSCRIBE', b'SUBSCRIBE',
'PSUBSCRIBE', b'PSUBSCRIBE',
'UNSUBSCRIBE', b'UNSUBSCRIBE',
'PUNSUBSCRIBE', b'PUNSUBSCRIBE',
)
async def create_connection(address, *, db=None, password=None, ssl=None,
encoding=None, parser=None, loop=None,
timeout=None, connection_cls=None):
"""Creates redis connection.
Opens connection to Redis server specified by address argument.
Address argument can be one of the following:
* A tuple representing (host, port) pair for TCP connections;
* A string representing either Redis URI or unix domain socket path.
SSL argument is passed through to asyncio.create_connection.
By default SSL/TLS is not used.
By default any timeout is applied at the connection stage, however
you can set a limitted time used trying to open a connection via
the `timeout` Kw.
Encoding argument can be used to decode byte-replies to strings.
By default no decoding is done.
Parser parameter can be used to pass custom Redis protocol parser class.
By default hiredis.Reader is used (unless it is missing or platform
is not CPython).
Return value is RedisConnection instance or a connection_cls if it is
given.
This function is a coroutine.
"""
assert isinstance(address, (tuple, list, str)), "tuple or str expected"
if isinstance(address, str):
address, options = parse_url(address)
logger.debug("Parsed Redis URI %r", address)
db = options.setdefault('db', db)
password = options.setdefault('password', password)
encoding = options.setdefault('encoding', encoding)
timeout = options.setdefault('timeout', timeout)
if 'ssl' in options:
assert options['ssl'] or (not options['ssl'] and not ssl), (
"Conflicting ssl options are set", options['ssl'], ssl)
ssl = ssl or options['ssl']
if timeout is not None and timeout <= 0:
raise ValueError("Timeout has to be None or a number greater than 0")
if connection_cls:
assert issubclass(connection_cls, AbcConnection),\
"connection_class does not meet the AbcConnection contract"
cls = connection_cls
else:
cls = RedisConnection
if loop is not None and sys.version_info >= (3, 8, 0):
warnings.warn("The loop argument is deprecated",
DeprecationWarning)
if isinstance(address, (list, tuple)):
host, port = address
logger.debug("Creating tcp connection to %r", address)
reader, writer = await asyncio.wait_for(open_connection(
host, port, limit=MAX_CHUNK_SIZE, ssl=ssl),
timeout)
sock = writer.transport.get_extra_info('socket')
if sock is not None:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
address = sock.getpeername()
address = tuple(address[:2])
else:
logger.debug("Creating unix connection to %r", address)
reader, writer = await asyncio.wait_for(open_unix_connection(
address, ssl=ssl, limit=MAX_CHUNK_SIZE),
timeout)
sock = writer.transport.get_extra_info('socket')
if sock is not None:
address = sock.getpeername()
conn = cls(reader, writer, encoding=encoding,
address=address, parser=parser)
try:
if password is not None:
await conn.auth(password)
if db is not None:
await conn.select(db)
except Exception:
conn.close()
await conn.wait_closed()
raise
return conn
class RedisConnection(AbcConnection):
"""Redis connection."""
def __init__(self, reader, writer, *, address, encoding=None,
parser=None, loop=None):
if loop is not None and sys.version_info >= (3, 8):
warnings.warn("The loop argument is deprecated",
DeprecationWarning)
if parser is None:
parser = Reader
assert callable(parser), (
"Parser argument is not callable", parser)
self._reader = reader
self._writer = writer
self._address = address
self._waiters = deque()
self._reader.set_parser(
parser(protocolError=ProtocolError, replyError=ReplyError)
)
self._reader_task = asyncio.ensure_future(self._read_data())
self._close_msg = None
self._db = 0
self._closing = False
self._closed = False
self._close_state = asyncio.Event()
self._reader_task.add_done_callback(lambda x: self._close_state.set())
self._in_transaction = None
self._transaction_error = None # XXX: never used?
self._in_pubsub = 0
self._pubsub_channels = coerced_keys_dict()
self._pubsub_patterns = coerced_keys_dict()
self._encoding = encoding
self._pipeline_buffer = None
def __repr__(self):
return '<RedisConnection [db:{}]>'.format(self._db)
async def _read_data(self):
"""Response reader task."""
last_error = ConnectionClosedError(
"Connection has been closed by server")
while not self._reader.at_eof():
try:
obj = await self._reader.readobj()
except asyncio.CancelledError:
# NOTE: reader can get cancelled from `close()` method only.
last_error = RuntimeError('this is unexpected')
break
except ProtocolError as exc:
# ProtocolError is fatal
# so connection must be closed
if self._in_transaction is not None:
self._transaction_error = exc
last_error = exc
break
except Exception as exc:
# NOTE: for QUIT command connection error can be received
# before response
last_error = exc
break
else:
if (obj == b'' or obj is None) and self._reader.at_eof():
logger.debug("Connection has been closed by server,"
" response: %r", obj)
last_error = ConnectionClosedError("Reader at end of file")
break
if isinstance(obj, MaxClientsError):
last_error = obj
break
if self._in_pubsub:
self._process_pubsub(obj)
else:
self._process_data(obj)
self._closing = True
get_event_loop().call_soon(self._do_close, last_error)
def _process_data(self, obj):
"""Processes command results."""
assert len(self._waiters) > 0, (type(obj), obj)
waiter, encoding, cb = self._waiters.popleft()
if isinstance(obj, RedisError):
if isinstance(obj, ReplyError):
if obj.args[0].startswith('READONLY'):
obj = ReadOnlyError(obj.args[0])
_set_exception(waiter, obj)
if self._in_transaction is not None:
self._transaction_error = obj
else:
if encoding is not None:
try:
obj = decode(obj, encoding)
except Exception as exc:
_set_exception(waiter, exc)
return
if cb is not None:
try:
obj = cb(obj)
except Exception as exc:
_set_exception(waiter, exc)
return
_set_result(waiter, obj)
if self._in_transaction is not None:
self._in_transaction.append((encoding, cb))
def _process_pubsub(self, obj, *, process_waiters=True):
"""Processes pubsub messages."""
kind, *args, data = obj
if kind in (b'subscribe', b'unsubscribe'):
chan, = args
if process_waiters and self._in_pubsub and self._waiters:
self._process_data(obj)
if kind == b'unsubscribe':
ch = self._pubsub_channels.pop(chan, None)
if ch:
ch.close()
self._in_pubsub = data
elif kind in (b'psubscribe', b'punsubscribe'):
chan, = args
if process_waiters and self._in_pubsub and self._waiters:
self._process_data(obj)
if kind == b'punsubscribe':
ch = self._pubsub_patterns.pop(chan, None)
if ch:
ch.close()
self._in_pubsub = data
elif kind == b'message':
chan, = args
self._pubsub_channels[chan].put_nowait(data)
elif kind == b'pmessage':
pattern, chan = args
self._pubsub_patterns[pattern].put_nowait((chan, data))
elif kind == b'pong':
if process_waiters and self._in_pubsub and self._waiters:
self._process_data(data or b'PONG')
else:
logger.warning("Unknown pubsub message received %r", obj)
@contextmanager
def _buffered(self):
# XXX: we must ensure that no await happens
# as long as we buffer commands.
# Probably we can set some error-raising callback on enter
# and remove it on exit
# if some await happens in between -> throw an error.
# This is creepy solution, 'cause some one might want to await
# on some other source except redis.
# So we must only raise error we someone tries to await
# pending aioredis future
# One of solutions is to return coroutine instead of a future
# in `execute` method.
# In a coroutine we can check if buffering is enabled and raise error.
# TODO: describe in docs difference in pipeline mode for
# conn.execute vs pipeline.execute()
if self._pipeline_buffer is None:
self._pipeline_buffer = bytearray()
try:
yield self
buf = self._pipeline_buffer
self._writer.write(buf)
finally:
self._pipeline_buffer = None
else:
yield self
def execute(self, command, *args, encoding=_NOTSET):
"""Executes redis command and returns Future waiting for the answer.
Raises:
* TypeError if any of args can not be encoded as bytes.
* ReplyError on redis '-ERR' responses.
* ProtocolError when response can not be decoded meaning connection
is broken.
* ConnectionClosedError when either client or server has closed the
connection.
"""
if self._reader is None or self._reader.at_eof():
msg = self._close_msg or "Connection closed or corrupted"
raise ConnectionClosedError(msg)
if command is None:
raise TypeError("command must not be None")
if None in args:
raise TypeError("args must not contain None")
command = command.upper().strip()
is_pubsub = command in _PUBSUB_COMMANDS
is_ping = command in ('PING', b'PING')
if self._in_pubsub and not (is_pubsub or is_ping):
raise RedisError("Connection in SUBSCRIBE mode")
elif is_pubsub:
logger.warning("Deprecated. Use `execute_pubsub` method directly")
return self.execute_pubsub(command, *args)
if command in ('SELECT', b'SELECT'):
cb = partial(self._set_db, args=args)
elif command in ('MULTI', b'MULTI'):
cb = self._start_transaction
elif command in ('EXEC', b'EXEC'):
cb = partial(self._end_transaction, discard=False)
encoding = None
elif command in ('DISCARD', b'DISCARD'):
cb = partial(self._end_transaction, discard=True)
else:
cb = None
if encoding is _NOTSET:
encoding = self._encoding
fut = get_event_loop().create_future()
if self._pipeline_buffer is None:
self._writer.write(encode_command(command, *args))
else:
encode_command(command, *args, buf=self._pipeline_buffer)
self._waiters.append((fut, encoding, cb))
return fut
def execute_pubsub(self, command, *channels):
"""Executes redis (p)subscribe/(p)unsubscribe commands.
Returns asyncio.gather coroutine waiting for all channels/patterns
to receive answers.
"""
command = command.upper().strip()
assert command in _PUBSUB_COMMANDS, (
"Pub/Sub command expected", command)
if self._reader is None or self._reader.at_eof():
raise ConnectionClosedError("Connection closed or corrupted")
if None in set(channels):
raise TypeError("args must not contain None")
if not len(channels):
raise TypeError("No channels/patterns supplied")
is_pattern = len(command) in (10, 12)
mkchannel = partial(Channel, is_pattern=is_pattern)
channels = [ch if isinstance(ch, AbcChannel) else mkchannel(ch)
for ch in channels]
if not all(ch.is_pattern == is_pattern for ch in channels):
raise ValueError("Not all channels {} match command {}"
.format(channels, command))
cmd = encode_command(command, *(ch.name for ch in channels))
res = []
for ch in channels:
fut = get_event_loop().create_future()
res.append(fut)
cb = partial(self._update_pubsub, ch=ch)
self._waiters.append((fut, None, cb))
if self._pipeline_buffer is None:
self._writer.write(cmd)
else:
self._pipeline_buffer.extend(cmd)
return asyncio.gather(*res)
def close(self):
"""Close connection."""
self._do_close(ConnectionForcedCloseError())
def _do_close(self, exc):
if self._closed:
return
self._closed = True
self._closing = False
self._writer.transport.close()
self._reader_task.cancel()
self._reader_task = None
self._writer = None
self._reader = None
self._pipeline_buffer = None
if exc is not None:
self._close_msg = str(exc)
while self._waiters:
waiter, *spam = self._waiters.popleft()
logger.debug("Cancelling waiter %r", (waiter, spam))
if exc is None:
_set_exception(waiter, ConnectionForcedCloseError())
else:
_set_exception(waiter, exc)
while self._pubsub_channels:
_, ch = self._pubsub_channels.popitem()
logger.debug("Closing pubsub channel %r", ch)
ch.close(exc)
while self._pubsub_patterns:
_, ch = self._pubsub_patterns.popitem()
logger.debug("Closing pubsub pattern %r", ch)
ch.close(exc)
@property
def closed(self):
"""True if connection is closed."""
closed = self._closing or self._closed
if not closed and self._reader and self._reader.at_eof():
self._closing = closed = True
get_event_loop().call_soon(self._do_close, None)
return closed
async def wait_closed(self):
"""Coroutine waiting until connection is closed."""
await self._close_state.wait()
@property
def db(self):
"""Currently selected db index."""
return self._db
@property
def encoding(self):
"""Current set codec or None."""
return self._encoding
@property
def address(self):
"""Redis server address, either host-port tuple or str."""
return self._address
def select(self, db):
"""Change the selected database for the current connection."""
if not isinstance(db, int):
raise TypeError("DB must be of int type, not {!r}".format(db))
if db < 0:
raise ValueError("DB must be greater or equal 0, got {!r}"
.format(db))
fut = self.execute('SELECT', db)
return wait_ok(fut)
def _set_db(self, ok, args):
assert ok in {b'OK', 'OK'}, ("Unexpected result of SELECT", ok)
self._db = args[0]
return ok
def _start_transaction(self, ok):
assert self._in_transaction is None, (
"Connection is already in transaction", self._in_transaction)
self._in_transaction = deque()
self._transaction_error = None
return ok
def _end_transaction(self, obj, discard):
assert self._in_transaction is not None, (
"Connection is not in transaction", obj)
self._transaction_error = None
recall, self._in_transaction = self._in_transaction, None
recall.popleft() # ignore first (its _start_transaction)
if discard:
return obj
assert isinstance(obj, list) or (obj is None and not discard), (
"Unexpected MULTI/EXEC result", obj, recall)
# TODO: need to be able to re-try transaction
if obj is None:
err = WatchVariableError("WATCH variable has changed")
obj = [err] * len(recall)
assert len(obj) == len(recall), (
"Wrong number of result items in mutli-exec", obj, recall)
res = []
for o, (encoding, cb) in zip(obj, recall):
if not isinstance(o, RedisError):
try:
if encoding:
o = decode(o, encoding)
if cb:
o = cb(o)
except Exception as err:
res.append(err)
continue
res.append(o)
return res
def _update_pubsub(self, obj, *, ch):
kind, *pattern, channel, subscriptions = obj
self._in_pubsub, was_in_pubsub = subscriptions, self._in_pubsub
# XXX: the channels/patterns storage should be refactored.
# if code which supposed to read from channel/pattern
# failed (exception in reader or else) than
# the channel object will still reside in memory
# and leak memory (messages will be put in queue).
if kind == b'subscribe' and channel not in self._pubsub_channels:
self._pubsub_channels[channel] = ch
elif kind == b'psubscribe' and channel not in self._pubsub_patterns:
self._pubsub_patterns[channel] = ch
if not was_in_pubsub:
self._process_pubsub(obj, process_waiters=False)
return obj
@property
def in_transaction(self):
"""Set to True when MULTI command was issued."""
return self._in_transaction is not None
@property
def in_pubsub(self):
"""Indicates that connection is in PUB/SUB mode.
Provides the number of subscribed channels.
"""
return self._in_pubsub
@property
def pubsub_channels(self):
"""Returns read-only channels dict."""
return types.MappingProxyType(self._pubsub_channels)
@property
def pubsub_patterns(self):
"""Returns read-only patterns dict."""
return types.MappingProxyType(self._pubsub_patterns)
def auth(self, password):
"""Authenticate to server."""
fut = self.execute('AUTH', password)
return wait_ok(fut)