Move binary module detection into executor/module_common.py

This commit is contained in:
Matt Martz 2016-05-11 15:14:01 -05:00
parent 3466e73c50
commit 0faddfa168
6 changed files with 37 additions and 36 deletions

View file

@ -490,6 +490,13 @@ def recursive_finder(name, data, py_module_names, py_module_cache, zf):
# Save memory; the file won't have to be read again for this ansible module. # Save memory; the file won't have to be read again for this ansible module.
del py_module_cache[py_module_file] del py_module_cache[py_module_file]
def _is_binary(module_path):
textchars = bytearray(set([7, 8, 9, 10, 12, 13, 27]) | set(range(0x20, 0x100)) - set([0x7f]))
with open(module_path, 'rb') as f:
start = f.read(1024)
return bool(start.translate(None, textchars))
def _find_snippet_imports(module_name, module_data, module_path, module_args, task_vars, module_compression): def _find_snippet_imports(module_name, module_data, module_path, module_args, task_vars, module_compression):
""" """
Given the source of the module, convert it to a Jinja2 template to insert Given the source of the module, convert it to a Jinja2 template to insert
@ -521,11 +528,13 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
module_substyle = 'jsonargs' module_substyle = 'jsonargs'
elif b'WANT_JSON' in module_data: elif b'WANT_JSON' in module_data:
module_substyle = module_style = 'non_native_want_json' module_substyle = module_style = 'non_native_want_json'
elif _is_binary(module_path):
module_substyle = module_style = 'binary'
shebang = None shebang = None
# Neither old-style nor non_native_want_json modules should be modified # Neither old-style, non_native_want_json nor binary modules should be modified
# except for the shebang line (Done by modify_module) # except for the shebang line (Done by modify_module)
if module_style in ('old', 'non_native_want_json'): if module_style in ('old', 'non_native_want_json', 'binary'):
return module_data, module_style, shebang return module_data, module_style, shebang
output = BytesIO() output = BytesIO()
@ -731,7 +740,9 @@ def modify_module(module_name, module_path, module_args, task_vars=dict(), modul
(module_data, module_style, shebang) = _find_snippet_imports(module_name, module_data, module_path, module_args, task_vars, module_compression) (module_data, module_style, shebang) = _find_snippet_imports(module_name, module_data, module_path, module_args, task_vars, module_compression)
if shebang is None: if module_style == 'binary':
return (module_path, module_data, module_style, shebang)
elif shebang is None:
lines = module_data.split(b"\n", 1) lines = module_data.split(b"\n", 1)
if lines[0].startswith(b"#!"): if lines[0].startswith(b"#!"):
shebang = lines[0].strip() shebang = lines[0].strip()
@ -753,4 +764,4 @@ def modify_module(module_name, module_path, module_args, task_vars=dict(), modul
else: else:
shebang = to_bytes(shebang, errors='strict') shebang = to_bytes(shebang, errors='strict')
return (module_data, module_style, shebang) return (module_path, module_data, module_style, shebang)

View file

@ -100,13 +100,6 @@ class ActionBase(with_metaclass(ABCMeta, object)):
return True return True
return False return False
def _is_binary(self, module_path):
textchars = bytearray(set([7, 8, 9, 10, 12, 13, 27]) | set(range(0x20, 0x100)) - set([0x7f]))
with open(module_path, 'rb') as f:
start = f.read(1024)
return bool(start.translate(None, textchars))
def _configure_module(self, module_name, module_args, task_vars=None): def _configure_module(self, module_name, module_args, task_vars=None):
''' '''
Handles the loading and templating of the module code through the Handles the loading and templating of the module code through the
@ -152,12 +145,9 @@ class ActionBase(with_metaclass(ABCMeta, object)):
"run 'git submodule update --init --recursive' to correct this problem." % (module_name)) "run 'git submodule update --init --recursive' to correct this problem." % (module_name))
# insert shared code and arguments into the module # insert shared code and arguments into the module
(module_data, module_style, module_shebang) = modify_module(module_name, module_path, module_args, task_vars=task_vars, module_compression=self._play_context.module_compression) (module_path, module_data, module_style, module_shebang) = modify_module(module_name, module_path, module_args, task_vars=task_vars, module_compression=self._play_context.module_compression)
if self._is_binary(module_path): return (module_style, module_shebang, module_data, module_path)
return ('non_native_want_json', None, None, module_path, True)
return (module_style, module_shebang, module_data, module_path, False)
def _compute_environment_string(self): def _compute_environment_string(self):
''' '''
@ -301,7 +291,7 @@ class ActionBase(with_metaclass(ABCMeta, object)):
return remote_path return remote_path
def _fixup_perms(self, remote_path, remote_user, execute=False, recursive=True): def _fixup_perms(self, remote_path, remote_user, execute=True, recursive=True):
""" """
We need the files we upload to be readable (and sometimes executable) We need the files we upload to be readable (and sometimes executable)
by the user being sudo'd to but we want to limit other people's access by the user being sudo'd to but we want to limit other people's access
@ -579,9 +569,8 @@ class ActionBase(with_metaclass(ABCMeta, object)):
# let module know our verbosity # let module know our verbosity
module_args['_ansible_verbosity'] = display.verbosity module_args['_ansible_verbosity'] = display.verbosity
(module_style, shebang, module_data, module_path, is_binary) = self._configure_module(module_name=module_name, module_args=module_args, task_vars=task_vars) (module_style, shebang, module_data, module_path) = self._configure_module(module_name=module_name, module_args=module_args, task_vars=task_vars)
if not shebang and module_style != 'binary':
if not shebang and not is_binary:
raise AnsibleError("module (%s) is missing interpreter line" % module_name) raise AnsibleError("module (%s) is missing interpreter line" % module_name)
# a remote tmp path may be necessary and not already created # a remote tmp path may be necessary and not already created
@ -593,13 +582,13 @@ class ActionBase(with_metaclass(ABCMeta, object)):
if tmp: if tmp:
remote_module_filename = self._connection._shell.get_remote_filename(module_path) remote_module_filename = self._connection._shell.get_remote_filename(module_path)
remote_module_path = self._connection._shell.join_path(tmp, remote_module_filename) remote_module_path = self._connection._shell.join_path(tmp, remote_module_filename)
if module_style in ['old', 'non_native_want_json']: if module_style in ('old', 'non_native_want_json', 'binary'):
# we'll also need a temp file to hold our module arguments # we'll also need a temp file to hold our module arguments
args_file_path = self._connection._shell.join_path(tmp, 'args') args_file_path = self._connection._shell.join_path(tmp, 'args')
if remote_module_path or module_style != 'new': if remote_module_path or module_style != 'new':
display.debug("transferring module to remote") display.debug("transferring module to remote")
if is_binary: if module_style == 'binary':
self._transfer_file(module_path, remote_module_path) self._transfer_file(module_path, remote_module_path)
else: else:
self._transfer_data(remote_module_path, module_data) self._transfer_data(remote_module_path, module_data)
@ -610,7 +599,7 @@ class ActionBase(with_metaclass(ABCMeta, object)):
for k,v in iteritems(module_args): for k,v in iteritems(module_args):
args_data += '%s="%s" ' % (k, pipes.quote(text_type(v))) args_data += '%s="%s" ' % (k, pipes.quote(text_type(v)))
self._transfer_data(args_file_path, args_data) self._transfer_data(args_file_path, args_data)
elif module_style == 'non_native_want_json': elif module_style in ('non_native_want_json', 'binary'):
self._transfer_data(args_file_path, json.dumps(module_args)) self._transfer_data(args_file_path, json.dumps(module_args))
display.debug("done transferring module to remote") display.debug("done transferring module to remote")
@ -618,7 +607,7 @@ class ActionBase(with_metaclass(ABCMeta, object)):
# Fix permissions of the tmp path and tmp files. This should be # Fix permissions of the tmp path and tmp files. This should be
# called after all files have been transferred. # called after all files have been transferred.
self._fixup_perms(tmp, remote_user, recursive=True, execute=is_binary) self._fixup_perms(tmp, remote_user, recursive=True)
cmd = "" cmd = ""
in_data = None in_data = None

View file

@ -54,18 +54,18 @@ class ActionModule(ActionBase):
module_args['_ansible_no_log'] = True module_args['_ansible_no_log'] = True
# configure, upload, and chmod the target module # configure, upload, and chmod the target module
(module_style, shebang, module_data, module_path, is_binary) = self._configure_module(module_name=module_name, module_args=module_args, task_vars=task_vars) (module_style, shebang, module_data, module_path) = self._configure_module(module_name=module_name, module_args=module_args, task_vars=task_vars)
if is_binary: if module_style == 'binary':
self._transfer_file(module_path, remote_module_path) self._transfer_file(module_path, remote_module_path)
else: else:
self._transfer_data(remote_module_path, module_data) self._transfer_data(remote_module_path, module_data)
# configure, upload, and chmod the async_wrapper module # configure, upload, and chmod the async_wrapper module
(async_module_style, shebang, async_module_data, _, _) = self._configure_module(module_name='async_wrapper', module_args=dict(), task_vars=task_vars) (async_module_style, shebang, async_module_data, _) = self._configure_module(module_name='async_wrapper', module_args=dict(), task_vars=task_vars)
self._transfer_data(async_module_path, async_module_data) self._transfer_data(async_module_path, async_module_data)
argsfile = None argsfile = None
if module_style == 'non_native_want_json': if module_style in ('non_native_want_json', 'binary'):
argsfile = self._transfer_data(self._connection._shell.join_path(tmp, 'arguments'), json.dumps(module_args)) argsfile = self._transfer_data(self._connection._shell.join_path(tmp, 'arguments'), json.dumps(module_args))
elif module_style == 'old': elif module_style == 'old':
args_data = "" args_data = ""

View file

@ -165,6 +165,7 @@ class ShellBase(object):
# don't quote the cmd if it's an empty string, because this will break pipelining mode # don't quote the cmd if it's an empty string, because this will break pipelining mode
if cmd.strip() != '': if cmd.strip() != '':
cmd = pipes.quote(cmd) cmd = pipes.quote(cmd)
cmd_parts = [] cmd_parts = []
if shebang: if shebang:
shebang = shebang.replace("#!", "").strip() shebang = shebang.replace("#!", "").strip()

View file

@ -54,8 +54,8 @@ class ShellModule(object):
return path return path
return '\'%s\'' % path return '\'%s\'' % path
# powershell requires that script files end with .ps1
def get_remote_filename(self, pathname): def get_remote_filename(self, pathname):
# powershell requires that script files end with .ps1
base_name = os.path.basename(pathname.strip()) base_name = os.path.basename(pathname.strip())
name, ext = os.path.splitext(base_name.strip()) name, ext = os.path.splitext(base_name.strip())
if ext.lower() not in ['.ps1', '.exe']: if ext.lower() not in ['.ps1', '.exe']:

View file

@ -217,7 +217,7 @@ class TestActionBase(unittest.TestCase):
with patch.object(os, 'rename') as m: with patch.object(os, 'rename') as m:
mock_task.args = dict(a=1, foo='fö〩') mock_task.args = dict(a=1, foo='fö〩')
mock_connection.module_implementation_preferences = ('',) mock_connection.module_implementation_preferences = ('',)
(style, shebang, data, module_path, is_binary) = action_base._configure_module(mock_task.action, mock_task.args) (style, shebang, data, path) = action_base._configure_module(mock_task.action, mock_task.args)
self.assertEqual(style, "new") self.assertEqual(style, "new")
self.assertEqual(shebang, b"#!/usr/bin/python") self.assertEqual(shebang, b"#!/usr/bin/python")
@ -229,7 +229,7 @@ class TestActionBase(unittest.TestCase):
mock_task.action = 'win_copy' mock_task.action = 'win_copy'
mock_task.args = dict(b=2) mock_task.args = dict(b=2)
mock_connection.module_implementation_preferences = ('.ps1',) mock_connection.module_implementation_preferences = ('.ps1',)
(style, shebang, data, module_path, is_binary) = action_base._configure_module('stat', mock_task.args) (style, shebang, data, path) = action_base._configure_module('stat', mock_task.args)
self.assertEqual(style, "new") self.assertEqual(style, "new")
self.assertEqual(shebang, None) self.assertEqual(shebang, None)
@ -572,7 +572,7 @@ class TestActionBase(unittest.TestCase):
action_base._low_level_execute_command = MagicMock() action_base._low_level_execute_command = MagicMock()
action_base._fixup_perms = MagicMock() action_base._fixup_perms = MagicMock()
action_base._configure_module.return_value = ('new', '#!/usr/bin/python', 'this is the module data', None, False) action_base._configure_module.return_value = ('new', '#!/usr/bin/python', 'this is the module data', 'path')
action_base._late_needs_tmp_path.return_value = False action_base._late_needs_tmp_path.return_value = False
action_base._compute_environment_string.return_value = '' action_base._compute_environment_string.return_value = ''
action_base._connection.has_pipelining = True action_base._connection.has_pipelining = True
@ -581,12 +581,12 @@ class TestActionBase(unittest.TestCase):
self.assertEqual(action_base._execute_module(module_name='foo', module_args=dict(z=9, y=8, x=7), task_vars=dict(a=1)), dict(rc=0, stdout="ok", stdout_lines=['ok'])) self.assertEqual(action_base._execute_module(module_name='foo', module_args=dict(z=9, y=8, x=7), task_vars=dict(a=1)), dict(rc=0, stdout="ok", stdout_lines=['ok']))
# test with needing/removing a remote tmp path # test with needing/removing a remote tmp path
action_base._configure_module.return_value = ('old', '#!/usr/bin/python', 'this is the module data', None, False) action_base._configure_module.return_value = ('old', '#!/usr/bin/python', 'this is the module data', 'path')
action_base._late_needs_tmp_path.return_value = True action_base._late_needs_tmp_path.return_value = True
action_base._make_tmp_path.return_value = '/the/tmp/path' action_base._make_tmp_path.return_value = '/the/tmp/path'
self.assertEqual(action_base._execute_module(), dict(rc=0, stdout="ok", stdout_lines=['ok'])) self.assertEqual(action_base._execute_module(), dict(rc=0, stdout="ok", stdout_lines=['ok']))
action_base._configure_module.return_value = ('non_native_want_json', '#!/usr/bin/python', 'this is the module data', None, False) action_base._configure_module.return_value = ('non_native_want_json', '#!/usr/bin/python', 'this is the module data', 'path')
self.assertEqual(action_base._execute_module(), dict(rc=0, stdout="ok", stdout_lines=['ok'])) self.assertEqual(action_base._execute_module(), dict(rc=0, stdout="ok", stdout_lines=['ok']))
play_context.become = True play_context.become = True
@ -594,14 +594,14 @@ class TestActionBase(unittest.TestCase):
self.assertEqual(action_base._execute_module(), dict(rc=0, stdout="ok", stdout_lines=['ok'])) self.assertEqual(action_base._execute_module(), dict(rc=0, stdout="ok", stdout_lines=['ok']))
# test an invalid shebang return # test an invalid shebang return
action_base._configure_module.return_value = ('new', '', 'this is the module data', None, False) action_base._configure_module.return_value = ('new', '', 'this is the module data', 'path')
action_base._late_needs_tmp_path.return_value = False action_base._late_needs_tmp_path.return_value = False
self.assertRaises(AnsibleError, action_base._execute_module) self.assertRaises(AnsibleError, action_base._execute_module)
# test with check mode enabled, once with support for check # test with check mode enabled, once with support for check
# mode and once with support disabled to raise an error # mode and once with support disabled to raise an error
play_context.check_mode = True play_context.check_mode = True
action_base._configure_module.return_value = ('new', '#!/usr/bin/python', 'this is the module data', None, False) action_base._configure_module.return_value = ('new', '#!/usr/bin/python', 'this is the module data', 'path')
self.assertEqual(action_base._execute_module(), dict(rc=0, stdout="ok", stdout_lines=['ok'])) self.assertEqual(action_base._execute_module(), dict(rc=0, stdout="ok", stdout_lines=['ok']))
action_base._supports_check_mode = False action_base._supports_check_mode = False
self.assertRaises(AnsibleError, action_base._execute_module) self.assertRaises(AnsibleError, action_base._execute_module)