diff --git a/v2/ansible/parsing/vault/__init__.py b/v2/ansible/parsing/vault/__init__.py index 92c99fdad5..ddb92e4e7d 100644 --- a/v2/ansible/parsing/vault/__init__.py +++ b/v2/ansible/parsing/vault/__init__.py @@ -22,6 +22,7 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +import sys import os import shlex import shutil @@ -35,7 +36,10 @@ from hashlib import sha256 from hashlib import md5 from binascii import hexlify from binascii import unhexlify +from six import binary_type, byte2int, PY2, text_type from ansible import constants as C +from ansible.utils.unicode import to_unicode, to_bytes + try: from Crypto.Hash import SHA256, HMAC @@ -60,13 +64,13 @@ except ImportError: # AES IMPORTS try: from Crypto.Cipher import AES as AES - HAS_AES = True + HAS_AES = True except ImportError: - HAS_AES = False + HAS_AES = False CRYPTO_UPGRADE = "ansible-vault requires a newer version of pycrypto than the one installed on your platform. You may fix this with OS-specific commands such as: yum install python-devel; rpm -e --nodeps python-crypto; pip install pycrypto" -HEADER='$ANSIBLE_VAULT' +HEADER=u'$ANSIBLE_VAULT' CIPHER_WHITELIST=['AES', 'AES256'] class VaultLib(object): @@ -76,26 +80,28 @@ class VaultLib(object): self.cipher_name = None self.version = '1.1' - def is_encrypted(self, data): + def is_encrypted(self, data): + data = to_unicode(data) if data.startswith(HEADER): return True else: return False def encrypt(self, data): + data = to_unicode(data) if self.is_encrypted(data): raise errors.AnsibleError("data is already encrypted") if not self.cipher_name: self.cipher_name = "AES256" - #raise errors.AnsibleError("the cipher must be set before encrypting data") + # raise errors.AnsibleError("the cipher must be set before encrypting data") - if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST: + if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST: cipher = globals()['Vault' + self.cipher_name] this_cipher = cipher() else: - raise errors.AnsibleError("%s cipher could not be found" % self.cipher_name) + raise errors.AnsibleError("{} cipher could not be found".format(self.cipher_name)) """ # combine sha + data @@ -106,11 +112,13 @@ class VaultLib(object): # encrypt sha + data enc_data = this_cipher.encrypt(data, self.password) - # add header + # add header tmp_data = self._add_header(enc_data) return tmp_data def decrypt(self, data): + data = to_bytes(data) + if self.password is None: raise errors.AnsibleError("A vault password must be specified to decrypt data") @@ -121,48 +129,47 @@ class VaultLib(object): data = self._split_header(data) # create the cipher object - if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST: - cipher = globals()['Vault' + self.cipher_name] + ciphername = to_unicode(self.cipher_name) + if 'Vault' + ciphername in globals() and ciphername in CIPHER_WHITELIST: + cipher = globals()['Vault' + ciphername] this_cipher = cipher() else: - raise errors.AnsibleError("%s cipher could not be found" % self.cipher_name) + raise errors.AnsibleError("{} cipher could not be found".format(ciphername)) # try to unencrypt data data = this_cipher.decrypt(data, self.password) if data is None: raise errors.AnsibleError("Decryption failed") - return data + return data - def _add_header(self, data): + def _add_header(self, data): # combine header and encrypted data in 80 char columns #tmpdata = hexlify(data) - tmpdata = [data[i:i+80] for i in range(0, len(data), 80)] - + tmpdata = [to_bytes(data[i:i+80]) for i in range(0, len(data), 80)] if not self.cipher_name: raise errors.AnsibleError("the cipher must be set before adding a header") - dirty_data = HEADER + ";" + str(self.version) + ";" + self.cipher_name + "\n" - + dirty_data = to_bytes(HEADER + ";" + self.version + ";" + self.cipher_name + "\n") for l in tmpdata: - dirty_data += l + '\n' + dirty_data += l + b'\n' return dirty_data - def _split_header(self, data): + def _split_header(self, data): # used by decrypt - tmpdata = data.split('\n') - tmpheader = tmpdata[0].strip().split(';') + tmpdata = data.split(b'\n') + tmpheader = tmpdata[0].strip().split(b';') - self.version = str(tmpheader[1].strip()) - self.cipher_name = str(tmpheader[2].strip()) - clean_data = '\n'.join(tmpdata[1:]) + self.version = to_unicode(tmpheader[1].strip()) + self.cipher_name = to_unicode(tmpheader[2].strip()) + clean_data = b'\n'.join(tmpdata[1:]) """ - # strip out newline, join, unhex + # strip out newline, join, unhex clean_data = [ x.strip() for x in clean_data ] clean_data = unhexlify(''.join(clean_data)) """ @@ -176,9 +183,9 @@ class VaultLib(object): pass class VaultEditor(object): - # uses helper methods for write_file(self, filename, data) - # to write a file so that code isn't duplicated for simple - # file I/O, ditto read_file(self, filename) and launch_editor(self, filename) + # uses helper methods for write_file(self, filename, data) + # to write a file so that code isn't duplicated for simple + # file I/O, ditto read_file(self, filename) and launch_editor(self, filename) # ... "Don't Repeat Yourself", etc. def __init__(self, cipher_name, password, filename): @@ -302,7 +309,7 @@ class VaultEditor(object): if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH: raise errors.AnsibleError(CRYPTO_UPGRADE) - # decrypt + # decrypt tmpdata = self.read_data(self.filename) this_vault = VaultLib(self.password) dec_data = this_vault.decrypt(tmpdata) @@ -324,7 +331,7 @@ class VaultEditor(object): return tmpdata def write_data(self, data, filename): - if os.path.isfile(filename): + if os.path.isfile(filename): os.remove(filename) f = open(filename, "wb") f.write(data) @@ -369,9 +376,10 @@ class VaultAES(object): """ Create a key and an initialization vector """ - d = d_i = '' + d = d_i = b'' while len(d) < key_length + iv_length: - d_i = md5(d_i + password + salt).digest() + text = "{}{}{}".format(d_i, password, salt) + d_i = md5(to_bytes(text)).digest() d += d_i key = d[:key_length] @@ -385,45 +393,49 @@ class VaultAES(object): # combine sha + data - this_sha = sha256(data).hexdigest() + this_sha = sha256(to_bytes(data)).hexdigest() tmp_data = this_sha + "\n" + data - in_file = BytesIO(tmp_data) + in_file = BytesIO(to_bytes(tmp_data)) in_file.seek(0) out_file = BytesIO() bs = AES.block_size - # Get a block of random data. EL does not have Crypto.Random.new() + # Get a block of random data. EL does not have Crypto.Random.new() # so os.urandom is used for cross platform purposes salt = os.urandom(bs - len('Salted__')) key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) cipher = AES.new(key, AES.MODE_CBC, iv) - out_file.write('Salted__' + salt) + full = to_bytes(b'Salted__' + salt) + out_file.write(full) + print(repr(full)) finished = False while not finished: chunk = in_file.read(1024 * bs) if len(chunk) == 0 or len(chunk) % bs != 0: padding_length = (bs - len(chunk) % bs) or bs - chunk += padding_length * chr(padding_length) + chunk += to_bytes(padding_length * chr(padding_length)) finished = True out_file.write(cipher.encrypt(chunk)) out_file.seek(0) enc_data = out_file.read() + #print(enc_data) tmp_data = hexlify(enc_data) + assert isinstance(tmp_data, binary_type) return tmp_data - + def decrypt(self, data, password, key_length=32): """ Read encrypted data from in_file and write decrypted to out_file """ # http://stackoverflow.com/a/14989032 - data = ''.join(data.split('\n')) + data = b''.join(data.split(b'\n')) data = unhexlify(data) in_file = BytesIO(data) @@ -431,29 +443,35 @@ class VaultAES(object): out_file = BytesIO() bs = AES.block_size - salt = in_file.read(bs)[len('Salted__'):] + tmpsalt = in_file.read(bs) + print(repr(tmpsalt)) + salt = tmpsalt[len('Salted__'):] key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) cipher = AES.new(key, AES.MODE_CBC, iv) - next_chunk = '' + next_chunk = b'' finished = False while not finished: chunk, next_chunk = next_chunk, cipher.decrypt(in_file.read(1024 * bs)) if len(next_chunk) == 0: - padding_length = ord(chunk[-1]) + if PY2: + padding_length = ord(chunk[-1]) + else: + padding_length = chunk[-1] + chunk = chunk[:-padding_length] finished = True out_file.write(chunk) # reset the stream pointer to the beginning out_file.seek(0) - new_data = out_file.read() + new_data = to_unicode(out_file.read()) # split out sha and verify decryption split_data = new_data.split("\n") this_sha = split_data[0] this_data = '\n'.join(split_data[1:]) - test_sha = sha256(this_data).hexdigest() + test_sha = sha256(to_bytes(this_data)).hexdigest() if this_sha != test_sha: raise errors.AnsibleError("Decryption failed") @@ -465,7 +483,7 @@ class VaultAES(object): class VaultAES256(object): """ - Vault implementation using AES-CTR with an HMAC-SHA256 authentication code. + Vault implementation using AES-CTR with an HMAC-SHA256 authentication code. Keys are derived using PBKDF2 """ @@ -481,7 +499,7 @@ class VaultAES256(object): keylength = 32 # match the size used for counter.new to avoid extra work - ivlength = 16 + ivlength = 16 hash_function = SHA256 @@ -489,7 +507,7 @@ class VaultAES256(object): pbkdf2_prf = lambda p, s: HMAC.new(p, s, hash_function).digest() - derivedkey = PBKDF2(password, salt, dkLen=(2 * keylength) + ivlength, + derivedkey = PBKDF2(password, salt, dkLen=(2 * keylength) + ivlength, count=10000, prf=pbkdf2_prf) key1 = derivedkey[:keylength] @@ -523,28 +541,28 @@ class VaultAES256(object): cipher = AES.new(key1, AES.MODE_CTR, counter=ctr) # ENCRYPT PADDED DATA - cryptedData = cipher.encrypt(data) + cryptedData = cipher.encrypt(data) # COMBINE SALT, DIGEST AND DATA hmac = HMAC.new(key2, cryptedData, SHA256) - message = "%s\n%s\n%s" % ( hexlify(salt), hmac.hexdigest(), hexlify(cryptedData) ) + message = b''.join([hexlify(salt), b"\n", to_bytes(hmac.hexdigest()), b"\n", hexlify(cryptedData)]) message = hexlify(message) return message def decrypt(self, data, password): # SPLIT SALT, DIGEST, AND DATA - data = ''.join(data.split("\n")) + data = b''.join(data.split(b"\n")) data = unhexlify(data) - salt, cryptedHmac, cryptedData = data.split("\n", 2) + salt, cryptedHmac, cryptedData = data.split(b"\n", 2) salt = unhexlify(salt) cryptedData = unhexlify(cryptedData) key1, key2, iv = self.gen_key_initctr(password, salt) - # EXIT EARLY IF DIGEST DOESN'T MATCH + # EXIT EARLY IF DIGEST DOESN'T MATCH hmacDecrypt = HMAC.new(key2, cryptedData, SHA256) - if not self.is_equal(cryptedHmac, hmacDecrypt.hexdigest()): + if not self.is_equal(cryptedHmac, to_bytes(hmacDecrypt.hexdigest())): return None # SET THE COUNTER AND THE CIPHER @@ -555,19 +573,31 @@ class VaultAES256(object): decryptedData = cipher.decrypt(cryptedData) # UNPAD DATA - padding_length = ord(decryptedData[-1]) + try: + padding_length = ord(decryptedData[-1]) + except TypeError: + padding_length = decryptedData[-1] + decryptedData = decryptedData[:-padding_length] - return decryptedData + return to_unicode(decryptedData) def is_equal(self, a, b): + """ + Comparing 2 byte arrrays in constant time + to avoid timing attacks. + + It would be nice if there was a library for this but + hey. + """ # http://codahale.com/a-lesson-in-timing-attacks/ if len(a) != len(b): return False - + result = 0 for x, y in zip(a, b): - result |= ord(x) ^ ord(y) - return result == 0 - - + if PY2: + result |= ord(x) ^ ord(y) + else: + result |= x ^ y + return result == 0 diff --git a/v2/test/parsing/vault/test_vault.py b/v2/test/parsing/vault/test_vault.py index 5609596404..2aaac27fc7 100644 --- a/v2/test/parsing/vault/test_vault.py +++ b/v2/test/parsing/vault/test_vault.py @@ -31,6 +31,7 @@ from binascii import hexlify from nose.plugins.skip import SkipTest from ansible.compat.tests import unittest +from ansible.utils.unicode import to_bytes, to_unicode from ansible import errors from ansible.parsing.vault import VaultLib @@ -70,8 +71,8 @@ class TestVaultLib(unittest.TestCase): def test_is_encrypted(self): v = VaultLib(None) - assert not v.is_encrypted("foobar"), "encryption check on plaintext failed" - data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(six.b("ansible")) + assert not v.is_encrypted(u"foobar"), "encryption check on plaintext failed" + data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible") assert v.is_encrypted(data), "encryption check on headered text failed" def test_add_header(self): @@ -79,9 +80,9 @@ class TestVaultLib(unittest.TestCase): v.cipher_name = "TEST" sensitive_data = "ansible" data = v._add_header(sensitive_data) - lines = data.split('\n') + lines = data.split(b'\n') assert len(lines) > 1, "failed to properly add header" - header = lines[0] + header = to_unicode(lines[0]) assert header.endswith(';TEST'), "header does end with cipher name" header_parts = header.split(';') assert len(header_parts) == 3, "header has the wrong number of parts" @@ -91,10 +92,10 @@ class TestVaultLib(unittest.TestCase): def test_split_header(self): v = VaultLib('ansible') - data = "$ANSIBLE_VAULT;9.9;TEST\nansible" + data = b"$ANSIBLE_VAULT;9.9;TEST\nansible" rdata = v._split_header(data) - lines = rdata.split('\n') - assert lines[0] == "ansible" + lines = rdata.split(b'\n') + assert lines[0] == b"ansible" assert v.cipher_name == 'TEST', "cipher name was not set" assert v.version == "9.9" @@ -102,7 +103,7 @@ class TestVaultLib(unittest.TestCase): if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: raise SkipTest v = VaultLib('ansible') - v.cipher_name = 'AES' + v.cipher_name = u'AES' enc_data = v.encrypt("foobar") dec_data = v.decrypt(enc_data) assert enc_data != "foobar", "encryption failed"