Make authorized_key preserve key order (#5339)

* Make authorized_key preserve key order

Track the ordering of keys in the original file (rank)
and try to preserve it when writing out updates.

Fixes #4780
This commit is contained in:
Adrian Likins 2016-10-21 12:28:28 -04:00 committed by Matt Clay
parent e7fcfa981e
commit 29978344ea

View file

@ -149,6 +149,7 @@ import os.path
import tempfile import tempfile
import re import re
import shlex import shlex
from operator import itemgetter
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.pycompat24 import get_exception from ansible.module_utils.pycompat24 import get_exception
@ -303,10 +304,13 @@ def parseoptions(module, options):
return options_dict return options_dict
def parsekey(module, raw_key): def parsekey(module, raw_key, rank=None):
''' '''
parses a key, which may or may not contain a list parses a key, which may or may not contain a list
of ssh-key options at the beginning of ssh-key options at the beginning
rank indicates the keys original ordering, so that
it can be written out in the same order.
''' '''
VALID_SSH2_KEY_TYPES = [ VALID_SSH2_KEY_TYPES = [
@ -333,6 +337,10 @@ def parsekey(module, raw_key):
lex.whitespace_split = True lex.whitespace_split = True
key_parts = list(lex) key_parts = list(lex)
if key_parts and key_parts[0] == '#':
# comment line, invalid line, etc.
return (raw_key, 'skipped', None, None, rank)
for i in range(0, len(key_parts)): for i in range(0, len(key_parts)):
if key_parts[i] in VALID_SSH2_KEY_TYPES: if key_parts[i] in VALID_SSH2_KEY_TYPES:
type_index = i type_index = i
@ -355,7 +363,7 @@ def parsekey(module, raw_key):
if len(key_parts) > (type_index + 1): if len(key_parts) > (type_index + 1):
comment = " ".join(key_parts[(type_index + 2):]) comment = " ".join(key_parts[(type_index + 2):])
return (key, key_type, options, comment) return (key, key_type, options, comment, rank)
def readkeys(module, filename): def readkeys(module, filename):
@ -364,15 +372,15 @@ def readkeys(module, filename):
keys = {} keys = {}
f = open(filename) f = open(filename)
for line in f.readlines(): for rank_index, line in enumerate(f.readlines()):
key_data = parsekey(module, line) key_data = parsekey(module, line, rank=rank_index)
if key_data: if key_data:
# use key as identifier # use key as identifier
keys[key_data[0]] = key_data keys[key_data[0]] = key_data
else: else:
# for an invalid line, just append the line # for an invalid line, just set the line
# to the array so it will be re-output later # dict key to the line so it will be re-output later
keys[line] = line keys[line] = (line, 'skipped', None, None, rank_index)
f.close() f.close()
return keys return keys
@ -380,10 +388,17 @@ def writekeys(module, filename, keys):
fd, tmp_path = tempfile.mkstemp('', 'tmp', os.path.dirname(filename)) fd, tmp_path = tempfile.mkstemp('', 'tmp', os.path.dirname(filename))
f = open(tmp_path,"w") f = open(tmp_path,"w")
# FIXME: only the f.writelines() needs to be in try clause
try: try:
for index, key in keys.items(): new_keys = keys.values()
# order the new_keys by their original ordering, via the rank item in the tuple
ordered_new_keys = sorted(new_keys, key=itemgetter(4))
for key in ordered_new_keys:
try: try:
(keyhash,type,options,comment) = key (keyhash, key_type, options, comment, rank) = key
option_str = "" option_str = ""
if options: if options:
option_strings = [] option_strings = []
@ -394,7 +409,15 @@ def writekeys(module, filename, keys):
option_strings.append("%s=%s" % (option_key, value)) option_strings.append("%s=%s" % (option_key, value))
option_str = ",".join(option_strings) option_str = ",".join(option_strings)
option_str += " " option_str += " "
key_line = "%s%s %s %s\n" % (option_str, type, keyhash, comment)
# comment line or invalid line, just leave it
if not key_type:
key_line = key
if key_type == 'skipped':
key_line = key[0]
else:
key_line = "%s%s %s %s\n" % (option_str, key_type, keyhash, comment)
except: except:
key_line = key key_line = key
f.writelines(key_line) f.writelines(key_line)
@ -430,43 +453,47 @@ def enforce_state(module, params):
module.fail_json(msg=error_msg % key) module.fail_json(msg=error_msg % key)
# extract individual keys into an array, skipping blank lines and comments # extract individual keys into an array, skipping blank lines and comments
key = [s for s in key.splitlines() if s and not s.startswith('#')] new_keys = [s for s in key.splitlines() if s and not s.startswith('#')]
# check current state -- just get the filename, don't create file # check current state -- just get the filename, don't create file
do_write = False do_write = False
params["keyfile"] = keyfile(module, user, do_write, path, manage_dir) params["keyfile"] = keyfile(module, user, do_write, path, manage_dir)
existing_keys = readkeys(module, params["keyfile"]) existing_keys = readkeys(module, params["keyfile"])
# Add a place holder for keys that should exist in the state=present and # Add a place holder for keys that should exist in the state=present and
# exclusive=true case # exclusive=true case
keys_to_exist = [] keys_to_exist = []
# we will order any non exclusive new keys higher than all the existing keys,
# resulting in the new keys being written to the key file after existing keys, but
# in the order of new_keys
max_rank_of_existing_keys = len(existing_keys)
# Check our new keys, if any of them exist we'll continue. # Check our new keys, if any of them exist we'll continue.
for new_key in key: for rank_index, new_key in enumerate(new_keys):
parsed_new_key = parsekey(module, new_key) parsed_new_key = parsekey(module, new_key, rank=rank_index)
if not parsed_new_key: if not parsed_new_key:
module.fail_json(msg="invalid key specified: %s" % new_key) module.fail_json(msg="invalid key specified: %s" % new_key)
if key_options is not None: if key_options is not None:
parsed_options = parseoptions(module, key_options) parsed_options = parseoptions(module, key_options)
parsed_new_key = (parsed_new_key[0], parsed_new_key[1], parsed_options, parsed_new_key[3]) # rank here is the rank in the provided new keys, which may be unrelated to rank in existing_keys
parsed_new_key = (parsed_new_key[0], parsed_new_key[1], parsed_options, parsed_new_key[3], parsed_new_key[4])
matched = False matched = False
non_matching_keys = [] non_matching_keys = []
if parsed_new_key[0] in existing_keys: if parsed_new_key[0] in existing_keys:
# Then we check if everything matches, including # Then we check if everything (except the rank at index 4) matches, including
# the key type and options. If not, we append this # the key type and options. If not, we append this
# existing key to the non-matching list # existing key to the non-matching list
# We only want it to match everything when the state # We only want it to match everything when the state
# is present # is present
if parsed_new_key != existing_keys[parsed_new_key[0]] and state == "present": if parsed_new_key[:4] != existing_keys[parsed_new_key[0]][:4] and state == "present":
non_matching_keys.append(existing_keys[parsed_new_key[0]]) non_matching_keys.append(existing_keys[parsed_new_key[0]])
else: else:
matched = True matched = True
# handle idempotent state=present # handle idempotent state=present
if state=="present": if state=="present":
keys_to_exist.append(parsed_new_key[0]) keys_to_exist.append(parsed_new_key[0])
@ -476,8 +503,12 @@ def enforce_state(module, params):
del existing_keys[non_matching_key[0]] del existing_keys[non_matching_key[0]]
do_write = True do_write = True
# new key that didn't exist before. Where should it go in the ordering?
if not matched: if not matched:
existing_keys[parsed_new_key[0]] = parsed_new_key # We want the new key to be after existing keys if not exclusive (rank > max_rank_of_existing_keys)
total_rank = max_rank_of_existing_keys + parsed_new_key[4]
# replace existing key tuple with new parsed key with its total rank
existing_keys[parsed_new_key[0]] = (parsed_new_key[0], parsed_new_key[1], parsed_new_key[2], parsed_new_key[3], total_rank)
do_write = True do_write = True
elif state=="absent": elif state=="absent":
@ -487,6 +518,7 @@ def enforce_state(module, params):
do_write = True do_write = True
# remove all other keys to honor exclusive # remove all other keys to honor exclusive
# for 'exclusive', make sure keys are written in the order the new keys were
if state == "present" and exclusive: if state == "present" and exclusive:
to_remove = frozenset(existing_keys).difference(keys_to_exist) to_remove = frozenset(existing_keys).difference(keys_to_exist)
for key in to_remove: for key in to_remove: