Submit New Event

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.

Submit News Feature

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.

Contribute a Blog

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.

Sign up for Newsletter

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.
Aug 6, 2020

Comparing Dask-ML and Ray Tune's Model Selection Algorithms


Hyperparameter optimization is the process of deducing model parameters thatcan’t be learned from data. This process is often time- and resource-consuming,especially in the context of deep learning. A good description of this processcan be found at “Tuning the hyper-parameters of an estimator,” andthe issues that arise are concisely summarized in Dask-ML’s documentation of“Hyper Parameter Searches.”

There’s a host of libraries and frameworks out there to address this problem.Scikit-Learn’s module has been mirrored in Dask-ML andauto-sklearn, both of which offer advanced hyperparameter optimizationtechniques. Other implementations that don’t follow the Scikit-Learn interfaceinclude Ray Tune, AutoML and Optuna.

Ray recently provided a wrapper to Ray Tune that mirrors the Scikit-LearnAPI called tune-sklearn (docs, source). The introduction of this librarystates the following:

 Cutting edge hyperparameter tuning techniques (Bayesian optimization, earlystopping, distributed execution) can provide significant speedups over gridsearch and random search.  

However, the machine learning ecosystem is missing a solution that providesusers with the ability to leverage these new algorithms while allowing usersto stay within the Scikit-Learn API. In this blog post, we introducetune-sklearn [Ray’s tuning library] to bridge this gap. Tune-sklearn is adrop-in replacement for Scikit-Learn’s model selection module withstate-of-the-art optimization features.


GridSearchCV 2.0 — New and Improved

This claim is inaccurate: for over a year Dask-ML has provided access to“cutting edge hyperparameter tuning techniques” with a Scikit-Learn compatibleAPI. To correct their statement, let’s look at each of the features that Ray’stune-sklearn provides, and compare them to Dask-ML:

 Here’s what [Ray’s] tune-sklearn has to offer:  
  1. Consistency with Scikit-Learn API
  3. Modern hyperparameter tuning techniques
  5. Framework support
  7. Scale up … [to] multiple cores and even multiple machines.

[Ray’s] Tune-sklearn is also fast.

Dask-ML’s model selection module has every one of the features:

  • Consistency with Scikit-Learn API: Dask-ML’s model selection APImirrors the Scikit-Learn model selection API.
  • Modern hyperparameter tuning techniques: Dask-ML offers state-of-the-arthyperparameter tuning techniques.
  • Framework support: Dask-ML model selection supports many librariesincluding Scikit-Learn, PyTorch, Keras, LightGBM and XGBoost.
  • Scale up: Dask-ML supports distributed tuning (how could it not?) andlarger-than-memory datasets.

Dask-ML is also fast. In “Speed” we show a benchmark betweenDask-ML, Ray and Scikit-Learn:

Only time-to-solution is relevant; all of these methods produce similar modelscores. See “Speed” for details.

Now, let’s walk through the details on how to use Dask-ML to obtain the 5features above.

Consistency with the Scikit-Learn API

Dask-ML is consistent with the Scikit-Learn API.

Here’s how to use Scikit-Learn’s, Dask-ML’s and Ray’s tune-sklearnhyperparameter optimization:

## Trimmed example; see appendix for more detail
from sklearn.model_selection import RandomizedSearchCV
search = RandomizedSearchCV(model, params, ...), y)

from dask_ml.model_selection import HyperbandSearchCV
search = HyperbandSearchCV(model, params, ...), y, classes=[0, 1])

from tune_sklearn import TuneSearchCV
search = TuneSearchCV(model, params, ...), y, classes=[0, 1])

The definitions of model and params follow the normal Scikit-Learndefinitions as detailed in the appendix.

Clearly, both Dask-ML and Ray’s tune-sklearn are Scikit-Learn compatible. Nowlet’s focus on how each search performs and how it’s configured.

Modern hyperparameter tuning techniques

Dask-ML offers state-of-the-art hyperparameter tuning techniquesin a Scikit-Learn interface.

The introduction of Ray’s tune-sklearn made this claim:

 tune-sklearn is the onlyScikit-Learn interface that allows you to easily leverage BayesianOptimization, HyperBand and other optimization techniques by simply toggling a few parameters.

The state-of-the-art in hyperparameter optimization is currently“Hyperband.” Hyperband reduces the amount of computationrequired with a principled early stopping scheme; past that, it’s the same asScikit-Learn’s popular RandomizedSearchCV.

Hyperband works. As such, it’s very popular. After the introduction ofHyperband in 2016 by Li et. al, the paper has been citedover 470 times and has been implemented in many different librariesincluding Dask-ML, Ray Tune, keras-tune, Optuna,AutoML,1 and Microsoft’s NNI. The original paper shows arather drastic improvement over all the relevantimplementations,2 and this drastic improvement persists infollow-up works.3 Some illustrative results from Hyperband arebelow:

 All algorithms are configured to do the same amount of work except “random2x” which does twice as much work. “hyperband (finite)” is similar Dask-ML’sdefault implementation, and “bracket s=4” is similar to Ray’s defaultimplementation. “random” is a random search. SMAC,4spearmint,5 and TPE6 are popular Bayesian algorithms.

Hyperband is undoubtedly a “cutting edge” hyperparameter optimizationtechnique. Dask-ML and Ray offer Scikit-Learn implementations of this algorithmthat rely on similar implementations, and Dask-ML’s implementation also has arule of thumb for configuration. Both Dask-ML’s and Ray’s documentationencourages use of Hyperband.

Ray does support using their Hyperband implementation on top of a techniquecalled Bayesian sampling. This changes the hyperparameter sampling scheme formodel initialization. This can be used in conjunction with Hyperband’s earlystopping scheme. Adding this option to Dask-ML’s Hyperband implementation isfuture work for Dask-ML.

Framework support

Dask-ML model selection supports many libraries including Scikit-Learn, PyTorch, Keras, LightGBM and XGBoost.

Ray’s tune-sklearn supports these frameworks:

 tune-sklearn is used primarily for tuningScikit-Learn models, but it also supports and provides examples for manyother frameworks with Scikit-Learn wrappers such as Skorch (Pytorch),KerasClassifiers (Keras), and XGBoostClassifiers (XGBoost).

Clearly, both Dask-ML and Ray support the many of the same libraries.

However, both Dask-ML and Ray have some qualifications. Certain libraries don’toffer an implementation of partial_fit,7 so not all of the modernhyperparameter optimization techniques can be offered. Here’s a table comparingdifferent libraries and their support in Dask-ML’s model selection and Ray’stune-sklearn:

           Model Library      Dask-ML support      Ray support      Dask-ML: early stopping?      Ray: early stopping?                  Scikit-Learn      ✔      ✔      ✔*      ✔*              PyTorch (via Skorch)      ✔      ✔      ✔      ✔              Keras (via SciKeras)      ✔      ✔      ✔**      ✔**              LightGBM      ✔      ✔      ❌      ❌              XGBoost      ✔      ✔      ❌      ❌      

* Only for the models that implement partial_fit.
** Thanks to work by the Dask developers around scikeras#24.

By this measure, Dask-ML and Ray model selection have the same level offramework support. Of course, Dask has tangential integration with LightGBM andXGBoost through Dask-ML’s xgboost module and dask-lightgbm.

Scale up

Dask-ML supports distributed tuning (how could it not?), aka parallelizationacross multiple machines/cores. In addition, it also supportslarger-than-memory data.

 [Ray’s] Tune-sklearn leverages Ray Tune, a library for distributedhyperparameter tuning, to efficiently and transparently parallelize crossvalidation on multiple cores and even multiple machines.

Naturally, Dask-ML also scales to multiple cores/machines because it relies onDask. Dask has wide support for different deployment options that spanfrom your personal machine to supercomputers. Dask will very likely work on topof any computing system you have available, including Kubernetes, SLURM, YARNand Hadoop clusters as well as your personal machine.

Dask-ML’s model selection also scales to larger-than-memory datasets, and isthoroughly tested. Support for larger-than-memory data is untested in Ray, andthere are no examples detailing how to use Ray Tune with the distributeddataset implementations in PyTorch/Keras.

In addition, I have benchmarked Dask-ML’s model selection module to see how thetime-to-solution is affected by the number of Dask workers in “Better andfaster hyperparameter optimization with Dask.” That is, how does thetime to reach a particular accuracy scale with the number of workers $P$? Atfirst, it’ll scale like $1/P$ but with large number of workers the serialportion will dictate time to solution according to Amdahl’s Law. Briefly, Ifound Dask-ML’s HyperbandSearchCV speedup started to saturate around 24workers for a particular search.


Both Dask-ML and Ray are much faster than Scikit-Learn.

Ray’s tune-sklearn runs some benchmarks in the introduction with theGridSearchCV class found in Scikit-Learn and Dask-ML. A more fair benchmarkwould be use Dask-ML’s HyperbandSearchCV because it is almost the same as thealgorithm in Ray’s tune-sklearn. To be specific, I’m interested in comparingthese methods:

  • Scikit-Learn’s RandomizedSearchCV. This is a popular implementation, onethat I’ve bootstrapped myself with a custom model.
  • Dask-ML’s HyperbandSearchCV. This is an early stopping technique forRandomizedSearchCV.
  • Ray tune-sklearn’s TuneSearchCV. This is a slightly different earlystopping technique than HyperbandSearchCV’s.

Each search is configured to perform the same task: sample 100 parameters andtrain for no longer than 100 “epochs” or passes through thedata.8 Each estimator is configured as their respectivedocumentation suggests. Each search uses 8 workers with a single crossvalidation split, and a partial_fit call takes one second with 50,000examples. The complete setup can be found in the appendix.

Here’s how long each library takes to complete the same search:

Notably, we didn’t improve the Dask-ML codebase for this benchmark, and ran thecode as it’s been for the last year.9 Regardless, it’s possible thatother artifacts from biased benchmarks crept into this benchmark.

Clearly, Ray and Dask-ML offer similar performance for 8 workers when comparedwith Scikit-Learn. To Ray’s credit, their implementation is ~15% faster thanDask-ML’s with 8 workers. We suspect that this performance boost comes from thefact that Ray implements an asynchronous variant of Hyperband. We shouldinvestigate this difference between Dask and Ray, and how each balances thetradeoffs, number FLOPs vs. time-to-solution. This will vary with the numberof workers: the asynchronous variant of Hyperband provides no benefit if usedwith a single worker.

Dask-ML reaches scores quickly in serial environments, or when the number ofworkers is small. Dask-ML prioritizes fitting high scoring models: if there are100 models to fit and only 4 workers available, Dask-ML selects the models thathave the highest score. This is most relevant in serialenvironments;10 see “Better and faster hyperparameter optimizationwith Dask” for benchmarks. This feature is omitted from thisbenchmark, which only focuses on time to solution.


Dask-ML and Ray offer the same features for model selection: state-of-the-artfeatures with a Scikit-Learn compatible API, and both implementations havefairly wide support for different frameworks and rely on backends that canscale to many machines.

In addition, the Ray implementation has provided motivation for furtherdevelopment, specifically on the following items:

  1. Adding support for more libraries, including Keras (dask-ml#696,dask-ml#713, scikeras#24). SciKeras is a Scikit-Learn wrapper forKeras that (now) works with Dask-ML model selection because SciKeras modelsimplement the Scikit-Learn model API.
  3. Better documenting the models that Dask-ML supports(dask-ml#699). Dask-ML supports any model that implement theScikit-Learn interface, and there are wrappers for Keras, PyTorch, LightGBMand XGBoost. Now, Dask-ML’s documentation prominently highlights thisfact.

The Ray implementation has also helped motivate and clarify future work.Dask-ML should include the following implementations:

  1. A Bayesian sampling scheme for the Hyperband implementation that’ssimilar to Ray’s and BOHB’s (dask-ml#697).
  3. A configuration of HyperbandSearchCV that’s well-suited forexploratory hyperparameter searches. An initial implementation is indask-ml#532, which should be benchmarked against Ray.

Luckily, all of these pieces of development are straightforward modificationsbecause the Dask-ML model selection framework is pretty flexible.

Thank you Tom Augspurger, Matthew Rocklin, Julia Signell, and BenjaminZaitlen for your feedback, suggestions and edits.


Benchmark setup

This is the complete setup for the benchmark between Dask-ML, Scikit-Learn andRay. Complete details can be found atstsievert/dask-hyperband-comparison.

Let’s create a dummy model that takes 1 second for a partial_fit call with50,000 examples. This is appropriate for this benchmark; we’re only interestedin the time required to finish the search, not how well the models do.Scikit-learn, Ray and Dask-ML have have very similar methods of choosinghyperparameters to evaluate; they differ in their early stopping techniques.

from scipy.stats import uniform
from sklearn.model_selection import make_classification
from benchmark import ConstantFunction  # custom module

# This model sleeps for `latency * len(X)` seconds before
# reporting a score of `value`.
model = ConstantFunction(latency=1 / 50e3, max_iter=max_iter)

params = {"value": uniform(0, 1)}
# This dummy dataset mirrors the MNIST dataset
X_train, y_train = make_classification(n_samples=int(60e3), n_features=784)

This model will take 2 minutes to train for 100 epochs (aka passes through thedata). Details can be found at stsievert/dask-hyperband-comparison.

Let’s configure our searches to use 8 workers with a single cross-validationsplit:

from sklearn.model_selection import RandomizedSearchCV, ShuffleSplit
split = ShuffleSplit(test_size=0.2, n_splits=1)
kwargs = dict(cv=split, refit=False)

search = RandomizedSearchCV(model, params, n_jobs=8, n_iter=n_params, **kwargs), y_train)  # 20.88 minutes

from dask_ml.model_selection import HyperbandSearchCV
dask_search = HyperbandSearchCV(
   model, params, test_size=0.2, max_iter=max_iter, aggressiveness=4

from tune_sklearn import TuneSearchCV
ray_search = TuneSearchCV(
   model, params, n_iter=n_params, max_iters=max_iter, early_stopping=True, **kwargs
), y_train)  # 2.93 minutes, y_train)  # 2.49 minutes

Full example usage

from sklearn.linear_model import SGDClassifier
from scipy.stats import uniform, loguniform
from sklearn.datasets import make_classification
model = SGDClassifier()
params = {"alpha": loguniform(1e-5, 1e-3), "l1_ratio": uniform(0, 1)}
X, y = make_classification()

from sklearn.model_selection import RandomizedSearchCV
search = RandomizedSearchCV(model, params, ...), y)

from dask_ml.model_selection import HyperbandSearchCV
HyperbandSearchCV(model, params, ...), y, classes=[0, 1])

from tune_sklearn import TuneSearchCV
search = TuneSearchCV(model, params, ...), y, classes=[0, 1])


  1.      Their implementation of Hyperband in HpBandSter is included in Auto-PyTorch and BOAH.    
  3.      See Figures 4, 7 and 8 in “Hyperband: A Novel Bandit-Based Approach to Hyperparameter Optimization.”    
  5.      See Figure 1 of the BOHB paper and a paper from an augmented reality company.    
  7.      SMAC is described in “Sequential Model-Based Optimization forGeneral Algorithm Configuration,” and is available in AutoML.    
  9.      Spearmint is described in “Practical Bayesian Optimization of MachineLearning Algorithms,” and is available in HIPS/spearmint.    
  11.      TPE is described in Section 4 of “Algorithms for Hyperparameter Optimization,” and is available through Hyperopt.    
  13.      From Ray’s “If the estimator does not support partial_fit, a warning will be shown saying early stopping cannot be done and it will simply run the cross-validation on Ray’s parallel back-end.”    
  15.      I choose to benchmark random searches instead of grid searches because random searches produce better results because grid searches require estimating how important each parameter is; for more detail see “Random Search for Hyperparameter Optimization” by Bergstra and Bengio.    
  17.      Despite a relevant implementation in dask-ml#527.    
  19.      Because priority is meaningless if there are an infinite number of workers.