How to pass the index of the iterable when using multiprocessing pool
Question:
I would like to call a function task()
in parallel N
times. The function accepts two arguments, one is an array and the second is an index to write the return result in to the array:
def task(arr, index):
arr[index] = "some result to return"
To be explicit, the reason for the array is so I can process all the parallel tasks once they have completed. I presume this is ok?
I have created a multiprocessing pool and it calls task()
:
def main():
N = 10
arr = np.empty(N)
pool = Pool(os.cpu_count())
pool.map(task, arr)
pool.close()
# Process results in arr
However, the problem is because map()
is already iterable, how do I explicitly pass in the index? Each call to task()
should pass in 0, 1, 2…. N.
Answers:
You can use:
import multiprocessing as mp
import numpy as np
def task(index, arr):
print(index, arr)
if __name__ == '__main__':
N = 10
arr = np.empty(N)
with mp.Pool(mp.cpu_count()) as pool:
pool.starmap(task, enumerate(arr))
Output:
0 6.9180446290108e-310
1 6.9180446290108e-310
2 6.91804453329406e-310
3 6.91804425777776e-310
4 6.9180448957438e-310
5 6.9180105412701e-310
6 6.9180443068017e-310
7 6.91804453327193e-310
9 6.9180449088978e-310
8 6.91804436388567e-310
I would like to call a function task()
in parallel N
times. The function accepts two arguments, one is an array and the second is an index to write the return result in to the array:
def task(arr, index):
arr[index] = "some result to return"
To be explicit, the reason for the array is so I can process all the parallel tasks once they have completed. I presume this is ok?
I have created a multiprocessing pool and it calls task()
:
def main():
N = 10
arr = np.empty(N)
pool = Pool(os.cpu_count())
pool.map(task, arr)
pool.close()
# Process results in arr
However, the problem is because map()
is already iterable, how do I explicitly pass in the index? Each call to task()
should pass in 0, 1, 2…. N.
You can use:
import multiprocessing as mp
import numpy as np
def task(index, arr):
print(index, arr)
if __name__ == '__main__':
N = 10
arr = np.empty(N)
with mp.Pool(mp.cpu_count()) as pool:
pool.starmap(task, enumerate(arr))
Output:
0 6.9180446290108e-310
1 6.9180446290108e-310
2 6.91804453329406e-310
3 6.91804425777776e-310
4 6.9180448957438e-310
5 6.9180105412701e-310
6 6.9180443068017e-310
7 6.91804453327193e-310
9 6.9180449088978e-310
8 6.91804436388567e-310