481 lines
16 KiB
Python
481 lines
16 KiB
Python
import asyncio
|
|
import collections
|
|
import types
|
|
import warnings
|
|
import sys
|
|
|
|
from .connection import create_connection, _PUBSUB_COMMANDS
|
|
from .log import logger
|
|
from .util import parse_url, CloseEvent
|
|
from .errors import PoolClosedError
|
|
from .abc import AbcPool
|
|
from .locks import Lock
|
|
|
|
|
|
async def create_pool(address, *, db=None, password=None, ssl=None,
|
|
encoding=None, minsize=1, maxsize=10,
|
|
parser=None, loop=None, create_connection_timeout=None,
|
|
pool_cls=None, connection_cls=None):
|
|
# FIXME: rewrite docstring
|
|
"""Creates Redis Pool.
|
|
|
|
By default it creates pool of Redis instances, but it is
|
|
also possible to create pool of plain connections by passing
|
|
``lambda conn: conn`` as commands_factory.
|
|
|
|
*commands_factory* parameter is deprecated since v0.2.9
|
|
|
|
All arguments are the same as for create_connection.
|
|
|
|
Returns RedisPool instance or a pool_cls if it is given.
|
|
"""
|
|
if pool_cls:
|
|
assert issubclass(pool_cls, AbcPool),\
|
|
"pool_class does not meet the AbcPool contract"
|
|
cls = pool_cls
|
|
else:
|
|
cls = ConnectionsPool
|
|
if isinstance(address, str):
|
|
address, options = parse_url(address)
|
|
db = options.setdefault('db', db)
|
|
password = options.setdefault('password', password)
|
|
encoding = options.setdefault('encoding', encoding)
|
|
create_connection_timeout = options.setdefault(
|
|
'timeout', create_connection_timeout)
|
|
if 'ssl' in options:
|
|
assert options['ssl'] or (not options['ssl'] and not ssl), (
|
|
"Conflicting ssl options are set", options['ssl'], ssl)
|
|
ssl = ssl or options['ssl']
|
|
# TODO: minsize/maxsize
|
|
|
|
pool = cls(address, db, password, encoding,
|
|
minsize=minsize, maxsize=maxsize,
|
|
ssl=ssl, parser=parser,
|
|
create_connection_timeout=create_connection_timeout,
|
|
connection_cls=connection_cls,
|
|
loop=loop)
|
|
try:
|
|
await pool._fill_free(override_min=False)
|
|
except Exception:
|
|
pool.close()
|
|
await pool.wait_closed()
|
|
raise
|
|
return pool
|
|
|
|
|
|
class ConnectionsPool(AbcPool):
|
|
"""Redis connections pool."""
|
|
|
|
def __init__(self, address, db=None, password=None, encoding=None,
|
|
*, minsize, maxsize, ssl=None, parser=None,
|
|
create_connection_timeout=None,
|
|
connection_cls=None,
|
|
loop=None):
|
|
assert isinstance(minsize, int) and minsize >= 0, (
|
|
"minsize must be int >= 0", minsize, type(minsize))
|
|
assert maxsize is not None, "Arbitrary pool size is disallowed."
|
|
assert isinstance(maxsize, int) and maxsize > 0, (
|
|
"maxsize must be int > 0", maxsize, type(maxsize))
|
|
assert minsize <= maxsize, (
|
|
"Invalid pool min/max sizes", minsize, maxsize)
|
|
if loop is not None and sys.version_info >= (3, 8):
|
|
warnings.warn("The loop argument is deprecated",
|
|
DeprecationWarning)
|
|
self._address = address
|
|
self._db = db
|
|
self._password = password
|
|
self._ssl = ssl
|
|
self._encoding = encoding
|
|
self._parser_class = parser
|
|
self._minsize = minsize
|
|
self._create_connection_timeout = create_connection_timeout
|
|
self._pool = collections.deque(maxlen=maxsize)
|
|
self._used = set()
|
|
self._acquiring = 0
|
|
self._cond = asyncio.Condition(lock=Lock())
|
|
self._close_state = CloseEvent(self._do_close)
|
|
self._pubsub_conn = None
|
|
self._connection_cls = connection_cls
|
|
|
|
def __repr__(self):
|
|
return '<{} [db:{}, size:[{}:{}], free:{}]>'.format(
|
|
self.__class__.__name__, self.db,
|
|
self.minsize, self.maxsize, self.freesize)
|
|
|
|
@property
|
|
def minsize(self):
|
|
"""Minimum pool size."""
|
|
return self._minsize
|
|
|
|
@property
|
|
def maxsize(self):
|
|
"""Maximum pool size."""
|
|
return self._pool.maxlen
|
|
|
|
@property
|
|
def size(self):
|
|
"""Current pool size."""
|
|
return self.freesize + len(self._used) + self._acquiring
|
|
|
|
@property
|
|
def freesize(self):
|
|
"""Current number of free connections."""
|
|
return len(self._pool)
|
|
|
|
@property
|
|
def address(self):
|
|
return self._address
|
|
|
|
async def clear(self):
|
|
"""Clear pool connections.
|
|
|
|
Close and remove all free connections.
|
|
"""
|
|
async with self._cond:
|
|
await self._do_clear()
|
|
|
|
async def _do_clear(self):
|
|
waiters = []
|
|
while self._pool:
|
|
conn = self._pool.popleft()
|
|
conn.close()
|
|
waiters.append(conn.wait_closed())
|
|
await asyncio.gather(*waiters)
|
|
|
|
async def _do_close(self):
|
|
async with self._cond:
|
|
assert not self._acquiring, self._acquiring
|
|
waiters = []
|
|
while self._pool:
|
|
conn = self._pool.popleft()
|
|
conn.close()
|
|
waiters.append(conn.wait_closed())
|
|
for conn in self._used:
|
|
conn.close()
|
|
waiters.append(conn.wait_closed())
|
|
await asyncio.gather(*waiters)
|
|
# TODO: close _pubsub_conn connection
|
|
logger.debug("Closed %d connection(s)", len(waiters))
|
|
|
|
def close(self):
|
|
"""Close all free and in-progress connections and mark pool as closed.
|
|
"""
|
|
if not self._close_state.is_set():
|
|
self._close_state.set()
|
|
|
|
@property
|
|
def closed(self):
|
|
"""True if pool is closed."""
|
|
return self._close_state.is_set()
|
|
|
|
async def wait_closed(self):
|
|
"""Wait until pool gets closed."""
|
|
await self._close_state.wait()
|
|
|
|
@property
|
|
def db(self):
|
|
"""Currently selected db index."""
|
|
return self._db or 0
|
|
|
|
@property
|
|
def encoding(self):
|
|
"""Current set codec or None."""
|
|
return self._encoding
|
|
|
|
def execute(self, command, *args, **kw):
|
|
"""Executes redis command in a free connection and returns
|
|
future waiting for result.
|
|
|
|
Picks connection from free pool and send command through
|
|
that connection.
|
|
If no connection is found, returns coroutine waiting for
|
|
free connection to execute command.
|
|
"""
|
|
conn, address = self.get_connection(command, args)
|
|
if conn is not None:
|
|
fut = conn.execute(command, *args, **kw)
|
|
return self._check_result(fut, command, args, kw)
|
|
else:
|
|
coro = self._wait_execute(address, command, args, kw)
|
|
return self._check_result(coro, command, args, kw)
|
|
|
|
def execute_pubsub(self, command, *channels):
|
|
"""Executes Redis (p)subscribe/(p)unsubscribe commands.
|
|
|
|
ConnectionsPool picks separate connection for pub/sub
|
|
and uses it until explicitly closed or disconnected
|
|
(unsubscribing from all channels/patterns will leave connection
|
|
locked for pub/sub use).
|
|
|
|
There is no auto-reconnect for this PUB/SUB connection.
|
|
|
|
Returns asyncio.gather coroutine waiting for all channels/patterns
|
|
to receive answers.
|
|
"""
|
|
conn, address = self.get_connection(command)
|
|
if conn is not None:
|
|
return conn.execute_pubsub(command, *channels)
|
|
else:
|
|
return self._wait_execute_pubsub(address, command, channels, {})
|
|
|
|
def get_connection(self, command, args=()):
|
|
"""Get free connection from pool.
|
|
|
|
Returns connection.
|
|
"""
|
|
# TODO: find a better way to determine if connection is free
|
|
# and not havily used.
|
|
command = command.upper().strip()
|
|
is_pubsub = command in _PUBSUB_COMMANDS
|
|
if is_pubsub and self._pubsub_conn:
|
|
if not self._pubsub_conn.closed:
|
|
return self._pubsub_conn, self._pubsub_conn.address
|
|
self._pubsub_conn = None
|
|
for i in range(self.freesize):
|
|
conn = self._pool[0]
|
|
self._pool.rotate(1)
|
|
if conn.closed: # or conn._waiters: (eg: busy connection)
|
|
continue
|
|
if conn.in_pubsub:
|
|
continue
|
|
if is_pubsub:
|
|
self._pubsub_conn = conn
|
|
self._pool.remove(conn)
|
|
self._used.add(conn)
|
|
return conn, conn.address
|
|
return None, self._address # figure out
|
|
|
|
def _check_result(self, fut, *data):
|
|
"""Hook to check result or catch exception (like MovedError).
|
|
|
|
This method can be coroutine.
|
|
"""
|
|
return fut
|
|
|
|
async def _wait_execute(self, address, command, args, kw):
|
|
"""Acquire connection and execute command."""
|
|
conn = await self.acquire(command, args)
|
|
try:
|
|
return (await conn.execute(command, *args, **kw))
|
|
finally:
|
|
self.release(conn)
|
|
|
|
async def _wait_execute_pubsub(self, address, command, args, kw):
|
|
if self.closed:
|
|
raise PoolClosedError("Pool is closed")
|
|
assert self._pubsub_conn is None or self._pubsub_conn.closed, (
|
|
"Expected no or closed connection", self._pubsub_conn)
|
|
async with self._cond:
|
|
if self.closed:
|
|
raise PoolClosedError("Pool is closed")
|
|
if self._pubsub_conn is None or self._pubsub_conn.closed:
|
|
conn = await self._create_new_connection(address)
|
|
self._pubsub_conn = conn
|
|
conn = self._pubsub_conn
|
|
return (await conn.execute_pubsub(command, *args, **kw))
|
|
|
|
async def select(self, db):
|
|
"""Changes db index for all free connections.
|
|
|
|
All previously acquired connections will be closed when released.
|
|
"""
|
|
res = True
|
|
async with self._cond:
|
|
for i in range(self.freesize):
|
|
res = res and (await self._pool[i].select(db))
|
|
self._db = db
|
|
return res
|
|
|
|
async def auth(self, password):
|
|
self._password = password
|
|
async with self._cond:
|
|
for i in range(self.freesize):
|
|
await self._pool[i].auth(password)
|
|
|
|
@property
|
|
def in_pubsub(self):
|
|
if self._pubsub_conn and not self._pubsub_conn.closed:
|
|
return self._pubsub_conn.in_pubsub
|
|
return 0
|
|
|
|
@property
|
|
def pubsub_channels(self):
|
|
if self._pubsub_conn and not self._pubsub_conn.closed:
|
|
return self._pubsub_conn.pubsub_channels
|
|
return types.MappingProxyType({})
|
|
|
|
@property
|
|
def pubsub_patterns(self):
|
|
if self._pubsub_conn and not self._pubsub_conn.closed:
|
|
return self._pubsub_conn.pubsub_patterns
|
|
return types.MappingProxyType({})
|
|
|
|
async def acquire(self, command=None, args=()):
|
|
"""Acquires a connection from free pool.
|
|
|
|
Creates new connection if needed.
|
|
"""
|
|
if self.closed:
|
|
raise PoolClosedError("Pool is closed")
|
|
async with self._cond:
|
|
if self.closed:
|
|
raise PoolClosedError("Pool is closed")
|
|
while True:
|
|
await self._fill_free(override_min=True)
|
|
if self.freesize:
|
|
conn = self._pool.popleft()
|
|
assert not conn.closed, conn
|
|
assert conn not in self._used, (conn, self._used)
|
|
self._used.add(conn)
|
|
return conn
|
|
else:
|
|
await self._cond.wait()
|
|
|
|
def release(self, conn):
|
|
"""Returns used connection back into pool.
|
|
|
|
When returned connection has db index that differs from one in pool
|
|
the connection will be closed and dropped.
|
|
When queue of free connections is full the connection will be dropped.
|
|
"""
|
|
assert conn in self._used, (
|
|
"Invalid connection, maybe from other pool", conn)
|
|
self._used.remove(conn)
|
|
if not conn.closed:
|
|
if conn.in_transaction:
|
|
logger.warning(
|
|
"Connection %r is in transaction, closing it.", conn)
|
|
conn.close()
|
|
elif conn.in_pubsub:
|
|
logger.warning(
|
|
"Connection %r is in subscribe mode, closing it.", conn)
|
|
conn.close()
|
|
elif conn._waiters:
|
|
logger.warning(
|
|
"Connection %r has pending commands, closing it.", conn)
|
|
conn.close()
|
|
elif conn.db == self.db:
|
|
if self.maxsize and self.freesize < self.maxsize:
|
|
self._pool.append(conn)
|
|
else:
|
|
# consider this connection as old and close it.
|
|
conn.close()
|
|
else:
|
|
conn.close()
|
|
# FIXME: check event loop is not closed
|
|
asyncio.ensure_future(self._wakeup())
|
|
|
|
def _drop_closed(self):
|
|
for i in range(self.freesize):
|
|
conn = self._pool[0]
|
|
if conn.closed:
|
|
self._pool.popleft()
|
|
else:
|
|
self._pool.rotate(-1)
|
|
|
|
async def _fill_free(self, *, override_min):
|
|
# drop closed connections first
|
|
self._drop_closed()
|
|
# address = self._address
|
|
while self.size < self.minsize:
|
|
self._acquiring += 1
|
|
try:
|
|
conn = await self._create_new_connection(self._address)
|
|
# check the healthy of that connection, if
|
|
# something went wrong just trigger the Exception
|
|
await conn.execute('ping')
|
|
self._pool.append(conn)
|
|
finally:
|
|
self._acquiring -= 1
|
|
# connection may be closed at yield point
|
|
self._drop_closed()
|
|
if self.freesize:
|
|
return
|
|
if override_min:
|
|
while not self._pool and self.size < self.maxsize:
|
|
self._acquiring += 1
|
|
try:
|
|
conn = await self._create_new_connection(self._address)
|
|
self._pool.append(conn)
|
|
finally:
|
|
self._acquiring -= 1
|
|
# connection may be closed at yield point
|
|
self._drop_closed()
|
|
|
|
def _create_new_connection(self, address):
|
|
return create_connection(address,
|
|
db=self._db,
|
|
password=self._password,
|
|
ssl=self._ssl,
|
|
encoding=self._encoding,
|
|
parser=self._parser_class,
|
|
timeout=self._create_connection_timeout,
|
|
connection_cls=self._connection_cls,
|
|
)
|
|
|
|
async def _wakeup(self, closing_conn=None):
|
|
async with self._cond:
|
|
self._cond.notify()
|
|
if closing_conn is not None:
|
|
await closing_conn.wait_closed()
|
|
|
|
def __enter__(self):
|
|
raise RuntimeError(
|
|
"'await' should be used as a context manager expression")
|
|
|
|
def __exit__(self, *args):
|
|
pass # pragma: nocover
|
|
|
|
def __await__(self):
|
|
# To make `with await pool` work
|
|
conn = yield from self.acquire().__await__()
|
|
return _ConnectionContextManager(self, conn)
|
|
|
|
def get(self):
|
|
'''Return async context manager for working with connection.
|
|
|
|
async with pool.get() as conn:
|
|
await conn.execute('get', 'my-key')
|
|
'''
|
|
return _AsyncConnectionContextManager(self)
|
|
|
|
|
|
class _ConnectionContextManager:
|
|
|
|
__slots__ = ('_pool', '_conn')
|
|
|
|
def __init__(self, pool, conn):
|
|
self._pool = pool
|
|
self._conn = conn
|
|
|
|
def __enter__(self):
|
|
return self._conn
|
|
|
|
def __exit__(self, exc_type, exc_value, tb):
|
|
try:
|
|
self._pool.release(self._conn)
|
|
finally:
|
|
self._pool = None
|
|
self._conn = None
|
|
|
|
|
|
class _AsyncConnectionContextManager:
|
|
|
|
__slots__ = ('_pool', '_conn')
|
|
|
|
def __init__(self, pool):
|
|
self._pool = pool
|
|
self._conn = None
|
|
|
|
async def __aenter__(self):
|
|
conn = await self._pool.acquire()
|
|
self._conn = conn
|
|
return self._conn
|
|
|
|
async def __aexit__(self, exc_type, exc_value, tb):
|
|
try:
|
|
self._pool.release(self._conn)
|
|
finally:
|
|
self._pool = None
|
|
self._conn = None
|