Submit New Event

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.

Submit News Feature

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.

Contribute a Blog

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.

Sign up for Newsletter

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.
Jan 17, 2017

Distributed NumPy on a Cluster with Dask Arrays

By

This work is supported by Continuum Analyticsthe XDATA Programand the Data Driven Discovery Initiative from the MooreFoundation

This page includes embedded large profiles. It may look better on the actualsite rather thanthrough syndicated pages like planet.python and it may take a while to load onnon-broadband connections (total size is around 20MB)

Summary

We analyze a stack of images in parallel with NumPy arrays distributed across acluster of machines on Amazon’s EC2 with Dask array. This is a modelapplication shared among many image analysis groups ranging from satelliteimagery to bio-medical applications. We go through a series of commonoperations:

  1. Inspect a sample of images locally with Scikit Image
  2. Construct a distributed Dask.array around all of our images
  3. Process and re-center images with Numba
  4. Transpose data to get a time-series for every pixel, compute FFTs

This last step is quite fun. Even if you skim through the rest of this articleI recommend checking out the last section.

Inspect Dataset

I asked a colleague at the US National Institutes for Health (NIH) for abiggish imaging dataset. He came back with the following message:

Electron microscopy may be generating the biggest ndarray datasets in the field - terabytes regularly. Neuroscience needs EM to see connections between neurons, because the critical features of neural synapses (connections) are below the diffraction limit of light microscopes. This type of research has been called “connectomics”. Many groups are looking at machine vision approaches to follow small neuron parts from one slice to the next.

This data is from drosophila: http://emdata.janelia.org/. Here is an example 2d slice of the data http://emdata.janelia.org/api/node/bf1/grayscale/raw/xy/2000_2000/1800_2300_5000.

import skimage.io
import matplotlib.pyplot as plt

sample = skimage.io.imread('http://emdata.janelia.org/api/node/bf1/grayscale/raw/xy/2000_2000/1800_2300_5000'
skimage.io.imshow(sample)

Sample electron microscopy image from stack

The last number in the URL is an index into a large stack of about 10000 images. We can change that number to get different slices through our 3D dataset.

samples = [skimage.io.imread('http://emdata.janelia.org/api/node/bf1/grayscale/raw/xy/2000_2000/1800_2300_%d' % i)
for i in [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000]]

fig, axarr = plt.subplots(1, 9, sharex=True, sharey=True, figsize=(24, 2.5))
for i, sample in enumerate(samples):
axarr[i].imshow(sample, cmap='gray')

Sample electron microscopy images over time

We see that our field of interest wanders across the frame over time and dropsoff in the beginning and at the end.

Create a Distributed Array

Even though our data is spread across many files, we still want to think of itas a single logical 3D array. We know how to get any particular 2D slice ofthat array using Scikit-image. Now we’re going to use Dask.array to stitchall of those Scikit-image calls into a single distributed array.

import dask.array as da
from dask import delayed

imread = delayed(skimage.io.imread, pure=True) # Lazy version of imread

urls = ['http://emdata.janelia.org/api/node/bf1/grayscale/raw/xy/2000_2000/1800_2300_%d' % i
for i in range(10000)] # A list of our URLs

lazy_values = [imread(url) for url in urls] # Lazily evaluate imread on each url

arrays = [da.from_delayed(lazy_value, # Construct a small Dask array
dtype=sample.dtype, # for every lazy value
shape=sample.shape)
for lazy_value in lazy_values]

stack = da.stack(arrays, axis=0) # Stack all small Dask arrays into one

>>> stack
dask.array<shape=(10000, 2000, 2000), dtype=uint8, chunksize=(1, 2000, 2000)>

>>> stack = stack.rechunk((20, 2000, 2000)) # combine chunks to reduce overhead
>>> stack
dask.array<shape=(10000, 2000, 2000), dtype=uint8, chunksize=(20, 2000, 2000)>

So here we’ve constructed a lazy Dask.array from 10 000 delayed calls toskimage.io.imread. We haven’t done any actual work yet, we’ve justconstructed a parallel array that knows how to get any particular slice of databy downloading the right image if necessary. This gives us a full NumPy-likeabstraction on top of all of these remote images. For example we can nowdownload a particular image just by slicing our Dask array.

>>> stack[5000, :, :].compute()
array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8)

>>> stack[5000, :, :].mean().compute()
11.49902425

However we probably don’t want to operate too much further without connectingto a cluster. That way we can just download all of the images once intodistributed RAM and start doing some real computations. I happen to have tenm4.2xlarges on Amazon’s EC2 (8 cores, 30GB RAM each) running Dask workers.So we’ll connect to those.

from dask.distributed import Client, progress
client = Client('schdeduler-address:8786')

>>> client
<Client: scheduler="scheduler-address:8786" processes=10 cores=80>

I’ve replaced the actual address of my scheduler (something like54.183.180.153 with `scheduler-address. Let’s go ahead and bring in all ofour images, persisting the array into concrete data in memory.

stack = client.persist(stack)

This starts downloads of our 10 000 images across our 10 workers. When thiscompletes we have 10 000 NumPy arrays spread around on our cluster, coordinatedby our single logical Dask array. This takes a while, about five minutes.We’re mostly network bound here (Janelia’s servers are not co-located with ourcompute nodes). Here is a parallel profile of the computation as aninteractive Bokeh plot.

There will be a few of these profile plots throughout the blogpost, so youmight want to familiarize yoursel with them now. Every horizontal rectangle inthis plot corresponds to a single Python function running somewhere in ourcluster over time. Because we called skimage.io.imread 10 000 times thereare 10 000 purple rectangles. Their position along the y-axis denotes which ofthe 80 cores in our cluster that they ran on and their position along thex-axis denotes their start and stop times. You can hover over each rectangle(function) for more information on what kind of task it was, how long it took,etc.. In the image below, purple rectangles are skimage.io.imread calls andred rectangles are data transfer between workers in our cluster. Click themagnifying glass icons in the upper right of the image to enable zooming tools.

Now that we have persisted our Dask array in memory our data is based onhundreds of concrete in-memory NumPy arrays across the cluster, rather thanbased on hundreds of lazy scikit-image calls. Now we can do all sorts of fundistributed array computations more quickly.

For example we can easily see our field of interest move across the frame byaveraging across time:

skimage.io.imshow(stack.mean(axis=0).compute())

Avergage image over time

Or we can see when the field of interest is actually present within the frameby averaging across x and y

plt.plot(stack.mean(axis=[1, 2]).compute())

Image brightness over time

By looking at the profile plots for each case we can see that averaging overtime involves much more inter-node communication, which can be quite expensivein this case.

Recenter Images with Numba

In order to remove the spatial offset across time we’re going to compute acentroid for each slice and then crop the image around that center. I lookedup centroids in the Scikit-Image docs and came across a function that did waymore than what I was looking for, so I just quickly coded up a solution in PurePython and then JIT-ed it with Numba (which makesthis run at C-speeds).

from numba import jit

@jit(nogil=True)
def centroid(im):
n, m = im.shape
total_x = 0
total_y = 0
total = 0
for i in range(n):
for j in range(m):
total += im[i, j]
total_x += i * im[i, j]
total_y += j * im[i, j]

if total > 0:
total_x /= total
total_y /= total
return total_x, total_y

>>> centroid(sample) # this takes around 9ms
(748.7325324581344, 802.4893005160851)

def recenter(im):
x, y = centroid(im.squeeze())
x, y = int(x), int(y)
if x < 500:
x = 500
if y < 500:
y = 500
if x > 1500:
x = 1500
if y > 1500:
y = 1500

return im[..., x-500:x+500, y-500:y+500]

plt.figure(figsize=(8, 8))
skimage.io.imshow(recenter(sample))

Recentered image

Now we map this function across our distributed array.

import numpy as np
def recenter_block(block):
""" Recenter a short stack of images """
return np.stack([recenter(block[i]) for i in range(block.shape[0])])

recentered = stack.map_blocks(recenter,
chunks=(20, 1000, 1000), # chunk size changes
dtype=a.dtype)
recentered = client.persist(recentered)

This profile provides a good opportunity to talk about a scheduling failure;things went a bit wrong here. Towards the beginning we quickly recenterseveral images (Numba is fast), taking around 300-400ms for each block oftwenty images. However as some workers finish all of their allotted tasks, thescheduler erroneously starts to load balance, moving images from busy workersto idle workers. Unfortunately the network at this time appeared to be muchslower than expected and so the move + compute elsewhere strategy ended upbeing much slower than just letting the busy workers finish their work. Thescheduler keeps track of expected compute times and transfer times precisely toavoid mistakes like this one. These sorts of issues are rare, but do occur onoccasion.

We check our work by averaging our re-centered images across time and displayingthat to the screen. We see that our images are better centered with each otheras expected.

skimage.io.imshow(recentered.mean(axis=0))

Recentered time average

This shows how easy it is to create fast in-memory code with Numba and thenscale it out with Dask.array. The two projects complement each other nicely,giving us near-optimal performance with intuitive code across a cluster.

Rechunk to Time Series by Pixel

We’re now going to rearrange our data from being partitioned by time slice, tobeing partitioned by pixel. This will allow us to run computations like FastFourier Transforms (FFTs) on each time series efficiently. Switching the chunkpattern back and forth like this is generally a very difficult operation fordistributed arrays because every slice of the array contributes to everytime-series. We have N-squared communication.

This analysis may not be appropriate for this data (we won’t learn any usefulscience from doing this), but it represents a very frequently asked question,so I wanted to include it.

Currently our Dask array has chunkshape (20, 1000, 1000), meaning that our datais collected into 500 NumPy arrays across the cluster, each of size (20, 1000, 1000).

>>> recentered
dask.array<shape=(10000, 1000, 1000), dtype=uint8, chunksize=(20, 1000, 1000)>

But we want to change this shape so that the chunks cover the entire firstaxis. We want all data for any particular pixel to be in the same NumPy array,not spread across hundreds of different NumPy arrays. We could solve this byrechunking so that each pixel is its own block like the following:

>>> rechunked = recentered.rechunk((10000, 1, 1))

However this would result in one million chunks (there are one million pixels)which will result in a bit of scheduling overhead. Instead we’ll collect ourtime-series into 10 x 10 groups of one hundred pixels. This will help us toreduce overhead.

>>> # rechunked = recentered.rechunk((10000, 1, 1)) # Too many chunks
>>> rechunked = recentered.rechunk((10000, 10, 10)) # Use larger chunks

Now we compute the FFT of each pixel, take the absolute value and square toget the power spectrum. Finally to conserve space we’ll down-grade the dtypeto float32 (our original data is only 8-bit anyway).

x = da.fft.fft(rechunked, axis=0)
power = abs(x ** 2).astype('float32')

power = client.persist(power, optimize_graph=False)

This is a fun profile to inspect; it includes both the rechunking and thesubsequent FFTs. We’ve included a real-time trace during execution, the fullprofile, as well as some diagnostics plots from a single worker. These plotstotal up to around 20MB. I sincerely apologize to those without broadbandaccess.

Here is a real time plot of the computation finishing over time:

Dask task stream of rechunk + fft

And here is a single interactive plot of the entire computation after itcompletes. Zoom with the tools in the upper right. Hover over rectangles toget more information. Remember that red is communication.

Screenshots of the diagnostic dashboard of a single worker during thiscomputation.

Worker communications during FFT

Worker communications during FFT

This computation starts with a lot of communication while we rechunk andrealign our data (recent optimizations here by AntoinePitrou in dask #417).Then we transition into doing thousands of small FFTs and other arithmeticoperations. All of the plots above show a nice transition from heavycommunication to heavy processing with some overlap each way (once some complexblocks are available we get to start overlapping communication andcomputation). Inter-worker communication was around 100-300 MB/s (typical forAmazon’s EC2) and CPU load remained high. We’re using our hardware.

Finally we can inspect the results. We see that the power spectrum is veryboring in the corner, and has typical activity towards the center of the image.

plt.semilogy(1 + power[:, 0, 0].compute())

Power spectrum near edge

plt.semilogy(1 + power[:, 500, 500].compute())

Power spectrum at center

Final Thoughts

This blogpost showed a non-trivial image processing workflow, emphasizing thefollowing points:

  1. Construct a Dask array from lazy SKImage calls.
  2. Use NumPy syntax with Dask.array to aggregate distributed data across acluster.
  3. Build a centroid function with Numba. Use Numba and Dask together toclean up an image stack.
  4. Rechunk to facilitate time-series operations. Perform FFTs.

Hopefully this example has components that look similar to what you want to dowith your data on your hardware. We would love to see more applications likethis out there in the wild.

What we could have done better

As always with all computationally focused blogposts we’ll include a section onwhat went wrong and what we could have done better with more time.

  1. Communication is too expensive: Interworker communications that shouldbe taking 200ms are taking up to 10 or 20 seconds. We need to take acloser look at our communications pipeline (which normally performs justfine on other computations) to see if something is acting up.Disucssion here dask/distributed #776and early work here dask/distributed #810.
  2. Faulty Load balancing: We discovered a case where our load-balancingheuristics misbehaved, incorrectly moving data between workers when itwould have been better to let everything alone. This is likely due to theoddly low bandwidth issues observed above.
  3. Loading from disk blocks network I/O: While doing this we discovered anissue where loading large amounts of data from disk can block workers fromresponding to network requests (dask/distributed#774)
  4. Larger datasets: It would be fun to try this on a much larger datasetto see how the solutions here scale.