ansible/bin/ansible-connection

355 lines
13 KiB
Text
Raw Normal View History

2016-10-26 17:38:08 +00:00
#!/usr/bin/env python
# (c) 2016, Ansible, Inc. <support@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
########################################################
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
2016-10-26 17:38:08 +00:00
__requires__ = ['ansible']
2016-10-26 17:38:08 +00:00
try:
import pkg_resources
except Exception:
pass
import fcntl
import os
import shlex
import signal
import socket
import struct
import sys
import time
import traceback
import syslog
import datetime
2016-10-26 17:38:08 +00:00
from io import BytesIO
from ansible import constants as C
from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.six.moves import cPickle, StringIO
from ansible.playbook.play_context import PlayContext
from ansible.plugins import connection_loader
from ansible.utils.path import unfrackpath, makedirs_safe
from ansible.errors import AnsibleConnectionFailure
from ansible.utils.display import Display
2016-10-26 17:38:08 +00:00
2016-10-26 17:38:08 +00:00
def do_fork():
'''
Does the required double fork for a daemon process. Based on
http://code.activestate.com/recipes/66012-fork-a-daemon-process-on-unix/
'''
try:
pid = os.fork()
if pid > 0:
return pid
os.chdir("/")
os.setsid()
os.umask(0)
try:
pid = os.fork()
if pid > 0:
sys.exit(0)
if C.DEFAULT_LOG_PATH != '':
out_file = file(C.DEFAULT_LOG_PATH, 'a+')
err_file = file(C.DEFAULT_LOG_PATH, 'a+', 0)
else:
out_file = file('/dev/null', 'a+')
err_file = file('/dev/null', 'a+', 0)
os.dup2(out_file.fileno(), sys.stdout.fileno())
os.dup2(err_file.fileno(), sys.stderr.fileno())
2016-10-26 17:38:08 +00:00
os.close(sys.stdin.fileno())
return pid
except OSError as e:
sys.exit(1)
except OSError as e:
sys.exit(1)
def send_data(s, data):
packed_len = struct.pack('!Q',len(data))
return s.sendall(packed_len + data)
def recv_data(s):
header_len = 8 # size of a packed unsigned long long
data = b""
while len(data) < header_len:
d = s.recv(header_len - len(data))
if not d:
return None
data += d
data_len = struct.unpack('!Q',data[:header_len])[0]
data = data[header_len:]
while len(data) < data_len:
d = s.recv(data_len - len(data))
if not d:
return None
data += d
return data
class Server():
2016-10-26 17:38:08 +00:00
def __init__(self, path, play_context):
display.vvvv("starting new persistent socket with path %s" % path, play_context.remote_addr)
2016-10-26 17:38:08 +00:00
self.path = path
self.play_context = play_context
self._start_time = datetime.datetime.now()
display.vvv("using connection plugin %s" % self.play_context.connection, play_context.remote_addr)
self.conn = connection_loader.get(play_context.connection, play_context, sys.stdin)
self.conn._connect()
if not self.conn.connected:
raise AnsibleConnectionFailure('unable to connect to remote host')
connection_time = datetime.datetime.now() - self._start_time
display.vvvv('connection established in %s' % connection_time, play_context.remote_addr)
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.socket.bind(path)
self.socket.listen(1)
2016-10-26 17:38:08 +00:00
signal.signal(signal.SIGALRM, self.alarm_handler)
def dispatch(self, obj, name, *args, **kwargs):
meth = getattr(obj, name, None)
if meth:
return meth(*args, **kwargs)
2016-10-26 17:38:08 +00:00
def alarm_handler(self, signum, frame):
'''
Alarm handler
'''
# FIXME: this should also set internal flags for other
# areas of code to check, so they can terminate
# earlier than the socket going back to the accept
# call and failing there.
#
# hooks the connection plugin to handle any cleanup
self.dispatch(self.conn, 'alarm_handler', signum, frame)
2016-10-26 17:38:08 +00:00
self.socket.close()
def run(self):
try:
while True:
# set the alarm, if we don't get an accept before it
# goes off we exit (via an exception caused by the socket
# getting closed while waiting on accept())
# FIXME: is this the best way to exit? as noted above in the
# handler we should probably be setting a flag to check
# here and in other parts of the code
signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT)
try:
(s, addr) = self.socket.accept()
display.vvvv('incoming request accepted on persistent socket', self.play_context.remote_addr)
2016-10-26 17:38:08 +00:00
# clear the alarm
# FIXME: potential race condition here between the accept and
# time to this call.
signal.alarm(0)
except:
break
while True:
data = recv_data(s)
if not data:
break
display.vvvv("received data on socket with len %s" % len(data), self.play_context.remote_addr)
2016-10-26 17:38:08 +00:00
signal.alarm(C.DEFAULT_TIMEOUT)
2016-10-26 17:38:08 +00:00
rc = 255
try:
if data.startswith(b'EXEC: '):
display.vvvv("socket operation is EXEC", self.play_context.remote_addr)
2016-10-26 17:38:08 +00:00
cmd = data.split(b'EXEC: ')[1]
(rc, stdout, stderr) = self.conn.exec_command(cmd)
elif data.startswith(b'PUT: ') or data.startswith(b'FETCH: '):
(op, src, dst) = shlex.split(to_native(data))
stdout = stderr = ''
try:
if op == 'FETCH:':
display.vvvv("socket operation is FETCH", self.play_context.remote_addr)
2016-10-26 17:38:08 +00:00
self.conn.fetch_file(src, dst)
elif op == 'PUT:':
display.vvvv("socket operation is PUT", self.play_context.remote_addr)
2016-10-26 17:38:08 +00:00
self.conn.put_file(src, dst)
rc = 0
except:
pass
elif data.startswith(b'CONTEXT: '):
display.vvvv("socket operation is CONTEXT", self.play_context.remote_addr)
pc_data = data.split(b'CONTEXT: ')[1]
src = StringIO(pc_data)
pc_data = cPickle.load(src)
src.close()
pc = PlayContext()
pc.deserialize(pc_data)
self.dispatch(self.conn, 'update_play_context', pc)
continue
2016-10-26 17:38:08 +00:00
else:
display.vvvv("socket operation is UNKNOWN", self.play_context.remote_addr)
2016-10-26 17:38:08 +00:00
stdout = ''
stderr = 'Invalid action specified'
except:
stdout = ''
stderr = traceback.format_exc()
signal.alarm(0)
display.vvvv("socket operation completed with rc %s" % rc, self.play_context.remote_addr)
2016-10-26 17:38:08 +00:00
send_data(s, to_bytes(str(rc)))
send_data(s, to_bytes(stdout))
send_data(s, to_bytes(stderr))
s.close()
except Exception as e:
display.vvvv(traceback.format_exc())
2016-10-26 17:38:08 +00:00
finally:
# when done, close the connection properly and cleanup
# the socket file so it can be recreated
end_time = datetime.datetime.now()
delta = end_time - self._start_time
display.v('shutting down control socket, connection was active for %s secs' % delta, self.play_context.remote_addr)
2016-10-26 17:38:08 +00:00
try:
self.conn.close()
self.socket.close()
2016-10-26 17:38:08 +00:00
except Exception as e:
pass
os.remove(self.path)
def main():
2016-10-26 17:38:08 +00:00
try:
# read the play context data via stdin, which means depickling it
# FIXME: as noted above, we will probably need to deserialize the
# connection loader here as well at some point, otherwise this
# won't find role- or playbook-based connection plugins
cur_line = sys.stdin.readline()
init_data = ''
while cur_line.strip() != '#END_INIT#':
if cur_line == '':
raise Exception("EOL found before init data was complete")
init_data += cur_line
cur_line = sys.stdin.readline()
src = BytesIO(to_bytes(init_data))
pc_data = cPickle.load(src)
pc = PlayContext()
pc.deserialize(pc_data)
except Exception as e:
# FIXME: better error message/handling/logging
2017-02-08 14:30:04 +00:00
sys.stderr.write(traceback.format_exc())
sys.exit("FAIL: %s" % e)
2016-10-26 17:38:08 +00:00
display.verbosity = pc.verbosity
ssh = connection_loader.get('ssh', class_only=True)
m = ssh._create_control_path(pc.remote_addr, pc.port, pc.remote_user)
2016-10-26 17:38:08 +00:00
# create the persistent connection dir if need be and create the paths
# which we will be using later
tmp_path = unfrackpath("$HOME/.ansible/pc")
makedirs_safe(tmp_path)
lk_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
sf_path = unfrackpath(m % dict(directory=tmp_path))
2016-10-26 17:38:08 +00:00
# if the socket file doesn't exist, spin up the daemon process
lock_fd = os.open(lk_path, os.O_RDWR|os.O_CREAT, 0o600)
fcntl.lockf(lock_fd, fcntl.LOCK_EX)
if not os.path.exists(sf_path):
pid = do_fork()
if pid == 0:
rc = 0
try:
server = Server(sf_path, pc)
except AnsibleConnectionFailure as exc:
display.vvv(str(exc), pc.remote_addr)
rc = 1
except Exception as exc:
display.vvv(traceback.format_exc(), pc.remote_addr)
rc = 1
fcntl.lockf(lock_fd, fcntl.LOCK_UN)
os.close(lock_fd)
if rc == 0:
server.run()
sys.exit(rc)
2016-10-26 17:38:08 +00:00
fcntl.lockf(lock_fd, fcntl.LOCK_UN)
os.close(lock_fd)
# now connect to the daemon process
# FIXME: if the socket file existed but the daemonized process was killed,
# the connection will timeout here. Need to make this more resilient.
rc = 0
while rc == 0:
data = sys.stdin.readline()
if data == '':
break
if data.strip() == '':
continue
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
attempts = 1
while True:
try:
sf.connect(sf_path)
break
except socket.error:
# FIXME: better error handling/logging/message here
time.sleep(C.PERSISTENT_CONNECT_INTERVAL)
2016-10-26 17:38:08 +00:00
attempts += 1
if attempts > C.PERSISTENT_CONNECT_RETRIES:
display.vvvv('number of connection attempts exceeded, unable to connect to control socket')
display.vvvv('persistent_connect_interval=%s, persistent_connect_retries=%s' % (C.PERSISTENT_CONNECT_INTERVAL, C.PERSISTENT_CONNECT_RETRIES))
sys.stderr.write('failed to connect to control socket')
2016-10-26 17:38:08 +00:00
sys.exit(255)
# send the play_context back into the connection so the connection
# can handle any privilege escalation activities
pc_data = 'CONTEXT: %s' % src.getvalue()
send_data(sf, to_bytes(pc_data))
src.close()
2016-10-26 17:38:08 +00:00
send_data(sf, to_bytes(data.strip()))
2016-10-26 17:38:08 +00:00
rc = int(recv_data(sf), 10)
stdout = recv_data(sf)
stderr = recv_data(sf)
2016-10-26 17:38:08 +00:00
sys.stdout.write(to_native(stdout))
sys.stderr.write(to_native(stderr))
sf.close()
break
2016-10-26 17:38:08 +00:00
sys.exit(rc)
if __name__ == '__main__':
display = Display()
2016-10-26 17:38:08 +00:00
main()