with System and Algorithm Co-design
Mu Li Thesis Defense
CSD, CMU Feb 2nd, 2017
Scaling Distributed Machine Learning
2
minw
nX
i=1
fi(w)
Large-scale problems
✧ Distributed systems ✧ Large scale optimization methods
Large Scale Machine Learning
✧ Machine learning learns from data ✧ More data ✓ better accuracy ✓ can use more complex models
3
Data size
Accu
racy
More complex models
Ads Click Prediction
✧ Predict if an ad will be clicked ✧ Each ad impression is an example ✧ Logistic regression ✓ Single machine processes 1 million
examples per second
4
Ads Click Prediction
✧ Predict if an ad will be clicked ✧ Each ad impression is an example ✧ Logistic regression ✓ Single machine processes 1 million
examples per second ✧ A typical industrial size problem has ✓ 100 billion examples ✓ 10 billion unique features
5
0
175
350
525
700
Year
2010 2011 2012 2013 2014
Training data size
(TB)
6
✧ Recognize the object in an image ✧ Convolutional neural network ✧ A state-of-the-art network ✓ Hundreds of layers ✓ Billions of floating-point operation for
processing a single image
Image Recognition
7
✧ Distribute workload among many machines
✧ Widely available thanks to cloud providers (AWS, GCP, Azure) machine
machine
machine
machine
switch switch
switch switch
Distributed Computing for Large Scale Problems
8
✧ Distribute workload among many machines
✧ Widely available thanks to cloud providers (AWS, GCP, Azure)
✧ Challenges ✓ Limited communication bandwidth (10x
less than memory bandwidth) ✓ Large synchronization cost (1ms latency) ✓ Job failures
machine
machine
switch
machine
machine
switch
switch switch
0
8
16
24
Machine time (hour)
100 1000 10000Failu
re ra
te (%
)
Distributed Computing for Large Scale Problems
Distributed Optimization for Large Scale ML
9
m[
i=1
Ii = {1, 2, . . . , n}
X
i2I1
@fi(wt)
X
i2I2
@fi(wt)
X
i2Im
@fi(wt)
wt wt+1+
…✧ Challenges ✓ Massive communication traffic ✓ Expensive global synchronization
minw
nX
i=1
fi(w)
10
Distributed SystemsDistributed Systems
✓ Large data size, complex models ✓ Fault tolerant ✓ Easy to use
Scaling Distributed Machine Learning
Distributed SystemsLarge Scale Optimization
✓ Communication efficient ✓ Convergence guarantee
11
Distributed SystemsDistributed Systems
Scaling Distributed Machine Learning
Distributed SystemsLarge Scale Optimization
for machine learning
Parameter Server for machine learning
for machine learning
MXNet for deep learning
for machine learning
DBPG for non-convex non-smooth fi
for machine learning
EMSO for efficient minibatch SGD
12
Distributed SystemsDistributed Systems
Scaling Distributed Machine Learning
Distributed SystemsLarge Scale Optimization
for machine learning
Parameter Server for machine learning
for machine learning
MXNet for deep learning
for machine learning
DBPG for non-convex non-smooth fi
for machine learning
EMSO for efficient minibatch SGD
Existing Open Source Systems in 2012
13
✧ MPI (message passing interface) ✓ Hard to use for sparse problems ✓ No fault tolerance
✧ Key-value store, e.g. redis ✓ Expensive individual key-value pair communication ✓ Difficult to program on the server side
✧ Hadoop/Spark ✓ BSP data consistency makes efficient implementation challenging
Parameter Server Architecture
14
Training data
push
Worker machines
Server machines
pull
Model[Smola’10, Dean’12]
update
Keys Features of our Implementation
✧ Trade off data consistency for speed ✓ Flexible consistency models ✓ User-defined filters
✧ Fault tolerance with chain replication
15
[Li et al, OSDI’14]
X
X
Flexible Consistency Model
16
Bounded delay / SSP [Langford 09, Cipar 13]
1 2 3 4 5
Eventual / Total asynchronous
[Smola 10]
1 2 3 4
1 2 3Sequential / BSP 4
gradient push & pulliter 0
gradient push & pulliter 1
execute after finished dependency
gradient push & pulliter 0
gradient push & pulliter 1
no dependency
Flexible models via task dependency graph
User-defined Filters
✧ User defined encoder/decoder for efficient communication
✧ Lossless compression ✓ General data compression: LZ, LZR, ..
✧ Lossy compression ✓ Random skip ✓ Fixed-point encoding
17
Encoder
Decoder
Data
Data
machine
machine
efficient communication
Fault Tolerance with Chain Replication✧ Model is partitioned by consistent hashing ✧ Chain replication
18
worker 0 server 0 server 1
push push
ackack
push
ack
push
push
ack
ackworker 0
server 0 server 1
worker 1
✧ Option: aggregation reduces backup traffic
19
Distributed SystemsDistributed Systems
Scaling Distributed Machine Learning
Distributed SystemsLarge Scale Optimization
for machine learning
Parameter Server for machine learning
for machine learning
MXNet for deep learning
for machine learning
DBPG for non-convex non-smooth fi
for machine learning
EMSO for efficient minibatch SGD
Proximal Gradient Method
✓ fi: continuously differentiable but not necessarily convex ✓ h: convex but possibly non-smooth
20
minw2⌦
nX
i=1
fi(w) + h(w)
✧ Iterative update
[Combettes’09]
Prox⌘(x) := argmin
y2⌦h(y) +
1
2⌘
kx� yk2
wt+1 = Prox�t
"wt � ⌘t
nX
i=1
fi(wt)
#
where
Delayed Block Proximal Gradient
✧ Algorithm design tailored for parameter server implementation ✓ Update a block of coordinates each time ✓ Allow delay among blocks ✓ Use filters during communication
✧ Only 300 lines of codes
21
[Li et al, NIPS’14]
data
model
Convergence Analysis
✧ Assumptions: ✓ Block Lipschitz continuity: within block , cross blocks
✓ Delay is bounded by 𝜏
✓ Lossy compressions such as random skip filter and significantly-modified filter
✧ DBPG converges to a stationary point if the learning rate is chosen as
22
Lvar Lcor
⌘t <1
Lvar
+ ⌧Lcor
Experiments on Ads Click Prediction
✧ Real dataset used in production ✓ 170 billion examples, 65 billion unique
features, 636 TB in total ✧ 1000 machines ✧ Sparse logistic regression
23
0
0.5
1
1.5
2
Maximal delay 𝜏
0 1 2 4 8 16
computing waiting
time (hour)
sequential
best trade-off1.6x
Time to achieve the same objective value
min
w
nX
i=1
log(1 + exp(�yi hxi, wi)) + �kwk1
Filters to Reduce Communication Traffic
✧ Key caching ✓ Cache feature IDs on both sender and
receiver ✧ Data compression ✧ KKT filter ✓ Shrink gradient to 0 based on the KKT
condition
24Tr
affic (
%)
0
25
50
75
100
Baseline Key Caching
Compre-ssing
KKT Filter
Traffi
c (%
)
0
25
50
75
100
Baseline Key Caching
Compre-ssing
KKT Filter
Server
Worker
2x
40x 40x
2x 2.5x12x
25
Distributed SystemsDistributed Systems
Scaling Distributed Machine Learning
Distributed SystemsLarge Scale Optimization
for machine learning
Parameter Server for machine learning
for machine learning
MXNet for deep learning
for machine learning
DBPG for non-convex non-smooth fi
for machine learning
EMSO for efficient minibatch SGD
Deep Learning is Unique✧ Complex workloads
✧ Heterogeneous computing
✧ Easy to use programming interface
26
“deep learning” trend in the past 5 years
Key Features of MXNet
✧ Easy-to-use front-end ✓ Mixed programming
✧ Scalable and efficient back-end ✓ Computation and memory optimization ✓ Auto-parallelization ✓ Scaling to multiple GPU/machines
27
[Chen et al, NIPS’15 workshop] (corresponding author)
28
✧ Imperative programming is flexible ✓ e.g. Numpy, Matlab, Torch, …
Mixed Programming
import mxnet as mx a = mx.nd.zeros((100, 50)) b = mx.nd.ones((100, 50)) c = a * b print(c) c += 1
import mxnet as mx net = mx.symbol.Variable('data') net = mx.symbol.FullyConnected( data=net, num_hidden=128) net = mx.symbol.SoftmaxOutput(data=net) model = mx.module.Module(net) model.forward(data=c) model.backward()
✧ Declarative programs are easy to optimize ✓ e.g. TensorFlow, Theano, Caffe, …
Good for defining the neural networkGood for updating and interacting
with the neural network
Back-end System
29
a b
1
+
⨉
c
fullc
softmax
weight
bias
Back-end
import mxnet as mx a = mx.nd.zeros((100, 50)) b = mx.nd.ones((100, 50)) c = a * b c += 1
import mxnet as mx net = mx.symbol.Variable('data') net = mx.symbol.FullyConnected( data=net, num_hidden=128) net = mx.symbol.SoftmaxOutput(data=net) texec = mx.module.Module(net) texec.forward(data=c) texec.backward()
Front-end
✧ Optimization ✓ Memory optimization ✓ Operator fusion and
runtime compilation ✧ Scheduling ✓ Auto-parallelization
Scale to Multiple GPU Machines
30
PCIe Switch
GPU
GPU
GPU
GPU
CPU
Network Switch
63 GB/s 4 PCIe 3.0 16x
15.75 GB/s PCIe 3.0 16x
1.25 GB/s 10 Gbit Ethernet
Hierarchical parameter server
Level-1 Servers
Workers
Level-2 Servers
GPUs
CPUs
✧ 1000 lines of codes
Experiment Setup
✧ ✓ 1.2 million images with 1000 classes
✧ Resnet 152-layer model ✧ EC2 P2.8 xlarge ✓ 8 K80 GPUs per machine
31
GPU 0-7
PCIe switchesCPU
✧ Minibatch SGD ✓ Draw a random set of examples It at
iteration t ✓ Update
wt+1 = wt �⌘t|It|
X
i2It
@fi(wt)
✧ Synchronized updating
Communication Cost
✧ Fix #GPUs per machine
32
time
(sec
)
0.2
0.325
0.45
0.575
0.7
# of GPUs
0 32 64 96 128
1 GPU/machine2 GPU/machine4 GPU/machine8 GPU/machine
Scalability
33
time
(sec
)
0
0.25
0.5
0.75
1
# of GPUs
0 32 64 96 128
Communication costbatch size/GPU=4batch size/GPU=8batch size/GPU=16
115x speedup
Convergence
✧ Increase learning rate by 5x
✧ Increase learning rate by 10x, decrease it at epoch 50, 80
34
Top-
1 val
idat
ion
accu
racy
0.1
0.275
0.45
0.625
0.8
# of epochs
0 30 60 90 120
batch size=256batch size=2,560batch size=5,120
35
Distributed SystemsDistributed Systems
Scaling Distributed Machine Learning
Distributed SystemsLarge Scale Optimization
for machine learning
Parameter Server for machine learning
for machine learning
MXNet for deep learning
for machine learning
DBPG for non-convex non-smooth fi
for machine learning
EMSO for efficient minibatch SGD
Minibatch SGD
✧ Large batch size b in SGD ✓ Better parallelization within a batch ✓ Less switching/communication cost
✧ Small batch size b ✓ Faster convergenceN: number of examples processed
36
O(1/pN + b/N)
Batch size (b)
Better system performance
convergence rateWorse
Motivation
✧ Improve converge rate for large batch size ✓ Example variance decreases with batch size ✓ Solve a more “accurate” optimization
subproblem over each batch
37
Batch size (b)
Better system performance
convergence rate
[Li et al, KDD’14]
Worse
Efficient Minibatch SGD✧ Define . Minibatch SGD solves
38
fIt(w) :=X
i2It
fi(w)
first-order approximation
wt = argminw2⌦
fIt(wt�1) + h@fIt(wt�1), w � wt�1i+
1
2⌘tkw � wt�1k22
�conservative penalty
✧ For convex fi, choose . EMSO has convergence rate (compared to )
⌘t = O(b/pN)
O(1/pN)
O(1/pN + b/N)
exact objective
wt = argminw2⌦
fIt(w) +
1
2⌘tkw � wt�1k22
�✧ EMSO solves the subproblem at each iteration
Experiment✧ Ads click prediction with fixed run time
39
Extended to deep learning in [Keskar et al, arXiv’16]O
bjec
tive
0.195
0.201
0.208
0.214
0.22
Batch size
1e3 1e4 1e5
SGDEMSO
Obj
ectiv
e
0.01
0.016
0.023
0.029
0.035
Batch size
1e3 1e4 1e5
SGDEMSO
Single machine 10 machines
40
minw
nX
i=1
fi(w)
✧ Distributed systems ✧ Large scale optimization
Large-scale problems
Reduce communication
costCo-design
✓ Communicate less ✓ Message compression ✓ Relaxed data consistency
With appropriate computational frameworks and algorithm design, distributed machine learning can be made simple,
fast, and scalable, both in theory and in practice.
Acknowledgement
41
with other 13 collaborators
Advisors
QQ & Alex
Committee members
Backup slides
42
Scaling to 16 GPUs in a Single Machine
43
time
(sec
)
0
0.25
0.5
0.75
1
# of GPUs
0 4 8 12 16
Comm Costbs/GPU=2bs/GPU=4bs/GPU=8bs/GPU=16
Communication dominates
15x
Compare to a L-BFGS Based System
44
System AParameter Server
0.1 1 10
Time (hour)
0
1.25
2.5
3.75
5
System A Parameter Server
computing waiting
Obj
ectiv
e
Tim
e (h
our)
Sections not Covered
✧ AdaDelay: model the actual delay for asynchronized SGD ✧ Parsa: data placement to reduce communication cost ✧ Difacto: large scale factorization machine
45