shuts down persistent connections at end of play run (#32825)

This change will now track any created persistent connection and shut it
down at the end of the play run.  This change also includes an update to
properly honor the reset_connection meta handler.
This commit is contained in:
Peter Sprygada 2017-11-22 10:30:06 -05:00 committed by John R Barker
parent 9d56ffa4ed
commit 69575e25d0
7 changed files with 62 additions and 32 deletions

View file

@ -51,6 +51,8 @@ class ConnectionProcess(object):
self.srv = JsonRpcServer()
self.sock = None
self.connection = None
def start(self):
try:
messages = list()
@ -67,6 +69,7 @@ class ConnectionProcess(object):
self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null')
self.connection.set_options()
self.connection._connect()
self.connection._socket_path = self.socket_path
self.srv.register(self.connection)
messages.append('connection to remote device started successfully')
@ -84,7 +87,7 @@ class ConnectionProcess(object):
def run(self):
try:
while True:
while self.connection.connected:
signal.signal(signal.SIGALRM, self.connect_timeout)
signal.signal(signal.SIGTERM, self.handler)
signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT)
@ -135,24 +138,19 @@ class ConnectionProcess(object):
def shutdown(self):
""" Shuts down the local domain socket
"""
if not os.path.exists(self.socket_path):
return
try:
if self.sock:
self.sock.close()
if self.connection:
self.connection.close()
except Exception:
pass
finally:
if os.path.exists(self.socket_path):
os.remove(self.socket_path)
setattr(self.connection, '_socket_path', None)
setattr(self.connection, '_connected', False)
if os.path.exists(self.socket_path):
try:
if self.sock:
self.sock.close()
if self.connection:
self.connection.close()
except:
pass
finally:
if os.path.exists(self.socket_path):
os.remove(self.socket_path)
setattr(self.connection, '_socket_path', None)
setattr(self.connection, '_connected', False)
display.display('shutdown complete', log_only=True)
def do_EXEC(self, data):

View file

@ -209,6 +209,15 @@ class Connection(ConnectionBase):
return 0, to_bytes(self._manager.session_id, errors='surrogate_or_strict'), b''
def reset(self):
'''
Reset the connection
'''
if self._socket_path:
display.vvvv('resetting persistent connection for socket_path %s' % self._socket_path, host=self._play_context.remote_addr)
self.close()
display.vvvv('reset call on connection instance', host=self._play_context.remote_addr)
def close(self):
if self._manager:
self._manager.close_session()

View file

@ -230,7 +230,8 @@ class Connection(ConnectionBase):
'''
if self._socket_path:
display.vvvv('resetting persistent connection for socket_path %s' % self._socket_path, host=self._play_context.remote_addr)
self.shutdown()
self.close()
display.vvvv('reset call on connection instance', host=self._play_context.remote_addr)
def close(self):
'''

View file

@ -34,8 +34,9 @@ from ansible.executor.process.worker import WorkerProcess
from ansible.executor.task_result import TaskResult
from ansible.inventory.host import Host
from ansible.module_utils.six.moves import queue as Queue
from ansible.module_utils.six import iteritems, string_types
from ansible.module_utils.six import iteritems, itervalues, string_types
from ansible.module_utils._text import to_text
from ansible.module_utils.connection import Connection
from ansible.playbook.helpers import load_list_of_blocks
from ansible.playbook.included_file import IncludedFile
from ansible.playbook.task_include import TaskInclude
@ -132,7 +133,15 @@ class StrategyBase:
self._results_thread.daemon = True
self._results_thread.start()
# holds the list of active (persistent) connections to be shutdown at
# play completion
self._active_connections = dict()
def cleanup(self):
# close active persistent connections
for sock in itervalues(self._active_connections):
conn = Connection(sock)
conn.reset()
self._final_q.put(_sentinel)
self._results_thread.join()
@ -892,8 +901,13 @@ class StrategyBase:
iterator._host_states[host.name].run_state = iterator.ITERATING_COMPLETE
msg = "ending play"
elif meta_action == 'reset_connection':
connection = connection_loader.get(play_context.connection, play_context, os.devnull)
play_context.set_options_from_plugin(connection)
if target_host in self._active_connections:
connection = Connection(self._active_connections[target_host])
del self._active_connections[target_host]
else:
connection = connection_loader.get(play_context.connection, play_context, os.devnull)
play_context.set_options_from_plugin(connection)
if connection:
connection.reset()
msg = 'reset connection'
@ -920,3 +934,12 @@ class StrategyBase:
if host.name not in self._tqm._unreachable_hosts:
hosts_left.append(host)
return hosts_left
def update_active_connections(self, results):
''' updates the current active persistent connections '''
for r in results:
if 'args' in r._task_fields:
socket_path = r._task_fields['args'].get('_ansible_socket')
if socket_path:
if r._host not in self._active_connections:
self._active_connections[r._host] = socket_path

View file

@ -175,6 +175,8 @@ class StrategyModule(StrategyBase):
results = self._process_pending_results(iterator)
host_results.extend(results)
self.update_active_connections(results)
try:
included_files = IncludedFile.process_include_results(
host_results,

View file

@ -292,6 +292,8 @@ class StrategyModule(StrategyBase):
host_results.extend(results)
self.update_active_connections(results)
try:
included_files = IncludedFile.process_include_results(
host_results,

View file

@ -44,15 +44,10 @@ class JsonRpcServer(object):
kwargs = params
rpc_method = None
if method in ('shutdown', 'reset'):
rpc_method = getattr(self, 'shutdown')
else:
for obj in self._objects:
rpc_method = getattr(obj, method, None)
if rpc_method:
break
for obj in self._objects:
rpc_method = getattr(obj, method, None)
if rpc_method:
break
if not rpc_method:
error = self.method_not_found()