Submit New Event

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.

Submit News Feature

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.

Contribute a Blog

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.

Sign up for Newsletter

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.
Feb 11, 2017

Experiment with Dask and TensorFlow


This work is supported by Continuum Analyticsthe XDATA Programand the Data Driven Discovery Initiative from the MooreFoundation


This post briefly describes potential interactions between Dask and TensorFlowand then goes through a concrete example using them together for distributedtraining with a moderately complex architecture.

This post was written in haste and the attached experiment is of low quality,see disclaimers below. A similar and much better example with XGBoost isincluded in the comments at the end.


Dask and TensorFlow both provide distributed computing in Python.TensorFlow excels at deep learning applications while Dask is more generic.We can combine both together in a few applications:

  1. Simple data parallelism: hyper-parameter searches during trainingand predicting already-trained models against large datasets are bothtrivial to distribute with Dask as they would be trivial to distribute withany distributed computing system (Hadoop/Spark/Flink/etc..) We won’tdiscuss this topic much. It should be straightforward.
  2. Deployment: A common pain point with TensorFlow is that setup isn’twell automated. This plagues all distributed systems, especially thosethat are run on a wide variety of cluster managers (see cluster deploymentblogpostfor more information). Fortunately, if you already have a Dask clusterrunning it’s trivial to stand up a distributed TensorFlow network ontop of it running within the same processes.
  3. Pre-processing: We pre-process data with dask.dataframe or dask.array,and then hand that data off to TensorFlow for training. If Dask andTensorFlow are co-located on the same processes then this movement isefficient. Working together we can build efficient and general use deeplearning pipelines.

In this blogpost we look very briefly at the first case of simpleparallelism. Then go into more depth on an experiment that uses Dask andTensorFlow in a more complex situation. We’ll find we can accomplish a fairlysophisticated workflow easily, both due to how sensible TensorFlow is to set upand how flexible Dask can be in advanced situations.

Motivation and Disclaimers

Distributed deep learning is fundamentally changing the way humanity solvessome very hard computing problems like natural language translation,speech-to-text transcription, image recognition, etc.. However, distributeddeep learning also suffers from public excitement, which may distort our imageof its utility. Distributed deep learning is not always the correct choice formost problems. This is for two reasons:

  1. Focusing on single machine computation is often a better use of time.Model design, GPU hardware, etc. can have a more dramatic impact thanscaling out. For newcomers to deep learning, watching online video lectureseries may be abetter use of time than reading this blogpost.
  2. Traditional machine learning techniques like logistic regression, andgradient boosted trees can be more effective than deep learning if you havefinite data. They can also sometimes provide valuable interpretabilityresults.

Regardless, there are some concrete take-aways, even if distributed deeplearning is not relevant to your application:

  1. TensorFlow is straightforward to set up from Python
  2. Dask is sufficiently flexible out of the box to support complex settingsand workflows
  3. We’ll see an example of a typical distributed learning approach thatgeneralizes beyond deep learning.

Additionally the author does not claim expertise in deep learning and wrotethis blogpost in haste.

Simple Parallelism

Most parallel computing is simple. We easily apply one function to lots ofdata, perhaps with slight variation. In the case of deep learning thiscan enable a couple of common workflows:

  1. Build many different models, train each on the same data, choose the bestperforming one. Using dask’s concurrent.futures interface, this lookssomething like the following:
  2. # Hyperparameter search
    client = Client('dask-scheduler-address:8786')
    scores =, hyper_param_list, data=data)
    best = client.submit(max, scores)
  3. Given an already-trained model, use it to predict outcomes on lots of data.Here we use a big data collection like dask.dataframe:
  4. # Distributed prediction

    df = dd.read_parquet('...')
    ... # do some preprocessing here
    df['outcome'] = df.map_partitions(predict)

These techniques are relatively straightforward if you have modest exposure toDask and TensorFlow (or any other machine learning library like scikit-learn),so I’m going to ignore them for now and focus on more complex situations.

Interested readers may find this blogpost onTensorFlow and Sparkof interest. It is a nice writeup that goes over these two techniques in moredetail.

A Distributed TensorFlow Application

We’re going to replicate this TensorFlow examplewhich uses multiple machines to train a model that fits in memory usingparameter servers for coordination. Our TensorFlow network will have threedifferent kinds of servers:

distributed TensorFlow training graph
  1. Workers: which will get updated parameters, consume training data, anduse that data to generate updates to send back to the parameter servers
  2. Parameter Servers: which will hold onto model parameters, synchronizingwith the workers as necessary
  3. Scorer: which will periodically test the current parameters againstvalidation/test data and emit a current cross_entropy score to see how wellthe system is running.

This is a fairly typical approach when the model can fit in one machine, butwhen we want to use multiple machines to accelerate training or because datavolumes are too large.

We’ll use TensorFlow to do all of the actual training and scoring. We’ll useDask to do everything else. In particular, we’re about to do the following:

  1. Prepare data with dask.array
  2. Set up TensorFlow workers as long-running tasks
  3. Feed data from Dask to TensorFlow while scores remain poor
  4. Let TensorFlow handle training using its own network

Prepare Data with Dask.array

For this toy example we’re just going to use the mnist data that comes withTensorFlow. However, we’ll artificially inflate this data by concatenatingit to itself many times across a cluster:

def get_mnist():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/tmp/mnist-data', one_hot=True)
return mnist.train.images, mnist.train.labels

import dask.array as da
from dask import delayed

datasets = [delayed(get_mnist)() for i in range(20)] # 20 versions of same dataset
images = [d[0] for d in datasets]
labels = [d[1] for d in datasets]

images = [da.from_delayed(im, shape=(55000, 784), dtype='float32') for im in images]
labels = [da.from_delayed(la, shape=(55000, 10), dtype='float32') for la in labels]

images = da.concatenate(images, axis=0)
labels = da.concatenate(labels, axis=0)

>>> images
dask.array<concate..., shape=(1100000, 784), dtype=float32, chunksize=(55000, 784)>

images, labels = c.persist([images, labels]) # persist data in memory

This gives us a moderately large distributed array of around a million tinyimages. If we wanted to we could inspect or clean up this data using normaldask.array constructs:

im = images[1].compute().reshape((28, 28))
plt.imshow(im, cmap='gray')

mnist number 3

im = images.mean(axis=0).compute().reshape((28, 28))
plt.imshow(im, cmap='gray')

mnist mean

im = images.var(axis=0).compute().reshape((28, 28))
plt.imshow(im, cmap='gray')

mnist var

This shows off how one can use Dask collections to clean up and providepre-processing and feature generation on data in parallel before sending it toTensorFlow. In our simple case we won’t actually do any of this, but it’suseful in more real-world situations.

Finally, after doing our preprocessing on the distributed array of all of ourdata we’re going to collect images and labels together and batch them intosmaller chunks. Again we use some dask.array constructs anddask.delayed when things getmessy.

images = images.rechunk((10000, 784))
labels = labels.rechunk((10000, 10))

images = images.to_delayed().flatten().tolist()
labels = labels.to_delayed().flatten().tolist()
batches = [delayed([im, la]) for im, la in zip(images, labels)]

batches = c.compute(batches)

Now we have a few hundred pairs of NumPy arrays in distributed memory waitingto be sent to a TensorFlow worker.

Setting up TensorFlow workers alongside Dask workers

Dask workers are just normal Python processes. TensorFlow can launch itselffrom a normal Python process. We’ve made a small functionherethat launches TensorFlow servers alongside Dask workers using Dask’s ability torun long-running tasks and maintain user-defined state. All together, this isabout 80 lines of code (including comments and docstrings) and allows us todefine our TensorFlow network on top of Dask as follows:

pip install git+

from dask.distibuted import Client # we already had this above
client = Client('dask-scheduler-address:8786')

from dask_tensorflow import start_tensorflow
tf_spec, dask_spec = start_tensorflow(client, ps=1, worker=4, scorer=1)

>>> tf_spec.as_dict()
{'ps': [''],
'scorer': [''],
'worker': ['',

>>> dask_spec
{'ps': ['tcp://'],
'scorer': ['tcp://'],
'worker': ['tcp://',

This starts three groups of TensorFlow servers in the Dask worker processes.TensorFlow will manage its own communication but co-exist right alongside Daskin the same machines and in the same shared memory spaces (note that in thespecs above the IP addresses match but the ports differ).

This also sets up a normal Python queue along which Dask can safely sendinformation to TensorFlow. This is how we’ll send those batches of trainingdata between the two services.

Define TensorFlow Model and Distribute Roles

Now is the part of the blogpost where my expertise wanes. I’m just going tocopy-paste-and-modify a canned example from the TensorFlow documentation. Thisis a simplistic model for this problem and it’s entirely possible that I’mmaking transcription errors. But still, it should get the point across. Youcan safely ignore most of this code. Dask stuff gets interesting againtowards the bottom:

import math
import tempfile
import time
from queue import Empty

hidden_units = 100
learning_rate = 0.01
sync_replicas = False
replicas_to_aggregate = len(dask_spec['worker'])

def model(server):
worker_device = "/job:%s/task:%d" % (server.server_def.job_name,
task_index = server.server_def.task_index
is_chief = task_index == 0

with tf.device(tf.train.replica_device_setter(

global_step = tf.Variable(0, name="global_step", trainable=False)

# Variables of the hidden layer
hid_w = tf.Variable(
[IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
stddev=1.0 / IMAGE_PIXELS),
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")

# Variables of the softmax layer
sm_w = tf.Variable(
[hidden_units, 10],
stddev=1.0 / math.sqrt(hidden_units)),
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")

# Ops: located on the worker specified with task_index
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, [None, 10])

hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)

y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))

opt = tf.train.AdamOptimizer(learning_rate)

if sync_replicas:
if replicas_to_aggregate is None:
replicas_to_aggregate = num_workers
replicas_to_aggregate = replicas_to_aggregate

opt = tf.train.SyncReplicasOptimizer(

train_step = opt.minimize(cross_entropy, global_step=global_step)

if sync_replicas:
local_init_op = opt.local_step_init_op
if is_chief:
local_init_op = opt.chief_init_op

ready_for_local_init_op = opt.ready_for_local_init_op

# Initial token and chief queue runners required by the sync_replicas mode
chief_queue_runner = opt.get_chief_queue_runner()
sync_init_op = opt.get_init_tokens_op()

init_op = tf.global_variables_initializer()
train_dir = tempfile.mkdtemp()

if sync_replicas:
sv = tf.train.Supervisor(
sv = tf.train.Supervisor(

sess_config = tf.ConfigProto(
device_filters=["/job:ps", "/job:worker/task:%d" % task_index])

# The chief worker (task_index==0) session will prepare the session,
# while the remaining workers will wait for the preparation to complete.
if is_chief:
print("Worker %d: Initializing session..." % task_index)
print("Worker %d: Waiting for session to be initialized..." %

sess = sv.prepare_or_wait_for_session(, config=sess_config)

if sync_replicas and is_chief:
# Chief worker will start the chief queue runner and call the init op.
sv.start_queue_runners(sess, [chief_queue_runner])

return sess, x, y_, train_step, global_step, cross_entropy

def ps_task():
with local_client() as c:

def scoring_task():
with local_client() as c:
# Scores Channel
scores ='scores', maxlen=10)

# Make Model
server = c.worker.tensorflow_server
sess, _, _, _, _, cross_entropy = model(c.worker.tensorflow_server)

# Testing Data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/tmp/mnist-data', one_hot=True)
test_data = {x: mnist.validation.images,
y_: mnist.validation.labels}

# Main Loop
while True:
score =, feed_dict=test_data)


def worker_task():
with local_client() as c:
scores ='scores')
num_workers = replicas_to_aggregate = len(dask_spec['worker'])

server = c.worker.tensorflow_server
queue = c.worker.tensorflow_queue

# Make model
sess, x, y_, train_step, global_step, _= model(c.worker.tensorflow_server)

# Main loop
while not scores or[-1] > 1000:
batch = queue.get(timeout=0.5)
except Empty:

train_data = {x: batch[0],
y_: batch[1]}[train_step, global_step], feed_dict=train_data)

The last three functions defined here, ps_task, scorer_task andworker_task are functions that we want to run on each of our three groups ofTensorFlow server types. The parameter server task just starts a long-runningtask and passively joins the TensorFlow network:

def ps_task():
with local_client() as c:

The scorer task opens up an inter-workerchannel ofcommunication named “scores”, creates the TensorFlow model, then every secondscores the current state of the model against validation data. It reports thescore on the inter-worker channel:

def scoring_task():
with local_client() as c:
scores ='scores') # inter-worker channel

# Make Model
sess, _, _, _, _, cross_entropy = model(c.worker.tensorflow_server)


while True:
score =, feed_dict=test_data)

The worker task makes the model, listens on the Dask-TensorFlow Queue for newtraining data, and continues training until the last reported score is goodenough.

def worker_task():
with local_client() as c:
scores ='scores')

queue = c.worker.tensorflow_queue

# Make model
sess, x, y_, train_step, global_step, _ = model(c.worker.tensorflow_server)

while[-1] > 1000:
batch = queue.get()

train_data = {x: batch[0],
y_: batch[1]}[train_step, global_step], feed_dict=train_data)

We launch these tasks on the Dask workers that have the correspondingTensorFlow servers (see tf_spec and dask_spec above):

ps_tasks = [c.submit(ps_task, workers=worker)
for worker in dask_spec['ps']]

worker_tasks = [c.submit(worker_task, workers=addr, pure=False)
for addr in dask_spec['worker']]

scorer_task = c.submit(scoring_task, workers=dask_spec['scorer'][0])

This starts long-running tasks that just sit there, waiting for externalstimulation:

long running TensorFlow tasks

Finally we construct a function to dump each of our batches of datafrom our Dask.array (from the very beginning of this post) into theDask-TensorFlow queues on our workers. We make sure to only run these taskswhere the Dask-worker has a corresponding TensorFlow training worker:

from distributed.worker_client import get_worker

def transfer_dask_to_tensorflow(batch):
worker = get_worker()

dump =, batches,
workers=dask_spec['worker'], pure=False)

If we want to we can track progress in our local session by subscribing to thesame inter-worker channel:

scores ='scores')

We can use this to repeatedly dump data into the workers over and over againuntil they converge.

while[-1] > 1000:
dump =, batches,
workers=dask_spec['worker'], pure=False)


We discussed a non-trivial way to use TensorFlow to accomplish distributedmachine learning. We used Dask to support TensorFlow in a few ways:

  1. Trivially setup the TensorFlow network
  2. Prepare and clean data
  3. Coordinate progress and stopping criteria

We found it convenient that Dask and TensorFlow could play nicely with eachother. Dask supported TensorFlow without getting in the way. The fact thatboth libraries play nicely within Python and the greater PyData stack(NumPy/Pandas) makes it trivial to move data between them without costly orcomplex tricks.

Additionally, we didn’t have to work to integrate these two systems. There isno need for a separate collaborative effort to integrate Dask and TensorFlow ata core level. Instead, they are designed in such a way so as to foster thistype of interaction without special attention or effort.

This is also the first blogpost that I’ve written that, from a Daskperspective, uses some more complex features like long runningtasksor publishing state between workers withchannels. Thesemore advanced features are invaluable when creating more complex/bespokeparallel computing systems, such as are often found within companies.

What we could have done better

From a deep learning perspective this example is both elementary andincomplete. It would have been nice to train on a dataset that was larger andmore complex than MNIST. Also it would be nice to see the effects of trainingover time and the performance of using different numbers of workers. Indefense of this blogpost I can only claim that Dask shouldn’t affect any ofthese scaling results, because TensorFlow is entirely in control at thesestages and TensorFlow already has plenty of published scaling information.

Generally speaking though, this experiment was done in a weekend afternoon andthe blogpost was written in a few hours shortly afterwards. If anyone isinterested in performing and publishing about a more serious distributed deeplearning experiment with TensorFlow and Dask I would be happy to support themon the Dask side. I think that there is plenty to learn here about bestpractices.


The following individuals contributed to the construction of this blogpost:

  • Stephan Hoyer contributed with conversationsabout how TensorFlow is used in practice and with concrete experience ondeployment.
  • Will Warner andErik Welch both provided valuable editing andlanguage recommendations