550 lines
20 KiB
Python
550 lines
20 KiB
Python
import types
|
|
import asyncio
|
|
import socket
|
|
import warnings
|
|
import sys
|
|
|
|
from functools import partial
|
|
from collections import deque
|
|
from contextlib import contextmanager
|
|
|
|
from .util import (
|
|
encode_command,
|
|
wait_ok,
|
|
_NOTSET,
|
|
_set_result,
|
|
_set_exception,
|
|
coerced_keys_dict,
|
|
decode,
|
|
parse_url,
|
|
get_event_loop,
|
|
)
|
|
from .parser import Reader
|
|
from .stream import open_connection, open_unix_connection
|
|
from .errors import (
|
|
ConnectionClosedError,
|
|
ConnectionForcedCloseError,
|
|
RedisError,
|
|
ProtocolError,
|
|
ReplyError,
|
|
WatchVariableError,
|
|
ReadOnlyError,
|
|
MaxClientsError
|
|
)
|
|
from .pubsub import Channel
|
|
from .abc import AbcChannel
|
|
from .abc import AbcConnection
|
|
from .log import logger
|
|
|
|
|
|
__all__ = ['create_connection', 'RedisConnection']
|
|
|
|
MAX_CHUNK_SIZE = 65536
|
|
|
|
_PUBSUB_COMMANDS = (
|
|
'SUBSCRIBE', b'SUBSCRIBE',
|
|
'PSUBSCRIBE', b'PSUBSCRIBE',
|
|
'UNSUBSCRIBE', b'UNSUBSCRIBE',
|
|
'PUNSUBSCRIBE', b'PUNSUBSCRIBE',
|
|
)
|
|
|
|
|
|
async def create_connection(address, *, db=None, password=None, ssl=None,
|
|
encoding=None, parser=None, loop=None,
|
|
timeout=None, connection_cls=None):
|
|
"""Creates redis connection.
|
|
|
|
Opens connection to Redis server specified by address argument.
|
|
Address argument can be one of the following:
|
|
* A tuple representing (host, port) pair for TCP connections;
|
|
* A string representing either Redis URI or unix domain socket path.
|
|
|
|
SSL argument is passed through to asyncio.create_connection.
|
|
By default SSL/TLS is not used.
|
|
|
|
By default any timeout is applied at the connection stage, however
|
|
you can set a limitted time used trying to open a connection via
|
|
the `timeout` Kw.
|
|
|
|
Encoding argument can be used to decode byte-replies to strings.
|
|
By default no decoding is done.
|
|
|
|
Parser parameter can be used to pass custom Redis protocol parser class.
|
|
By default hiredis.Reader is used (unless it is missing or platform
|
|
is not CPython).
|
|
|
|
Return value is RedisConnection instance or a connection_cls if it is
|
|
given.
|
|
|
|
This function is a coroutine.
|
|
"""
|
|
assert isinstance(address, (tuple, list, str)), "tuple or str expected"
|
|
if isinstance(address, str):
|
|
address, options = parse_url(address)
|
|
logger.debug("Parsed Redis URI %r", address)
|
|
db = options.setdefault('db', db)
|
|
password = options.setdefault('password', password)
|
|
encoding = options.setdefault('encoding', encoding)
|
|
timeout = options.setdefault('timeout', timeout)
|
|
if 'ssl' in options:
|
|
assert options['ssl'] or (not options['ssl'] and not ssl), (
|
|
"Conflicting ssl options are set", options['ssl'], ssl)
|
|
ssl = ssl or options['ssl']
|
|
|
|
if timeout is not None and timeout <= 0:
|
|
raise ValueError("Timeout has to be None or a number greater than 0")
|
|
|
|
if connection_cls:
|
|
assert issubclass(connection_cls, AbcConnection),\
|
|
"connection_class does not meet the AbcConnection contract"
|
|
cls = connection_cls
|
|
else:
|
|
cls = RedisConnection
|
|
|
|
if loop is not None and sys.version_info >= (3, 8, 0):
|
|
warnings.warn("The loop argument is deprecated",
|
|
DeprecationWarning)
|
|
|
|
if isinstance(address, (list, tuple)):
|
|
host, port = address
|
|
logger.debug("Creating tcp connection to %r", address)
|
|
reader, writer = await asyncio.wait_for(open_connection(
|
|
host, port, limit=MAX_CHUNK_SIZE, ssl=ssl),
|
|
timeout)
|
|
sock = writer.transport.get_extra_info('socket')
|
|
if sock is not None:
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
|
address = sock.getpeername()
|
|
address = tuple(address[:2])
|
|
else:
|
|
logger.debug("Creating unix connection to %r", address)
|
|
reader, writer = await asyncio.wait_for(open_unix_connection(
|
|
address, ssl=ssl, limit=MAX_CHUNK_SIZE),
|
|
timeout)
|
|
sock = writer.transport.get_extra_info('socket')
|
|
if sock is not None:
|
|
address = sock.getpeername()
|
|
|
|
conn = cls(reader, writer, encoding=encoding,
|
|
address=address, parser=parser)
|
|
|
|
try:
|
|
if password is not None:
|
|
await conn.auth(password)
|
|
if db is not None:
|
|
await conn.select(db)
|
|
except Exception:
|
|
conn.close()
|
|
await conn.wait_closed()
|
|
raise
|
|
return conn
|
|
|
|
|
|
class RedisConnection(AbcConnection):
|
|
"""Redis connection."""
|
|
|
|
def __init__(self, reader, writer, *, address, encoding=None,
|
|
parser=None, loop=None):
|
|
if loop is not None and sys.version_info >= (3, 8):
|
|
warnings.warn("The loop argument is deprecated",
|
|
DeprecationWarning)
|
|
if parser is None:
|
|
parser = Reader
|
|
assert callable(parser), (
|
|
"Parser argument is not callable", parser)
|
|
self._reader = reader
|
|
self._writer = writer
|
|
self._address = address
|
|
self._waiters = deque()
|
|
self._reader.set_parser(
|
|
parser(protocolError=ProtocolError, replyError=ReplyError)
|
|
)
|
|
self._reader_task = asyncio.ensure_future(self._read_data())
|
|
self._close_msg = None
|
|
self._db = 0
|
|
self._closing = False
|
|
self._closed = False
|
|
self._close_state = asyncio.Event()
|
|
self._reader_task.add_done_callback(lambda x: self._close_state.set())
|
|
self._in_transaction = None
|
|
self._transaction_error = None # XXX: never used?
|
|
self._in_pubsub = 0
|
|
self._pubsub_channels = coerced_keys_dict()
|
|
self._pubsub_patterns = coerced_keys_dict()
|
|
self._encoding = encoding
|
|
self._pipeline_buffer = None
|
|
|
|
def __repr__(self):
|
|
return '<RedisConnection [db:{}]>'.format(self._db)
|
|
|
|
async def _read_data(self):
|
|
"""Response reader task."""
|
|
last_error = ConnectionClosedError(
|
|
"Connection has been closed by server")
|
|
while not self._reader.at_eof():
|
|
try:
|
|
obj = await self._reader.readobj()
|
|
except asyncio.CancelledError:
|
|
# NOTE: reader can get cancelled from `close()` method only.
|
|
last_error = RuntimeError('this is unexpected')
|
|
break
|
|
except ProtocolError as exc:
|
|
# ProtocolError is fatal
|
|
# so connection must be closed
|
|
if self._in_transaction is not None:
|
|
self._transaction_error = exc
|
|
last_error = exc
|
|
break
|
|
except Exception as exc:
|
|
# NOTE: for QUIT command connection error can be received
|
|
# before response
|
|
last_error = exc
|
|
break
|
|
else:
|
|
if (obj == b'' or obj is None) and self._reader.at_eof():
|
|
logger.debug("Connection has been closed by server,"
|
|
" response: %r", obj)
|
|
last_error = ConnectionClosedError("Reader at end of file")
|
|
break
|
|
|
|
if isinstance(obj, MaxClientsError):
|
|
last_error = obj
|
|
break
|
|
if self._in_pubsub:
|
|
self._process_pubsub(obj)
|
|
else:
|
|
self._process_data(obj)
|
|
self._closing = True
|
|
get_event_loop().call_soon(self._do_close, last_error)
|
|
|
|
def _process_data(self, obj):
|
|
"""Processes command results."""
|
|
assert len(self._waiters) > 0, (type(obj), obj)
|
|
waiter, encoding, cb = self._waiters.popleft()
|
|
if isinstance(obj, RedisError):
|
|
if isinstance(obj, ReplyError):
|
|
if obj.args[0].startswith('READONLY'):
|
|
obj = ReadOnlyError(obj.args[0])
|
|
_set_exception(waiter, obj)
|
|
if self._in_transaction is not None:
|
|
self._transaction_error = obj
|
|
else:
|
|
if encoding is not None:
|
|
try:
|
|
obj = decode(obj, encoding)
|
|
except Exception as exc:
|
|
_set_exception(waiter, exc)
|
|
return
|
|
if cb is not None:
|
|
try:
|
|
obj = cb(obj)
|
|
except Exception as exc:
|
|
_set_exception(waiter, exc)
|
|
return
|
|
_set_result(waiter, obj)
|
|
if self._in_transaction is not None:
|
|
self._in_transaction.append((encoding, cb))
|
|
|
|
def _process_pubsub(self, obj, *, process_waiters=True):
|
|
"""Processes pubsub messages."""
|
|
kind, *args, data = obj
|
|
if kind in (b'subscribe', b'unsubscribe'):
|
|
chan, = args
|
|
if process_waiters and self._in_pubsub and self._waiters:
|
|
self._process_data(obj)
|
|
if kind == b'unsubscribe':
|
|
ch = self._pubsub_channels.pop(chan, None)
|
|
if ch:
|
|
ch.close()
|
|
self._in_pubsub = data
|
|
elif kind in (b'psubscribe', b'punsubscribe'):
|
|
chan, = args
|
|
if process_waiters and self._in_pubsub and self._waiters:
|
|
self._process_data(obj)
|
|
if kind == b'punsubscribe':
|
|
ch = self._pubsub_patterns.pop(chan, None)
|
|
if ch:
|
|
ch.close()
|
|
self._in_pubsub = data
|
|
elif kind == b'message':
|
|
chan, = args
|
|
self._pubsub_channels[chan].put_nowait(data)
|
|
elif kind == b'pmessage':
|
|
pattern, chan = args
|
|
self._pubsub_patterns[pattern].put_nowait((chan, data))
|
|
elif kind == b'pong':
|
|
if process_waiters and self._in_pubsub and self._waiters:
|
|
self._process_data(data or b'PONG')
|
|
else:
|
|
logger.warning("Unknown pubsub message received %r", obj)
|
|
|
|
@contextmanager
|
|
def _buffered(self):
|
|
# XXX: we must ensure that no await happens
|
|
# as long as we buffer commands.
|
|
# Probably we can set some error-raising callback on enter
|
|
# and remove it on exit
|
|
# if some await happens in between -> throw an error.
|
|
# This is creepy solution, 'cause some one might want to await
|
|
# on some other source except redis.
|
|
# So we must only raise error we someone tries to await
|
|
# pending aioredis future
|
|
# One of solutions is to return coroutine instead of a future
|
|
# in `execute` method.
|
|
# In a coroutine we can check if buffering is enabled and raise error.
|
|
|
|
# TODO: describe in docs difference in pipeline mode for
|
|
# conn.execute vs pipeline.execute()
|
|
if self._pipeline_buffer is None:
|
|
self._pipeline_buffer = bytearray()
|
|
try:
|
|
yield self
|
|
buf = self._pipeline_buffer
|
|
self._writer.write(buf)
|
|
finally:
|
|
self._pipeline_buffer = None
|
|
else:
|
|
yield self
|
|
|
|
def execute(self, command, *args, encoding=_NOTSET):
|
|
"""Executes redis command and returns Future waiting for the answer.
|
|
|
|
Raises:
|
|
* TypeError if any of args can not be encoded as bytes.
|
|
* ReplyError on redis '-ERR' responses.
|
|
* ProtocolError when response can not be decoded meaning connection
|
|
is broken.
|
|
* ConnectionClosedError when either client or server has closed the
|
|
connection.
|
|
"""
|
|
if self._reader is None or self._reader.at_eof():
|
|
msg = self._close_msg or "Connection closed or corrupted"
|
|
raise ConnectionClosedError(msg)
|
|
if command is None:
|
|
raise TypeError("command must not be None")
|
|
if None in args:
|
|
raise TypeError("args must not contain None")
|
|
command = command.upper().strip()
|
|
is_pubsub = command in _PUBSUB_COMMANDS
|
|
is_ping = command in ('PING', b'PING')
|
|
if self._in_pubsub and not (is_pubsub or is_ping):
|
|
raise RedisError("Connection in SUBSCRIBE mode")
|
|
elif is_pubsub:
|
|
logger.warning("Deprecated. Use `execute_pubsub` method directly")
|
|
return self.execute_pubsub(command, *args)
|
|
|
|
if command in ('SELECT', b'SELECT'):
|
|
cb = partial(self._set_db, args=args)
|
|
elif command in ('MULTI', b'MULTI'):
|
|
cb = self._start_transaction
|
|
elif command in ('EXEC', b'EXEC'):
|
|
cb = partial(self._end_transaction, discard=False)
|
|
encoding = None
|
|
elif command in ('DISCARD', b'DISCARD'):
|
|
cb = partial(self._end_transaction, discard=True)
|
|
else:
|
|
cb = None
|
|
if encoding is _NOTSET:
|
|
encoding = self._encoding
|
|
fut = get_event_loop().create_future()
|
|
if self._pipeline_buffer is None:
|
|
self._writer.write(encode_command(command, *args))
|
|
else:
|
|
encode_command(command, *args, buf=self._pipeline_buffer)
|
|
self._waiters.append((fut, encoding, cb))
|
|
return fut
|
|
|
|
def execute_pubsub(self, command, *channels):
|
|
"""Executes redis (p)subscribe/(p)unsubscribe commands.
|
|
|
|
Returns asyncio.gather coroutine waiting for all channels/patterns
|
|
to receive answers.
|
|
"""
|
|
command = command.upper().strip()
|
|
assert command in _PUBSUB_COMMANDS, (
|
|
"Pub/Sub command expected", command)
|
|
if self._reader is None or self._reader.at_eof():
|
|
raise ConnectionClosedError("Connection closed or corrupted")
|
|
if None in set(channels):
|
|
raise TypeError("args must not contain None")
|
|
if not len(channels):
|
|
raise TypeError("No channels/patterns supplied")
|
|
is_pattern = len(command) in (10, 12)
|
|
mkchannel = partial(Channel, is_pattern=is_pattern)
|
|
channels = [ch if isinstance(ch, AbcChannel) else mkchannel(ch)
|
|
for ch in channels]
|
|
if not all(ch.is_pattern == is_pattern for ch in channels):
|
|
raise ValueError("Not all channels {} match command {}"
|
|
.format(channels, command))
|
|
cmd = encode_command(command, *(ch.name for ch in channels))
|
|
res = []
|
|
for ch in channels:
|
|
fut = get_event_loop().create_future()
|
|
res.append(fut)
|
|
cb = partial(self._update_pubsub, ch=ch)
|
|
self._waiters.append((fut, None, cb))
|
|
if self._pipeline_buffer is None:
|
|
self._writer.write(cmd)
|
|
else:
|
|
self._pipeline_buffer.extend(cmd)
|
|
return asyncio.gather(*res)
|
|
|
|
def close(self):
|
|
"""Close connection."""
|
|
self._do_close(ConnectionForcedCloseError())
|
|
|
|
def _do_close(self, exc):
|
|
if self._closed:
|
|
return
|
|
self._closed = True
|
|
self._closing = False
|
|
self._writer.transport.close()
|
|
self._reader_task.cancel()
|
|
self._reader_task = None
|
|
self._writer = None
|
|
self._reader = None
|
|
self._pipeline_buffer = None
|
|
|
|
if exc is not None:
|
|
self._close_msg = str(exc)
|
|
|
|
while self._waiters:
|
|
waiter, *spam = self._waiters.popleft()
|
|
logger.debug("Cancelling waiter %r", (waiter, spam))
|
|
if exc is None:
|
|
_set_exception(waiter, ConnectionForcedCloseError())
|
|
else:
|
|
_set_exception(waiter, exc)
|
|
while self._pubsub_channels:
|
|
_, ch = self._pubsub_channels.popitem()
|
|
logger.debug("Closing pubsub channel %r", ch)
|
|
ch.close(exc)
|
|
while self._pubsub_patterns:
|
|
_, ch = self._pubsub_patterns.popitem()
|
|
logger.debug("Closing pubsub pattern %r", ch)
|
|
ch.close(exc)
|
|
|
|
@property
|
|
def closed(self):
|
|
"""True if connection is closed."""
|
|
closed = self._closing or self._closed
|
|
if not closed and self._reader and self._reader.at_eof():
|
|
self._closing = closed = True
|
|
get_event_loop().call_soon(self._do_close, None)
|
|
return closed
|
|
|
|
async def wait_closed(self):
|
|
"""Coroutine waiting until connection is closed."""
|
|
await self._close_state.wait()
|
|
|
|
@property
|
|
def db(self):
|
|
"""Currently selected db index."""
|
|
return self._db
|
|
|
|
@property
|
|
def encoding(self):
|
|
"""Current set codec or None."""
|
|
return self._encoding
|
|
|
|
@property
|
|
def address(self):
|
|
"""Redis server address, either host-port tuple or str."""
|
|
return self._address
|
|
|
|
def select(self, db):
|
|
"""Change the selected database for the current connection."""
|
|
if not isinstance(db, int):
|
|
raise TypeError("DB must be of int type, not {!r}".format(db))
|
|
if db < 0:
|
|
raise ValueError("DB must be greater or equal 0, got {!r}"
|
|
.format(db))
|
|
fut = self.execute('SELECT', db)
|
|
return wait_ok(fut)
|
|
|
|
def _set_db(self, ok, args):
|
|
assert ok in {b'OK', 'OK'}, ("Unexpected result of SELECT", ok)
|
|
self._db = args[0]
|
|
return ok
|
|
|
|
def _start_transaction(self, ok):
|
|
assert self._in_transaction is None, (
|
|
"Connection is already in transaction", self._in_transaction)
|
|
self._in_transaction = deque()
|
|
self._transaction_error = None
|
|
return ok
|
|
|
|
def _end_transaction(self, obj, discard):
|
|
assert self._in_transaction is not None, (
|
|
"Connection is not in transaction", obj)
|
|
self._transaction_error = None
|
|
recall, self._in_transaction = self._in_transaction, None
|
|
recall.popleft() # ignore first (its _start_transaction)
|
|
if discard:
|
|
return obj
|
|
assert isinstance(obj, list) or (obj is None and not discard), (
|
|
"Unexpected MULTI/EXEC result", obj, recall)
|
|
# TODO: need to be able to re-try transaction
|
|
if obj is None:
|
|
err = WatchVariableError("WATCH variable has changed")
|
|
obj = [err] * len(recall)
|
|
assert len(obj) == len(recall), (
|
|
"Wrong number of result items in mutli-exec", obj, recall)
|
|
res = []
|
|
for o, (encoding, cb) in zip(obj, recall):
|
|
if not isinstance(o, RedisError):
|
|
try:
|
|
if encoding:
|
|
o = decode(o, encoding)
|
|
if cb:
|
|
o = cb(o)
|
|
except Exception as err:
|
|
res.append(err)
|
|
continue
|
|
res.append(o)
|
|
return res
|
|
|
|
def _update_pubsub(self, obj, *, ch):
|
|
kind, *pattern, channel, subscriptions = obj
|
|
self._in_pubsub, was_in_pubsub = subscriptions, self._in_pubsub
|
|
# XXX: the channels/patterns storage should be refactored.
|
|
# if code which supposed to read from channel/pattern
|
|
# failed (exception in reader or else) than
|
|
# the channel object will still reside in memory
|
|
# and leak memory (messages will be put in queue).
|
|
if kind == b'subscribe' and channel not in self._pubsub_channels:
|
|
self._pubsub_channels[channel] = ch
|
|
elif kind == b'psubscribe' and channel not in self._pubsub_patterns:
|
|
self._pubsub_patterns[channel] = ch
|
|
if not was_in_pubsub:
|
|
self._process_pubsub(obj, process_waiters=False)
|
|
return obj
|
|
|
|
@property
|
|
def in_transaction(self):
|
|
"""Set to True when MULTI command was issued."""
|
|
return self._in_transaction is not None
|
|
|
|
@property
|
|
def in_pubsub(self):
|
|
"""Indicates that connection is in PUB/SUB mode.
|
|
|
|
Provides the number of subscribed channels.
|
|
"""
|
|
return self._in_pubsub
|
|
|
|
@property
|
|
def pubsub_channels(self):
|
|
"""Returns read-only channels dict."""
|
|
return types.MappingProxyType(self._pubsub_channels)
|
|
|
|
@property
|
|
def pubsub_patterns(self):
|
|
"""Returns read-only patterns dict."""
|
|
return types.MappingProxyType(self._pubsub_patterns)
|
|
|
|
def auth(self, password):
|
|
"""Authenticate to server."""
|
|
fut = self.execute('AUTH', password)
|
|
return wait_ok(fut)
|