Python 3: Catching warnings during multiprocessing

Question:

Too long; didn’t read

The warnings.catch_warnings() context manager is not thread safe. How do I use it in a parallel processing environment?

Background

The code below solves a maximization problem using parallel processing with Python’s multiprocessing module. It takes a list of (immutable) widgets, partitions them up (see Efficient multiprocessing of massive, brute force maximization in Python 3), finds the maxima (“finalists”) of all the partitions, and then finds the maximum (“champion”) of those “finalists.” If I understand my own code correctly (and I wouldn’t be here if I did), I’m sharing memory with all the child processes to give them the input widgets, and multiprocessing uses an operating-system-level pipe and pickling to send the finalist widgets back to the main process when the workers are done.

Source of the problem

I want to catch the redundant widget warnings being caused by widgets’ re-instantiation after the unpickling that happens when the widgets come out of the inter-process pipe. When widget objects instantiate, they validate their own data, emitting warnings from the Python standard warnings module to tell the app’s user that the widget suspects there is a problem with the user’s input data. Because unpickling causes objects to instantiate, my understanding of the code implies that each widget object is reinstantiated exactly once if and only if it is a finalist after it comes out of the pipe — see the next section to see why this isn’t correct.

The widgets were already created before being frobnicated, so the user is already painfully aware of what input he got wrong and doesn’t want to hear about it again. These are the warnings I’d like to catch with the warnings module’s catch_warnings() context manager (i.e., a with statement).

Failed solutions

In my tests I’ve narrowed down when the superfluous warnings are being emitted to anywhere between what I’ve labeled below as Line A and Line B. What surprises me is that the warnings are being emitted in places other than just near output_queue.get(). This implies to me that multiprocessing sends the widgets to the workers using pickling.

The upshot is that putting a context manager created by warnings.catch_warnings() even around everything from Line A to Line B and setting the right warnings filter inside this context does not catch the warnings. This implies to me that the warnings are being emitted in the worker processes. Putting this context manager around the worker code does not catch the warnings either.

The code

This example omits the code for deciding if the problem size is too small to bother with forking processes, importing multiprocessing, and defining my_frobnal_counter, and my_load_balancer.

"Call `frobnicate(list_of_widgets)` to get the widget with the most frobnals"

def frobnicate_parallel_worker(widgets, output_queue):
    resultant_widget = max(widgets, key=my_frobnal_counter)
    output_queue.put(resultant_widget)

def frobnicate_parallel(widgets):
    output_queue = multiprocessing.Queue()
    # partitions: Generator yielding tuples of sets
    partitions = my_load_balancer(widgets)
    processes = []
    # Line A: Possible start of where the warnings are coming from.
    for partition in partitions:
        p = multiprocessing.Process(
                 target=frobnicate_parallel_worker,
                 args=(partition, output_queue))
        processes.append(p)
        p.start()
    finalists = []
    for p in processes:
        finalists.append(output_queue.get())
    # Avoid deadlocks in Unix by draining queue before joining processes
    for p in processes:
        p.join()
    # Line B: Warnings no longer possible after here.
    return max(finalists, key=my_frobnal_counter)
Asked By: wkschwartz

||

Answers:

you can try to override the Process.run method to use warnings.catch_warnings.

>>> from multiprocessing import Process
>>> 
>>> def yell(text):
...    import warnings
...    print 'about to yell %s' % text
...    warnings.warn(text)
... 
>>> class CustomProcess(Process):
...    def run(self, *args, **kwargs):
...       import warnings
...       with warnings.catch_warnings():
...          warnings.simplefilter("ignore")
...          return Process.run(self, *args, **kwargs)
... 
>>> if __name__ == '__main__':
...    quiet = CustomProcess(target=yell, args=('...not!',))
...    quiet.start()
...    quiet.join()
...    noisy = Process(target=yell, args=('AAAAAAaaa!',))
...    noisy.start()
...    noisy.join()
... 
about to yell ...not!
about to yell AAAAAAaaa!
__main__:4: UserWarning: AAAAAAaaa!
>>> 

or you can use some of the internals… (__warningregistry__)

>>> from multiprocessing import Process
>>> import exceptions
>>> def yell(text):
...    import warnings
...    print 'about to yell %s' % text
...    warnings.warn(text)
...    # not filtered
...    warnings.warn('complimentary second warning.')
... 
>>> WARNING_TEXT = 'AAAAaaaaa!'
>>> WARNING_TYPE = exceptions.UserWarning
>>> WARNING_LINE = 4
>>> 
>>> class SelectiveProcess(Process):
...    def run(self, *args, **kwargs):
...       registry = globals().setdefault('__warningregistry__', {})
...       registry[(WARNING_TEXT, WARNING_TYPE, WARNING_LINE)] = True
...       return Process.run(self, *args, **kwargs)
... 
>>> if __name__ == '__main__':
...    p = SelectiveProcess(target=yell, args=(WARNING_TEXT,))
...    p.start()
...    p.join()
... 
about to yell AAAAaaaaa!
__main__:6: UserWarning: complimentary second warning.
>>> 
Answered By: dnozay

The unpickling would not cause the __init__ to be executed twice. I ran the following code on Windows, and it doesn’t happen (each __init__ is run precisely once).

Therefore, you need to provide us with the code from my_load_balancer and from widgets’ class. At this point, your question simply doesn’t provide enough information.

As a random guess, you might check whether my_load_balancer makes copies of widgets, causing them to be instantiated once again.

import multiprocessing
import collections

"Call `frobnicate(list_of_widgets)` to get the widget with the most frobnals"

def my_load_balancer(widgets):
    partitions = tuple(set() for _ in range(8))
    for i, widget in enumerate(widgets):
        partitions[i % 8].add(widget)
    for partition in partitions:
        yield partition

def my_frobnal_counter(widget):
    return widget.id

def frobnicate_parallel_worker(widgets, output_queue):
    resultant_widget = max(widgets, key=my_frobnal_counter)
    output_queue.put(resultant_widget)

def frobnicate_parallel(widgets):
    output_queue = multiprocessing.Queue()
    # partitions: Generator yielding tuples of sets
    partitions = my_load_balancer(widgets)
    processes = []
    # Line A: Possible start of where the warnings are coming from.
    for partition in partitions:
        p = multiprocessing.Process(
                 target=frobnicate_parallel_worker,
                 args=(partition, output_queue))
        processes.append(p)
        p.start()
    finalists = []
    for p in processes:
        finalists.append(output_queue.get())
    # Avoid deadlocks in Unix by draining queue before joining processes
    for p in processes:
        p.join()
    # Line B: Warnings no longer possible after here.
    return max(finalists, key=my_frobnal_counter)

class Widget:
    id = 0
    def __init__(self):
        print('initializing Widget {}'.format(self.id))
        self.id = Widget.id
        Widget.id += 1

    def __str__(self):
        return str(self.id)

    def __repr__(self):
        return str(self)

def main():

    widgets = [Widget() for _ in range(16)]
    result = frobnicate_parallel(widgets)
    print(result.id)


if __name__ == '__main__':
    main()
Answered By: max

Years later, I finally have a solution (found while working on an unrelated problem). I’ve tested this on Python 3.7, 3.8, and 3.9.

Temporarily patch sys.warnoptions with the empty list []. You only need to do this around the call to process.start(). sys.warnoptions is documented as an implementation detail that you shouldn’t manually modify; the official recommendations are to use functions in the warnings module and to set PYTHONWARNINGS in os.environ. This doesn’t work. The only thing that seems to work is patching sys.warnoptions. In a test, you can do the following:

import multiprocessing
from unittest.mock import patch
p = multiprocessing.Process(target=my_function)
with patch('sys.warnoptions', []):
    p.start()
p.join()

If you don’t want to use unittest.mock, just patch by hand:

import multiprocessing
import sys
p = multiprocessing.Process(target=my_function)
old_warnoptions = sys.warnoptions
try:
    sys.warnoptions = []
    p.start()
finally:
    sys.warnoptions = old_warnoptions
p.join()
Answered By: wkschwartz