diff --git a/server/pulp/server/async/tasks.py b/server/pulp/server/async/tasks.py index 2f1a32d..6cd5c04 100644 --- a/server/pulp/server/async/tasks.py +++ b/server/pulp/server/async/tasks.py @@ -20,13 +20,13 @@ from pulp.server.async.celery_instance import celery, RESOURCE_MANAGER_QUEUE, \ DEDICATED_QUEUE_EXCHANGE from pulp.server.exceptions import PulpException, MissingResource, \ PulpCodedException +from pulp.server.db.connection import suppress_connect_warning from pulp.server.db.model import Worker, ReservedResource, TaskStatus from pulp.server.exceptions import NoWorkers from pulp.server.managers.repo import _common as common_utils from pulp.server.managers import factory as managers from pulp.server.managers.schedule import utils - controller = control.Control(app=celery) _logger = logging.getLogger(__name__) @@ -100,7 +100,8 @@ class PulpTask(CeleryTask): """ args = self._type_transform(args) kwargs = self._type_transform(kwargs) - return super(PulpTask, self).__call__(*args, **kwargs) + with suppress_connect_warning(_logger): + return super(PulpTask, self).__call__(*args, **kwargs) @task(base=PulpTask, acks_late=True) @@ -450,24 +451,27 @@ class Task(PulpTask, ReservedTaskMixin): This overrides PulpTask's __call__() method. We use this method for task state tracking of Pulp tasks. """ - # Check task status and skip running the task if task state is 'canceled'. - try: - task_status = TaskStatus.objects.get(task_id=self.request.id) - except DoesNotExist: - task_status = None - if task_status and task_status['state'] == constants.CALL_CANCELED_STATE: - _logger.debug("Task cancel received for task-id : [%s]" % self.request.id) - return - # Update start_time and set the task state to 'running' for asynchronous tasks. - # Skip updating status for eagerly executed tasks, since we don't want to track - # synchronous tasks in our database. - if not self.request.called_directly: - now = datetime.now(dateutils.utc_tz()) - start_time = dateutils.format_iso8601_datetime(now) - # Using 'upsert' to avoid a possible race condition described in the apply_async method - # above. - TaskStatus.objects(task_id=self.request.id).update_one( - set__state=constants.CALL_RUNNING_STATE, set__start_time=start_time, upsert=True) + with suppress_connect_warning(_logger): + # Check task status and skip running the task if task state is 'canceled'. + try: + task_status = TaskStatus.objects.get(task_id=self.request.id) + except DoesNotExist: + task_status = None + if task_status and task_status['state'] == constants.CALL_CANCELED_STATE: + _logger.debug("Task cancel received for task-id : [%s]" % self.request.id) + return + # Update start_time and set the task state to 'running' for asynchronous tasks. + # Skip updating status for eagerly executed tasks, since we don't want to track + # synchronous tasks in our database. + if not self.request.called_directly: + now = datetime.now(dateutils.utc_tz()) + start_time = dateutils.format_iso8601_datetime(now) + # Using 'upsert' to avoid a possible race condition described + # in the apply_async method above. + TaskStatus.objects(task_id=self.request.id).update_one( + set__state=constants.CALL_RUNNING_STATE, set__start_time=start_time, + upsert=True) + # Run the actual task _logger.debug("Running task : [%s]" % self.request.id) return super(Task, self).__call__(*args, **kwargs) @@ -491,34 +495,35 @@ class Task(PulpTask, ReservedTaskMixin): % {'id': kwargs['scheduled_call_id']}) utils.reset_failure_count(kwargs['scheduled_call_id']) if not self.request.called_directly: - now = datetime.now(dateutils.utc_tz()) - finish_time = dateutils.format_iso8601_datetime(now) - task_status = TaskStatus.objects.get(task_id=task_id) - task_status['finish_time'] = finish_time - task_status['result'] = retval - - # Only set the state to finished if it's not already in a complete state. This is - # important for when the task has been canceled, so we don't move the task from canceled - # to finished. - if task_status['state'] not in constants.CALL_COMPLETE_STATES: - task_status['state'] = constants.CALL_FINISHED_STATE - if isinstance(retval, TaskResult): - task_status['result'] = retval.return_value - if retval.error: - task_status['error'] = retval.error.to_dict() - if retval.spawned_tasks: - task_list = [] - for spawned_task in retval.spawned_tasks: - if isinstance(spawned_task, AsyncResult): - task_list.append(spawned_task.task_id) - elif isinstance(spawned_task, dict): - task_list.append(spawned_task['task_id']) - task_status['spawned_tasks'] = task_list - if isinstance(retval, AsyncResult): - task_status['spawned_tasks'] = [retval.task_id, ] - task_status['result'] = None - - task_status.save() + with suppress_connect_warning(_logger): + now = datetime.now(dateutils.utc_tz()) + finish_time = dateutils.format_iso8601_datetime(now) + task_status = TaskStatus.objects.get(task_id=task_id) + task_status['finish_time'] = finish_time + task_status['result'] = retval + + # Only set the state to finished if it's not already in a complete state. + # This is important for when the task has been canceled, so we don't move + # the task from canceled to finished. + if task_status['state'] not in constants.CALL_COMPLETE_STATES: + task_status['state'] = constants.CALL_FINISHED_STATE + if isinstance(retval, TaskResult): + task_status['result'] = retval.return_value + if retval.error: + task_status['error'] = retval.error.to_dict() + if retval.spawned_tasks: + task_list = [] + for spawned_task in retval.spawned_tasks: + if isinstance(spawned_task, AsyncResult): + task_list.append(spawned_task.task_id) + elif isinstance(spawned_task, dict): + task_list.append(spawned_task['task_id']) + task_status['spawned_tasks'] = task_list + if isinstance(retval, AsyncResult): + task_status['spawned_tasks'] = [retval.task_id, ] + task_status['result'] = None + + task_status.save() common_utils.delete_working_directory() def on_failure(self, exc, task_id, args, kwargs, einfo): diff --git a/server/pulp/server/db/connection.py b/server/pulp/server/db/connection.py index b419560..0223a01 100644 --- a/server/pulp/server/db/connection.py +++ b/server/pulp/server/db/connection.py @@ -5,21 +5,21 @@ import itertools import logging import ssl import time +import warnings +from contextlib import contextmanager from gettext import gettext as _ import mongoengine +import semantic_version from pymongo.collection import Collection from pymongo.errors import AutoReconnect, OperationFailure from pymongo.son_manipulator import NamespaceInjector from pulp.common import error_codes - from pulp.server import config from pulp.server.compat import wraps from pulp.server.exceptions import PulpCodedException, PulpException -import semantic_version - _CONNECTION = None _DATABASE = None @@ -168,6 +168,87 @@ def initialize(name=None, seeds=None, max_pool_size=None, replica_set=None, max_ raise +@contextmanager +def suppress_connect_warning(logger): + """ + A context manager that will suppress pymongo's connect before fork warning + + python's warnings module gives you a way to filter warnings (warnings.filterwarnings), + and a way to catch warnings (warnings.catch_warnings), but not a way to do both. This + context manager filters out the specific python warning about connecting before fork, + while allowing all other warnings to normally be issued, so they aren't covered up + by this context manager. + + The behavior seen here is based on the warnings.catch_warnings context manager, which + also works by stashing the original showwarnings function and replacing it with a custom + function while the context is entered. + + Outright replacement of functions in the warnings module is recommended by that module. + + The logger from the calling module is used to help identify which call to + this context manager suppressed the pymongo warning. + + :param logger: logger from the module using this context manager + :type logger: logging.Logger + """ + try: + warning_func_name = warnings.showwarning.func_name + except AttributeError: + warning_func_name = None + + # if the current showwarning func isn't already pymongo_suppressing_showwarning, replace it. + # checking this makes this context manager reentrant with itself, since it won't replace + # showwarning functions already replaced by this CM, but will replace all others + if warning_func_name != 'pymongo_suppressing_showwarning': + original_showwarning = warnings.showwarning + + # this is effectively a functools.partial used to generate a version of the warning catcher + # using the passed-in logger and original showwarning function, but the logging module + # rudely checks the type before calling this, and does not accept partials + def pymongo_suppressing_showwarning(*args, **kwargs): + return _pymongo_warning_catcher(logger, original_showwarning, *args, **kwargs) + + try: + # replace warnings.showwarning with our pymongo warning catcher, + # using the passed-in logger and the current showwarning function + warnings.showwarning = pymongo_suppressing_showwarning + yield + finally: + # whatever happens, restore the original showwarning function + warnings.showwarning = original_showwarning + else: + # showwarning already replaced outside this context manager, nothing to do + yield + + +def _pymongo_warning_catcher(logger, showwarning, message, category, *args, **kwargs): + """ + An implementation of warnings.showwarning that supresses pymongo's connect before work warning + + This is intended to be wrapped with functools.partial by the mechanism that replaces + the warnings.showwarnings function, with the first two args being a list in which to store + the caught pymongo warning(s), and the second being the original warnings.showwarnings + function, through which all other warnings will be passed. + + :param caught: list to be populated with caught warnings for inspection in the caller + :type caught: list + :param showwarning: The "real" warnings.showwarning function, for passing unrelated warnings + :type showwarning: types.FunctionType + + All remaining args are the same as warnings.showwarning, and are only used here for filtering + """ + message_expected = 'MongoClient opened before fork' + # message is an instance of category, which becomes the warning message when cast as str + if category is UserWarning and message_expected in str(message): + # warning is pymongo connect before fork warning, log it... + logger.debug('pymongo reported connection before fork, ignoring') + # ...and filter it out for the rest of this process's lifetime + warnings.filterwarnings('ignore', message_expected) + else: + # not interested in this warning, run it through the provided showwarning function + showwarning(message, category, *args, **kwargs) + + def _connect_to_one_of_seeds(connection_kwargs, seeds_list, db_name): """ Helper function to iterate over a list of database seeds till a successful connection is made diff --git a/server/test/unit/server/db/test_connection.py b/server/test/unit/server/db/test_connection.py index dc834a8..d6c2d60 100644 --- a/server/test/unit/server/db/test_connection.py +++ b/server/test/unit/server/db/test_connection.py @@ -1,4 +1,5 @@ import unittest +import warnings from mock import call, patch, MagicMock, Mock from pymongo.errors import AutoReconnect @@ -908,3 +909,80 @@ class TestUnsafeRetry(unittest.TestCase): final_answer = mock_func() m_logger.error.assert_called_once_with('mock_func operation failed on mock_coll') self.assertTrue(final_answer is 'final') + + +@patch('warnings.showwarning', autospec=True) +class TestSuppressBeforeForkWarning(unittest.TestCase): + def test_warning_suppressed(self, showwarning): + logger = Mock() + + with connection.suppress_connect_warning(logger): + self.assertTrue(warnings.showwarning is not warnings._show_warning) + # The string to match is in this warning... + warnings.warn('MongoClient opened before fork mock warning') + # ...but not this warning + warnings.warn('Mock warning unrelated to MongoClient') + + # two warnings were raised in-context: one should emit a debug log message, + # the other should have called showwarning as-normal + self.assertEqual(logger.debug.call_count, 1) + self.assertEqual(showwarning.call_count, 1) + + def test_warning_restored(self, showwarning): + logger = Mock() + + with connection.suppress_connect_warning(logger): + # inside the context, warnings.showwarning has been replaced + self.assertTrue(warnings.showwarning is not showwarning) + + # upon leaving the context, warnings.showwarning is restored + self.assertTrue(warnings.showwarning is showwarning) + + def test_warning_restored_after_exception(self, showwarning): + logger = Mock() + showwarning.side_effect = Exception('Oh no!') + + with connection.suppress_connect_warning(logger): + self.assertTrue(warnings.showwarning is not showwarning) + self.assertRaises(Exception, warnings.warn, 'This will explode.') + + # despite the exception warnings.showwarning is restored, + # even if an exception was raised + self.assertTrue(warnings.showwarning is showwarning) + + def test_reentrant(self, showwarning): + logger = Mock() + + with connection.suppress_connect_warning(logger): + self.assertTrue(warnings.showwarning is not showwarning) + suppressing_showwarning = warnings.showwarning + with connection.suppress_connect_warning(logger): + self.assertTrue(warnings.showwarning is not showwarning) + + # showwarning should not be replaced in this inner context, so the version + # seen in the outer context should still be the current version seen + self.assertTrue(warnings.showwarning is suppressing_showwarning) + + # nesting suppress_connect_warning contexts does not implode the universe, + # but does still restore showwarning + self.assertTrue(warnings.showwarning is showwarning) + + def test_warnings_ignored(self, showwarning): + logger = Mock() + + with connection.suppress_connect_warning(logger): + warnings.warn('MongoClient opened before fork mock warning') + self.assertEqual(showwarning.call_count, 0) + + # after catching and logging the connect before fork warning, future warnings should be + # ignored. verify this first by snooping around in warnings.filters and checking that + # the first filter (and therefore newest, based on warnings.filterwarnings behavior) + # is the one added by the suppress_connect_warning context manager + action, regex = warnings.filters[0][:2] + self.assertEqual(action, 'ignore') + self.assertEqual(regex.pattern, 'MongoClient opened before fork') + + # also very this by issuing a matching warning outside of the suppressing context, + # and seeing that showwarning is not called + warnings.warn('MongoClient opened before fork mock warning') + self.assertEqual(showwarning.call_count, 0)