adasum: adaptive summation of gradients for deep learning
TRANSCRIPT
Adasum: Adaptive Summation of Gradients for Deep Learning
Saeed Maleki, Madan Musuvathi, Todd Mytkowicz, Olli Saarikivi
Microsoft Research
Emad Barsoum, Sergii Dymchenko, Jaliya Ekanayake, Vadim Eksarevskiy, Maxim Lukiyanov, Tianju Xu
Microsoft
TH I S I S AN E L E PHANT R EA L L Y !DATA TH I S I S AN E L E PHANT R EA L L Y !
𝑤𝑥
𝐿𝑜𝑠𝑠
𝑖
𝐿𝑖⋯
Motivation: Neural Network Training is Inherently Sequential• Stochastic Gradient Descent (SGD)
• Workhorse for training a neural network
• Inherently sequential
TH I S I D AN E L A PHANT T EA L L Y A
do not achieve the same“per-epoch accuracy”
Motivation: More parallelism = less accuracy
Adasum addresses this
Background: SGD and Gradients
𝑤𝑥
𝐿𝑜𝑠𝑠
𝑖
𝐿𝑖⋯
=+
𝐷1
𝐷2
𝐷3
𝑤𝑥
𝐿𝑜𝑠𝑠
𝑤1
𝐷2𝐷1
𝑤𝑥
𝐿𝑜𝑠𝑠
𝑖
𝐿𝑖⋯
𝑤0 𝑤1
𝑤2
𝐷2𝐷1
𝑤0
𝑤3
𝑤0+Δ𝑤
𝑤1
𝑤0
𝑤2
𝑤3
𝑆𝐺𝐷 𝑤0 + Δ𝑤 = 𝑆𝐺𝐷 𝑤0 + +𝑂( Δ𝑤 2)𝑆𝐺𝐷′ 𝑤0 ⋅ Δ𝑤
𝜟𝒘
local model 𝑤3
+ 𝑓(Δ𝑤)+𝑀 ⋅ Δ𝑤
• 𝑀 = 𝐼 − 𝛼𝐻 where 𝐻 is the Hessian matrix and 𝛼 is the learning rate
• Very costly to calculate 𝐻
a matrix 𝑀
Adasum: Symbolic SGD
+ 𝐼 − 𝛼𝐻 ⋅ Δ𝑤+ Δ𝑤 − 𝛼𝐻 ⋅ Δ𝑤
𝑤1
𝐷2𝐷1
𝑤𝑥
𝐿𝑜𝑠𝑠
𝑖
𝐿𝑖⋯
𝑤0 𝑤1
𝑤2
𝐷2𝐷1
𝑤0
𝑤3
𝑤0+Δ𝑤
𝑤1
𝑤0
𝑤2
𝑤3𝜟𝒘
Adasum: Symbolic SGD
+ Δ𝑤 − 𝛼𝐻 ⋅ Δ𝑤
𝜟𝒘 𝑤0
𝑤3
usually ignored
𝑤0
𝑤3𝒘𝟑 + 𝜟𝒘
ok to ignore
𝒘𝟑 + 𝜟𝒘not ok to ignore
𝜟𝒘
𝜟𝒘′
𝒘𝟑 + 𝜟𝒘′
now ok to ignore
orthogonal space
𝐷2
gradient gradient
𝐷2𝐷1Δ𝑤′ is a “good” direction
𝑤1
𝑤0
𝑤2
𝑤3𝜟𝒘
𝑤0
𝑤1gradient𝜟𝒘
𝜟𝒘′
• Is Δ𝑤′ a “good” direction to move along?
• Δ𝑤 is the gradient w.r.t. the green bowl
• Any direction that has a positiveinner product with the gradientsdecays the loss• Δ𝑤′ is a “good” direction
• Adasum operator sums Δ𝑤 with projection
orthogonal space tothe orange gradient
Adasum: Adaptive Sum
• Adasum combines Δ𝑤 from any number of processors
• Adasum combines Δ𝑤 from different processors by projection and summation. Effectively:• They are added when they are orthogonal• Only one is taken when they are parallel
• Traditionally, Δ𝑤 from different processors are:• Either summed: can be too aggressive• Or Averaged: can be too conservative
Orthogonality of Gradients
• We use Pythagorean theorem to define orthogonality
• For P gradients it ranges between:• 1 for all orthogonal
• 1/P for all parallel
• Gradients start out all parallel
• Later in the training they becomemore orthogonal
• Convergence starts out slowbut speeds up later
BERT with 64 GPUs
Adasum is in Horovod
• Horovod is an open-source distributed training framework by Uber that supports both PyTorch and TensorFlow• Adasum is integrated in Horovod
• Adasum is easy to use:• horovod.allreduce(gradients, op=hvd.adasum)
• No hyperparameter
• Adasum allows scaling SGD• Minimizes convergence slowdown in scale
Result – Adasum vs. Allreduce/Averaging Latency on 64 GPUs• Adasum has some
computation overhead• But no communication overhead
• Almost negligible overhead• Communication latency
dominates the computationlatency
64 GPUs
Results – Adasum Convergence on MNIST
• Standard 2-layer CNN gets 99.3% in 2 epochs sequentially with batch size 32
• Tuned LR for averaging and Adasum
• Default Horovod fails at 16 GPUs, while Adasum still works with 32
• With tuned learning rates Adasumhas much better convergence for 16 and 32 GPUs
97.8%
97.9%
98.0%
98.1%
98.2%
98.3%
98.4%
98.5%
98.6%
98.7%
98.8%
98.9%
99.0%
99.1%
99.2%
99.3%
4 GPUs 8 GPUs 16 GPUs 32 GPUs
Test Accuracy at 2 epochs
Adasum (tuned LR) Adasum
Horovod (tuned LR) Default Horovod
Results – Adasum Convergence on word2vec
• word2vec model: embedding for words• London to England::Paris to France
• Runs on 32 nodes
• Tuned LR for averaging
• Adasum matches sequential accuracy
AS
Results – Adasum Convergence on Resnet
• The MLPerf Resnet50 on ImageNet • 16K batch size• 64 GPUs
• Adasum reaches MLPerftarget accuracy 74.9% in 69 epochs
• Default Horovod never reaches target accuracy in 90 epochs
Epochs
Validation Accuracy — Adasum — Default Horovod
Results – Adasum Convergence on BERT
• BERT is a common modeltrained nowadays
• Bigger batch size = more parallelism• 64k and 128k
• LAMB is the optimizerused at large scales
• Adasum beats LAMBwith 128k batchsize
Please use Adasum!
• And let us know what you think!
• https://github.com/horovod/