Why does unpickling several files in parallel using a Pool object "freeze"?

Question:

I was trying to unpickle several files in parallel; however, the process just "hangs."

First, I create 6 small pickle files:

import pickle
from multiprocessing import Pool

class A:
    x: int
    y: str
    
    def __init__(self, x:int, y:str):
        self.x = x
        self.y = y

for i in range(6):
    with open(str(i) + '.pkl', 'wb') as f:
        pickle.dump(A(i, str(i) + 'abcdefg'), f)

Next, I try to load them in parallel:

def load(file):
    with open(file, 'rb') as f:
        data = pickle.load(f)

with Pool(6) as p:
    p.map(load, [str(number) + '.pkl' for number in range(6)])

When I launch this from Spyder, all 8 logical processors (4 "actual" cores) show close to 100% utilization in Windows Task Manager; however, the process does not terminate. If I try to stop the process from Spyder, it fails to do so; the only way to stop it is to close Spyder entirely.

However, if I do this in a normal for loop, it takes a fraction of a second:

for file in [str(number) + '.pkl' for number in range(6)]:
    load(file)

Can someone explain why this is? Also, is there some way to fix the original code such that it actually works? (The files I’m actually trying to load and process are much larger, which is why I would like the parallelism; the above code is just an example to help people reproduce the issue).

I have read several Q&As related to loading with pool, such as this Q&A; however, none of them helped.

Answers:

When using multiprocessing you have to include the main check see this

So something like this will work for you:

import pickle
from multiprocessing import Pool

class A:    
    def __init__(self, x:int, y:str):
        self.x = x
        self.y = y
        
def make():
    for i in range(6):
        with open(str(i) + '.pkl', 'wb') as f:
            pickle.dump(A(i, str(i) + 'abcdefg'), f)

def load(file):
    with open(file, 'rb') as f:
        data = pickle.load(f)

def main(): 
    with Pool() as p:
        p.map(load, [str(number) + '.pkl' for number in range(6)])

if __name__ == '__main__':
    make()
    main()
Answered By: OldBill
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.