Enable host_key checking at the strategy level
Implements a new method in the ssh connection plugin (fetch_and_store_key) which is used to prefetch the key using ssh-keyscan.
This commit is contained in:
parent
15135f3c16
commit
e5c2c03dea
6 changed files with 273 additions and 33 deletions
|
@ -32,6 +32,7 @@ from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVar
|
|||
from ansible.playbook.conditional import Conditional
|
||||
from ansible.playbook.task import Task
|
||||
from ansible.template import Templar
|
||||
from ansible.utils.connection import get_smart_connection_type
|
||||
from ansible.utils.encrypt import key_for_hostname
|
||||
from ansible.utils.listify import listify_lookup_plugin_terms
|
||||
from ansible.utils.unicode import to_unicode
|
||||
|
@ -564,21 +565,7 @@ class TaskExecutor:
|
|||
|
||||
conn_type = self._play_context.connection
|
||||
if conn_type == 'smart':
|
||||
conn_type = 'ssh'
|
||||
if sys.platform.startswith('darwin') and self._play_context.password:
|
||||
# due to a current bug in sshpass on OSX, which can trigger
|
||||
# a kernel panic even for non-privileged users, we revert to
|
||||
# paramiko on that OS when a SSH password is specified
|
||||
conn_type = "paramiko"
|
||||
else:
|
||||
# see if SSH can support ControlPersist if not use paramiko
|
||||
try:
|
||||
cmd = subprocess.Popen(['ssh','-o','ControlPersist'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
(out, err) = cmd.communicate()
|
||||
if "Bad configuration option" in err or "Usage:" in err:
|
||||
conn_type = "paramiko"
|
||||
except OSError:
|
||||
conn_type = "paramiko"
|
||||
conn_type = get_smart_connection_type(self._play_context)
|
||||
|
||||
connection = self._shared_loader_obj.connection_loader.get(conn_type, self._play_context, self._new_stdin)
|
||||
if not connection:
|
||||
|
|
|
@ -57,6 +57,7 @@ class Host:
|
|||
name=self.name,
|
||||
vars=self.vars.copy(),
|
||||
address=self.address,
|
||||
has_hostkey=self.has_hostkey,
|
||||
uuid=self._uuid,
|
||||
gathered_facts=self._gathered_facts,
|
||||
groups=groups,
|
||||
|
@ -65,10 +66,11 @@ class Host:
|
|||
def deserialize(self, data):
|
||||
self.__init__()
|
||||
|
||||
self.name = data.get('name')
|
||||
self.vars = data.get('vars', dict())
|
||||
self.address = data.get('address', '')
|
||||
self._uuid = data.get('uuid', uuid.uuid4())
|
||||
self.name = data.get('name')
|
||||
self.vars = data.get('vars', dict())
|
||||
self.address = data.get('address', '')
|
||||
self.has_hostkey = data.get('has_hostkey', False)
|
||||
self._uuid = data.get('uuid', uuid.uuid4())
|
||||
|
||||
groups = data.get('groups', [])
|
||||
for group_data in groups:
|
||||
|
@ -89,6 +91,7 @@ class Host:
|
|||
|
||||
self._gathered_facts = False
|
||||
self._uuid = uuid.uuid4()
|
||||
self.has_hostkey = False
|
||||
|
||||
def __repr__(self):
|
||||
return self.get_name()
|
||||
|
|
|
@ -23,11 +23,11 @@ __metaclass__ = type
|
|||
import fcntl
|
||||
import gettext
|
||||
import os
|
||||
|
||||
from abc import ABCMeta, abstractmethod, abstractproperty
|
||||
|
||||
from functools import wraps
|
||||
from ansible.compat.six import with_metaclass
|
||||
|
||||
from ansible.compat.six import with_metaclass
|
||||
from ansible import constants as C
|
||||
from ansible.errors import AnsibleError
|
||||
from ansible.plugins import shell_loader
|
||||
|
@ -233,3 +233,4 @@ class ConnectionBase(with_metaclass(ABCMeta, object)):
|
|||
f = self._play_context.connection_lockfd
|
||||
fcntl.lockf(f, fcntl.LOCK_UN)
|
||||
display.vvvv('CONNECTION: pid %d released lock on %d' % (os.getpid(), f))
|
||||
|
||||
|
|
|
@ -19,7 +19,12 @@
|
|||
from __future__ import (absolute_import, division, print_function)
|
||||
__metaclass__ = type
|
||||
|
||||
from ansible.compat.six import text_type
|
||||
|
||||
import base64
|
||||
import fcntl
|
||||
import hmac
|
||||
import operator
|
||||
import os
|
||||
import pipes
|
||||
import pty
|
||||
|
@ -28,9 +33,13 @@ import shlex
|
|||
import subprocess
|
||||
import time
|
||||
|
||||
from hashlib import md5, sha1, sha256
|
||||
|
||||
from ansible import constants as C
|
||||
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
|
||||
from ansible.plugins.connection import ConnectionBase
|
||||
from ansible.utils.boolean import boolean
|
||||
from ansible.utils.connection import get_smart_connection_type
|
||||
from ansible.utils.path import unfrackpath, makedirs_safe
|
||||
from ansible.utils.unicode import to_bytes, to_unicode
|
||||
|
||||
|
@ -41,7 +50,128 @@ except ImportError:
|
|||
display = Display()
|
||||
|
||||
SSHPASS_AVAILABLE = None
|
||||
HASHED_KEY_MAGIC = "|1|"
|
||||
|
||||
def split_args(argstring):
|
||||
"""
|
||||
Takes a string like '-o Foo=1 -o Bar="foo bar"' and returns a
|
||||
list ['-o', 'Foo=1', '-o', 'Bar=foo bar'] that can be added to
|
||||
the argument list. The list will not contain any empty elements.
|
||||
"""
|
||||
return [to_unicode(x.strip()) for x in shlex.split(to_bytes(argstring)) if x.strip()]
|
||||
|
||||
def get_ssh_opts(play_context):
|
||||
# FIXME: caching may help here
|
||||
opts_dict = dict()
|
||||
try:
|
||||
cmd = ['ssh', '-G', play_context.remote_addr]
|
||||
res = subprocess.check_output(cmd)
|
||||
for line in res.split('\n'):
|
||||
if ' ' in line:
|
||||
(key, val) = line.split(' ', 1)
|
||||
else:
|
||||
key = line
|
||||
val = ''
|
||||
opts_dict[key.lower()] = val
|
||||
|
||||
# next, we manually override any options that are being
|
||||
# set via ssh_args or due to the fact that `ssh -G` doesn't
|
||||
# actually use the options set via -o
|
||||
for opt in ['ssh_args', 'ssh_common_args', 'ssh_extra_args']:
|
||||
attr = getattr(play_context, opt, None)
|
||||
if attr is not None:
|
||||
args = split_args(attr)
|
||||
for arg in args:
|
||||
if '=' in arg:
|
||||
(key, val) = arg.split('=', 1)
|
||||
opts_dict[key.lower()] = val
|
||||
|
||||
return opts_dict
|
||||
except subprocess.CalledProcessError:
|
||||
return dict()
|
||||
|
||||
def host_in_known_hosts(host, ssh_opts):
|
||||
# the setting from the ssh_opts may actually be multiple files, so
|
||||
# we use shlex.split and simply take the first one specified
|
||||
user_host_file = os.path.expanduser(shlex.split(ssh_opts.get('userknownhostsfile', '~/.ssh/known_hosts'))[0])
|
||||
|
||||
host_file_list = []
|
||||
host_file_list.append(user_host_file)
|
||||
host_file_list.append("/etc/ssh/ssh_known_hosts")
|
||||
host_file_list.append("/etc/ssh/ssh_known_hosts2")
|
||||
|
||||
hfiles_not_found = 0
|
||||
for hf in host_file_list:
|
||||
if not os.path.exists(hf):
|
||||
continue
|
||||
try:
|
||||
host_fh = open(hf)
|
||||
except (OSError, IOError) as e:
|
||||
continue
|
||||
else:
|
||||
data = host_fh.read()
|
||||
host_fh.close()
|
||||
|
||||
for line in data.split("\n"):
|
||||
line = line.strip()
|
||||
if line is None or " " not in line:
|
||||
continue
|
||||
tokens = line.split()
|
||||
if not tokens:
|
||||
continue
|
||||
if tokens[0].find(HASHED_KEY_MAGIC) == 0:
|
||||
# this is a hashed known host entry
|
||||
try:
|
||||
(kn_salt, kn_host) = tokens[0][len(HASHED_KEY_MAGIC):].split("|",2)
|
||||
hash = hmac.new(kn_salt.decode('base64'), digestmod=sha1)
|
||||
hash.update(host)
|
||||
if hash.digest() == kn_host.decode('base64'):
|
||||
return True
|
||||
except:
|
||||
# invalid hashed host key, skip it
|
||||
continue
|
||||
else:
|
||||
# standard host file entry
|
||||
if host in tokens[0]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def fetch_ssh_host_key(play_context, ssh_opts):
|
||||
keyscan_cmd = ['ssh-keyscan']
|
||||
|
||||
if play_context.port:
|
||||
keyscan_cmd.extend(['-p', text_type(play_context.port)])
|
||||
|
||||
if boolean(ssh_opts.get('hashknownhosts', 'no')):
|
||||
keyscan_cmd.append('-H')
|
||||
|
||||
keyscan_cmd.append(play_context.remote_addr)
|
||||
|
||||
p = subprocess.Popen(keyscan_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
|
||||
(stdout, stderr) = p.communicate()
|
||||
if stdout == '':
|
||||
raise AnsibleConnectionFailure("Failed to connect to the host to fetch the host key: %s." % stderr)
|
||||
else:
|
||||
return stdout
|
||||
|
||||
def add_host_key(host_key, ssh_opts):
|
||||
# the setting from the ssh_opts may actually be multiple files, so
|
||||
# we use shlex.split and simply take the first one specified
|
||||
user_known_hosts = os.path.expanduser(shlex.split(ssh_opts.get('userknownhostsfile', '~/.ssh/known_hosts'))[0])
|
||||
user_ssh_dir = os.path.dirname(user_known_hosts)
|
||||
|
||||
if not os.path.exists(user_ssh_dir):
|
||||
raise AnsibleError("the user ssh directory does not exist: %s" % user_ssh_dir)
|
||||
elif not os.path.isdir(user_ssh_dir):
|
||||
raise AnsibleError("%s is not a directory" % user_ssh_dir)
|
||||
|
||||
try:
|
||||
display.vv("adding to known_hosts file: %s" % user_known_hosts)
|
||||
with open(user_known_hosts, 'a') as f:
|
||||
f.write(host_key)
|
||||
except (OSError, IOError) as e:
|
||||
raise AnsibleError("error when trying to access the known hosts file: '%s', error was: %s" % (user_known_hosts, text_type(e)))
|
||||
|
||||
class Connection(ConnectionBase):
|
||||
''' ssh based connections '''
|
||||
|
@ -62,6 +192,56 @@ class Connection(ConnectionBase):
|
|||
def _connect(self):
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def fetch_and_store_key(host, play_context):
|
||||
ssh_opts = get_ssh_opts(play_context)
|
||||
if not host_in_known_hosts(play_context.remote_addr, ssh_opts):
|
||||
display.debug("host %s does not have a known host key, fetching it" % host)
|
||||
|
||||
# build the list of valid host key types, for use later as we scan for keys.
|
||||
# we also use this to determine the most preferred key when multiple keys are available
|
||||
valid_host_key_types = [x.lower() for x in ssh_opts.get('hostbasedkeytypes', '').split(',')]
|
||||
|
||||
# attempt to fetch the key with ssh-keyscan. More than one key may be
|
||||
# returned, so we save all and use the above list to determine which
|
||||
host_key_data = fetch_ssh_host_key(play_context, ssh_opts).strip().split('\n')
|
||||
host_keys = dict()
|
||||
for host_key in host_key_data:
|
||||
(host_info, key_type, key_hash) = host_key.strip().split(' ', 3)
|
||||
key_type = key_type.lower()
|
||||
if key_type in valid_host_key_types and key_type not in host_keys:
|
||||
host_keys[key_type.lower()] = host_key
|
||||
|
||||
if len(host_keys) == 0:
|
||||
raise AnsibleConnectionFailure("none of the available host keys found were in the HostBasedKeyTypes configuration option")
|
||||
|
||||
# now we determine the preferred key by sorting the above dict on the
|
||||
# index of the key type in the valid keys list
|
||||
preferred_key = sorted(host_keys.items(), cmp=lambda x,y: cmp(valid_host_key_types.index(x), valid_host_key_types.index(y)), key=operator.itemgetter(0))[0]
|
||||
|
||||
# shamelessly copied from here:
|
||||
# https://github.com/ojarva/python-sshpubkeys/blob/master/sshpubkeys/__init__.py#L39
|
||||
# (which shamelessly copied it from somewhere else...)
|
||||
(host_info, key_type, key_hash) = preferred_key[1].strip().split(' ', 3)
|
||||
decoded_key = key_hash.decode('base64')
|
||||
fp_plain = md5(decoded_key).hexdigest()
|
||||
key_data = ':'.join(a+b for a, b in zip(fp_plain[::2], fp_plain[1::2]))
|
||||
|
||||
# prompt the user to add the key
|
||||
# if yes, add it, otherwise raise AnsibleConnectionFailure
|
||||
display.display("\nThe authenticity of host %s (%s) can't be established." % (host.name, play_context.remote_addr))
|
||||
display.display("%s key fingerprint is SHA256:%s." % (key_type.upper(), sha256(decoded_key).digest().encode('base64').strip()))
|
||||
display.display("%s key fingerprint is MD5:%s." % (key_type.upper(), key_data))
|
||||
response = display.prompt("Are you sure you want to continue connecting (yes/no)? ")
|
||||
display.display("")
|
||||
if boolean(response):
|
||||
add_host_key(host_key, ssh_opts)
|
||||
return True
|
||||
else:
|
||||
raise AnsibleConnectionFailure("Host key validation failed.")
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _sshpass_available():
|
||||
global SSHPASS_AVAILABLE
|
||||
|
@ -100,15 +280,6 @@ class Connection(ConnectionBase):
|
|||
|
||||
return controlpersist, controlpath
|
||||
|
||||
@staticmethod
|
||||
def _split_args(argstring):
|
||||
"""
|
||||
Takes a string like '-o Foo=1 -o Bar="foo bar"' and returns a
|
||||
list ['-o', 'Foo=1', '-o', 'Bar=foo bar'] that can be added to
|
||||
the argument list. The list will not contain any empty elements.
|
||||
"""
|
||||
return [to_unicode(x.strip()) for x in shlex.split(to_bytes(argstring)) if x.strip()]
|
||||
|
||||
def _add_args(self, explanation, args):
|
||||
"""
|
||||
Adds the given args to self._command and displays a caller-supplied
|
||||
|
@ -157,7 +328,7 @@ class Connection(ConnectionBase):
|
|||
# Next, we add [ssh_connection]ssh_args from ansible.cfg.
|
||||
|
||||
if self._play_context.ssh_args:
|
||||
args = self._split_args(self._play_context.ssh_args)
|
||||
args = split_args(self._play_context.ssh_args)
|
||||
self._add_args("ansible.cfg set ssh_args", args)
|
||||
|
||||
# Now we add various arguments controlled by configuration file settings
|
||||
|
@ -210,7 +381,7 @@ class Connection(ConnectionBase):
|
|||
for opt in ['ssh_common_args', binary + '_extra_args']:
|
||||
attr = getattr(self._play_context, opt, None)
|
||||
if attr is not None:
|
||||
args = self._split_args(attr)
|
||||
args = split_args(attr)
|
||||
self._add_args("PlayContext set %s" % opt, args)
|
||||
|
||||
# Check if ControlPersist is enabled and add a ControlPath if one hasn't
|
||||
|
|
|
@ -29,7 +29,7 @@ import zlib
|
|||
from jinja2.exceptions import UndefinedError
|
||||
|
||||
from ansible import constants as C
|
||||
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable
|
||||
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure
|
||||
from ansible.executor.play_iterator import PlayIterator
|
||||
from ansible.executor.process.worker import WorkerProcess
|
||||
from ansible.executor.task_result import TaskResult
|
||||
|
@ -39,6 +39,7 @@ from ansible.playbook.helpers import load_list_of_blocks
|
|||
from ansible.playbook.included_file import IncludedFile
|
||||
from ansible.plugins import action_loader, connection_loader, filter_loader, lookup_loader, module_loader, test_loader
|
||||
from ansible.template import Templar
|
||||
from ansible.utils.connection import get_smart_connection_type
|
||||
from ansible.vars.unsafe_proxy import wrap_var
|
||||
|
||||
try:
|
||||
|
@ -139,6 +140,33 @@ class StrategyBase:
|
|||
|
||||
display.debug("entering _queue_task() for %s/%s" % (host, task))
|
||||
|
||||
if C.HOST_KEY_CHECKING and not host.has_hostkey:
|
||||
# caveat here, regarding with loops. It is assumed that none of the connection
|
||||
# related variables would contain '{{item}}' as it would cause some really
|
||||
# weird loops. As is, if someone did something odd like that they would need
|
||||
# to disable host key checking
|
||||
templar = Templar(loader=self._loader, variables=task_vars)
|
||||
temp_pc = play_context.set_task_and_variable_override(task=task, variables=task_vars, templar=templar)
|
||||
temp_pc.post_validate(templar)
|
||||
if temp_pc.connection in ('smart', 'ssh') and get_smart_connection_type(temp_pc) == 'ssh':
|
||||
try:
|
||||
# get the ssh connection plugin's class, and use its builtin
|
||||
# static method to fetch and save the key to the known_hosts file
|
||||
ssh_conn = connection_loader.get('ssh', class_only=True)
|
||||
ssh_conn.fetch_and_store_key(host, temp_pc)
|
||||
except AnsibleConnectionFailure as e:
|
||||
# if that fails, add the host to the list of unreachable
|
||||
# hosts and send the appropriate callback
|
||||
self._tqm._unreachable_hosts[host.name] = True
|
||||
self._tqm._stats.increment('dark', host.name)
|
||||
tr = TaskResult(host=host, task=task, return_data=dict(msg=text_type(e)))
|
||||
self._tqm.send_callback('v2_runner_on_unreachable', tr)
|
||||
return
|
||||
|
||||
# finally, we set the has_hostkey flag to true for this
|
||||
# host so we can skip it quickly in the future
|
||||
host.has_hostkey = True
|
||||
|
||||
task_vars['hostvars'] = self._tqm.hostvars
|
||||
# and then queue the new task
|
||||
display.debug("%s - putting task (%s) in queue" % (host, task))
|
||||
|
|
50
lib/ansible/utils/connection.py
Normal file
50
lib/ansible/utils/connection.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
# (c) 2015, Ansible, Inc. <support@ansible.com>
|
||||
#
|
||||
# This file is part of Ansible
|
||||
#
|
||||
# Ansible is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# Ansible is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from __future__ import (absolute_import, division, print_function)
|
||||
__metaclass__ = type
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
__all__ = ['get_smart_connection_type']
|
||||
|
||||
def get_smart_connection_type(play_context):
|
||||
'''
|
||||
Uses the ssh command with the ControlPersist option while checking
|
||||
for an error to determine if we should use ssh or paramiko. Also
|
||||
may take other factors into account.
|
||||
'''
|
||||
|
||||
conn_type = 'ssh'
|
||||
if sys.platform.startswith('darwin') and play_context.password:
|
||||
# due to a current bug in sshpass on OSX, which can trigger
|
||||
# a kernel panic even for non-privileged users, we revert to
|
||||
# paramiko on that OS when a SSH password is specified
|
||||
conn_type = "paramiko"
|
||||
else:
|
||||
# see if SSH can support ControlPersist if not use paramiko
|
||||
try:
|
||||
cmd = subprocess.Popen(['ssh','-o','ControlPersist'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
(out, err) = cmd.communicate()
|
||||
if "Bad configuration option" in err or "Usage:" in err:
|
||||
conn_type = "paramiko"
|
||||
except OSError:
|
||||
conn_type = "paramiko"
|
||||
|
||||
return conn_type
|
Loading…
Reference in a new issue