Multiprocessing Pool hangs if child process killed

Question:

I launched a pool of worker processes and submitted a bunch of tasks. The system ran low on memory and the oomkiller killed one of the worker processes. The parent process just hung there waiting for the tasks to finish and never returned.

Here’s a runnable example that reproduces the problem. Instead of waiting for oomkiller to kill one of the worker processes, I find the process ids of all the worker processes and tell the first task to kill that process. (The call to ps won’t work on all operating systems.)

import os
import signal
from multiprocessing import Pool
from random import choice
from subprocess import run, PIPE
from time import sleep


def run_task(task):
    target_process_id, n = task
    print(f'Processing item {n} in process {os.getpid()}.')
    delay = n + 1
    sleep(delay)
    if n == 0:
        print(f'Item {n} killing process {target_process_id}.')
        os.kill(target_process_id, signal.SIGKILL)
    else:
        print(f'Item {n} finished.')
    return n, delay


def main():
    print('Starting.')
    pool = Pool()

    ps_output = run(['ps', '-opid', '--no-headers', '--ppid', str(os.getpid())],
                    stdout=PIPE, encoding='utf8')
    child_process_ids = [int(line) for line in ps_output.stdout.splitlines()]
    target_process_id = choice(child_process_ids[1:-1])

    tasks = ((target_process_id, i) for i in range(10))
    for n, delay in pool.imap_unordered(run_task, tasks):
        print(f'Received {delay} from item {n}.')

    print('Closing.')
    pool.close()
    pool.join()
    print('Done.')


if __name__ == '__main__':
    main()

When I run that on a system with eight CPU’s, I see this output:

Starting.
Processing item 0 in process 303.
Processing item 1 in process 304.
Processing item 2 in process 305.
Processing item 3 in process 306.
Processing item 4 in process 307.
Processing item 5 in process 308.
Processing item 6 in process 309.
Processing item 7 in process 310.
Item 0 killing process 308.
Processing item 8 in process 303.
Received 1 from item 0.
Processing item 9 in process 315.
Item 1 finished.
Received 2 from item 1.
Item 2 finished.
Received 3 from item 2.
Item 3 finished.
Received 4 from item 3.
Item 4 finished.
Received 5 from item 4.
Item 6 finished.
Received 7 from item 6.
Item 7 finished.
Received 8 from item 7.
Item 8 finished.
Received 9 from item 8.
Item 9 finished.
Received 10 from item 9.

You can see that item 5 never returns, and the pool just waits forever.

How can I get the parent process to notice when a child process is killed?

Asked By: Don Kirkby

||

Answers:

This problem is described in Python bug 9205, but they decided to fix it in the concurrent.futures module instead of in the multiprocessing module. In order to take advantage of the fix, switch to the newer process pool.

import os
import signal
from concurrent.futures.process import ProcessPoolExecutor
from random import choice
from subprocess import run, PIPE
from time import sleep


def run_task(task):
    target_process_id, n = task
    print(f'Processing item {n} in process {os.getpid()}.')
    delay = n + 1
    sleep(delay)
    if n == 0:
        print(f'Item {n} killing process {target_process_id}.')
        os.kill(target_process_id, signal.SIGKILL)
    else:
        print(f'Item {n} finished.')
    return n, delay


def main():
    print('Starting.')
    pool = ProcessPoolExecutor()

    pool.submit(lambda: None)  # Force the pool to launch some child processes.
    ps_output = run(['ps', '-opid', '--no-headers', '--ppid', str(os.getpid())],
                    stdout=PIPE, encoding='utf8')
    child_process_ids = [int(line) for line in ps_output.stdout.splitlines()]
    target_process_id = choice(child_process_ids[1:-1])

    tasks = ((target_process_id, i) for i in range(10))
    for n, delay in pool.map(run_task, tasks):
        print(f'Received {delay} from item {n}.')

    print('Closing.')
    pool.shutdown()
    print('Done.')


if __name__ == '__main__':
    main()

Now when you run it, you get a clear error message.

Starting.
Processing item 0 in process 549.
Processing item 1 in process 550.
Processing item 2 in process 552.
Processing item 3 in process 551.
Processing item 4 in process 553.
Processing item 5 in process 554.
Processing item 6 in process 555.
Processing item 7 in process 556.
Item 0 killing process 556.
Processing item 8 in process 549.
Received 1 from item 0.
Traceback (most recent call last):
  File "/home/don/.config/JetBrains/PyCharm2020.1/scratches/scratch2.py", line 42, in <module>
    main()
  File "/home/don/.config/JetBrains/PyCharm2020.1/scratches/scratch2.py", line 33, in main
    for n, delay in pool.map(run_task, tasks):
  File "/usr/lib/python3.7/concurrent/futures/process.py", line 483, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.7/concurrent/futures/_base.py", line 598, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.7/concurrent/futures/_base.py", line 428, in result
    return self.__get_result()
  File "/usr/lib/python3.7/concurrent/futures/_base.py", line 384, in __get_result
    raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
Answered By: Don Kirkby

I ran into the same issue, and concurrent.futures were not much better either when it comes to handling the problem. I’ve ended up with the Ray module, here is my example code which retries the killed tasks with a decreasing number of workers. That way the most memory hungry ones have a chance to complete in the worst case on a single worker. Run it carefully, as the OOM killer might kill other processes too:

import ray
import logging
from multiprocessing import cpu_count
import numpy as np
import psutil

# the default max_retries is 3, but in this case there is no point to retry with the same amount of workers
@ray.remote(max_retries=0)
def f(x):
    logging.warning("worker started %s", x)
    allocate = int(psutil.virtual_memory().total / (cpu_count() - 3) / 8)
    logging.warning("worker allocate %s element float array for %s", allocate, x)
    crash = np.ones([allocate])
    # make sure the interpreter won't optimize out the above allocation
    logging.warning("worker print %s for %x", crash[0], x)
    logging.warning("worker finished %s", x)
    return x

def main():
    processes = cpu_count() - 1
    alljobs = range(processes + 1)
    completedjobs = []

    try:
        while alljobs:
            logging.warning("Number of jobs: %s", len(alljobs))
            logging.warning("Number of workers: %s", processes)
            ray.init(num_cpus=processes)
            result_ids = [f.remote(i) for i in alljobs]
            while True:
                try:
                    while len(result_ids):
                        done_id, result_ids = ray.wait(result_ids, num_returns=1)
                        x = ray.get(done_id[0])
                        logging.warning("results from %s", x)
                        completedjobs.append(x)
                except ray.exceptions.WorkerCrashedError:
                    logging.warning("Continue after WorkerCrashedError")
                    continue
                break
            # rerun the killed jobs on fewer workers to relieve memory pressure
            alljobs = list(set(alljobs) - set(completedjobs))
            ray.shutdown()
            if processes > 1:
                processes -= 1
            else:
                break
    except Exception as ex:
        template = "An exception of type {0} occurred. Arguments:n{1!r}"
        message = template.format(type(ex).__name__, ex.args)
        logging.exception(message)
        raise

if __name__ == "__main__":
    main()
Answered By: schaman
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.