Source code for

# -*- coding: utf-8 -*- {{{
# vim: set fenc=utf-8 ft=python sw=4 ts=4 sts=4 et:
# Copyright 2017, Battelle Memorial Institute.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# This material was prepared as an account of work sponsored by an agency of
# the United States Government. Neither the United States Government nor the
# United States Department of Energy, nor Battelle, nor any of their
# employees, nor any jurisdiction or organization that has cooperated in the
# development of these materials, makes any warranty, express or
# implied, or assumes any legal liability or responsibility for the accuracy,
# completeness, or usefulness or any information, apparatus, product,
# software, or process disclosed, or represents that its use would not infringe
# privately owned rights. Reference herein to any specific commercial product,
# process, or service by trade name, trademark, manufacturer, or otherwise
# does not necessarily constitute or imply its endorsement, recommendation, or
# favoring by the United States Government or any agency thereof, or
# Battelle Memorial Institute. The views and opinions of authors expressed
# herein do not necessarily state or reflect those of the
# United States Government or any agency thereof.
# under Contract DE-AC05-76RL01830
# }}}

from __future__ import absolute_import

import os
import logging
import zmq
from zmq import Frame, NOBLOCK, ZMQError, EINVAL, EHOSTUNREACH

__all__ = ['BaseRouter', 'OUTGOING', 'INCOMING', 'UNROUTABLE', 'ERROR']


_log = logging.getLogger(__name__)

# Optimizing by pre-creating frames
    errnum: (zmq.Frame(str(errnum).encode('ascii')),
    for errnum in [zmq.EHOSTUNREACH, zmq.EAGAIN]

_log = logging.getLogger(__name__)

[docs]class BaseRouter(object): '''Abstract base class of VIP router implementation. Router implementers should inherit this class and implement the setup() method to bind to appropriate addresses, set identities, setup authentication, etc, etc. The socket will be created by the start() method, which will then call the setup() method. Once started, the socket may be polled for incoming messages and those messages are handled/routed by calling the route() method. During routing, the issue() method, which may be implemented, will be called to allow for debugging and logging. Custom subsystems may be implemented in the handle_subsystem() method. The socket will be closed when the stop() method is called. ''' _context_class = zmq.Context _socket_class = zmq.Socket _poller_class = zmq.Poller def __init__(self, context=None, default_user_id=None): '''Initialize the object instance. If context is None (the default), the zmq global context will be used for socket creation. ''' self.context = context or self._context_class.instance() self.default_user_id = default_user_id self.socket = None self._peers = set() self._poller = self._poller_class() self._ext_sockets = [] self._socket_id_mapping = {}
[docs] def run(self): '''Main router loop.''' self.start() try: while True: self.poll_sockets() finally: self.stop()
[docs] def start(self): '''Create the socket and call setup(). The socket is save in the socket attribute. The setup() method is called at the end of the method to perform additional setup. ''' self.socket = sock = self._socket_class(self.context, zmq.ROUTER) sock.router_mandatory = True sock.sndtimeo = 0 sock.tcp_keepalive = True sock.tcp_keepalive_idle = 180 sock.tcp_keepalive_intvl = 20 sock.tcp_keepalive_cnt = 6 self.context.set(zmq.MAX_SOCKETS, 30690) sock.set_hwm(6000) _log.debug("ROUTER SENDBUF: {0}, {1}".format(sock.getsockopt(zmq.SNDBUF), sock.getsockopt(zmq.RCVBUF))) self.setup()
[docs] def stop(self, linger=1): '''Close the socket.''' self.socket.close(linger)
[docs] def setup(self): '''Called from start() method to setup the socket. Implement this method to bind the socket, set identities and options, etc. ''' raise NotImplementedError()
[docs] def poll_sockets(self): '''Called inside run method Implement this method to poll for sockets for incoming messages. ''' raise NotImplementedError()
@property def poll(self): '''Returns the underlying socket's poll method.''' return self.socket.poll
[docs] def handle_subsystem(self, frames, user_id): '''Handle additional subsystems and provide a response. This method does nothing by default and may be implemented by subclasses to provide additional subsystems. frames is a list of zmq.Frame objects with the following elements: [SENDER, RECIPIENT, PROTOCOL, USER_ID, MSG_ID, SUBSYSTEM, ...] The return value should be None, if the subsystem is unknown, an empty list or False (or other False value) if the message was handled but does not require/generate a response, or a list of containing the following elements: [RECIPIENT, SENDER, PROTOCOL, USER_ID, MSG_ID, SUBSYSTEM, ...] ''' pass
[docs] def issue(self, topic, frames, extra=None): pass
if zmq.zmq_version_info() >= (4, 1, 0): def lookup_user_id(self, sender, recipient, auth_token): '''Find and return a user identifier. Returns the UTF-8 encoded User-Id property from the sender frame or None if the authenticator did not set the User-Id metadata. May be extended to perform additional lookups. ''' # pylint: disable=unused-argument # A user id might/should be set by the ZAP authenticator try: return recipient.get('User-Id').encode('utf-8') except ZMQError as exc: if exc.errno != EINVAL: raise return self.default_user_id else:
[docs] def lookup_user_id(self, sender, recipient, auth_token): '''Find and return a user identifier. A no-op by default, this method must be overridden to map the sender and auth_token to a user ID. The returned value must be a string or None (if the token was not found). ''' return self.default_user_id
def _distribute(self, *parts): drop = set() empty = Frame(b'') frames = [empty, empty, Frame(b'VIP1'), empty, empty] frames.extend(Frame(f) for f in parts) for peer in self._peers: frames[0] = peer drop.update(self._send(frames)) for peer in drop: self._drop_peer(peer) def _drop_pubsub_peers(self, peer): '''Drop peers for pubsub subsystem. To be handled by subclasses''' pass def _add_pubsub_peers(self, peer): '''Add peers for pubsub subsystem. To be handled by subclasses''' pass def _add_peer(self, peer): if peer in self._peers: return self._distribute(b'peerlist', b'add', peer) self._peers.add(peer) self._add_pubsub_peers(peer) def _drop_peer(self, peer): try: self._peers.remove(peer) except KeyError: return self._distribute(b'peerlist', b'drop', peer) self._drop_pubsub_peers(peer)
[docs] def route(self, frames): '''Route one message and return. One message is read from the socket and processed. If the recipient is the router (empty recipient), the standard hello and ping subsystems are handled. Other subsystems are sent to handle_subsystem() for processing. Messages destined for other entities are routed appropriately. ''' socket = self.socket issue = self.issue issue(INCOMING, frames) # for f in frames: # _log.debug("ROUTER Receiving frames: {}".format(bytes(f))) if len(frames) < 6: # Cannot route if there are insufficient frames, such as # might happen with a router probe. if len(frames) == 2 and frames[0] and not frames[1]: issue(UNROUTABLE, frames, 'router probe') self._add_peer(frames[0].bytes) else: issue(UNROUTABLE, frames, 'too few frames') return sender, recipient, proto, auth_token, msg_id = frames[:5] if proto.bytes != b'VIP1': # Peer is not talking a protocol we understand issue(UNROUTABLE, frames, 'bad VIP signature') return user_id = self.lookup_user_id(sender, recipient, auth_token) if user_id is None: user_id = b'' self._add_peer(sender.bytes) subsystem = frames[5] if not recipient.bytes: # Handle requests directed at the router name = subsystem.bytes if name == b'hello': frames = [sender, recipient, proto, user_id, msg_id, b'hello', b'welcome', b'1.0', socket.identity, sender] elif name == b'ping': frames[:7] = [ sender, recipient, proto, user_id, msg_id, b'ping', b'pong'] elif name == b'peerlist': try: op = frames[6].bytes except IndexError: op = None frames = [sender, recipient, proto, b'', msg_id, subsystem] if op == b'list': frames.append(b'listing') frames.extend(self._peers) else: error = (b'unknown' if op else b'missing') + b' operation' frames.extend([b'error', error]) elif name == b'error': return else: response = self.handle_subsystem(frames, user_id) if response is None: # Handler does not know of the subsystem errnum, errmsg = error = _INVALID_SUBSYSTEM issue(ERROR, frames, error) frames = [sender, recipient, proto, b'', msg_id, b'error', errnum, errmsg, b'', subsystem] elif not response: # Subsystem does not require a response return else: frames = response else: # Route all other requests to the recipient frames[:4] = [recipient, sender, proto, user_id] for peer in self._send(frames): self._drop_peer(peer)
def _send(self, frames): issue = self.issue socket = self.socket drop = [] recipient, sender = frames[:2] # Expecting outgoing frames: # [RECIPIENT, SENDER, PROTO, USER_ID, MSG_ID, SUBSYS, ...] # for f in frames: # _log.debug("ROUTER sending frames: {}".format(bytes(f))) try: # Try sending the message to its recipient socket.send_multipart(frames, flags=NOBLOCK, copy=False) issue(OUTGOING, frames) except ZMQError as exc: try: errnum, errmsg = error = _ROUTE_ERRORS[exc.errno] except KeyError: error = None if error is None: raise issue(ERROR, frames, error) if exc.errno == EHOSTUNREACH: drop.append(bytes(recipient)) if exc.errno != EHOSTUNREACH or sender is not frames[0]: # Only send errors if the sender and recipient differ proto, user_id, msg_id, subsystem = frames[2:6] frames = [sender, b'', proto, user_id, msg_id, b'error', errnum, errmsg, recipient, subsystem] try: socket.send_multipart(frames, flags=NOBLOCK, copy=False) issue(OUTGOING, frames) except ZMQError as exc: try: errnum, errmsg = error = _ROUTE_ERRORS[exc.errno] except KeyError: error = None if error is None: raise issue(ERROR, frames, error) if exc.errno == EHOSTUNREACH: drop.append(bytes(sender)) return drop