"""
This module implements and asyncio :class:`asyncio.Protocol` protocol for a
request-reply :class:`Channel`.
"""
from asyncio import coroutine
import asyncio
import collections
import itertools
import logging
import struct
import traceback
from aiomas.exceptions import RemoteException
import aiomas.codecs
__all__ = [
'REQUEST', 'RESULT', 'EXCEPTION', 'DEFAULT_CODEC',
'open_connection', 'start_server',
'Header', 'Request', 'Channel', 'ChannelProtocol',
]
# Message types
REQUEST = 0
RESULT = 1
EXCEPTION = 2
DEFAULT_CODEC = aiomas.codecs.JSON
Header = struct.Struct('!L')
logger = logging.getLogger(__name__)
@coroutine
[docs]def open_connection(addr, *,
loop=None, codec=None, extra_serializers=(), **kwds):
"""Return a :class:`Channel` connected to *addr*.
This is a convenience wrapper for
:meth:`asyncio.BaseEventLoop.create_connection()` and
:meth:`asyncio.BaseEventLoop.create_unix_connection`.
If *addr* is a tuple ``(host, port)``, a TCP connection will be created.
If *addr* is a string, it should be a path name pointing to the unix domain
socket to connect to.
You can optionally provide the event *loop* to use.
By default, the :class:`~aiomas.codecs.JSON` *codec* is used. You
can override this by passing any subclass of :class:`aiomas.codecs.Codec`
as *codec*.
You can also pass a list of *extra_serializers* for the codec. The list
entires need to be callables that return a tuple with the arguments for
:meth:`~aiomas.codecs.Codec.add_serializer()`.
The remaining keyword argumens *kwds* are forwarded to
:meth:`asyncio.BaseEventLoop.create_connection()` and
:meth:`asyncio.BaseEventLoop.create_unix_connection` respectively.
"""
if loop is None:
loop = asyncio.get_event_loop()
codec = codec() if codec else DEFAULT_CODEC()
for s in extra_serializers:
codec.add_serializer(*s())
def factory():
return ChannelProtocol(codec, loop=loop)
if type(addr) is tuple:
t, p = yield from loop.create_connection(factory, *addr, **kwds)
elif type(addr) is str:
t, p = yield from loop.create_unix_connection(factory, addr, **kwds)
else:
raise ValueError('Unkown address type: %s' % addr)
return p.channel
@coroutine
[docs]def start_server(addr, client_connected_cb, *,
loop=None, codec=None, extra_serializers=(), **kwds):
"""Start a server listening on *addr* and call *client_connected_cb*
for every client connecting to it.
This function is a convenience wrapper for
:meth:`asyncio.BaseEventLoop.create_server()` and
:meth:`asyncio.BaseEventLoop.create_unix_server`.
If *addr* is a tuple ``(host, port)``, a TCP socket will be created. If
*addr* is a string, a unix domain socket at this path will be created.
The single argument of the callable *client_connected_cb* is a new instance
of :class:`Channel`.
You can optionally provide the event *loop* to use.
By default, the :class:`~aiomas.codecs.JSON` *codec* is used. You can
override this by passing any subclass of :class:`aiomas.codecs.Codec` as
*codec*.
You can also pass a list of *extra_serializers* for the codec. The list
entires need to be callables that return a tuple with the arguments for
:meth:`~aiomas.codecs.Codec.add_serializer()`.
The remaining keyword argumens *kwds* are forwarded to
:meth:`asyncio.BaseEventLoop.create_server()` and
:meth:`asyncio.BaseEventLoop.create_unix_server` respectively.
"""
if loop is None:
loop = asyncio.get_event_loop()
if not codec:
codec = DEFAULT_CODEC
def factory():
c = codec()
for s in extra_serializers:
c.add_serializer(*s())
return ChannelProtocol(c, client_connected_cb, loop=loop)
if type(addr) is tuple:
return (yield from loop.create_server(factory, *addr, **kwds))
elif type(addr) is str:
return (yield from loop.create_unix_server(factory, addr, **kwds))
else:
raise ValueError('Unkown address type: %s' % addr)
[docs]class ChannelProtocol(asyncio.Protocol):
"""Asyncio :class:`asyncio.Protocol` which connects the low level transport
with the high level :class:`Channel` API.
The *codec* is used to (de)serialize messages. It should be a sub-class of
:class:`aiomas.codecs.Codec`.
Optionally you can also pass a function/coroutine *client_connected_cb*
that will be executed when a new connection is made (see
:func:`start_server()`).
"""
def __init__(self, codec, client_connected_cb=None, *, loop):
super().__init__()
self.codec = codec
self.transport = None
self.channel = None
self._client_connected_cb = client_connected_cb
self._loop = loop
self._buffer = bytearray()
self._read_size = None
# For flow control
self._paused = False
self._drain_waiter = None
self._connection_lost = False
self._out_msgs = asyncio.Queue()
self._task_process_out_msgs = asyncio.async(self._process_out_msgs())
[docs] def connection_made(self, transport):
"""Create a new :class:`Channel` instance for a new connection.
Also call the *client_connected_cb* if one was passed to this class.
"""
self.transport = transport
self.channel = Channel(self, self.codec, transport, loop=self._loop)
if self._client_connected_cb is not None:
res = self._client_connected_cb(self.channel)
if asyncio.iscoroutine(res):
asyncio.async(res, loop=self._loop)
[docs] def connection_lost(self, exc):
"""Set a :exc:`ConnectionError` to the :class:`Channel` to indicate
that the connection is closed."""
if exc is None: # pragma: no branch
exc = ConnectionError('Connection closed')
self.channel._set_exception(exc)
# Flow control
self._connection_lost = True
self._task_process_out_msgs.cancel()
# Wake up the writer if currently paused.
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
waiter.set_exception(exc)
# if exc is None:
# waiter.set_result(None)
# else:
# waiter.set_exception(exc)
[docs] def data_received(self, data):
"""Buffer incomming data until we have a complete message and then
pass it to :class:`Channel`.
Messages are fixed length. The first four bytes (in network byte
order) encode the length of the following payload. The payload is
a triple ``(msg_type, msg_id, content)`` encoded with the specified
*codec*.
"""
self._buffer.extend(data)
while True:
# We may have more then one message in the buffer,
# so we loop over the buffer until we got all complete messages.
if self._read_size is None and len(self._buffer) >= Header.size:
# Received the complete header of a new message
self._read_size = Header.unpack_from(self._buffer)[0]
# TODO: Check for too large messages?
self._read_size += Header.size
if self._read_size and len(self._buffer) >= self._read_size:
# At least one complete message is in the buffer
data = self._buffer[Header.size:self._read_size]
self._buffer = self._buffer[self._read_size:]
self._read_size = None
msg_type, msg_id, content = self.codec.decode(data)
try:
self.channel._feed_data(msg_type, msg_id, content)
except RuntimeError as exc:
self.channel._set_exception(exc)
else:
# No complete message in the buffer. We are done.
break
[docs] def eof_received(self):
"""Set a :exc:`ConnectionResetError` to the :class:`Channel`."""
# In previous revisions, an IncompleteMessage error was raised if we
# already received the beginning of a new message. However, having
# to types of exceptions raised by this methods makes things more
# complicated for the user. The benefit of the IncompleteMessage was
# not big enough.
self.channel._set_exception(ConnectionResetError())
@coroutine
[docs] def write(self, content):
"""Serialize *content* and write the result to the transport.
This method is a coroutine.
"""
content = Header.pack(len(content)) + content
done = asyncio.Future()
self._out_msgs.put_nowait((done, content))
yield from done
[docs] def pause_writing(self):
"""Set the *paused* flag to ``True``.
Can only be called if we are not already paused.
"""
assert not self._paused
self._paused = True
if self._loop.get_debug():
logger.debug("%r pauses writing", self)
[docs] def resume_writing(self):
"""Set the *paused* flat to ``False`` and trigger the waiter future.
Can only be called if we are paused.
"""
assert self._paused
self._paused = False
if self._loop.get_debug():
logger.debug("%r resumes writing", self)
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
waiter.set_result(None)
@coroutine
def _process_out_msgs(self):
try:
while True:
done, content = yield from self._out_msgs.get()
self.transport.write(content)
yield from self._drain_helper()
done.set_result(None)
except asyncio.CancelledError:
assert self._connection_lost
@coroutine
def _drain_helper(self, before=False):
if self._connection_lost:
raise ConnectionResetError('Connection lost')
if not self._paused:
return
waiter = self._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = asyncio.Future(loop=self._loop)
self._drain_waiter = waiter
yield from waiter
[docs]class Request:
"""Represents a request returned by :meth:`Channel.recv()`. You shoudn't
instantiate it yourself.
*content* contains the incoming message.
*msg_id* is the ID for that message. It is unique within a channel.
*protocol* is the channel's :class:`ChannelProtocol` instance that is used
for writing back the reply.
To reply to that request you can ``yield from`` :meth:`Request.reply()`
or :meth:`Request.fail()`.
"""
def __init__(self, content, message_id, protocol):
self._content = content
self._msg_id = message_id
self._protocol = protocol
@property
def content(self):
"""The content of the incoming message."""
return self._content
@coroutine
[docs] def reply(self, result):
"""Reply to the request with the provided result."""
protocol = self._protocol
content = (RESULT, self._msg_id, result)
try:
content = protocol.codec.encode(content)
except Exception as e:
return (yield from self.fail(e))
else:
yield from protocol.write(content)
@coroutine
[docs] def fail(self, exception):
"""Indicate a failure described by the *exception* instance.
This will raise a :exc:`~aiomas.exceptions.RemoteException` on the
other side of the channel.
"""
protocol = self._protocol
stacktrace = traceback.format_exception(exception.__class__, exception,
exception.__traceback__)
content = (EXCEPTION, self._msg_id, ''.join(stacktrace))
content = protocol.codec.encode(content)
yield from protocol.write(content)
[docs]class Channel:
"""A Channel represents a request-reply channel between two endpoints. An
instance of it is returned by :func:`open_connection()` or is passed to the
callback of :func:`start_server()`.
*protocol* is an instance of :class:`ChannelProtocol`.
*transport* is an :class:`asyncio.BaseTransport`.
*loop* is an instance of an :class:`asyncio.BaseEventLoop`.
"""
def __init__(self, protocol, codec, transport, loop):
self._protocol = protocol
self._codec = codec
self._transport = transport
self._loop = loop
self._message_id = itertools.count()
self._out_messages = {} # message_id -> message
self._in_queue = collections.deque()
self._waiter = None # A future.
self._exception = None
@property
def codec(self):
"""The codec used to de-/encode messages send via the channel."""
return self._codec
@property
def transport(self):
"""The transport of this channel (see the `Python documentation
<https://docs.python.org/3/library/asyncio-protocol.html#transports>`_
for details).
"""
return self._transport
[docs] def send(self, content):
"""Send a request *content* to the other end and return a future which
is triggered when a reply arrives.
One of the following exceptions may be raised:
- :exc:`~aiomas.exceptions.RemoteException`: The remote site raised an
exception during the computation of the result.
- :exc:`ConnectionError` (or its subclass :exc:`ConnectionResetError`):
The connection was closed during the request.
- :exc:`RuntimeError`:
- If an invalid message type was received.
- If the future returned by this method was already triggered or
canceled by a third party when an answer to the request arrives
(e.g., if a task containing the future is cancelled). You get
more detailed exception messages if you `enable asyncio's debug
mode`__
__ https://docs.python.org/3/library/asyncio-dev.html
.. code-block:: python
try:
result = yield from channel.request('ohai')
except RemoteException as exc:
print(exc)
"""
if self._exception is not None:
raise self._exception
message_id = next(self._message_id)
out_message = asyncio.Future(loop=self._loop)
if self._loop.get_debug():
self._out_messages[message_id] = (out_message, content)
else:
self._out_messages[message_id] = out_message
data = self._codec.encode((REQUEST, message_id, content))
asyncio.async(self._protocol.write(data), loop=self._loop)
return out_message
@coroutine
[docs] def recv(self):
"""Wait for an incoming :class:`Request` and return it.
May raise one of the following exceptions:
- :exc:`ConnectionError` (or its subclass :exc:`ConnectionResetError`):
The connection was closed during the request.
- :exc:`RuntimeError`: If two processes try to read from the same
channel or if an invalid message type was received.
"""
if self._exception is not None:
raise self._exception
if not self._in_queue:
if self._waiter is not None:
raise RuntimeError('recv() called while another coroutine is '
'already waiting for incoming data.')
self._waiter = asyncio.Future(loop=self._loop)
try:
yield from self._waiter
finally:
self._waiter = None
return self._in_queue.popleft()
[docs] def close(self):
"""Close the channel's transport."""
if self._transport is not None:
transport = self._transport
self._transport = None
return transport.close()
def _feed_data(self, msg_type, msg_id, content):
"""Called by :class:`ChannelProtocol` when a new message arrived."""
if msg_type == REQUEST:
message = Request(content, msg_id, self._protocol)
self._in_queue.append(message)
waiter = self._waiter
if waiter is not None:
self._waiter = None
waiter.set_result(False)
elif msg_type == RESULT or msg_type == EXCEPTION:
if self._loop.get_debug():
message, req = self._out_messages.pop(msg_id)
else:
message = self._out_messages.pop(msg_id)
if message.done():
errmsg = 'Request reply already set.'
if message.cancelled():
errmsg = 'Request was cancelled.'
if self._loop.get_debug():
errmsg += ' Request: %s; Reply: %s' % (req, content)
raise RuntimeError(errmsg)
if msg_type == RESULT:
message.set_result(content)
else:
origin = self.get_extra_info('peername')
message.set_exception(RemoteException(origin, content))
else:
raise RuntimeError('Invalid message type %d' % msg_type)
def _set_exception(self, exc):
"""Set an exception as result for all futures managed by the Channel
in order to stop all coroutines from reading/writing to the socket."""
self._exception = exc
# Set exception to wait-recv future
waiter = self._waiter
if waiter is not None:
self._waiter = None
if not waiter.cancelled():
waiter.set_exception(exc)
# Set exception to all message futures which wait for a reply
for msg in self._out_messages.values():
if self._loop.get_debug():
msg, _ = msg
if not msg.done():
msg.set_exception(exc)
self.close()