Enable host_key checking at the strategy level

Implements a new method in the ssh connection plugin (fetch_and_store_key)
which is used to prefetch the key using ssh-keyscan.
This commit is contained in:
James Cammarata 2015-12-15 09:39:13 -05:00
parent 15135f3c16
commit e5c2c03dea
6 changed files with 273 additions and 33 deletions

View file

@ -32,6 +32,7 @@ from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVar
from ansible.playbook.conditional import Conditional
from ansible.playbook.task import Task
from ansible.template import Templar
from ansible.utils.connection import get_smart_connection_type
from ansible.utils.encrypt import key_for_hostname
from ansible.utils.listify import listify_lookup_plugin_terms
from ansible.utils.unicode import to_unicode
@ -564,21 +565,7 @@ class TaskExecutor:
conn_type = self._play_context.connection
if conn_type == 'smart':
conn_type = 'ssh'
if sys.platform.startswith('darwin') and self._play_context.password:
# due to a current bug in sshpass on OSX, which can trigger
# a kernel panic even for non-privileged users, we revert to
# paramiko on that OS when a SSH password is specified
conn_type = "paramiko"
else:
# see if SSH can support ControlPersist if not use paramiko
try:
cmd = subprocess.Popen(['ssh','-o','ControlPersist'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(out, err) = cmd.communicate()
if "Bad configuration option" in err or "Usage:" in err:
conn_type = "paramiko"
except OSError:
conn_type = "paramiko"
conn_type = get_smart_connection_type(self._play_context)
connection = self._shared_loader_obj.connection_loader.get(conn_type, self._play_context, self._new_stdin)
if not connection:

View file

@ -57,6 +57,7 @@ class Host:
name=self.name,
vars=self.vars.copy(),
address=self.address,
has_hostkey=self.has_hostkey,
uuid=self._uuid,
gathered_facts=self._gathered_facts,
groups=groups,
@ -65,10 +66,11 @@ class Host:
def deserialize(self, data):
self.__init__()
self.name = data.get('name')
self.vars = data.get('vars', dict())
self.address = data.get('address', '')
self._uuid = data.get('uuid', uuid.uuid4())
self.name = data.get('name')
self.vars = data.get('vars', dict())
self.address = data.get('address', '')
self.has_hostkey = data.get('has_hostkey', False)
self._uuid = data.get('uuid', uuid.uuid4())
groups = data.get('groups', [])
for group_data in groups:
@ -89,6 +91,7 @@ class Host:
self._gathered_facts = False
self._uuid = uuid.uuid4()
self.has_hostkey = False
def __repr__(self):
return self.get_name()

View file

@ -23,11 +23,11 @@ __metaclass__ = type
import fcntl
import gettext
import os
from abc import ABCMeta, abstractmethod, abstractproperty
from functools import wraps
from ansible.compat.six import with_metaclass
from ansible.compat.six import with_metaclass
from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.plugins import shell_loader
@ -233,3 +233,4 @@ class ConnectionBase(with_metaclass(ABCMeta, object)):
f = self._play_context.connection_lockfd
fcntl.lockf(f, fcntl.LOCK_UN)
display.vvvv('CONNECTION: pid %d released lock on %d' % (os.getpid(), f))

View file

@ -19,7 +19,12 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible.compat.six import text_type
import base64
import fcntl
import hmac
import operator
import os
import pipes
import pty
@ -28,9 +33,13 @@ import shlex
import subprocess
import time
from hashlib import md5, sha1, sha256
from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
from ansible.plugins.connection import ConnectionBase
from ansible.utils.boolean import boolean
from ansible.utils.connection import get_smart_connection_type
from ansible.utils.path import unfrackpath, makedirs_safe
from ansible.utils.unicode import to_bytes, to_unicode
@ -41,7 +50,128 @@ except ImportError:
display = Display()
SSHPASS_AVAILABLE = None
HASHED_KEY_MAGIC = "|1|"
def split_args(argstring):
"""
Takes a string like '-o Foo=1 -o Bar="foo bar"' and returns a
list ['-o', 'Foo=1', '-o', 'Bar=foo bar'] that can be added to
the argument list. The list will not contain any empty elements.
"""
return [to_unicode(x.strip()) for x in shlex.split(to_bytes(argstring)) if x.strip()]
def get_ssh_opts(play_context):
# FIXME: caching may help here
opts_dict = dict()
try:
cmd = ['ssh', '-G', play_context.remote_addr]
res = subprocess.check_output(cmd)
for line in res.split('\n'):
if ' ' in line:
(key, val) = line.split(' ', 1)
else:
key = line
val = ''
opts_dict[key.lower()] = val
# next, we manually override any options that are being
# set via ssh_args or due to the fact that `ssh -G` doesn't
# actually use the options set via -o
for opt in ['ssh_args', 'ssh_common_args', 'ssh_extra_args']:
attr = getattr(play_context, opt, None)
if attr is not None:
args = split_args(attr)
for arg in args:
if '=' in arg:
(key, val) = arg.split('=', 1)
opts_dict[key.lower()] = val
return opts_dict
except subprocess.CalledProcessError:
return dict()
def host_in_known_hosts(host, ssh_opts):
# the setting from the ssh_opts may actually be multiple files, so
# we use shlex.split and simply take the first one specified
user_host_file = os.path.expanduser(shlex.split(ssh_opts.get('userknownhostsfile', '~/.ssh/known_hosts'))[0])
host_file_list = []
host_file_list.append(user_host_file)
host_file_list.append("/etc/ssh/ssh_known_hosts")
host_file_list.append("/etc/ssh/ssh_known_hosts2")
hfiles_not_found = 0
for hf in host_file_list:
if not os.path.exists(hf):
continue
try:
host_fh = open(hf)
except (OSError, IOError) as e:
continue
else:
data = host_fh.read()
host_fh.close()
for line in data.split("\n"):
line = line.strip()
if line is None or " " not in line:
continue
tokens = line.split()
if not tokens:
continue
if tokens[0].find(HASHED_KEY_MAGIC) == 0:
# this is a hashed known host entry
try:
(kn_salt, kn_host) = tokens[0][len(HASHED_KEY_MAGIC):].split("|",2)
hash = hmac.new(kn_salt.decode('base64'), digestmod=sha1)
hash.update(host)
if hash.digest() == kn_host.decode('base64'):
return True
except:
# invalid hashed host key, skip it
continue
else:
# standard host file entry
if host in tokens[0]:
return True
return False
def fetch_ssh_host_key(play_context, ssh_opts):
keyscan_cmd = ['ssh-keyscan']
if play_context.port:
keyscan_cmd.extend(['-p', text_type(play_context.port)])
if boolean(ssh_opts.get('hashknownhosts', 'no')):
keyscan_cmd.append('-H')
keyscan_cmd.append(play_context.remote_addr)
p = subprocess.Popen(keyscan_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
(stdout, stderr) = p.communicate()
if stdout == '':
raise AnsibleConnectionFailure("Failed to connect to the host to fetch the host key: %s." % stderr)
else:
return stdout
def add_host_key(host_key, ssh_opts):
# the setting from the ssh_opts may actually be multiple files, so
# we use shlex.split and simply take the first one specified
user_known_hosts = os.path.expanduser(shlex.split(ssh_opts.get('userknownhostsfile', '~/.ssh/known_hosts'))[0])
user_ssh_dir = os.path.dirname(user_known_hosts)
if not os.path.exists(user_ssh_dir):
raise AnsibleError("the user ssh directory does not exist: %s" % user_ssh_dir)
elif not os.path.isdir(user_ssh_dir):
raise AnsibleError("%s is not a directory" % user_ssh_dir)
try:
display.vv("adding to known_hosts file: %s" % user_known_hosts)
with open(user_known_hosts, 'a') as f:
f.write(host_key)
except (OSError, IOError) as e:
raise AnsibleError("error when trying to access the known hosts file: '%s', error was: %s" % (user_known_hosts, text_type(e)))
class Connection(ConnectionBase):
''' ssh based connections '''
@ -62,6 +192,56 @@ class Connection(ConnectionBase):
def _connect(self):
return self
@staticmethod
def fetch_and_store_key(host, play_context):
ssh_opts = get_ssh_opts(play_context)
if not host_in_known_hosts(play_context.remote_addr, ssh_opts):
display.debug("host %s does not have a known host key, fetching it" % host)
# build the list of valid host key types, for use later as we scan for keys.
# we also use this to determine the most preferred key when multiple keys are available
valid_host_key_types = [x.lower() for x in ssh_opts.get('hostbasedkeytypes', '').split(',')]
# attempt to fetch the key with ssh-keyscan. More than one key may be
# returned, so we save all and use the above list to determine which
host_key_data = fetch_ssh_host_key(play_context, ssh_opts).strip().split('\n')
host_keys = dict()
for host_key in host_key_data:
(host_info, key_type, key_hash) = host_key.strip().split(' ', 3)
key_type = key_type.lower()
if key_type in valid_host_key_types and key_type not in host_keys:
host_keys[key_type.lower()] = host_key
if len(host_keys) == 0:
raise AnsibleConnectionFailure("none of the available host keys found were in the HostBasedKeyTypes configuration option")
# now we determine the preferred key by sorting the above dict on the
# index of the key type in the valid keys list
preferred_key = sorted(host_keys.items(), cmp=lambda x,y: cmp(valid_host_key_types.index(x), valid_host_key_types.index(y)), key=operator.itemgetter(0))[0]
# shamelessly copied from here:
# https://github.com/ojarva/python-sshpubkeys/blob/master/sshpubkeys/__init__.py#L39
# (which shamelessly copied it from somewhere else...)
(host_info, key_type, key_hash) = preferred_key[1].strip().split(' ', 3)
decoded_key = key_hash.decode('base64')
fp_plain = md5(decoded_key).hexdigest()
key_data = ':'.join(a+b for a, b in zip(fp_plain[::2], fp_plain[1::2]))
# prompt the user to add the key
# if yes, add it, otherwise raise AnsibleConnectionFailure
display.display("\nThe authenticity of host %s (%s) can't be established." % (host.name, play_context.remote_addr))
display.display("%s key fingerprint is SHA256:%s." % (key_type.upper(), sha256(decoded_key).digest().encode('base64').strip()))
display.display("%s key fingerprint is MD5:%s." % (key_type.upper(), key_data))
response = display.prompt("Are you sure you want to continue connecting (yes/no)? ")
display.display("")
if boolean(response):
add_host_key(host_key, ssh_opts)
return True
else:
raise AnsibleConnectionFailure("Host key validation failed.")
return False
@staticmethod
def _sshpass_available():
global SSHPASS_AVAILABLE
@ -100,15 +280,6 @@ class Connection(ConnectionBase):
return controlpersist, controlpath
@staticmethod
def _split_args(argstring):
"""
Takes a string like '-o Foo=1 -o Bar="foo bar"' and returns a
list ['-o', 'Foo=1', '-o', 'Bar=foo bar'] that can be added to
the argument list. The list will not contain any empty elements.
"""
return [to_unicode(x.strip()) for x in shlex.split(to_bytes(argstring)) if x.strip()]
def _add_args(self, explanation, args):
"""
Adds the given args to self._command and displays a caller-supplied
@ -157,7 +328,7 @@ class Connection(ConnectionBase):
# Next, we add [ssh_connection]ssh_args from ansible.cfg.
if self._play_context.ssh_args:
args = self._split_args(self._play_context.ssh_args)
args = split_args(self._play_context.ssh_args)
self._add_args("ansible.cfg set ssh_args", args)
# Now we add various arguments controlled by configuration file settings
@ -210,7 +381,7 @@ class Connection(ConnectionBase):
for opt in ['ssh_common_args', binary + '_extra_args']:
attr = getattr(self._play_context, opt, None)
if attr is not None:
args = self._split_args(attr)
args = split_args(attr)
self._add_args("PlayContext set %s" % opt, args)
# Check if ControlPersist is enabled and add a ControlPath if one hasn't

View file

@ -29,7 +29,7 @@ import zlib
from jinja2.exceptions import UndefinedError
from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure
from ansible.executor.play_iterator import PlayIterator
from ansible.executor.process.worker import WorkerProcess
from ansible.executor.task_result import TaskResult
@ -39,6 +39,7 @@ from ansible.playbook.helpers import load_list_of_blocks
from ansible.playbook.included_file import IncludedFile
from ansible.plugins import action_loader, connection_loader, filter_loader, lookup_loader, module_loader, test_loader
from ansible.template import Templar
from ansible.utils.connection import get_smart_connection_type
from ansible.vars.unsafe_proxy import wrap_var
try:
@ -139,6 +140,33 @@ class StrategyBase:
display.debug("entering _queue_task() for %s/%s" % (host, task))
if C.HOST_KEY_CHECKING and not host.has_hostkey:
# caveat here, regarding with loops. It is assumed that none of the connection
# related variables would contain '{{item}}' as it would cause some really
# weird loops. As is, if someone did something odd like that they would need
# to disable host key checking
templar = Templar(loader=self._loader, variables=task_vars)
temp_pc = play_context.set_task_and_variable_override(task=task, variables=task_vars, templar=templar)
temp_pc.post_validate(templar)
if temp_pc.connection in ('smart', 'ssh') and get_smart_connection_type(temp_pc) == 'ssh':
try:
# get the ssh connection plugin's class, and use its builtin
# static method to fetch and save the key to the known_hosts file
ssh_conn = connection_loader.get('ssh', class_only=True)
ssh_conn.fetch_and_store_key(host, temp_pc)
except AnsibleConnectionFailure as e:
# if that fails, add the host to the list of unreachable
# hosts and send the appropriate callback
self._tqm._unreachable_hosts[host.name] = True
self._tqm._stats.increment('dark', host.name)
tr = TaskResult(host=host, task=task, return_data=dict(msg=text_type(e)))
self._tqm.send_callback('v2_runner_on_unreachable', tr)
return
# finally, we set the has_hostkey flag to true for this
# host so we can skip it quickly in the future
host.has_hostkey = True
task_vars['hostvars'] = self._tqm.hostvars
# and then queue the new task
display.debug("%s - putting task (%s) in queue" % (host, task))

View file

@ -0,0 +1,50 @@
# (c) 2015, Ansible, Inc. <support@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import subprocess
import sys
__all__ = ['get_smart_connection_type']
def get_smart_connection_type(play_context):
'''
Uses the ssh command with the ControlPersist option while checking
for an error to determine if we should use ssh or paramiko. Also
may take other factors into account.
'''
conn_type = 'ssh'
if sys.platform.startswith('darwin') and play_context.password:
# due to a current bug in sshpass on OSX, which can trigger
# a kernel panic even for non-privileged users, we revert to
# paramiko on that OS when a SSH password is specified
conn_type = "paramiko"
else:
# see if SSH can support ControlPersist if not use paramiko
try:
cmd = subprocess.Popen(['ssh','-o','ControlPersist'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(out, err) = cmd.communicate()
if "Bad configuration option" in err or "Usage:" in err:
conn_type = "paramiko"
except OSError:
conn_type = "paramiko"
return conn_type