Get thread index from a multiprocessing.pool.ThreadPool in Python

Question:

I’m using a library whose multiprocessing is implemented using multiprocessing.pool.ThreadPool(processes).

How is it possible to get/compute the thread index from the pool (starting from 0 to processes-1) ?

I’ve been reading the documentation and searching the web without finding convincing solution. I can get the thread ID (using threading.get_ident()), and could go through all the threads to construct a mapping between their index and their ID but I would need to use some kind of time.sleep() to ensure I browse them all… Do you think of any better solution?

Asked By: Jav

||

Answers:

The idea is to create a worker function, called test_worker in the example below, that returns its thread identity and the argument it is called with, which takes on values of 0 … pool size – 1. We then submit tasks with:

pool.map(test_worker, range(POOLSIZE), 1)

By specifying a chunksize value of 1, the idea is that each thread will be given just 1 task to process with the first thread given argument 0, the second thread argument 1, etc. We must ensure that test_worker gives up control of the processor to the other threads in the pool. If it were to consist only of a return statement, the first thread might end up processing all the tasks. Essentially tasks are placed on a single queue in lists of chunksize tasks and each pool thread takes off the next available list and processes the tasks in the list, But if the task is so trivial, it is possible that the first thread could actually grab all the lists because it never gives up control of the processor to the other threads. To avoid this, we insert a call to time.sleep in our worker.

from multiprocessing.pool import ThreadPool
import threading


def test_worker(i):
    # To ensure that the worker gives up control of the processor we sleep.
    # Otherwise, the same thread may be given all the tasks to process.
    time.sleep(.1)
    return threading.get_ident(), i

def real_worker(x):
    # return the argument squared and the id of the thread that did the work
    return x**2, threading.get_ident()

POOLSIZE = 5
with ThreadPool(POOLSIZE) as pool:
    # chunksize = 1 is critical to be sure that we have 1 task per thread:
    thread_dict = {result[0]: result[1]
                   for result in pool.map(test_worker, range(POOLSIZE), 1)}
    assert(len(thread_dict) == POOLSIZE)
    print(thread_dict)
    value, id = pool.apply(real_worker, (7,))
    print(value) # should be 49
    assert (id in thread_dict)
    print('thread index = ', thread_dict[id])

Prints:

{16880: 0, 16152: 1, 7400: 2, 13320: 3, 168: 4}
49
thread index =  4

A Version That Does Not Use sleep

from multiprocessing.pool import ThreadPool
import threading
import time

def test_worker(i, event):
    if event:
        event.wait()
    return threading.get_ident(), i

def real_worker(x):
    return x**2, threading.get_ident()


# Let's use a really big pool size for a good test:
POOLSIZE = 500
events = [threading.Event() for _ in range(POOLSIZE - 1)]
with ThreadPool(POOLSIZE) as pool:
    thread_dict = {}
    # These first POOLSIZE - 1 tasks will wait until we set their events
    results = [pool.apply_async(test_worker, args=(i, event)) for i, event in enumerate(events)]
    # This last one is not passed an event and so it does not wait.
    # When it completes, we can be sure the other tasks, which have been submitted before it
    # have already been picked up by the other threads in the pool.
    id, index = pool.apply(test_worker, args=(POOLSIZE - 1, None))
    thread_dict[id] = index
    # let the others complete:
    for event in events:
        event.set()
    for result in results:
        id, index = result.get()
        thread_dict[id] = index
    assert(len(thread_dict) == POOLSIZE)
    value, id = pool.apply(real_worker, (7,))
    print(value) # should be 49
    assert (id in thread_dict)
    print('thread index = ', thread_dict[id])

Prints:

49
thread index =  499
Answered By: Booboo

It is possible to get the index of the thread in the ThreadPool without using sleep, by using the initializer function. This is a function that is called once immediately after the thread is started. It can be used to acquire resources, such as a database connection, to use exactly one connection per thread.

Use threading.local() to make sure that each thread can store and access its own resource. In the example below we treat the index in the ThreadPool as a resource. Use a Queue to make sure no two threads grab the same resource.

from multiprocessing.pool import ThreadPool
import threading
import time
import queue

POOL_SIZE = 4
local_storage = threading.local()

def init_thread_resource(resources):
    local_storage.pool_idx = resources.get(False)
    print(f'nThread {threading.get_ident()} has pool_idx {local_storage.pool_idx}')
    ## A thread can also initialize other things here, meant for only 1 thread, e.g.
    # local_storage.db_connection = open_db_connection()

def task(item):
    # When running this example you may see all the tasks are picked up by one thread. 
    # Uncomment time.sleep below to see each of the threads do some work.
    # This is not required to assign a unique index to each thread.
    # time.sleep(1)
    return f'Thread {threading.get_ident()} with pool_idx {local_storage.pool_idx} handled {item}'

def run_concurrently():
    # Initialize the resources
    resources = queue.Queue(POOL_SIZE)  # one resource per thread
    for pool_idx in range(POOL_SIZE):
        resources.put(pool_idx, False)
    container = range(500, 500 + POOL_SIZE)  # Offset by 500 to not confuse the items with the pool_idx
    with ThreadPool(POOL_SIZE, init_thread_resource, [resources]) as pool:
        records = pool.map(task, container)
    print('n'.join(records))

run_concurrently()

This outputs:

Thread 32904 with pool_idx 0 handled 500
Thread 14532 with pool_idx 1 handled 501
Thread 32008 with pool_idx 2 handled 502
Thread 31552 with pool_idx 3 handled 503
Answered By: MathKid