Determine how many times dask computed something

Question:

Question

I’m wondering if it is possible with dask (specifically dask arrays) to know if and when something has been computed. I’m thinking of unit tests wanting to know how many times dask computed an array. Similar to mock objects knowing how many times they were called. Does something like this exist already? If not, is there a better way than making a custom callback? If this doesn’t exist, is it something the dask core devs would be interested in adding to core dask for testing?

Details

Say I have a function which takes in an xarray DataArray, does some stuff to it, and returns it. There are some cases where dask arrays are implicitly converted to numpy arrays including a new dask-user not knowing the best dask-friendly way to do something. I would like to write my unit tests to make sure that I or another contributor doesn’t accidentally hurt the performance of a function. This is especially important considering test data is often a simplified/small version of real world cases and the performance hit of computing a dask array multiple times may not be seen in these cases.

Asked By: djhoese

||

Answers:

There are a variety of ways to trigger on execution.

One would be to specify a custom scheduler:

def my_scheduler(dsk, keys, **kwargs):
    print('computing!')
    return dask.get(dsk, keys, **kwargs)

with dask.config.set(scheduler=my_scheduler):
    ...

Custom callbacks, like what you suggest are also pretty easy to implement.

If you’re using dask array exclusively then you could look at array plugins

There are a variety of other approaches used in the test suite.

Answered By: MRocklin

Here is what I ended up doing as a simple solution based on MRocklin’s answer.

class CustomScheduler(object):
    def __init__(self, max_computes=1):
        self.max_computes = max_computes
        self.total_computes = 0

    def __call__(self, dsk, keys, **kwargs):
        self.total_computes += 1
        if self.total_computes > self.max_computes:
            raise RuntimeError("Too many dask computations were scheduled: {}".format(self.total_computes))
        return dask.get(dsk, keys, **kwargs)

I then use it like this:

with dask.config.set(scheduler=CustomScheduler(0)):
    # dask array stuff

This answer was posted as an edit to the question Determine how many times dask computed something by the OP djhoese under CC BY-SA 4.0.

Answered By: vvvvv
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.