Add support for using the Jinja2 built-ins such as range

This commit is contained in:
Daniel Hokka Zakrisson 2012-12-19 09:42:15 +01:00
parent ab9e9486ee
commit 0f1706220b

View file

@ -242,14 +242,18 @@ def template(basedir, text, vars, expand_lists=False):
class _jinja2_vars(object): class _jinja2_vars(object):
''' helper class to template all variable content before jinja2 sees it ''' ''' helper class to template all variable content before jinja2 sees it '''
def __init__(self, basedir, vars): def __init__(self, basedir, vars, globals):
self.basedir = basedir self.basedir = basedir
self.vars = vars self.vars = vars
self.globals = globals
def __contains__(self, k): def __contains__(self, k):
return k in self.vars return k in self.vars or k in self.globals
def __getitem__(self, varname): def __getitem__(self, varname):
if varname not in self.vars: if varname not in self.vars:
raise KeyError("undefined variable: %s" % varname) if varname in self.globals:
return self.globals[varname]
else:
raise KeyError("undefined variable: %s" % varname)
var = self.vars[varname] var = self.vars[varname]
# HostVars is special, return it as-is # HostVars is special, return it as-is
if isinstance(var, dict) and type(var) != dict: if isinstance(var, dict) and type(var) != dict:
@ -308,7 +312,7 @@ def template_from_file(basedir, path, vars):
# This line performs deep Jinja2 magic that uses the _jinja2_vars object for vars # This line performs deep Jinja2 magic that uses the _jinja2_vars object for vars
# Ideally, this could use some API where setting shared=True and the object won't get # Ideally, this could use some API where setting shared=True and the object won't get
# passed through dict(o), but I have not found that yet. # passed through dict(o), but I have not found that yet.
res = jinja2.utils.concat(t.root_render_func(t.new_context(_jinja2_vars(basedir, vars), shared=True))) res = jinja2.utils.concat(t.root_render_func(t.new_context(_jinja2_vars(basedir, vars, t.globals), shared=True)))
if data.endswith('\n') and not res.endswith('\n'): if data.endswith('\n') and not res.endswith('\n'):
res = res + '\n' res = res + '\n'