Enable writing plugins for jinja2 tests

This commit is contained in:
Devin Christensen 2014-11-26 17:58:45 -07:00
parent 6a8062baad
commit 1bf5224f82
13 changed files with 216 additions and 99 deletions

View file

@ -156,6 +156,7 @@ DEFAULT_CONNECTION_PLUGIN_PATH = get_config(p, DEFAULTS, 'connection_plugins', '
DEFAULT_LOOKUP_PLUGIN_PATH = get_config(p, DEFAULTS, 'lookup_plugins', 'ANSIBLE_LOOKUP_PLUGINS', '~/.ansible/plugins/lookup_plugins:/usr/share/ansible_plugins/lookup_plugins') DEFAULT_LOOKUP_PLUGIN_PATH = get_config(p, DEFAULTS, 'lookup_plugins', 'ANSIBLE_LOOKUP_PLUGINS', '~/.ansible/plugins/lookup_plugins:/usr/share/ansible_plugins/lookup_plugins')
DEFAULT_VARS_PLUGIN_PATH = get_config(p, DEFAULTS, 'vars_plugins', 'ANSIBLE_VARS_PLUGINS', '~/.ansible/plugins/vars_plugins:/usr/share/ansible_plugins/vars_plugins') DEFAULT_VARS_PLUGIN_PATH = get_config(p, DEFAULTS, 'vars_plugins', 'ANSIBLE_VARS_PLUGINS', '~/.ansible/plugins/vars_plugins:/usr/share/ansible_plugins/vars_plugins')
DEFAULT_FILTER_PLUGIN_PATH = get_config(p, DEFAULTS, 'filter_plugins', 'ANSIBLE_FILTER_PLUGINS', '~/.ansible/plugins/filter_plugins:/usr/share/ansible_plugins/filter_plugins') DEFAULT_FILTER_PLUGIN_PATH = get_config(p, DEFAULTS, 'filter_plugins', 'ANSIBLE_FILTER_PLUGINS', '~/.ansible/plugins/filter_plugins:/usr/share/ansible_plugins/filter_plugins')
DEFAULT_TEST_PLUGIN_PATH = get_config(p, DEFAULTS, 'test_plugins', 'ANSIBLE_TEST_PLUGINS', '~/.ansible/plugins/test_plugins:/usr/share/ansible_plugins/test_plugins')
CACHE_PLUGIN = get_config(p, DEFAULTS, 'fact_caching', 'ANSIBLE_CACHE_PLUGIN', 'memory') CACHE_PLUGIN = get_config(p, DEFAULTS, 'fact_caching', 'ANSIBLE_CACHE_PLUGIN', 'memory')
CACHE_PLUGIN_CONNECTION = get_config(p, DEFAULTS, 'fact_caching_connection', 'ANSIBLE_CACHE_PLUGIN_CONNECTION', None) CACHE_PLUGIN_CONNECTION = get_config(p, DEFAULTS, 'fact_caching_connection', 'ANSIBLE_CACHE_PLUGIN_CONNECTION', None)

View file

@ -74,55 +74,6 @@ def to_nice_json(a, *args, **kw):
return to_json(a, *args, **kw) return to_json(a, *args, **kw)
return json.dumps(a, indent=4, sort_keys=True, *args, **kw) return json.dumps(a, indent=4, sort_keys=True, *args, **kw)
def failed(*a, **kw):
''' Test if task result yields failed '''
item = a[0]
if type(item) != dict:
raise errors.AnsibleFilterError("|failed expects a dictionary")
rc = item.get('rc',0)
failed = item.get('failed',False)
if rc != 0 or failed:
return True
else:
return False
def success(*a, **kw):
''' Test if task result yields success '''
return not failed(*a, **kw)
def changed(*a, **kw):
''' Test if task result yields changed '''
item = a[0]
if type(item) != dict:
raise errors.AnsibleFilterError("|changed expects a dictionary")
if not 'changed' in item:
changed = False
if ('results' in item # some modules return a 'results' key
and type(item['results']) == list
and type(item['results'][0]) == dict):
for result in item['results']:
changed = changed or result.get('changed', False)
else:
changed = item.get('changed', False)
return changed
def skipped(*a, **kw):
''' Test if task result yields skipped '''
item = a[0]
if type(item) != dict:
raise errors.AnsibleFilterError("|skipped expects a dictionary")
skipped = item.get('skipped', False)
return skipped
def mandatory(a):
''' Make a variable mandatory '''
try:
a
except NameError:
raise errors.AnsibleFilterError('Mandatory variable not defined.')
else:
return a
def bool(a): def bool(a):
''' return a bool for the arg ''' ''' return a bool for the arg '''
if a is None or type(a) == bool: if a is None or type(a) == bool:
@ -142,27 +93,6 @@ def fileglob(pathname):
''' return list of matched files for glob ''' ''' return list of matched files for glob '''
return glob.glob(pathname) return glob.glob(pathname)
def regex(value='', pattern='', ignorecase=False, match_type='search'):
''' Expose `re` as a boolean filter using the `search` method by default.
This is likely only useful for `search` and `match` which already
have their own filters.
'''
if ignorecase:
flags = re.I
else:
flags = 0
_re = re.compile(pattern, flags=flags)
_bool = __builtins__.get('bool')
return _bool(getattr(_re, match_type, 'search')(value))
def match(value, pattern='', ignorecase=False):
''' Perform a `re.match` returning a boolean '''
return regex(value, pattern, ignorecase, 'match')
def search(value, pattern='', ignorecase=False):
''' Perform a `re.search` returning a boolean '''
return regex(value, pattern, ignorecase, 'search')
def regex_replace(value='', pattern='', replacement='', ignorecase=False): def regex_replace(value='', pattern='', replacement='', ignorecase=False):
''' Perform a `re.sub` returning a string ''' ''' Perform a `re.sub` returning a string '''
@ -299,19 +229,6 @@ class FilterModule(object):
'realpath': partial(unicode_wrap, os.path.realpath), 'realpath': partial(unicode_wrap, os.path.realpath),
'relpath': partial(unicode_wrap, os.path.relpath), 'relpath': partial(unicode_wrap, os.path.relpath),
# failure testing
'failed' : failed,
'success' : success,
# changed testing
'changed' : changed,
# skip testing
'skipped' : skipped,
# variable existence
'mandatory': mandatory,
# value as boolean # value as boolean
'bool': bool, 'bool': bool,
@ -333,9 +250,6 @@ class FilterModule(object):
'fileglob': fileglob, 'fileglob': fileglob,
# regex # regex
'match': match,
'search': search,
'regex': regex,
'regex_replace': regex_replace, 'regex_replace': regex_replace,
# ? : ; # ? : ;

View file

@ -67,13 +67,6 @@ def max(a):
_max = __builtins__.get('max') _max = __builtins__.get('max')
return _max(a); return _max(a);
def isnotanumber(x):
try:
return math.isnan(x)
except TypeError:
return False
def logarithm(x, base=math.e): def logarithm(x, base=math.e):
try: try:
if base == 10: if base == 10:
@ -107,7 +100,6 @@ class FilterModule(object):
def filters(self): def filters(self):
return { return {
# general math # general math
'isnan': isnotanumber,
'min' : min, 'min' : min,
'max' : max, 'max' : max,

View file

@ -0,0 +1,113 @@
# (c) 2012, Jeroen Hoekx <jeroen@hoekx.be>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
import re
from ansible import errors
def failed(*a, **kw):
''' Test if task result yields failed '''
item = a[0]
if type(item) != dict:
raise errors.AnsibleFilterError("|failed expects a dictionary")
rc = item.get('rc',0)
failed = item.get('failed',False)
if rc != 0 or failed:
return True
else:
return False
def success(*a, **kw):
''' Test if task result yields success '''
return not failed(*a, **kw)
def changed(*a, **kw):
''' Test if task result yields changed '''
item = a[0]
if type(item) != dict:
raise errors.AnsibleFilterError("|changed expects a dictionary")
if not 'changed' in item:
changed = False
if ('results' in item # some modules return a 'results' key
and type(item['results']) == list
and type(item['results'][0]) == dict):
for result in item['results']:
changed = changed or result.get('changed', False)
else:
changed = item.get('changed', False)
return changed
def skipped(*a, **kw):
''' Test if task result yields skipped '''
item = a[0]
if type(item) != dict:
raise errors.AnsibleFilterError("|skipped expects a dictionary")
skipped = item.get('skipped', False)
return skipped
def mandatory(a):
''' Make a variable mandatory '''
try:
a
except NameError:
raise errors.AnsibleFilterError('Mandatory variable not defined.')
else:
return a
def regex(value='', pattern='', ignorecase=False, match_type='search'):
''' Expose `re` as a boolean filter using the `search` method by default.
This is likely only useful for `search` and `match` which already
have their own filters.
'''
if ignorecase:
flags = re.I
else:
flags = 0
_re = re.compile(pattern, flags=flags)
_bool = __builtins__.get('bool')
return _bool(getattr(_re, match_type, 'search')(value))
def match(value, pattern='', ignorecase=False):
''' Perform a `re.match` returning a boolean '''
return regex(value, pattern, ignorecase, 'match')
def search(value, pattern='', ignorecase=False):
''' Perform a `re.search` returning a boolean '''
return regex(value, pattern, ignorecase, 'search')
class TestModule(object):
''' Ansible core jinja2 tests '''
def tests(self):
return {
# failure testing
'failed' : failed,
'success' : success,
# changed testing
'changed' : changed,
# skip testing
'skipped' : skipped,
# variable existence
'mandatory': mandatory,
# regex
'match': match,
'search': search,
'regex': regex,
}

View file

@ -0,0 +1,36 @@
# (c) 2014, Brian Coca <bcoca@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import math
from ansible import errors
def isnotanumber(x):
try:
return math.isnan(x)
except TypeError:
return False
class TestModule(object):
''' Ansible math jinja2 tests '''
def tests(self):
return {
# general math
'isnan': isnotanumber,
}

View file

@ -1403,7 +1403,11 @@ def safe_eval(expr, locals={}, include_exceptions=False):
for filter in filter_loader.all(): for filter in filter_loader.all():
filter_list.extend(filter.filters().keys()) filter_list.extend(filter.filters().keys())
CALL_WHITELIST = C.DEFAULT_CALLABLE_WHITELIST + filter_list test_list = []
for test in test_loader.all():
test_list.extend(test.tests().keys())
CALL_WHITELIST = C.DEFAULT_CALLABLE_WHITELIST + filter_list + test_list
class CleansingNodeVisitor(ast.NodeVisitor): class CleansingNodeVisitor(ast.NodeVisitor):
def generic_visit(self, node, inside_call=False): def generic_visit(self, node, inside_call=False):

View file

@ -296,6 +296,13 @@ filter_loader = PluginLoader(
'filter_plugins' 'filter_plugins'
) )
test_loader = PluginLoader(
'TestModule',
'ansible.runner.test_plugins',
C.DEFAULT_TEST_PLUGIN_PATH,
'test_plugins'
)
fragment_loader = PluginLoader( fragment_loader = PluginLoader(
'ModuleDocFragment', 'ModuleDocFragment',
'ansible.utils.module_docs_fragments', 'ansible.utils.module_docs_fragments',

View file

@ -39,6 +39,7 @@ from ansible.utils import to_bytes, to_unicode
class Globals(object): class Globals(object):
FILTERS = None FILTERS = None
TESTS = None
def __init__(self): def __init__(self):
pass pass
@ -54,10 +55,26 @@ def _get_filters():
filters = {} filters = {}
for fp in plugins: for fp in plugins:
filters.update(fp.filters()) filters.update(fp.filters())
filters.update(_get_tests())
Globals.FILTERS = filters Globals.FILTERS = filters
return Globals.FILTERS return Globals.FILTERS
def _get_tests():
''' return test plugin instances '''
if Globals.TESTS is not None:
return Globals.TESTS
from ansible import utils
plugins = [ x for x in utils.plugins.test_loader.all()]
tests = {}
for tp in plugins:
tests.update(tp.tests())
Globals.TESTS = tests
return Globals.TESTS
def _get_extensions(): def _get_extensions():
''' return jinja2 extensions to load ''' ''' return jinja2 extensions to load '''
@ -237,6 +254,7 @@ def template_from_file(basedir, path, vars, vault_password=None):
environment = jinja2.Environment(loader=loader, trim_blocks=True, extensions=_get_extensions()) environment = jinja2.Environment(loader=loader, trim_blocks=True, extensions=_get_extensions())
environment.filters.update(_get_filters()) environment.filters.update(_get_filters())
environment.tests.update(_get_tests())
environment.globals['lookup'] = my_lookup environment.globals['lookup'] = my_lookup
environment.globals['finalize'] = my_finalize environment.globals['finalize'] = my_finalize
if fail_on_undefined: if fail_on_undefined:
@ -351,6 +369,7 @@ def template_from_string(basedir, data, vars, fail_on_undefined=False):
environment = jinja2.Environment(trim_blocks=True, undefined=StrictUndefined, extensions=_get_extensions(), finalize=my_finalize) environment = jinja2.Environment(trim_blocks=True, undefined=StrictUndefined, extensions=_get_extensions(), finalize=my_finalize)
environment.filters.update(_get_filters()) environment.filters.update(_get_filters())
environment.tests.update(_get_tests())
environment.template_class = J2Template environment.template_class = J2Template
if '_original_file' in vars: if '_original_file' in vars:

View file

@ -162,6 +162,7 @@ DEFAULT_CONNECTION_PLUGIN_PATH = get_config(p, DEFAULTS, 'connection_plugins', '
DEFAULT_LOOKUP_PLUGIN_PATH = get_config(p, DEFAULTS, 'lookup_plugins', 'ANSIBLE_LOOKUP_PLUGINS', '~/.ansible/plugins/lookup_plugins:/usr/share/ansible_plugins/lookup_plugins') DEFAULT_LOOKUP_PLUGIN_PATH = get_config(p, DEFAULTS, 'lookup_plugins', 'ANSIBLE_LOOKUP_PLUGINS', '~/.ansible/plugins/lookup_plugins:/usr/share/ansible_plugins/lookup_plugins')
DEFAULT_VARS_PLUGIN_PATH = get_config(p, DEFAULTS, 'vars_plugins', 'ANSIBLE_VARS_PLUGINS', '~/.ansible/plugins/vars_plugins:/usr/share/ansible_plugins/vars_plugins') DEFAULT_VARS_PLUGIN_PATH = get_config(p, DEFAULTS, 'vars_plugins', 'ANSIBLE_VARS_PLUGINS', '~/.ansible/plugins/vars_plugins:/usr/share/ansible_plugins/vars_plugins')
DEFAULT_FILTER_PLUGIN_PATH = get_config(p, DEFAULTS, 'filter_plugins', 'ANSIBLE_FILTER_PLUGINS', '~/.ansible/plugins/filter_plugins:/usr/share/ansible_plugins/filter_plugins') DEFAULT_FILTER_PLUGIN_PATH = get_config(p, DEFAULTS, 'filter_plugins', 'ANSIBLE_FILTER_PLUGINS', '~/.ansible/plugins/filter_plugins:/usr/share/ansible_plugins/filter_plugins')
DEFAULT_TEST_PLUGIN_PATH = get_config(p, DEFAULTS, 'test_plugins', 'ANSIBLE_TEST_PLUGINS', '~/.ansible/plugins/test_plugins:/usr/share/ansible_plugins/test_plugins')
CACHE_PLUGIN = get_config(p, DEFAULTS, 'fact_caching', 'ANSIBLE_CACHE_PLUGIN', 'memory') CACHE_PLUGIN = get_config(p, DEFAULTS, 'fact_caching', 'ANSIBLE_CACHE_PLUGIN', 'memory')
CACHE_PLUGIN_CONNECTION = get_config(p, DEFAULTS, 'fact_caching_connection', 'ANSIBLE_CACHE_PLUGIN_CONNECTION', None) CACHE_PLUGIN_CONNECTION = get_config(p, DEFAULTS, 'fact_caching_connection', 'ANSIBLE_CACHE_PLUGIN_CONNECTION', None)

View file

@ -311,6 +311,13 @@ filter_loader = PluginLoader(
'filter_plugins' 'filter_plugins'
) )
test_loader = PluginLoader(
'TestModule',
'ansible.plugins.test',
C.DEFAULT_TEST_PLUGIN_PATH,
'test_plugins'
)
fragment_loader = PluginLoader( fragment_loader = PluginLoader(
'ModuleDocFragment', 'ModuleDocFragment',
'ansible.utils.module_docs_fragments', 'ansible.utils.module_docs_fragments',

View file

@ -28,7 +28,7 @@ from jinja2.runtime import StrictUndefined
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleFilterError, AnsibleUndefinedVariable from ansible.errors import AnsibleError, AnsibleFilterError, AnsibleUndefinedVariable
from ansible.plugins import filter_loader, lookup_loader from ansible.plugins import filter_loader, lookup_loader, test_loader
from ansible.template.safe_eval import safe_eval from ansible.template.safe_eval import safe_eval
from ansible.template.template import AnsibleJ2Template from ansible.template.template import AnsibleJ2Template
from ansible.template.vars import AnsibleJ2Vars from ansible.template.vars import AnsibleJ2Vars
@ -57,6 +57,7 @@ class Templar:
self._loader = loader self._loader = loader
self._basedir = loader.get_basedir() self._basedir = loader.get_basedir()
self._filters = None self._filters = None
self._tests = None
self._available_variables = variables self._available_variables = variables
# flags to determine whether certain failures during templating # flags to determine whether certain failures during templating
@ -93,9 +94,26 @@ class Templar:
self._filters = dict() self._filters = dict()
for fp in plugins: for fp in plugins:
self._filters.update(fp.filters()) self._filters.update(fp.filters())
self._filters.update(self._get_tests())
return self._filters.copy() return self._filters.copy()
def _get_tests(self):
'''
Returns tests plugins, after loading and caching them if need be
'''
if self._tests is not None:
return self._tests.copy()
plugins = [x for x in test_loader.all()]
self._tests = dict()
for fp in plugins:
self._tests.update(fp.tests())
return self._tests.copy()
def _get_extensions(self): def _get_extensions(self):
''' '''
Return jinja2 extensions to load. Return jinja2 extensions to load.
@ -229,6 +247,7 @@ class Templar:
environment = Environment(trim_blocks=True, undefined=StrictUndefined, extensions=self._get_extensions(), finalize=self._finalize) environment = Environment(trim_blocks=True, undefined=StrictUndefined, extensions=self._get_extensions(), finalize=self._finalize)
environment.filters.update(self._get_filters()) environment.filters.update(self._get_filters())
environment.tests.update(self._get_tests())
environment.template_class = AnsibleJ2Template environment.template_class = AnsibleJ2Template
# FIXME: may not be required anymore, as the basedir stuff will # FIXME: may not be required anymore, as the basedir stuff will

View file

@ -23,7 +23,7 @@ import sys
from six.moves import builtins from six.moves import builtins
from ansible import constants as C from ansible import constants as C
from ansible.plugins import filter_loader from ansible.plugins import filter_loader, test_loader
def safe_eval(expr, locals={}, include_exceptions=False): def safe_eval(expr, locals={}, include_exceptions=False):
''' '''
@ -77,7 +77,11 @@ def safe_eval(expr, locals={}, include_exceptions=False):
for filter in filter_loader.all(): for filter in filter_loader.all():
filter_list.extend(filter.filters().keys()) filter_list.extend(filter.filters().keys())
CALL_WHITELIST = C.DEFAULT_CALLABLE_WHITELIST + filter_list test_list = []
for test in test_loader.all():
test_list.extend(test.tests().keys())
CALL_WHITELIST = C.DEFAULT_CALLABLE_WHITELIST + filter_list + test_list
class CleansingNodeVisitor(ast.NodeVisitor): class CleansingNodeVisitor(ast.NodeVisitor):
def generic_visit(self, node, inside_call=False): def generic_visit(self, node, inside_call=False):