Source code for volttron.platform.auth

# -*- coding: utf-8 -*- {{{
# vim: set fenc=utf-8 ft=python sw=4 ts=4 sts=4 et:
#
# Copyright 2019, Battelle Memorial Institute.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This material was prepared as an account of work sponsored by an agency of
# the United States Government. Neither the United States Government nor the
# United States Department of Energy, nor Battelle, nor any of their
# employees, nor any jurisdiction or organization that has cooperated in the
# development of these materials, makes any warranty, express or
# implied, or assumes any legal liability or responsibility for the accuracy,
# completeness, or usefulness or any information, apparatus, product,
# software, or process disclosed, or represents that its use would not infringe
# privately owned rights. Reference herein to any specific commercial product,
# process, or service by trade name, trademark, manufacturer, or otherwise
# does not necessarily constitute or imply its endorsement, recommendation, or
# favoring by the United States Government or any agency thereof, or
# Battelle Memorial Institute. The views and opinions of authors expressed
# herein do not necessarily state or reflect those of the
# United States Government or any agency thereof.
#
# PACIFIC NORTHWEST NATIONAL LABORATORY operated by
# BATTELLE for the UNITED STATES DEPARTMENT OF ENERGY
# under Contract DE-AC05-76RL01830
# }}}


import bisect
import logging
import os
import random
import re
import shutil
import uuid
from collections import defaultdict

import gevent
import gevent.core
from gevent.fileobject import FileObject
from zmq import green as zmq

from volttron.platform import jsonapi
from volttron.platform.agent.known_identities import VOLTTRON_CENTRAL_PLATFORM, CONTROL, MASTER_WEB
from volttron.platform.vip.agent.errors import VIPError
from volttron.platform.vip.pubsubservice import ProtectedPubSubTopics
from .agent.utils import strip_comments, create_file_if_missing, watch_file
from .vip.agent import Agent, Core, RPC
from .vip.socket import encode_key, BASE64_ENCODED_CURVE_KEY_LEN

_log = logging.getLogger(__name__)

_dump_re = re.compile(r'([,\\])')
_load_re = re.compile(r'\\(.)|,')


[docs]def isregex(obj): return len(obj) > 1 and obj[0] == obj[-1] == '/'
[docs]def dump_user(*args): return ','.join([_dump_re.sub(r'\\\1', arg) for arg in args])
[docs]def load_user(string): def sub(match): return match.group(1) or '\x00' return _load_re.sub(sub, string).split('\x00')
[docs]class AuthException(Exception): """General exception for any auth error""" pass
[docs]class AuthService(Agent): def __init__(self, auth_file, protected_topics_file, setup_mode, aip, *args, **kwargs): self.allow_any = kwargs.pop('allow_any', False) super(AuthService, self).__init__(*args, **kwargs) # This agent is started before the router so we need # to keep it from blocking. self.core.delay_running_event_set = False self.auth_file_path = os.path.abspath(auth_file) self.auth_file = AuthFile(self.auth_file_path) self.aip = aip self.zap_socket = None self._zap_greenlet = None self.auth_entries = [] self._is_connected = False self._protected_topics_file = protected_topics_file self._protected_topics_file_path = os.path.abspath(protected_topics_file) self._protected_topics_for_rmq = ProtectedPubSubTopics() self._setup_mode = setup_mode self._auth_failures = [] self._auth_denied = [] self._auth_approved = [] def topics(): return defaultdict(set) self._user_to_permissions = topics()
[docs] @Core.receiver('onsetup') def setup_zap(self, sender, **kwargs): self.zap_socket = zmq.Socket(zmq.Context.instance(), zmq.ROUTER) self.zap_socket.bind('inproc://zeromq.zap.01') if self.allow_any: _log.warning('insecure permissive authentication enabled') self.read_auth_file() self._read_protected_topics_file() self.core.spawn(watch_file, self.auth_file_path, self.read_auth_file) self.core.spawn(watch_file, self._protected_topics_file_path, self._read_protected_topics_file) if self.core.messagebus == 'rmq': self.vip.peerlist.onadd.connect(self._check_topic_rules)
[docs] def read_auth_file(self): _log.info('loading auth file %s', self.auth_file_path) entries = self.auth_file.read_allow_entries() entries = [entry for entry in entries if entry.enabled] # sort the entries so the regex credentails follow the concrete creds entries.sort() self.auth_entries = entries if self._is_connected: try: _log.debug("Sending auth updates to peers") # Give it few seconds for platform to startup or for the # router to detect agent install/remove action gevent.sleep(2) self._send_update() except BaseException as e: _log.error("Exception sending auth updates to peer. {}".format(e)) raise e _log.info('auth file %s loaded', self.auth_file_path)
[docs] def get_protected_topics(self): protected = self._protected_topics return protected
def _read_protected_topics_file(self): # Read protected topics file and send to router try: create_file_if_missing(self._protected_topics_file) with open(self._protected_topics_file) as fil: # Use gevent FileObject to avoid blocking the thread data = FileObject(fil, close=False).read() self._protected_topics = jsonapi.loads(data) if data else {} if self.core.messagebus == 'rmq': self._load_protected_topics_for_rmq() # Deferring the RMQ topic permissions to after "onstart" event else: self._send_protected_update_to_pubsub(self._protected_topics) except Exception: _log.exception('error loading %s', self._protected_topics_file) def _send_update(self): user_to_caps = self.get_user_to_capabilities() i = 0 exception = None peers = None # peerlist times out lots of times when running test suite. This happens even with higher timeout in get() # but if we retry peerlist succeeds by second attempt most of the time!!! while not peers and i < 3: try: i = i + 1 peers = self.vip.peerlist().get(timeout=0.5) except BaseException as e: _log.warning("Attempt {} to get peerlist failed with exception {}".format(i, e)) peers = list(self.vip.peerlist.peers_list) _log.warning("Get list of peers from subsystem directly".format(peers)) exception = e if not peers: raise exception _log.debug("after getting peerlist to send auth updates") for peer in peers: if peer not in [self.core.identity]: self.vip.rpc.call(peer, 'auth.update', user_to_caps) if self.core.messagebus == 'rmq': self._check_rmq_topic_permissions() else: self._send_auth_update_to_pubsub() def _send_auth_update_to_pubsub(self): user_to_caps = self.get_user_to_capabilities() # Send auth update message to router json_msg = jsonapi.dumpb( dict(capabilities=user_to_caps) ) frames = [zmq.Frame(b'auth_update'), zmq.Frame(json_msg)] # <recipient, subsystem, args, msg_id, flags> self.core.socket.send_vip(b'', b'pubsub', frames, copy=False) def _send_protected_update_to_pubsub(self, contents): protected_topics_msg = jsonapi.dumpb(contents) frames = [zmq.Frame(b'protected_update'), zmq.Frame(protected_topics_msg)] if self._is_connected: try: # <recipient, subsystem, args, msg_id, flags> self.core.socket.send_vip(b'', b'pubsub', frames, copy=False) except VIPError as ex: _log.error("Error in sending protected topics update to clear PubSub: " + str(ex))
[docs] @Core.receiver('onstop') def stop_zap(self, sender, **kwargs): if self._zap_greenlet is not None: self._zap_greenlet.kill()
[docs] @Core.receiver('onfinish') def unbind_zap(self, sender, **kwargs): if self.zap_socket is not None: self.zap_socket.unbind('inproc://zeromq.zap.01')
[docs] @Core.receiver('onstart') def zap_loop(self, sender, **kwargs): """ The zap loop is the starting of the authentication process for the VOLTTRON zmq message bus. It talks directly with the low level socket so all responses must be byte like objects, in this case we are going to send zmq frames across the wire. :param sender: :param kwargs: :return: """ self._is_connected = True self._zap_greenlet = gevent.getcurrent() sock = self.zap_socket time = gevent.core.time blocked = {} wait_list = [] timeout = None if self.core.messagebus == 'rmq': # Check the topic permissions of all the connected agents self._check_rmq_topic_permissions() else: self._send_protected_update_to_pubsub(self._protected_topics) while True: events = sock.poll(timeout) now = time() if events: zap = sock.recv_multipart() version = zap[2] if version != b'1.0': continue domain, address, userid, kind = zap[4:8] credentials = zap[8:] if kind == b'CURVE': credentials[0] = encode_key(credentials[0]) elif kind not in [b'NULL', b'PLAIN']: continue response = zap[:4] domain = domain.decode("utf-8") address = address.decode("utf-8") kind = kind.decode("utf-8") user = self.authenticate(domain, address, kind, credentials) _log.info("AUTH: After authenticate user id: {0}, {1}".format(user, userid)) if user: _log.info( 'authentication success: userid=%r domain=%r, address=%r, ' 'mechanism=%r, credentials=%r, user=%r', userid, domain, address, kind, credentials[:1], user) response.extend([b'200', b'SUCCESS', user.encode("utf-8"), b'']) sock.send_multipart(response) else: userid = str(uuid.uuid4()) _log.info( 'authentication failure: userid=%r, domain=%r, address=%r, ' 'mechanism=%r, credentials=%r', userid, domain, address, kind, credentials) # If in setup mode, add/update auth entry if self._setup_mode: self._update_auth_entry(domain, address, kind, credentials[0], userid) _log.info( 'new authentication entry added in setup mode: domain=%r, address=%r, ' 'mechanism=%r, credentials=%r, user_id=%r', domain, address, kind, credentials[:1], userid) response.extend([b'200', b'SUCCESS', b'', b'']) _log.debug("AUTH response: {}".format(response)) sock.send_multipart(response) else: if type(userid) == bytes: userid = userid.decode("utf-8") self._update_auth_failures(domain, address, kind, credentials[0], userid) try: expire, delay = blocked[address] except KeyError: delay = random.random() else: if now >= expire: delay = random.random() else: delay *= 2 if delay > 100: delay = 100 expire = now + delay bisect.bisect(wait_list, (expire, address, response)) blocked[address] = expire, delay while wait_list: expire, address, response = wait_list[0] if now < expire: break wait_list.pop(0) response.extend([b'400', b'FAIL', b'', b'']) sock.send_multipart(response) try: if now >= blocked[address][0]: blocked.pop(address) except KeyError: pass timeout = (wait_list[0][0] - now) if wait_list else None
[docs] def authenticate(self, domain, address, mechanism, credentials): for entry in self.auth_entries: if entry.match(domain, address, mechanism, credentials): return entry.user_id or dump_user( domain, address, mechanism, *credentials[:1]) if mechanism == 'NULL' and address.startswith('localhost:'): parts = address.split(':')[1:] if len(parts) > 2: pid = int(parts[2]) agent_uuid = self.aip.agent_uuid_from_pid(pid) if agent_uuid: return dump_user(domain, address, 'AGENT', agent_uuid) uid = int(parts[0]) if uid == os.getuid(): return dump_user(domain, address, mechanism, *credentials[:1]) if self.allow_any: return dump_user(domain, address, mechanism, *credentials[:1])
[docs] @RPC.export def get_user_to_capabilities(self): """RPC method Gets a mapping of all users to their capabiliites. :returns: mapping of users to capabilities :rtype: dict """ user_to_caps = {} for entry in self.auth_entries: user_to_caps[entry.user_id] = entry.capabilities return user_to_caps
[docs] @RPC.export def get_authorizations(self, user_id): """RPC method Gets capabilities, groups, and roles for a given user. :param user_id: user id field from VOLTTRON Interconnect Protocol :type user_id: str :returns: tuple of capabiliy-list, group-list, role-list :rtype: tuple """ use_parts = True try: domain, address, mechanism, credentials = load_user(user_id) except ValueError: use_parts = False for entry in self.auth_entries: if entry.user_id == user_id: return [entry.capabilities, entry.groups, entry.roles] elif use_parts: if entry.match(domain, address, mechanism, [credentials]): return entry.capabilities, entry.groups, entry.roles
[docs] @RPC.export @RPC.allow(capabilities="allow_auth_modifications") def approve_authorization_failure(self, user_id): """RPC method Approves a previously failed authorization :param user_id: user id field from VOLTTRON Interconnect Protocol :type user_id: str """ for pending in self._auth_failures: if user_id == pending['user_id']: self._update_auth_entry( pending['domain'], pending['address'], pending['mechanism'], pending['credentials'], pending['user_id'] ) self._auth_approved.append(pending) del self._auth_failures[self._auth_failures.index(pending)] for pending in self._auth_denied: if user_id == pending['user_id']: self._update_auth_entry( pending['domain'], pending['address'], pending['mechanism'], pending['credentials'], pending['user_id'] ) self._auth_approved.append(pending) del self._auth_denied[self._auth_denied.index(pending)]
[docs] @RPC.export @RPC.allow(capabilities="allow_auth_modifications") def deny_authorization_failure(self, user_id): """RPC method Denies a previously failed authorization :param user_id: user id field from VOLTTRON Interconnect Protocol :type user_id: str """ for pending in self._auth_failures: if user_id == pending['user_id']: self._auth_denied.append(pending) del self._auth_failures[self._auth_failures.index(pending)] for pending in self._auth_approved: if user_id == pending['user_id']: self._remove_auth_entry(pending['credentials']) self._auth_denied.append(pending) del self._auth_approved[self._auth_approved.index(pending)]
[docs] @RPC.export @RPC.allow(capabilities="allow_auth_modifications") def delete_authorization_failure(self, user_id): """RPC method Denies a previously failed authorization :param user_id: user id field from VOLTTRON Interconnect Protocol :type user_id: str """ for pending in self._auth_failures: if user_id == pending['user_id']: del self._auth_failures[self._auth_failures.index(pending)] for pending in self._auth_approved: if user_id == pending['user_id']: self._remove_auth_entry(pending['credentials']) del self._auth_approved[self._auth_approved.index(pending)] for pending in self._auth_denied: if user_id == pending['user_id']: del self._auth_denied[self._auth_denied.index(pending)]
[docs] @RPC.export def get_authorization_failures(self): return list(self._auth_failures)
[docs] @RPC.export def get_authorization_approved(self): return list(self._auth_approved)
[docs] @RPC.export def get_authorization_denied(self): return list(self._auth_denied)
def _get_authorizations(self, user_id, index): """Convenience method for getting authorization component by index""" auths = self.get_authorizations(user_id) if auths: return auths[index] return []
[docs] @RPC.export def get_capabilities(self, user_id): """RPC method Gets capabilities for a given user. :param user_id: user id field from VOLTTRON Interconnect Protocol :type user_id: str :returns: list of capabilities :rtype: list """ return self._get_authorizations(user_id, 0)
[docs] @RPC.export def get_groups(self, user_id): """RPC method Gets groups for a given user. :param user_id: user id field from VOLTTRON Interconnect Protocol :type user_id: str :returns: list of groups :rtype: list """ return self._get_authorizations(user_id, 1)
[docs] @RPC.export def get_roles(self, user_id): """RPC method Gets roles for a given user. :param user_id: user id field from VOLTTRON Interconnect Protocol :type user_id: str :returns: list of roles :rtype: list """ return self._get_authorizations(user_id, 2)
def _update_auth_entry(self, domain, address, mechanism, credential, user_id): # Make a new entry fields = { "domain": domain, "address": address, "mechanism": mechanism, "credentials": credential, "user_id": user_id, "groups": "", "roles": "", "capabilities": "", "comments": "Auth entry added in setup mode", } new_entry = AuthEntry(**fields) try: self.auth_file.add(new_entry, overwrite=False) except AuthException as err: _log.error('ERROR: %s\n' % str(err)) def _remove_auth_entry(self, credential): try: self.auth_file.remove_by_credentials(credential) except AuthException as err: _log.error('ERROR: %s\n' % str(err)) def _update_auth_failures(self, domain, address, mechanism, credential, user_id): for entry in self._auth_denied: # Check if failure entry has been denied. If so, increment the failure's denied count if ((entry['domain'] == domain) and (entry['address'] == address) and (entry['mechanism'] == mechanism) and (entry['credentials'] == credential)): entry['retries'] += 1 return for entry in self._auth_failures: # Check if failure entry exists. If so, increment the failure count if ((entry['domain'] == domain) and (entry['address'] == address) and (entry['mechanism'] == mechanism) and (entry['credentials'] == credential)): entry['retries'] += 1 return # Add a new failure entry fields = { "domain": domain, "address": address, "mechanism": mechanism, "credentials": credential, "user_id": user_id, "retries": 1 } self._auth_failures.append(dict(fields)) return def _load_protected_topics_for_rmq(self): try: write_protect = self._protected_topics['write-protect'] except KeyError: write_protect = [] topics = ProtectedPubSubTopics() try: for entry in write_protect: topics.add(entry['topic'], entry['capabilities']) except KeyError: _log.exception('invalid format for protected topics ') else: self._protected_topics_for_rmq = topics def _check_topic_rules(self, sender, **kwargs): delay = 0.05 self.core.spawn_later(delay, self._check_rmq_topic_permissions) def _check_rmq_topic_permissions(self): """ Go through the topic permissions for each agent based on the protected topic setting. Update the permissions for the agent/user based on the latest configuration :return: """ return # Get agent to capabilities mapping user_to_caps = self.get_user_to_capabilities() # Get topics to capabilities mapping topic_to_caps = self._protected_topics_for_rmq.get_topic_caps() # topic to caps peers = self.vip.peerlist().get(timeout=5) # _log.debug("USER TO CAPS: {0}, TOPICS TO CAPS: {1}, {2}".format(user_to_caps, # topic_to_caps, # self._user_to_permissions)) if not user_to_caps or not topic_to_caps: # clear all old permission rules for peer in peers: self._user_to_permissions[peer].clear() else: for topic, caps_for_topic in topic_to_caps.items(): for user in user_to_caps: try: caps_for_user = user_to_caps[user] common_caps = list(set(caps_for_user).intersection(caps_for_topic)) if common_caps: self._user_to_permissions[user].add(topic) else: try: self._user_to_permissions[user].remove(topic) except KeyError as e: if not self._user_to_permissions[user]: self._user_to_permissions[user] = set() except KeyError as e: try: self._user_to_permissions[user].remove(topic) except KeyError as e: if not self._user_to_permissions[user]: self._user_to_permissions[user] = set() all = set() for user in user_to_caps: all.update(self._user_to_permissions[user]) # Set topic permissions now for peer in peers: not_allowed = all.difference(self._user_to_permissions[peer]) self._update_topic_permission_tokens(peer, not_allowed) def _update_topic_permission_tokens(self, identity, not_allowed): """ Make rules for read and write permission on topic (routing key) for an agent based on protected topics setting :param identity: identity of the agent :return: """ read_tokens = ["{instance}.{identity}".format(instance=self.core.instance_name, identity=identity), "__pubsub__.*"] write_tokens = ["{instance}.*".format(instance=self.core.instance_name, identity=identity)] if not not_allowed: write_tokens.append("__pubsub__.{instance}.*".format(instance=self.core.instance_name)) else: not_allowed_string = "|".join(not_allowed) write_tokens.append("__pubsub__.{instance}.".format(instance=self.core.instance_name) + "^(!({not_allow})).*$".format(not_allow=not_allowed_string)) current = self.core.rmq_mgmt.get_topic_permissions_for_user(identity) # _log.debug("CURRENT for identity: {0}, {1}".format(identity, current)) if current and isinstance(current, list): current = current[0] dift = False read_allowed_str = "|".join(read_tokens) write_allowed_str = "|".join(write_tokens) if re.search(current['read'], read_allowed_str): dift = True current["read"] = read_allowed_str if re.search(current["write"], write_allowed_str): dift = True current["write"] = write_allowed_str # _log.debug("NEW {0}, DIFF: {1} ".format(current, dift)) # if dift: # set_topic_permissions_for_user(current, identity) else: current = dict() current["exchange"] = "volttron" current["read"] = "|".join(read_tokens) current["write"] = "|".join(write_tokens) # _log.debug("NEW {0}, New string ".format(current)) # set_topic_permissions_for_user(current, identity) def _check_token(self, actual, allowed): pending = actual[:] for tk in actual: if tk in allowed: pending.remove(tk) return pending
[docs]class String(str): def __new__(cls, value): obj = super(String, cls).__new__(cls, value) if isregex(obj): obj.regex = regex = re.compile('^' + obj[1:-1] + '$') obj.match = lambda val: bool(regex.match(val)) return obj
[docs] def match(self, value): return value == self
[docs]class List(list):
[docs] def match(self, value): for elem in self: if elem.match(value): return True return False
[docs]class AuthEntryInvalid(AuthException): """Exception for invalid AuthEntry objects""" pass
[docs]class AuthEntry(object): """An authentication entry contains fields for authenticating and granting permissions to an agent that connects to the platform. :param str domain: Name assigned to locally bound address :param str address: Remote address of the agent :param str mechanism: Authentication mechanism, valid options are 'NULL' (no authentication), 'PLAIN' (username/password), 'CURVE' (CurveMQ public/private keys) :param str credentials: Value depends on `mechanism` parameter: `None` if mechanism is 'NULL'; password if mechanism is 'PLAIN'; encoded public key if mechanism is 'CURVE' (see :py:meth:`volttron.platform.vip.socket.encode_key` for method to encode public key) :param str user_id: Name to associate with agent (Note: this does not have to match the agent's VIP identity) :param list capabilities: Authorized capabilities for this agent :param list roles: Authorized roles for this agent. (Role names map to a set of capabilities) :param list groups: Authorized groups for this agent. (Group names map to a set of roles) :param str comments: Comments to associate with entry :param bool enabled: Entry will only be used if this value is True :param kwargs: These extra arguments will be ignored """ def __init__(self, domain=None, address=None, mechanism='CURVE', credentials=None, user_id=None, groups=None, roles=None, capabilities=None, comments=None, enabled=True, **kwargs): self.domain = AuthEntry._build_field(domain) self.address = AuthEntry._build_field(address) self.mechanism = mechanism self.credentials = AuthEntry._build_field(credentials) self.groups = AuthEntry._build_field(groups) or [] self.roles = AuthEntry._build_field(roles) or [] self.capabilities = AuthEntry.build_capabilities_field(capabilities) or {} self.comments = AuthEntry._build_field(comments) if user_id is None: user_id = str(uuid.uuid4()) self.user_id = user_id self.enabled = enabled if kwargs: _log.debug( 'auth record has unrecognized keys: %r' % (list(kwargs.keys()),)) self._check_validity() def __lt__(self, other): """Entries with non-regex credentials will be less than regex credentials. When sorted, the non-regex credentials will be checked first.""" try: self.credentials.regex except AttributeError: return True return False @staticmethod def _build_field(value): if not value: return None if isinstance(value, str): return String(value) return List(String(elem) for elem in value)
[docs] @staticmethod def build_capabilities_field(value): #_log.debug("_build_capabilities {}".format(value)) if not value: return None if isinstance(value, list): result = dict() for elem in value: # update if it is not there or if existing entry doesn't have args. # i.e. capability with args can override capability str temp = result.update(AuthEntry._get_capability(elem)) if temp and result[next(iter(temp))] is None: result.update(temp) _log.debug("Returning field _build_capabilities {}".format(result)) return result else: return AuthEntry._get_capability(value)
@staticmethod def _get_capability(value): err_message = "Invalid capability value: {} of type {}. Capability entries can only be a string or " \ "dictionary or list containing string/dictionary. " \ "dictionaries should be of the format {'capability_name':None} or " \ "{'capability_name':{'arg1':'value',...}" if isinstance(value, str): return {value: None} elif isinstance(value, dict): return value else: raise AuthEntryInvalid(err_message.format(value, type(value)))
[docs] def add_capabilities(self, capabilities): temp = AuthEntry.build_capabilities_field(capabilities) if temp: self.capabilities.update(temp)
[docs] def match(self, domain, address, mechanism, credentials): return ((self.domain is None or self.domain.match(domain)) and (self.address is None or self.address.match(address)) and self.mechanism == mechanism and (self.mechanism == 'NULL' or (len(self.credentials) > 0 and self.credentials.match(credentials[0]))))
def __str__(self): return ('domain={0.domain!r}, address={0.address!r}, ' 'mechanism={0.mechanism!r}, credentials={0.credentials!r}, ' 'user_id={0.user_id!r}, capabilities={0.capabilities!r}'.format(self)) def __repr__(self): cls = self.__class__ return '%s.%s(%s)' % (cls.__module__, cls.__name__, self)
[docs] @staticmethod def valid_credentials(cred, mechanism='CURVE'): """Raises AuthEntryInvalid if credentials are invalid""" AuthEntry.valid_mechanism(mechanism) if mechanism == 'NULL': return if cred is None: raise AuthEntryInvalid( 'credentials parameter is required for mechanism {}' .format(mechanism)) if isregex(cred): return if mechanism == 'CURVE' and len(cred) != BASE64_ENCODED_CURVE_KEY_LEN: raise AuthEntryInvalid('Invalid CURVE public key {}')
[docs] @staticmethod def valid_mechanism(mechanism): """Raises AuthEntryInvalid if mechanism is invalid""" if mechanism not in ('NULL', 'PLAIN', 'CURVE'): raise AuthEntryInvalid( 'mechanism must be either "NULL", "PLAIN" or "CURVE"')
def _check_validity(self): """Raises AuthEntryInvalid if entry is invalid""" AuthEntry.valid_credentials(self.credentials, self.mechanism)
[docs]class AuthFile(object): def __init__(self, auth_file=None): if auth_file is None: auth_file_dir = os.path.expanduser( os.environ.get('VOLTTRON_HOME', '~/.volttron')) auth_file = os.path.join(auth_file_dir, 'auth.json') self.auth_file = auth_file self._check_for_upgrade() @property def version(self): return {'major': 1, 'minor': 2} def _check_for_upgrade(self): allow_list, groups, roles, version = self._read() if version != self.version: if version['major'] <= self.version['major']: self._upgrade(allow_list, groups, roles, version) else: _log.error('This version of VOLTTRON cannot parse {}. ' 'Please upgrade VOLTTRON or move or delete ' 'this file.'.format(self.auth_file)) def _read(self): auth_data = {} try: create_file_if_missing(self.auth_file) with open(self.auth_file) as fil: # Use gevent FileObject to avoid blocking the thread before_strip_comments = FileObject(fil, close=False).read() if isinstance(before_strip_comments, bytes): before_strip_comments = before_strip_comments.decode("utf-8") data = strip_comments(before_strip_comments) if data: auth_data = jsonapi.loads(data) except Exception: _log.exception('error loading %s', self.auth_file) allow_list = auth_data.get('allow', []) groups = auth_data.get('groups', {}) roles = auth_data.get('roles', {}) version = auth_data.get('version', {'major': 0, 'minor': 0}) return allow_list, groups, roles, version
[docs] def read(self): """Gets the allowed entries, groups, and roles from the auth file. :returns: tuple of allow-entries-list, groups-dict, roles-dict :rtype: tuple """ allow_list, groups, roles, _ = self._read() entries = self._get_entries(allow_list) self._use_groups_and_roles(entries, groups, roles) return entries, groups, roles
def _upgrade(self, allow_list, groups, roles, version): backup = self.auth_file + '.' + str(uuid.uuid4()) + '.bak' shutil.copy(self.auth_file, backup) _log.info('Created backup of {} at {}'.format(self.auth_file, backup)) def warn_invalid(entry, msg=''): _log.warn('Invalid entry {} in auth file {}. {}' .format(entry, self.auth_file, msg)) def upgrade_0_to_1(allow_list): new_allow_list = [] for entry in allow_list: try: credentials = entry['credentials'] except KeyError: warn_invalid(entry) continue if isregex(credentials): msg = 'Cannot upgrade entries with regex credentials' warn_invalid(entry, msg) continue if credentials == 'NULL': mechanism = 'NULL' credentials = None else: match = re.match(r'^(PLAIN|CURVE):(.*)', credentials) if match is None: msg = 'Expected NULL, PLAIN, or CURVE credentials' warn_invalid(entry, msg) continue try: mechanism = match.group(1) credentials = match.group(2) except IndexError: warn_invalid(entry, 'Unexpected credential format') continue new_allow_list.append({ "domain": entry.get('domain'), "address": entry.get('address'), "mechanism": mechanism, "credentials": credentials, "user_id": entry.get('user_id'), "groups": entry.get('groups', []), "roles": entry.get('roles', []), "capabilities": entry.get('capabilities', []), "comments": entry.get('comments'), "enabled": entry.get('enabled', True) }) return new_allow_list def upgrade_1_0_to_1_1(allow_list): new_allow_list = [] user_id_set = set() for entry in allow_list: user_id = entry.get('user_id') if user_id: if user_id in user_id_set: new_user_id = str(uuid.uuid4()) msg = ('user_id {} is already present in ' 'authentication entry. Changed to user_id to ' '{}').format(user_id, new_user_id) _log.warn(msg) user_id_ = new_user_id else: user_id = str(uuid.uuid4()) user_id_set.add(user_id) entry['user_id'] = user_id new_allow_list.append(entry) return new_allow_list def upgrade_1_1_to_1_2(allow_list): new_allow_list = [] for entry in allow_list: user_id = entry.get('user_id') if user_id in [CONTROL, VOLTTRON_CENTRAL_PLATFORM]: user_id = '/.*/' capabilities = entry.get('capabilities') entry['capabilities'] = AuthEntry.build_capabilities_field(capabilities) or {} entry['capabilities']['edit_config_store'] = {'identity': user_id} new_allow_list.append(entry) return new_allow_list if version['major'] == 0: allow_list = upgrade_0_to_1(allow_list) version['major'] = 1 version['minor'] = 0 if version['major'] == 1 and version['minor'] == 0: allow_list = upgrade_1_0_to_1_1(allow_list) version['minor'] = 1 if version['major'] == 1 and version['minor'] == 1: allow_list = upgrade_1_1_to_1_2(allow_list) entries = self._get_entries(allow_list) self._write(entries, groups, roles)
[docs] def read_allow_entries(self): """Gets the allowed entries from the auth file. :returns: list of allow-entries :rtype: list """ return self.read()[0]
[docs] def find_by_credentials(self, credentials): """Find all entries that have the given credentials :param str credentials: The credentials to search for :return: list of entries :rtype: list """ return [entry for entry in self.read_allow_entries() if str(entry.credentials) == credentials]
def _get_entries(self, allow_list): entries = [] for file_entry in allow_list: try: entry = AuthEntry(**file_entry) except TypeError: _log.warn('invalid entry %r in auth file %s', file_entry, self.auth_file) except AuthEntryInvalid as e: _log.warn('invalid entry %r in auth file %s (%s)', file_entry, self.auth_file, str(e)) else: entries.append(entry) return entries def _use_groups_and_roles(self, entries, groups, roles): """Add capabilities to each entry based on groups and roles""" for entry in entries: entry_roles = entry.roles # Each group is a list of roles for group in entry.groups: entry_roles += groups.get(group, []) capabilities = [] # Each role is a list of capabilities for role in entry_roles: capabilities += roles.get(role, []) entry.add_capabilities(list(set(capabilities))) def _check_if_exists(self, entry): """Raises AuthFileEntryAlreadyExists if entry is already in file""" for index, prev_entry in enumerate(self.read_allow_entries()): if entry.user_id == prev_entry.user_id: raise AuthFileUserIdAlreadyExists(entry.user_id, [index]) # Compare AuthEntry objects component-wise, rather than # using match, because match will evaluate regex. if (prev_entry.domain == entry.domain and prev_entry.address == entry.address and prev_entry.mechanism == entry.mechanism and prev_entry.credentials == entry.credentials): raise AuthFileEntryAlreadyExists([index]) def _update_by_indices(self, auth_entry, indices): """Updates all entries at given indices with auth_entry""" for index in indices: self.update_by_index(auth_entry, index)
[docs] def add(self, auth_entry, overwrite=False): """Adds an AuthEntry to the auth file :param auth_entry: authentication entry :param overwrite: set to true to overwrite matching entries :type auth_entry: AuthEntry :type overwrite: bool .. warning:: If overwrite is set to False and if auth_entry matches an existing entry then this method will raise AuthFileEntryAlreadyExists """ try: self._check_if_exists(auth_entry) except AuthFileEntryAlreadyExists as err: if overwrite: _log.debug("Updating existing auth entry with {} ".format(auth_entry)) self._update_by_indices(auth_entry, err.indices) else: raise err else: entries, groups, roles = self.read() entries.append(auth_entry) self._write(entries, groups, roles) _log.debug("Added auth entry {} ".format(auth_entry)) gevent.sleep(1)
[docs] def remove_by_credentials(self, credentials): """Removes entry from auth file by credential :para credential: entries will this credential will be removed :type credential: str """ entries, groups, roles = self.read() entries = [e for e in entries if e.credentials != credentials] self._write(entries, groups, roles)
[docs] def remove_by_index(self, index): """Removes entry from auth file by index :param index: index of entry to remove :type index: int .. warning:: Calling with out-of-range index will raise AuthFileIndexError """ self.remove_by_indices([index])
[docs] def remove_by_indices(self, indices): """Removes entry from auth file by indices :param indices: list of indicies of entries to remove :type indices: list .. warning:: Calling with out-of-range index will raise AuthFileIndexError """ indices = list(set(indices)) indices.sort(reverse=True) entries, groups, roles = self.read() for index in indices: try: del entries[index] except IndexError: raise AuthFileIndexError(index) self._write(entries, groups, roles)
def _set_groups_or_roles(self, groups_or_roles, is_group=True): param_name = 'groups' if is_group else 'roles' if not isinstance(groups_or_roles, dict): raise ValueError('{} parameter must be dict'.format(param_name)) for key, value in groups_or_roles.items(): if not isinstance(value, list): raise ValueError('each value of the {} dict must be ' 'a list'.format(param_name)) entries, groups, roles = self.read() if is_group: groups = groups_or_roles else: roles = groups_or_roles self._write(entries, groups, roles)
[docs] def set_groups(self, groups): """Define the mapping of group names to role lists :param groups: dict where the keys are group names and the values are lists of capability names :type groups: dict .. warning:: Calling with invalid groups will raise ValueError """ self._set_groups_or_roles(groups, is_group=True)
[docs] def set_roles(self, roles): """Define the mapping of role names to capability lists :param roles: dict where the keys are role names and the values are lists of group names :type groups: dict .. warning:: Calling with invalid roles will raise ValueError """ self._set_groups_or_roles(roles, is_group=False)
[docs] def update_by_index(self, auth_entry, index): """Updates entry will given auth entry at given index :param auth_entry: new authorization entry :param index: index of entry to update :type auth_entry: AuthEntry :type index: int .. warning:: Calling with out-of-range index will raise AuthFileIndexError """ entries, groups, roles = self.read() try: entries[index] = auth_entry except IndexError: raise AuthFileIndexError(index) self._write(entries, groups, roles)
def _write(self, entries, groups, roles): auth = {'allow': [vars(x) for x in entries], 'groups': groups, 'roles': roles, 'version': self.version} with open(self.auth_file, 'w') as fp: jsonapi.dump(auth, fp, indent=2)
[docs]class AuthFileIndexError(AuthException, IndexError): """Exception for invalid indices provided to AuthFile""" def __init__(self, indices, message=None): if not isinstance(indices, list): indices = [indices] if message is None: message = 'Invalid {}: {}'.format( 'indicies' if len(indices) > 1 else 'index', indices) super(AuthFileIndexError, self).__init__(message) self.indices = indices
[docs]class AuthFileEntryAlreadyExists(AuthFileIndexError): """Exception if adding an entry that already exists""" def __init__(self, indicies, message=None): if message is None: message = ('entry matches domain, address and credentials at ' 'index {}').format(indicies) super(AuthFileEntryAlreadyExists, self).__init__(indicies, message)
[docs]class AuthFileUserIdAlreadyExists(AuthFileEntryAlreadyExists): """Exception if adding an entry that has a taken user_id""" def __init__(self, user_id, indicies, message=None): if message is None: message = ('user_id {} is already in use at ' 'index {}').format(user_id, indicies) super(AuthFileUserIdAlreadyExists, self).__init__(indicies, message)