allow jinja2 unique filter compat (#45637)

* allow jinja2 unique filter compat
* detect if unique is provided, fallback with warning
* handle j2 specific params
* now all filters using unique must pass environment
* added env to tests

also normalized on how we normally import and use exceptoins
This commit is contained in:
Brian Coca 2018-09-25 14:27:02 -04:00 committed by GitHub
parent f4f5d941e5
commit 32ec69d827
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 39 deletions

View file

@ -27,14 +27,49 @@ import collections
import itertools
import math
from ansible import errors
from jinja2.filters import environmentfilter
from ansible.errors import AnsibleFilterError
from ansible.module_utils import basic
from ansible.module_utils.six import binary_type, text_type
from ansible.module_utils.six.moves import zip, zip_longest
from ansible.module_utils._text import to_native
from ansible.module_utils._text import to_native, to_text
try:
from jinja2.filters import do_unique
HAS_UNIQUE = True
except ImportError:
HAS_UNIQUE = False
try:
from __main__ import display
except ImportError:
from ansible.utils.display import Display
display = Display()
def unique(a):
@environmentfilter
def unique(environment, a, case_sensitive=False, attribute=None):
error = None
try:
if HAS_UNIQUE:
c = set(do_unique(environment, a, case_sensitive=case_sensitive, attribute=attribute))
except Exception as e:
if case_sensitive or attribute:
raise AnsibleFilterError("Jinja2's unique filter failed and we cannot fall back to Ansible's version "
"as it does not support the parameters supplied", orig_exc=e)
else:
display.warning('Falling back to Ansible unique filter as Jinaj2 one failed: %s' % to_text(e))
error = e
if not HAS_UNIQUE or error:
# handle Jinja2 specific attributes when using Ansible's version
if case_sensitive or attribute:
raise AnsibleFilterError("Ansible's unique filter does not support case_sensitive nor attribute parameters, "
"you need a newer version of Jinja2 that provides their version of the filter.")
if isinstance(a, collections.Hashable):
c = set(a)
else:
@ -45,36 +80,40 @@ def unique(a):
return c
def intersect(a, b):
@environmentfilter
def intersect(environment, a, b):
if isinstance(a, collections.Hashable) and isinstance(b, collections.Hashable):
c = set(a) & set(b)
else:
c = unique([x for x in a if x in b])
c = unique(environment, [x for x in a if x in b])
return c
def difference(a, b):
@environmentfilter
def difference(environment, a, b):
if isinstance(a, collections.Hashable) and isinstance(b, collections.Hashable):
c = set(a) - set(b)
else:
c = unique([x for x in a if x not in b])
c = unique(environment, [x for x in a if x not in b])
return c
def symmetric_difference(a, b):
@environmentfilter
def symmetric_difference(environment, a, b):
if isinstance(a, collections.Hashable) and isinstance(b, collections.Hashable):
c = set(a) ^ set(b)
else:
isect = intersect(a, b)
c = [x for x in union(a, b) if x not in isect]
isect = intersect(environment, a, b)
c = [x for x in union(environment, a, b) if x not in isect]
return c
def union(a, b):
@environmentfilter
def union(environment, a, b):
if isinstance(a, collections.Hashable) and isinstance(b, collections.Hashable):
c = set(a) | set(b)
else:
c = unique(a + b)
c = unique(environment, a + b)
return c
@ -95,14 +134,14 @@ def logarithm(x, base=math.e):
else:
return math.log(x, base)
except TypeError as e:
raise errors.AnsibleFilterError('log() can only be used on numbers: %s' % str(e))
raise AnsibleFilterError('log() can only be used on numbers: %s' % str(e))
def power(x, y):
try:
return math.pow(x, y)
except TypeError as e:
raise errors.AnsibleFilterError('pow() can only be used on numbers: %s' % str(e))
raise AnsibleFilterError('pow() can only be used on numbers: %s' % str(e))
def inversepower(x, base=2):
@ -112,7 +151,7 @@ def inversepower(x, base=2):
else:
return math.pow(x, 1.0 / float(base))
except (ValueError, TypeError) as e:
raise errors.AnsibleFilterError('root() can only be used on numbers: %s' % str(e))
raise AnsibleFilterError('root() can only be used on numbers: %s' % str(e))
def human_readable(size, isbits=False, unit=None):
@ -120,7 +159,7 @@ def human_readable(size, isbits=False, unit=None):
try:
return basic.bytes_to_human(size, isbits, unit)
except Exception:
raise errors.AnsibleFilterError("human_readable() can't interpret following string: %s" % size)
raise AnsibleFilterError("human_readable() can't interpret following string: %s" % size)
def human_to_bytes(size, default_unit=None, isbits=False):
@ -128,7 +167,7 @@ def human_to_bytes(size, default_unit=None, isbits=False):
try:
return basic.human_to_bytes(size, default_unit, isbits)
except Exception:
raise errors.AnsibleFilterError("human_to_bytes() can't interpret following string: %s" % size)
raise AnsibleFilterError("human_to_bytes() can't interpret following string: %s" % size)
def rekey_on_member(data, key, duplicates='error'):
@ -141,7 +180,7 @@ def rekey_on_member(data, key, duplicates='error'):
value would be duplicated or to overwrite previous entries if that's the case.
"""
if duplicates not in ('error', 'overwrite'):
raise errors.AnsibleFilterError("duplicates parameter to rekey_on_member has unknown value: {0}".format(duplicates))
raise AnsibleFilterError("duplicates parameter to rekey_on_member has unknown value: {0}".format(duplicates))
new_obj = {}
@ -150,24 +189,24 @@ def rekey_on_member(data, key, duplicates='error'):
elif isinstance(data, collections.Iterable) and not isinstance(data, (text_type, binary_type)):
iterate_over = data
else:
raise errors.AnsibleFilterError("Type is not a valid list, set, or dict")
raise AnsibleFilterError("Type is not a valid list, set, or dict")
for item in iterate_over:
if not isinstance(item, collections.Mapping):
raise errors.AnsibleFilterError("List item is not a valid dict")
raise AnsibleFilterError("List item is not a valid dict")
try:
key_elem = item[key]
except KeyError:
raise errors.AnsibleFilterError("Key {0} was not found".format(key))
raise AnsibleFilterError("Key {0} was not found".format(key))
except Exception as e:
raise errors.AnsibleFilterError(to_native(e))
raise AnsibleFilterError(to_native(e))
# Note: if new_obj[key_elem] exists it will always be a non-empty dict (it will at
# minimun contain {key: key_elem}
if new_obj.get(key_elem, None):
if duplicates == 'error':
raise errors.AnsibleFilterError("Key {0} is not unique, cannot correctly turn into dict".format(key_elem))
raise AnsibleFilterError("Key {0} is not unique, cannot correctly turn into dict".format(key_elem))
elif duplicates == 'overwrite':
new_obj[key_elem] = item
else:

View file

@ -4,9 +4,10 @@
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import pytest
from jinja2 import Environment
import ansible.plugins.filter.mathstuff as ms
from ansible.errors import AnsibleFilterError
@ -22,41 +23,43 @@ TWO_SETS_DATA = (([1, 2], [3, 4], ([], sorted([1, 2]), sorted([1, 2, 3, 4]), sor
(['a', 'b', 'c'], ['d', 'c', 'e'], (['c'], sorted(['a', 'b']), sorted(['a', 'b', 'd', 'e']), sorted(['a', 'b', 'c', 'e', 'd']))),
)
env = Environment()
@pytest.mark.parametrize('data, expected', UNIQUE_DATA)
class TestUnique:
def test_unhashable(self, data, expected):
assert sorted(ms.unique(list(data))) == expected
assert sorted(ms.unique(env, list(data))) == expected
def test_hashable(self, data, expected):
assert sorted(ms.unique(tuple(data))) == expected
assert sorted(ms.unique(env, tuple(data))) == expected
@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA)
class TestIntersect:
def test_unhashable(self, dataset1, dataset2, expected):
assert sorted(ms.intersect(list(dataset1), list(dataset2))) == expected[0]
assert sorted(ms.intersect(env, list(dataset1), list(dataset2))) == expected[0]
def test_hashable(self, dataset1, dataset2, expected):
assert sorted(ms.intersect(tuple(dataset1), tuple(dataset2))) == expected[0]
assert sorted(ms.intersect(env, tuple(dataset1), tuple(dataset2))) == expected[0]
@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA)
class TestDifference:
def test_unhashable(self, dataset1, dataset2, expected):
assert sorted(ms.difference(list(dataset1), list(dataset2))) == expected[1]
assert sorted(ms.difference(env, list(dataset1), list(dataset2))) == expected[1]
def test_hashable(self, dataset1, dataset2, expected):
assert sorted(ms.difference(tuple(dataset1), tuple(dataset2))) == expected[1]
assert sorted(ms.difference(env, tuple(dataset1), tuple(dataset2))) == expected[1]
@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA)
class TestSymmetricDifference:
def test_unhashable(self, dataset1, dataset2, expected):
assert sorted(ms.symmetric_difference(list(dataset1), list(dataset2))) == expected[2]
assert sorted(ms.symmetric_difference(env, list(dataset1), list(dataset2))) == expected[2]
def test_hashable(self, dataset1, dataset2, expected):
assert sorted(ms.symmetric_difference(tuple(dataset1), tuple(dataset2))) == expected[2]
assert sorted(ms.symmetric_difference(env, tuple(dataset1), tuple(dataset2))) == expected[2]
class TestMin: