Make the timeout decorator raise an exception out of the function's scope (#49921)
* Revert "allow caller to deal with timeout (#49449)" This reverts commit63279823a7
. Flawed on many levels * Adds poor API to a public function * Papers over the fact that the public function is doing something bad by catching exceptions it cannot handle in the first place * Papers over the real cause of the issue which is a bug in the timeout decorator * Doesn't reraise properly * Catches the wrong exception Fixes #49824 Fixes #49817 * Make the timeout decorator properly raise an exception outside of the function's scope signal handlers which raise exceptions will never work well because the exception can be raised anywhere in the called code. This leads to exception race conditions where the exceptions could end up being hanlded by unintended pieces of the called code. The timeout decorator was using just that idiom. It was especially bad because the decorator syntactically occurs outside of the called code but because of the signal handler, the exception was being raised inside of the called code. This change uses a thread instead of a signal to manage the timeout in parallel to the execution of the decorated function. Since raising of the exception happens inside of the decorator, now, instead of inside of a signal handler, the timeout exception is raised from outside of the called code as expected which makes reasoning about where exceptions are to be expected intuitive again. Fixes #43884 * Add a common case test. Adding an integration test driven from our unittests. Most of the time we'll timeout in run_command which is running things in a subprocess. Create a test for that specific case in case anything funky comes up between threading and execve. * Don't use OSError-based TimeoutError as a base class Unlike most standard exceptions, OSError has a specific parameter list with specific meanings. Instead follow the example of other stdlib functions, concurrent.futures and multiprocessing and define a separate TimeoutException. * Add comment and docstring to point out that this is not hte Python3 TimeoutError (cherry picked from commitbd072fe83a
)
This commit is contained in:
parent
19dfb2f396
commit
00a02574c2
2 changed files with 77 additions and 18 deletions
|
@ -16,7 +16,8 @@
|
|||
from __future__ import (absolute_import, division, print_function)
|
||||
__metaclass__ = type
|
||||
|
||||
import signal
|
||||
import multiprocessing
|
||||
import multiprocessing.pool as mp
|
||||
|
||||
# timeout function to make sure some fact gathering
|
||||
# steps do not exceed a time limit
|
||||
|
@ -30,24 +31,25 @@ class TimeoutError(Exception):
|
|||
|
||||
|
||||
def timeout(seconds=None, error_message="Timer expired"):
|
||||
|
||||
"""
|
||||
Timeout decorator to expire after a set number of seconds. This raises an
|
||||
ansible.module_utils.facts.TimeoutError if the timeout is hit before the
|
||||
function completes.
|
||||
"""
|
||||
def decorator(func):
|
||||
def _handle_timeout(signum, frame):
|
||||
msg = 'Timer expired after %s seconds' % globals().get('GATHER_TIMEOUT')
|
||||
raise TimeoutError(msg)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
local_seconds = seconds
|
||||
if local_seconds is None:
|
||||
local_seconds = globals().get('GATHER_TIMEOUT') or DEFAULT_GATHER_TIMEOUT
|
||||
signal.signal(signal.SIGALRM, _handle_timeout)
|
||||
signal.alarm(local_seconds)
|
||||
timeout_value = seconds
|
||||
if timeout_value is None:
|
||||
timeout_value = globals().get('GATHER_TIMEOUT') or DEFAULT_GATHER_TIMEOUT
|
||||
|
||||
pool = mp.ThreadPool(processes=1)
|
||||
res = pool.apply_async(func, args, kwargs)
|
||||
pool.close()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
return result
|
||||
return res.get(timeout_value)
|
||||
except multiprocessing.TimeoutError:
|
||||
# This is an ansible.module_utils.common.facts.timeout.TimeoutError
|
||||
raise TimeoutError('Timer expired after %s seconds' % timeout_value)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
|
|
@ -20,13 +20,11 @@
|
|||
from __future__ import (absolute_import, division)
|
||||
__metaclass__ = type
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from ansible.compat.tests import unittest
|
||||
from ansible.compat.tests.mock import patch, MagicMock
|
||||
|
||||
from ansible.module_utils.facts import timeout
|
||||
|
||||
|
||||
|
@ -67,6 +65,10 @@ def sleep_amount_explicit_lower(amount):
|
|||
return 'Succeeded after {0} sec'.format(amount)
|
||||
|
||||
|
||||
#
|
||||
# Tests for how the timeout decorator is specified
|
||||
#
|
||||
|
||||
def test_defaults_still_within_bounds():
|
||||
# If the default changes outside of these bounds, some of the tests will
|
||||
# no longer test the right thing. Need to review and update the timeouts
|
||||
|
@ -110,3 +112,58 @@ def test_explicit_timeout():
|
|||
sleep_time = 3
|
||||
with pytest.raises(timeout.TimeoutError):
|
||||
assert sleep_amount_explicit_lower(sleep_time) == '(Not expected to succeed)'
|
||||
|
||||
|
||||
#
|
||||
# Test that exception handling works
|
||||
#
|
||||
|
||||
@timeout.timeout(1)
|
||||
def function_times_out():
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
# This is just about the same test as function_times_out but uses a separate process which is where
|
||||
# we normally have our timeouts. It's more of an integration test than a unit test.
|
||||
@timeout.timeout(1)
|
||||
def function_times_out_in_run_command(am):
|
||||
am.run_command([sys.executable, '-c', 'import time ; time.sleep(2)'])
|
||||
|
||||
|
||||
@timeout.timeout(1)
|
||||
def function_other_timeout():
|
||||
raise TimeoutError('Vanilla Timeout')
|
||||
|
||||
|
||||
@timeout.timeout(1)
|
||||
def function_raises():
|
||||
1 / 0
|
||||
|
||||
|
||||
@timeout.timeout(1)
|
||||
def function_catches_all_exceptions():
|
||||
try:
|
||||
time.sleep(10)
|
||||
except BaseException:
|
||||
raise RuntimeError('We should not have gotten here')
|
||||
|
||||
|
||||
def test_timeout_raises_timeout():
|
||||
with pytest.raises(timeout.TimeoutError):
|
||||
assert function_times_out() == '(Not expected to succeed)'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('stdin', ({},), indirect=['stdin'])
|
||||
def test_timeout_raises_timeout_integration_test(am):
|
||||
with pytest.raises(timeout.TimeoutError):
|
||||
assert function_times_out_in_run_command(am) == '(Not expected to succeed)'
|
||||
|
||||
|
||||
def test_timeout_raises_other_exception():
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
assert function_raises() == '(Not expected to succeed)'
|
||||
|
||||
|
||||
def test_exception_not_caught_by_called_code():
|
||||
with pytest.raises(timeout.TimeoutError):
|
||||
assert function_catches_all_exceptions() == '(Not expected to succeed)'
|
||||
|
|
Loading…
Reference in a new issue