limit number of CPUs used by dask compute

Question:

Below code uses appx 1 sec to execute on an 8-CPU system. How to manually configure number of CPUs used by dask.compute eg to 4 CPUs so the below code will use appx 2 sec to execute even on an 8-CPU system?

import dask
from time import sleep

def f(x):
    sleep(1)
    return x**2

objs = [dask.delayed(f)(x) for x in range(8)]
print(dask.compute(*objs))  # (0, 1, 4, 9, 16, 25, 36, 49)
Asked By: Russell Burdt

||

Answers:

There are a few options:

  1. specify number of workers at the time of cluster creation
from dask.distributed import Client

# without specifying unique thread, the function is executed
# on all threads
client = Client(n_workers=4, threads_per_worker=1)

# the rest of your code is not changed
  1. specify how many (and which) workers should execute a task

client = Client(n_workers=8, threads_per_worker=1)

list_workers = list(client.scheduler_info()['workers'])

client.compute(objs, workers=list_workers[:4]) 

# submit only to the first 4 workers
# note that workers should still be single-threaded, but the difference
# from option 1 is that you could in principle have more workers
# that are idle, also the `workers` kwarg can be passed to
# dask.compute rather than client.compute
  1. specify a semaphore
from dask.distributed import Client, Semaphore

client = Client()
sem = Semaphore(max_leases=4, name="foo")

def fmodified(x, sem):
    with sem:
        return f(x)

objs = [dask.delayed(fmodified)(x, sem) for x in range(8)]
print(dask.compute(*objs))  # (0, 1, 4, 9, 16, 25, 36, 49)

Update: as noted by @mdurant in the comments, if you are running this in a script, then if __name__ == "main": is needed to guard the relevant code from being executed by workers. For example, the second option from the list above would look like this in a script:

#!/usr/bin/env python3
import dask
from dask.distributed import Client
from time import sleep

def f(x):
    sleep(1)
    return x**2

objs = [dask.delayed(f)(x) for x in range(8)]

if __name__ == "main":
    client = Client(n_workers=8, threads_per_worker=1)

    list_workers = list(client.scheduler_info()['workers'])

    results = client.compute(objs, workers=list_workers[:4])

    print(results)
Answered By: SultanOrazbayev