We present an experiment on how to pass data from a loosely coupled parallelcomputing system like Dask to a tightly coupled parallel computing system likeMPI.
We give motivation and a complete digestible example.
Here is a gist of the code and results.
Disclaimer: Nothing in this post is polished or production ready. This is anexperiment designed to start conversation. No long-term support is offered.
We often get the following question:
How do I use Dask to pre-process my data,but then pass those results to a traditional MPI application?
You might want to do this because you’re supporting legacy code writtenin MPI, or because your computation requires tightly coupled parallelism of thesort that only MPI can deliver.
The simplest thing to do of course is to write your Dask results to disk andthen load them back from disk with MPI. Given the relative cost of yourcomputation to data loading, this might be a great choice.
For the rest of this blogpost we’re going to assume that it’s not.
We have a trivial MPI library written in MPI4Pywhere each rank just prints out all the data that it was given. In principlethough it could call into C++ code, and do arbitrary MPI things.
# my_mpi_lib.py
from mpi4py import MPI
comm = MPI.COMM_WORLD
def print_data_and_rank(chunks: list):
""" Fake function that mocks out how an MPI function should operate
- It takes in a list of chunks of data that are present on this machine
- It does whatever it wants to with this data and MPI
Here for simplicity we just print the data and print the rank
- Maybe it returns something
"""
rank = comm.Get_rank()
for chunk in chunks:
print("on rank:", rank)
print(chunk)
return sum(chunk.sum() for chunk in chunks)
In our dask program we’re going to use Dask normally to load in data, do somepreprocessing, and then hand off all of that data to each MPI rank, which willcall the print_data_and_rank function above to initialize the MPIcomputation.
# my_dask_script.py
# Set up Dask workers from within an MPI job using the dask_mpi project
# See https://dask-mpi.readthedocs.io/en/latest/
from dask_mpi import initialize
initialize()
from dask.distributed import Client, wait, futures_of
client = Client()
# Use Dask Array to "load" data (actually just create random data here)
import dask.array as da
x = da.random.random(100000000, chunks=(1000000,))
x = x.persist()
wait(x)
# Find out where data is on each worker
# TODO: This could be improved on the Dask side to reduce boiler plate
from toolz import first
from collections import defaultdict
key_to_part_dict = {str(part.key): part for part in futures_of(x)}
who_has = client.who_has(x)
worker_map = defaultdict(list)
for key, workers in who_has.items():
worker_map[first(workers)].append(key_to_part_dict[key])
# Call an MPI-enabled function on the list of data present on each worker
from my_mpi_lib import print_data_and_rank
futures = [client.submit(print_data_and_rank, list_of_parts, workers=worker)
for worker, list_of_parts in worker_map.items()]
wait(futures)
client.close()
Then we can call this mix of Dask and an MPI program using normal mpirun ormpiexec commands.
mpirun -np 5 python my_dask_script.py
So MPI started up and ran our script.The dask-mpi project set a Daskscheduler on rank 0, runs our client code on rank 1, and then runs a bunch of workers on ranks 2+.
Our script then created a Dask array, though presumably here it would read indata from some source, do more complex Dask manipulations before continuing on.
We then wait until all of the Dask work has finished and is in a quiet state.We then query the state in the scheduler to find out where all of that datalives. That’s this code here:
# Find out where data is on each worker
# TODO: This could be improved on the Dask side to reduce boiler plate
from toolz import first
from collections import defaultdict
key_to_part_dict = {str(part.key): part for part in futures_of(x)}
who_has = client.who_has(x)
worker_map = defaultdict(list)
for key, workers in who_has.items():
worker_map[first(workers)].append(key_to_part_dict[key])
Admittedly, this code is gross, and not particularly friendly or obvious tonon-Dask experts (or even Dask experts themselves, I had to steal this from theDask XGBoost project, which doesthe same trick).
But after that we just call our MPI library’s initialize function,print_data_and_rank on all of our data using Dask’sFutures interface.That function gets the data directly from local memory (the Dask workers andMPI ranks are in the same process), and does whatever the MPI applicationwants.
This could be improved in a few ways: