new arxiv:1903.06237v3 [cs.lg] 31 jul 2019 · 2019. 8. 2. ·...

17
Inefficiency of K-FAC for Large Batch Size Training Linjian Ma 1* , Gabe Montague 1* , Jiayu Ye 1* , Zhewei Yao 1 , Amir Gholami 1 , Kurt Keutzer 1 , AND Michael W. Mahoney 1 1 University of California at Berkeley, {linjian, gabe_montague, yejiayu, zheweiy, amirgh, keutzer, mahoneymw}@berkeley.edu Abstract. In stochastic optimization, using large batch sizes during training can leverage parallel resources to produce faster wall-clock training times per training epoch. However, for both training loss and testing error, recent results analyzing large batch Stochastic Gradient Descent (SGD) have found sharp diminishing returns, beyond a certain critical batch size. In the hopes of addressing this, it has been suggested that the Kronecker-Factored Approximate Curvature (K-FAC) method allows for greater scalability to large batch sizes, for non-convex machine learning problems such as neural network optimization, as well as greater robustness to variation in model hyperparameters. Here, we perform a detailed empirical analysis of large batch size training for both K-FAC and SGD, evaluating performance in terms of both wall-clock time and aggregate computational cost. Our main results are twofold: first, we find that both K-FAC and SGD doesn’t have ideal scalability behavior beyond a certain batch size, and that K-FAC does not exhibit improved large-batch scalability behavior, as compared to SGD; and second, we find that K-FAC, in addition to requiring more hyperparameters to tune, suffers from similar hyperparameter sensitivity behavior as does SGD. We discuss extensive results using ResNet and AlexNet on CIFAR-10 and SVHN, respectively, as well as more general implications of our findings. 1. Introduction. As the boundaries of parallelism are pushed by modern hardware and distributed systems, researchers are increasingly turning their attention toward leveraging these advances for faster training of deep neural networks (DNNs). When using the prevailing Stochastic Gradient Descent (SGD) method, a batch of training data is split across computational processing units, which together compute a stochastic gradient used to update the parameters of the DNN. To allow for efficient parallel scalability to a large number of processors, one would like to use a large batch of training data to compute the stochastic gradient estimate [1]. However, using a large batch size changes the dynamics of the training. It has been demonstrated both theoretically [2, 3, 4] and empirically [5, 6, 7, 8, 9] that, in many cases, training with large-batch SGD comes with significant drawbacks. This includes degraded testing performance, worse implicit regularization, and diminishing returns in terms of training loss reduction. Among other things, there exists a critical batch size beyond which these effects are most acute. For practitioners operating on a data-driven computational budget, large batch size comes with the additional inconvenience of increased sensitivity to hyperparameters and thus increased tuning time and cost [8]. In attempts to mitigate these shortcomings, a number of solutions have been proposed and demonstrated to exhibit varying degrees of effectiveness [10, 11, 12, 13, 14, 15, 16, 17]. An important practical consideration in methods such as [10, 11, 13, 14, 15, 17] is the sensitivity to hyperparameters. Generally, this sensitivity is quite strong, and the required tuning process is expensive in terms of both analyst time and in-search training time. If performing large batch training requires significant hyperparameter tuning for each batch-size, then one would not achieve any effective speed up in total training time (i.e., hyperparameter tuning time plus final training time). With the exception of the recent work of [12], these discussions are largely ignored in the proposed solutions. Recently, it has been suggested [8, 7, 18] that the (very) approximate second-order optimization method known as Kronecker-Factored Approximate Curvature (K-FAC) [19] may help to alleviate the issues of large-batch data inefficiency and generalization degradation exhibited by the first-order SGD method. K-FAC views the parameter space as a manifold of distribution space, in which distance between parameter vectors is measured by a variant of the Kullback–Leibler divergence between their corresponding distributions. In certain circumstances, K-FAC has been demonstrated to attain comparable effectiveness with large batch size as SGD [18, 14], but the effects of batch size on training under K-FAC remain largely unstudied. In this work, we investigate these issues, evaluating the hypothesis that K-FAC is capable of alleviating the difficulties of large-batch SGD. In particular, we focus on the following two questions regarding K-FAC and large batch training: What is the scalability behavior of K-FAC, and how does it compare with that of SGD? How does increasing batch size affect the hyperparameter sensitivity of K-FAC? To answer these questions, we conduct a comprehensive investigation in the context of image classification on CIFAR-10 [20] and SVHN [21]. We investigate the performance of CIFAR-10 with a Residual Network (ResNet20 and ResNet32) classifier [22], and we investigate SVHN with a AlexNet classifier [23]. We investigate the problem of large-batch diminishing returns by measuring iteration speedup and comparing it to an ideal * Equal contribution. Authors ordered alphabetically. arXiv:1903.06237v3 [cs.LG] 31 Jul 2019

Upload: others

Post on 14-Apr-2021

1 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size TrainingLinjian Ma1∗, Gabe Montague1∗, Jiayu Ye1∗,

Zhewei Yao1, Amir Gholami1, Kurt Keutzer1, AND Michael W. Mahoney11University of California at Berkeley,

{linjian, gabe_montague, yejiayu, zheweiy, amirgh, keutzer, mahoneymw}@berkeley.edu

Abstract. In stochastic optimization, using large batch sizes during training can leverage parallel resources to producefaster wall-clock training times per training epoch. However, for both training loss and testing error, recent results analyzinglarge batch Stochastic Gradient Descent (SGD) have found sharp diminishing returns, beyond a certain critical batch size. In thehopes of addressing this, it has been suggested that the Kronecker-Factored Approximate Curvature (K-FAC) method allows forgreater scalability to large batch sizes, for non-convex machine learning problems such as neural network optimization, as well asgreater robustness to variation in model hyperparameters. Here, we perform a detailed empirical analysis of large batch sizetraining for both K-FAC and SGD, evaluating performance in terms of both wall-clock time and aggregate computational cost.Our main results are twofold: first, we find that both K-FAC and SGD doesn’t have ideal scalability behavior beyond a certainbatch size, and that K-FAC does not exhibit improved large-batch scalability behavior, as compared to SGD; and second, we findthat K-FAC, in addition to requiring more hyperparameters to tune, suffers from similar hyperparameter sensitivity behavior asdoes SGD. We discuss extensive results using ResNet and AlexNet on CIFAR-10 and SVHN, respectively, as well as more generalimplications of our findings.

1. Introduction. As the boundaries of parallelism are pushed by modern hardware and distributedsystems, researchers are increasingly turning their attention toward leveraging these advances for fastertraining of deep neural networks (DNNs). When using the prevailing Stochastic Gradient Descent (SGD)method, a batch of training data is split across computational processing units, which together compute astochastic gradient used to update the parameters of the DNN.

To allow for efficient parallel scalability to a large number of processors, one would like to use a large batchof training data to compute the stochastic gradient estimate [1]. However, using a large batch size changes thedynamics of the training. It has been demonstrated both theoretically [2, 3, 4] and empirically [5, 6, 7, 8, 9] that,in many cases, training with large-batch SGD comes with significant drawbacks. This includes degraded testingperformance, worse implicit regularization, and diminishing returns in terms of training loss reduction. Amongother things, there exists a critical batch size beyond which these effects are most acute. For practitionersoperating on a data-driven computational budget, large batch size comes with the additional inconvenience ofincreased sensitivity to hyperparameters and thus increased tuning time and cost [8]. In attempts to mitigatethese shortcomings, a number of solutions have been proposed and demonstrated to exhibit varying degrees ofeffectiveness [10, 11, 12, 13, 14, 15, 16, 17].An important practical consideration in methods such as [10, 11, 13, 14, 15, 17] is the sensitivity to

hyperparameters. Generally, this sensitivity is quite strong, and the required tuning process is expensive interms of both analyst time and in-search training time. If performing large batch training requires significanthyperparameter tuning for each batch-size, then one would not achieve any effective speed up in total trainingtime (i.e., hyperparameter tuning time plus final training time). With the exception of the recent work of [12],these discussions are largely ignored in the proposed solutions.Recently, it has been suggested [8, 7, 18] that the (very) approximate second-order optimization method

known as Kronecker-Factored Approximate Curvature (K-FAC) [19] may help to alleviate the issues oflarge-batch data inefficiency and generalization degradation exhibited by the first-order SGD method. K-FACviews the parameter space as a manifold of distribution space, in which distance between parameter vectorsis measured by a variant of the Kullback–Leibler divergence between their corresponding distributions. Incertain circumstances, K-FAC has been demonstrated to attain comparable effectiveness with large batch sizeas SGD [18, 14], but the effects of batch size on training under K-FAC remain largely unstudied.

In this work, we investigate these issues, evaluating the hypothesis that K-FAC is capable of alleviating thedifficulties of large-batch SGD. In particular, we focus on the following two questions regarding K-FAC andlarge batch training:• What is the scalability behavior of K-FAC, and how does it compare with that of SGD?• How does increasing batch size affect the hyperparameter sensitivity of K-FAC?To answer these questions, we conduct a comprehensive investigation in the context of image classificationon CIFAR-10 [20] and SVHN [21]. We investigate the performance of CIFAR-10 with a Residual Network(ResNet20 and ResNet32) classifier [22], and we investigate SVHN with a AlexNet classifier [23]. We investigatethe problem of large-batch diminishing returns by measuring iteration speedup and comparing it to an ideal

∗Equal contribution. Authors ordered alphabetically.

arX

iv:1

903.

0623

7v3

[cs

.LG

] 3

1 Ju

l 201

9

Page 2: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 2

scaling scenario. Our key observations are as follows:• Performance. Even with extensive hyperparameter tuning, K-FAC has comparable, but not superior,

train/test performance to SGD (Fig. 1).• Speedup. Both K-FAC and SGD have diminishing returns with the increase of batch size. Increasing

batch size for K-FAC yields lower, i.e., less prominent, speedup, as compared with SGD, when measured interms of iterations (Fig. 2).

• Hyperparameter Sensitivity. K-FAC hyperparameter sensitivity depends on both batch size andepochs/iterations. For fixed epochs, i.e., running the same number of epochs, larger batch sizes resultin greater hyperparameter sensitivity and smaller regions of hyperparameter space which result in “goodconvergence.” For fixed iterations, i.e., running the same number of iterations, larger batch sizes resultin less sensitivity and larger regions of hyperparameter space which result in “good convergence” (Fig. 3,Fig. 4).We start with mathematical background and related work in Section 2, followed by a description of our

experimental setup in Section 3. Our empirical results demonstrating the inefficiencies of both K-FAC andSGD with large batch sizes appear in Section 4. Our conclusions are in Section 5. Additional material ispresented in the Appendix.

2. Background and Related Work. For a supervised learning framework, the goal is to minimize aloss function expressed as

(2.1) L(θ) =1

N

N∑i=1

l(xi, yi, θ),

where θ ∈ Rd is the vector of model parameters, and l(x, y, θ) is the loss for a datum (x, y) ∈ (X,Y ). Here, Xis the input, Y is the corresponding label, and N = |X| is the cardinality of the training set. SGD is typicallyused to optimize the loss by taking steps of the form:

(2.2) θt+1 = θt − ηt1

|B|∑

(x,y)∈B

∇θl(x, y, θt),

where B is a mini-batch of examples drawn randomly from X × Y , and ηt is the learning rate at iteration t.

2.1. Kronecker-Factored Approximate Curvature. As opposed to SGD, which treats model pa-rameter space as Euclidean, natural gradient descent methods [24] for DNN optimization operate in the spaceof distributions defined by the model, in which the parameter distance between two vectors is defined using theKL-divergence between the two corresponding distributions. Denoting DKL as our vector norm in this space,it can be shown that DKL(∆θ) ≈ 1

2∆θ>F∆θ, where F is the Fisher Information Matrix (FIM) defined as:

(2.3) F = E[∇θ log p(y|x, θ)∇θ log p(y|x, θ)>],

where the expectation is taken over both the model’s training data space and target variable space [19].The update rule for natural gradient descent then becomes θt+1 = θt − ηtF−1∇θ`(θt). As noted by [14] andothers, the FIM is often poorly-conditioned for DNNs, leading to unstable training. To counter this effect, adamping term is often added (i.e., the FIM is preconditioned). Using the preconditioned FIM, the update rulethen becomes:

(2.4) θt+1 = θt − ηt(F + λI)−1∇θ`(θt),

with λ denoting a positive damping parameter. Due to the computational intractability of the true Fishermatrix, natural gradient methods typically rely on approximations to F . For example, [19] proposes anapproximation, the K-FAC methog, exploiting the assumption that (i) F is largely block-diagonal1 and (ii)across the training distribution, the products of unit activations and products of unit output derivatives arestatistically independent. While these assumptions are inexact, their accuracy has been empirically verifiedby the authors in several cases. The approximation can be written as:

(2.5) Fi = E[Ai−1A>i−1 ⊗GiG>i ] ≈ E[Ai−1A

>i−1]⊗ E[GiG>i ],

where Fi represents the Fisher matrix of i-th layer, Gi is the gradient of the loss with respect to the i-th layeroutput before non-linear activation function, and Ai−1 is the activation output of the previous layer. Notethat this is the approximation form we use in our implementation.

1Although the K-FAC authors propose an alternative tridiagonal approximation that eases the strength of this assumption,we consider their block diagonal approximation of the Fisher, due to its demonstrated performance.

Page 3: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 3

2.2. Difficulties of Large Batch Training. The problems of large batch training under SGD havebeen studied in detail through both analytical and empirical studies. [2] proves that for convex cases, increasingbatch size by a factor f yields a no worse than a factor f speedup in the number of SGD iterations, so long asbatch size is below a critical point. The batch sizes falling below this critical point are referred to collectivelyas the linear scaling regime. [6] empirically investigates this in the context of non-convex training of DNNs fora variety of training workloads; and it finds evidence of a similar critical batch size for the non-convex case,before which f -fold increases of batch size yield f -fold reductions in total iterations needed to converge, andafter which diminishing returns are observed, eventually leading to stagnation and no further benefit.

Subsequent to [6], [8, 7] obtain broadly similar conclusions with more detailed studies. In particular, [7] goesfurther to predict the critical batch size to the nearest order of magnitude, demonstrating that critical batchsize can be predicted from the gradient noise scale, representing the noise-to-signal ratio of the stochasticestimation of the gradient. The authors further find that gradient noise scale increases during the course oftraining. This principle motivates the success of techniques as in [11, 25, 12], in which batch size is adaptivelyincreased during training.

Apart from increasing batch size during training, effort has been undertaken to increase critical batch sizeand linear scaling throughout the entire training process. [10] attempts to improve SGD scalability by tuninghyperparameters more carefully using a linear batch-size to learning-rate relationship. While this proveseffective for the authors’ training setup, [6] demonstrates that for a wide variety of other training workloads alinear scaling rule is ineffective to counter inefficiencies of large batch.[8] proposes that K-FAC may help to extend the linear scaling regime and increase the critical batch size

further than exhibited with SGD, allowing for greater scalability with large batches. Recent work has appliedK-FAC to large-batch training settings, as in training ResNet50 on ImageNet within 35 epochs [14], alongwith the development of a large-batch parallelized K-FAC implementation [18]. Both provide discussion ofthis, and both demonstrate “near-perfect” scaling behavior for large-batch K-FAC. [18] goes further in depthto suggest that K-FAC scalability is superior to SGD for high training losses and low batch sizes, for a singletraining workload. Our work more formally investigates the scalability of K-FAC versus SGD, and it findsevidence to the contrary; that is, for the workload and batch sizes we consider, K-FAC scalability is no betterthan that of SGD.

3. Experimental Setup. We investigate the performances of both K-FAC and SGD on CIFAR-10with ResNet20 and ResNet32, and on SVHN with AlexNet. To be comparable with state-of-art results forK-FAC [26], we apply batch normalization to the models along with standard data augmentation during thetraining process. We further regularize with a weight decay parameter of 5× 10−4. We perform extensivehyperparameter tuning individually for each batch size ranging from 128 to 16,384.

3.1. Training Budget and Learning Rate Schedule. In many scenarios, training within a hyperpa-rameter search is stopped after some fixed amount of time. When this time is specified in terms of number ofepochs, we call this a (normal) epoch budget. For our training of K-FAC and SGD, we use a modified versionof this stopping rule which we refer to as an adjusted epoch budget, in which the epoch limit of training isextended proportionally to the log of the batch size. Specifically, we use the rule: for CIFAR-10, number oftraining epochs equals (log2(batch size/128) + 1)× 100; for SVHN, it equals (log2(batch size/128) + 1)× 20.This adjusted schedule allows larger-batch training runs more of a chance to converge by affording them agreater number of iterations than would normally be allowed under a traditional epoch budget.For experiments on CIFAR-10, we decay the learning rate twice by a factor of ten for each run over the

course of training. These two learning rate decays separate the training process into three stages. Becausetraining extends to a greater number of epochs for large batches under the adjusted epoch budget, for largebatch runs we allow a proportionally greater number of epochs to pass before learning rate decay. For eachrun we therefore decay the learning rate at 40% and 80% of the total epochs2. We refer to this decay schemeas a scaled learning rate schedule. In Appendix A.1, we empirically validate our reasoning that our scaledlearning rate schedule (as opposed to a fixed learning rate schedule3) helps large-batch performance. Similarly,for SVHN, we also choose the scaled learning rate schedule and decay the learning rate by a factor of 5 at 50%of the total epochs.

2This schedule can loosely be regarded as a mixture of an epoch-driven schedule, as in [27], and an iteration-based schedule,as in [22].

3For our investigated fixed learning rate schedule, the two learning rate decays happen at epochs 40 and 80, regardless of thetotal number of epochs. A similar schedule is used in [27].

Page 4: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 4

3.2. Hyperparameter Tuning. For both K-FAC and SGD, in order to perform training, we must dealwith hyperparameters, and we describe this here.

K-FAC. For K-FAC, we use the various techniques discussed in [18, 14]. We precondition the Fishermatrix based on Eqn. (2.4) according to the methodology presented in [28, Appendix A.2]. Although [19]argues in favor of an alternative damping mechanism that approximates the damping of Eqn. (2.4), weobserved comparable performance using the normal approach. The details of these two methods and ourcomparison between them are laid out in the Appendix.For hyperparameter tuning of our CIFAR-10 experiments, we conduct a log-space grid search over 64

configurations with learning rates ranging from 10−3 to 2.187, and with damping ranging from 10−4 to 0.2187.Similarly, for our SVHN experiments, we conduct a log-space grid search over 64 configurations with learningrates ranging from 10−5 to 0.02187, and with damping ranging from 10−4 to 0.2187. The decay rate forsecond-order statistics is held constant at 0.9 throughout training. We use update clipping as in [18], with aconstant parameter of 0.1.

SGD. To ensure a fair comparison between methods, we employ a similarly extensive hyperparametertuning process for SGD. In particular, we conduct a similar log-space grid search over 64 hyperparameterconfigurations. For CIFAR-10 experiments, learning rates range from 0.05 to 9.62, and momentum rangefrom 0.9 to 0.999. For SVHN experiments, learning rates range from 0.005 to 0.962, and momentum rangefrom 0.9 to 0.999.

3.3. Speedup Ratio. We use speedup ratio [6] to measure the efficiency of large batch training based oniterations. We define the convergence rate kc(m) as the fewest number of iterations to reach a certain criteriac under the batch size m, where c is defined as attaining a target accuracy or loss threshold. Here, kc(m) isa minimum, as it is picked across all configurations of hyperparameters. We then define the speedup ratiosc(m;m0) as kc(m0)/kc(m), in which we rely on some small batch size m0 as our reference for convergencerate when comparing to larger batch sizes m > m0. In an ideal scenario, the batch size has no effect on theperformance increase per training observation, so in such cases sc(m;m0) = m

m0.

It should be noted that for K-FAC speedup we solely measure the number of iterations and ignore thecost of computing the inversion of the Fisher matrix. The latter can become very expensive, and multipleapproaches such as stochastic low-rank approximation and/or inexact iterative solves can be used. However,as we will show, K-FAC speedup is far from ideal, even when ignoring this cost of performing more exactcomputations.

4. Experimental Results. We perform extensive experiments on both CIFAR-10 and SVHN datasetswith both K-FAC and SGD. Section 4.1 compares the training and test performances of K-FAC and SGDresulting from extensive hyperparameter tuning for each batch size. Section 4.2 then discusses the large-batchscaling behaviors of K-FAC and SGD and compares them to the ideal scaling scenario [6, 8]. Finally, Section 4.3investigates the hyperparameter sensitivity of the K-FAC method.

4.1. Comparing Best Performance of K-FAC and SGD. We run K-FAC and SGD for multiplebatch sizes, with stopping conditions determined according to the adjusted epoch budget and the scaledlearning rate schedule. The highest test accuracy and the lowest training loss achieved for each batch sizeare plotted in Fig. 1. We include the detailed training trajectories over time for both SGD and K-FAC inAppendix A.3.

In this training context, K-FAC minimizes training loss most effectively for medium-sized batches (around210). Inspecting the training trajectories, we found that both the smallest (27) and largest batch sizes (214)needed more epochs/iterations to converge. For SGD however, training loss is minimized prominently at largerbatch sizes. When comparing training trajectories with K-FAC, we found that SGD made much more progressper-iteration in reducing loss, allowing it to minimize the objective with a smaller number of updates, asshown in Fig. 1. A more detailed comparison of the per-iteration progress of SGD versus K-FAC can be foundin the following section. For CIFAR-10 experiments, the gap between SGD and K-FAC in large-batch trainingloss is also present in their generalization performance. SGD’s greater efficiency in maximizing per-iterationaccuracy allows it to attain a higher level of test performance with the same number of training epochs.

4.2. Large-Batch Scalability of K-FAC and SGD. Training efficiency was measured for each batchsize in terms of iterations to a target training loss or test accuracy (kc(m)). The speed up versus batch sizerelations are displayed in Fig. 2. Dotted lines denote the ideal scaling relationship between batch size anditerations. We normalize each method-target line independently, dividing by the iterations at the smallestbatch size k(27) so that each of the dotted ideal lines is aligned in the plots, and we take the reciprocal to

Page 5: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 5

27 28 29 210 211 212 213 214

Batch Size0.90

0.91

0.92

0.93

0.94

0.95

Best

Tes

t Acc

urac

y

ResNet20: Best Test Accuracy vs. Batch SizeSGDK-FAC

27 28 29 210 211 212 213 214

Batch Size0.00

0.02

0.04

0.06

0.08

0.10

Best

Tra

inin

g Lo

ss

ResNet20: Best Training Loss vs. Batch SizeSGDK-FAC

27 28 29 210 211 212 213 214

Batch Size0.88

0.89

0.90

0.91

0.92

0.93

0.94

0.95

Best

Tes

t Acc

urac

yResNet32: Best Test Accuracy vs. Batch Size

SGDK-FAC

27 28 29 210 211 212 213 214

Batch Size0.00

0.02

0.04

0.06

0.08

0.10

Best

Tra

inin

g Lo

ss

ResNet32: Best Training Loss vs. Batch SizeSGDK-FAC

27 28 29 210 211 212 213 214

Batch Size0.9000.9050.9100.9150.9200.9250.9300.9350.940

Best

Tes

t Acc

urac

y

AlexNet: Best Test Accuracy vs. Batch SizeSGDK-FAC

27 28 29 210 211 212 213 214

Batch Size0.00

0.02

0.04

0.06

0.08

0.10

Best

Tra

inin

g Lo

ss

AlexNet: Best Training Loss vs. Batch SizeSGDK-FAC

Fig. 1: From top to bottom: Best test accuracy / training loss versus batch size for SGD and K-FACwith ResNet20 on CIFAR-10, ResNet32 on CIFAR-10, and AlexNet on SVHN, respectively. Large-batchK-FAC does not achieve higher accuracy or lower losses than large-batch SGD, given the same number oftraining epochs.

obtain the speedup function s(m; 27) = k(27)/k(m), where m is a given batch size. We also present theiterations-batch size relations to same targets in Appendix A.4. To ensure a fair comparison between batchsizes, similar to what is done in [6], we select target loss values as follows:• We wish to analyze how quickly using different batch sizes reaches a given threshold. However, not allthresholds are feasible, since large batch sizes may never reach a low training loss, whereas small batchesmay reach it easily. Thus, for speed up comparison purposes, we set thresholds such that all the batch sizescan reach it. We choose a selected run that belongs to the worst-performing batch size and method. Thisselection is made after loss-based hyperparameter tuning is finished.

• Then, for each training stage of the selected run4, we generate a training loss/test accuracy target, usingthe value that is linearly interpolated at 80% between the loss/accuracy values directly before and directlyafter the training stage.We use the resulting target values in Fig. 2. For both K-FAC and SGD, diminishing return effects are

present. In all examined cases, K-FAC deviates from ideal scaling (dotted lines) to a greater extent than SGD,as batch size increases. This difference explains why in Fig. 1 SGD is increasingly able to outperform K-FACfor large batches, given a fixed budget. We note that for both SGD and K-FAC, the linear scaling regime islargely nonexistent, particularly for the highest-performance targets; diminishing returns begin immediatelyfrom the smallest batch size. Fig. 2 provides evidence contrary to the conjectures of [8], which posit thatK-FAC may exhibit a larger regime of perfect scaling than SGD. The per-iteration training trajectoriessupporting Fig. 2 can be observed in Appendix A.3.

4.3. Hyperparameter Sensitivity of K-FAC. The hyperparameter tuning spaces on all three modelsfor K-FAC are laid out in Fig. 3, which relates the selected hyperparameters for damping and learning ratewith test accuracy achieved under the adjusted epoch budget. The complete list of heatmaps for K-FAChyperparameter tuning are provided in Appendix A.5. All heatmaps we observe demonstrate a consistenttrend in terms of:

4Learning rate decays separate the training process into stages as defined before.

Page 6: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 6

27 28 29 210 211 212 213 214

Batch Size

20

21

22

23

24

25

26

27

Spee

dup

Ratio

ResNet20: Speedup Ratio to Test AccuracyIdealK-FAC Acc. 0.60K-FAC Acc. 0.84K-FAC Acc. 0.89SGD Acc. 0.60SGD Acc. 0.84SGD Acc. 0.89

27 28 29 210 211 212 213 214

Batch Size

20

21

22

23

24

25

26

27

Spee

dup

Ratio

ResNet20: Speedup Ratio to Training LossIdealSGD Loss 0.92SGD Loss 0.28SGD Loss 0.12K-FAC Loss 0.92K-FAC Loss 0.28K-FAC Loss 0.12

27 28 29 210 211 212 213 214

Batch Size

20

21

22

23

24

25

26

27

Spee

dup

Ratio

ResNet32: Speedup Ratio to Test AccuracyIdealK-FAC Acc. 0.64K-FAC Acc. 0.87K-FAC Acc. 0.89SGD Acc. 0.64SGD Acc. 0.87SGD Acc. 0.89

27 28 29 210 211 212 213 214

Batch Size

20

21

22

23

24

25

26

27

Spee

dup

Ratio

ResNet32: Speedup Ratio to Training LossIdealSGD Loss 0.94SGD Loss 0.23SGD Loss 0.08K-FAC Loss 0.94K-FAC Loss 0.23K-FAC Loss 0.08

27 28 29 210 211 212 213 214

Batch Size

20

21

22

23

24

25

26

27

Spee

dup

Ratio

AlexNet: Speedup Ratio to Test AccuracyIdealK-FAC Acc. 0.71K-FAC Acc. 0.91SGD Acc. 0.71SGD Acc. 0.91

27 28 29 210 211 212 213 214

Batch Size

20

21

22

23

24

25

26

27

Spee

dup

Ratio

AlexNet: Speedup Ratio to Training LossIdealSGD Loss 0.61SGD Loss 0.12K-FAC Loss 0.61K-FAC Loss 0.12

Fig. 2: From top to bottom: speed up to a target training loss / test accuracy versus batch size for both SGDand K-FAC with ResNet20 on CIFAR-10, ResNet32 on CIFAR-10 and AlexNet on SVHN, respectively. Thediminishing returns effect can be seen to be more prominent in K-FAC (circles) than in SGD (triangles).

• A positive correlation between damping and learning rate.• A shrinking of the high-accuracy region with increasing batch size.The second point suggests a relationship between batch size and hyperparameter sensitivity for K-FAC, whichcan be measured in terms of the volume of hyperparameter space corresponding to successful training. Inevaluating hyperparameter sensitivity (or inversely robustness), we take the approach of [8], distinguishingbetween two types of robustness, each corresponding to a different definition of “successful training”:• Epoch-based robustness, in which success is defined by training to a desired accuracy or loss within a fixed

number of epochs.• Iteration-based robustness, in which success is defined by training to a desired accuracy or loss within a

fixed number of iterations.It is important to note that a set of hyperparameters considered to be acceptable to a practitioner under aniteration budget may at the same time be considered unacceptable to a practitioner operating under an epochbudget. It is for this reason that we make this distinction.

Through this lens, the robustness behavior of K-FAC with ResNet20 on CIFAR-10 is exhibited in Fig. 4. Wedescribe the robustness behavior of the other two models in Appendix A.6. Distributions of training accuracyacross hyperparameters are represented by box plots composed from the 64 hyperparameter configurations(8 damping parameters, 8 learning rate parameters). In Fig. 4a and 4d, we show the distributions of testaccuracies and training losses for each batch size at the end of training under the adjusted epoch budget.The greater spread of accuracies observed for larger batch sizes indicates that given this budget we shouldconsider batch sizes from 212 to 214 as more sensitive to hyperparameter tuning than batch sizes from 28 to211. Informally, if we draw a horizontal line at a desired test accuracy of, e.g., 0.8, then the batch sizes withboxplots containing the majority of their hyperparameter distribution above the 0.8 line should be favored as

Page 7: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 7

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.89 0.9 0.9 0.88 0.83 0.61 0.35 0.3

0.9 0.9 0.89 0.88 0.85 0.58 0.33 0.23

0.9 0.91 0.9 0.9 0.85 0.66 0.53 0.31

0.89 0.9 0.91 0.9 0.88 0.82 0.75 0.66

0.88 0.9 0.91 0.91 0.9 0.89 0.84 0.8

0.84 0.89 0.9 0.91 0.92 0.9 0.89 0.85

0.77 0.85 0.9 0.91 0.92 0.92 0.91 0.88

0.65 0.78 0.87 0.89 0.91 0.92 0.92 0.89

ResNet20: batch size 128

0.0

0.2

0.4

0.6

0.8

1.0

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.91 0.92 0.91 0.91 0.89 0.82 0.1 0.1

0.91 0.92 0.91 0.91 0.89 0.83 0.1 0.1

0.91 0.92 0.91 0.91 0.9 0.86 0.41 0.1

0.9 0.91 0.92 0.92 0.9 0.89 0.85 0.77

0.87 0.9 0.91 0.91 0.91 0.91 0.89 0.87

0.83 0.87 0.9 0.91 0.92 0.92 0.91 0.9

0.72 0.83 0.87 0.9 0.91 0.92 0.93 0.91

0.58 0.72 0.83 0.88 0.9 0.92 0.92 0.92

ResNet20: batch size 1024

0.0

0.2

0.4

0.6

0.8

1.0

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.82 0.8 0.7 0.8 0.65 0.45 0.36 0.14

0.82 0.81 0.8 0.69 0.68 0.43 0.36 0.19

0.82 0.82 0.78 0.69 0.76 0.59 0.53 0.25

0.79 0.82 0.8 0.81 0.75 0.75 0.51 0.55

0.71 0.8 0.83 0.82 0.83 0.83 0.83 0.84

0.57 0.72 0.8 0.82 0.83 0.86 0.89 0.89

0.42 0.58 0.72 0.8 0.84 0.86 0.88 0.9

0.32 0.42 0.58 0.72 0.8 0.85 0.88 0.9

ResNet20: batch size 16384

0.0

0.2

0.4

0.6

0.8

1.0

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.9 0.9 0.9 0.9 0.9 0.91 0.93 0.91

0.89 0.9 0.91 0.9 0.9 0.91 0.92 0.92

0.89 0.9 0.9 0.91 0.9 0.92 0.93 0.92

0.88 0.9 0.9 0.9 0.91 0.92 0.93 0.92

0.85 0.89 0.9 0.9 0.9 0.92 0.93 0.91

0.83 0.88 0.9 0.9 0.91 0.92 0.93 0.92

0.74 0.85 0.89 0.9 0.91 0.91 0.93 0.92

0.64 0.77 0.85 0.89 0.91 0.92 0.93 0.92

ResNet32: batch size 128

0.0

0.2

0.4

0.6

0.8

1.0

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187da

mpi

ng

0.9 0.9 0.89 0.86 0.8 0.86 0.9 0.89

0.9 0.9 0.9 0.88 0.82 0.86 0.91 0.9

0.88 0.9 0.9 0.89 0.84 0.88 0.91 0.92

0.87 0.89 0.9 0.89 0.87 0.89 0.92 0.93

0.84 0.88 0.9 0.9 0.9 0.9 0.92 0.93

0.8 0.85 0.89 0.9 0.9 0.9 0.92 0.93

0.69 0.81 0.85 0.89 0.9 0.91 0.92 0.93

0.58 0.71 0.82 0.86 0.89 0.91 0.92 0.93

ResNet32: batch size 1024

0.0

0.2

0.4

0.6

0.8

1.0

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.71 0.4 0.27 0.37 0.32 0.32 0.28 0.25

0.79 0.64 0.31 0.43 0.39 0.43 0.26 0.24

0.82 0.74 0.69 0.32 0.52 0.48 0.45 0.37

0.81 0.8 0.72 0.67 0.52 0.65 0.56 0.64

0.71 0.78 0.77 0.78 0.72 0.68 0.74 0.78

0.59 0.71 0.77 0.79 0.77 0.74 0.83 0.89

0.44 0.59 0.7 0.77 0.82 0.82 0.85 0.89

0.33 0.44 0.59 0.7 0.78 0.83 0.87 0.9

ResNet32: batch size 16384

0.0

0.2

0.4

0.6

0.8

1.0

1e-0

5

3e-0

5

9e-0

5

0.00

027

0.00

081

0.00

243

0.00

729

0.02

187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.9 0.92 0.91 0.88 0.11 0.19 0.078 0.19

0.88 0.92 0.91 0.91 0.89 0.093 0.068 0.077

0.84 0.91 0.91 0.91 0.91 0.91 0.2 0.2

0.75 0.88 0.91 0.91 0.91 0.92 0.9 0.62

0.53 0.84 0.9 0.9 0.9 0.92 0.92 0.92

0.25 0.73 0.87 0.9 0.89 0.9 0.92 0.93

0.2 0.38 0.81 0.88 0.89 0.89 0.92 0.92

0.2 0.2 0.52 0.84 0.89 0.89 0.9 0.92

AlexNet: batch size 128

0.0

0.2

0.4

0.6

0.8

1.0

1e-0

5

3e-0

5

9e-0

5

0.00

027

0.00

081

0.00

243

0.00

729

0.02

187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.91 0.062 0.84 0.097 0.064 0.2 0.16 0.087

0.89 0.9 0.87 0.85 0.093 0.2 0.16 0.11

0.86 0.9 0.89 0.9 0.9 0.88 0.2 0.2

0.8 0.88 0.89 0.88 0.9 0.9 0.91 0.9

0.58 0.85 0.89 0.88 0.88 0.91 0.92 0.92

0.22 0.71 0.86 0.89 0.88 0.89 0.91 0.91

0.2 0.27 0.77 0.87 0.89 0.88 0.9 0.91

0.2 0.2 0.31 0.8 0.88 0.89 0.89 0.9

AlexNet: batch size 1024

0.0

0.2

0.4

0.6

0.8

1.0

1e-0

5

3e-0

5

9e-0

5

0.00

027

0.00

081

0.00

243

0.00

729

0.02

187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.89 0.91 0.91 0.91 0.2 0.11 0.064 0.1

0.86 0.89 0.91 0.91 0.9 0.23 0.2 0.092

0.76 0.87 0.89 0.91 0.91 0.9 0.21 0.2

0.3 0.77 0.87 0.89 0.91 0.9 0.89 0.2

0.2 0.31 0.77 0.87 0.89 0.91 0.9 0.89

0.2 0.2 0.31 0.77 0.87 0.89 0.9 0.91

0.18 0.2 0.2 0.31 0.77 0.86 0.89 0.9

0.12 0.18 0.2 0.2 0.31 0.77 0.86 0.89

AlexNet: batch size 16384

0.0

0.2

0.4

0.6

0.8

1.0

Fig. 3: From top to bottom: accuracy at end of training under adjusted-epoch budget versus damping andlearning rate for batch sizes 128, 1,024 and 16,384 for CIFAR-10 with ResNet20, CIFAR-10 with ResNet-32,SVHN with AlexNet, respectively. A positive correlation between damping and learning rate is exhibited, aswell as a shrinking of the high-accuracy region for large batch sizes.

being more robust.We can simulate stopping of training in terms of epochs and iterations to extract insight about robustness

for other types of budgets than our own. Regardless of the stopping criteria, we expect that longer training willyield greater robustness (although at the cost of significantly higher computational/budget overhead). Figure 4band 4e shows how the hyperparameter robustness of different batch sizes changes as a function of stopping epoch.Each group along the X-axis corresponds to a hypothetical epoch budget. The relationship demonstrates thatfor K-FAC, robustness increases with amount of training, but more interestingly it decreases with batch size.This can be observed by noting that for any fixed epoch, the distributions of accuracies corresponding to largerbatch sizes fall lower than their smaller-batch counterparts, meaning fewer hyperparameter configurationswill fall above a desired accuracy threshold. A similar robustness trend is observed with the distributionsof training losses. We perform a similar analysis for iteration budgets in Fig. 4c and 4f, and we find asexpected that robustness increases with training. Unlike the case of an epoch budget however, we find that foriteration budgets robustness increases with larger batch size. This is observed by noting that the distributionscorresponding to large batch are more concentrated towards higher accuracy and lower loss, although theeffect is not as pronounced as in the fixed epoch case.Together, the results show that (i): epoch-based robustness is inversely related to batch size, and (ii):

iteration-based robustness is directly related to batch size. This is analogous to the findings of [8] for SGD.Further work may seek to compare the nature of iteration-based or epoch-based robustness between the twomethods in more detail.

5. Conclusions. Through extensive experimentation and training on both CIFAR-10 and SVHN datasets,we find that K-FAC exhibits similar diminishing returns with large batch training as SGD. In our results,K-FAC has comparable but not better training or testing performance than SGD given the same level oftraining and tuning. Comparing the scalability behaviors of the two methods, we find that K-FAC exhibits a

Page 8: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 8

27 28 29 210 211 212 213 214

Batch Size

0.2

0.3

0.4

0.5

0.6

0.7

0.8

0.9Te

st A

ccur

acy

ResNet20, K-FAC: Accuracy vs BS

(a)

100 200 300 400 500 600 700Epoch

0.0

0.2

0.4

0.6

0.8

1.0

Test

Acc

urac

y

ResNet20, K-FAC: Accuracy vs. Epoch

Batch Size2561,0244,09616,384

(b)

1.0k 1.5k 2.0k 2.5k 3.0k 3.5k 4.0k 4.5kIterations

0.0

0.2

0.4

0.6

0.8

1.0

Test

Acc

urac

y

ResNet20, K-FAC: Accuracy vs Iterations

Batch Size4,0968,19216,384

(c)

27 28 29 210 211 212 213 214

Batch Size

0.0

0.5

1.0

1.5

2.0

2.5

Trai

ning

Los

s

ResNet20, K-FAC: Training Loss vs BS

(d)

100 200 300 400 500 600 700Epoch

0

1

2

3

4

5

Trai

ning

Los

sResNet20, K-FAC: Training Loss vs. Epoch

Batch Size16,3844,0961,024256

(e)

1.0k 1.5k 2.0k 2.5k 3.0k 3.5k 4.0k 4.5kIterations

0

1

2

3

4

5

Trai

ning

Los

s

ResNet20, K-FAC: Training Loss vs IterationsBatch Size

4,0968,19216,384

(f)

Fig. 4: (a)(d): Test accuracy / training loss distribution vs. batch size for K-FAC at the end of trainingunder an adjusted epoch budget. Larger batch sizes result in lower accuracies and higher training lossesthat are more sensitive to hyperparameter choice. (b)(e): Comparison of test accuracy / training lossdistributions over various epochs. Smaller batch sizes provide better solutions that are less sensitive to choiceof hyperparameters. (c)(f): Comparison of test accuracy / training loss distributions over various iterationnumbers. Large batch sizes exhibit a trend of providing better solutions that are less sensitive to choice ofhyperparameters.

smaller regime of ideal scaling than SGD, suggesting that K-FAC’s scalability to large batch training is nobetter than SGD’s. Finally, we find that K-FAC exhibits a similar relationship between budget and robustnessas SGD, in which K-FAC is less robust to tuning under epoch budgets, but more robust to tuning underiteration budgets, mirroring the findings of similar work in literature for SGD [8]. Taken as a whole, ourresults suggest that, although K-FAC has been applied to large batch training scenarios, it encounters thesame large-batch issues to an equal or greater extent as SGD. It remains to be seen whether other variantsof sub-sampled Newton methods [29, 30] can lead to improved results or whether this is a more ubiquitousaspect of stochastic optimization algorithms applied to non-convex optimization problems of interest inmachine learning.

Acknowledgments. This work was supported by a gracious fund from Intel corporation. We would liketo thank the Intel VLAB team for providing us with access to their computing cluster. We also gratefullyacknowledge the support of NVIDIA Corporation for their donation of the Titan Xp GPU used for thisresearch. MWM would also like to acknowledge ARO, DARPA, NSF, ONR, and Intel for providing partialsupport of this work.

Page 9: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 9

REFERENCES

[1] A. Gholami, A. Azad, P. Jin, K. Keutzer, and A. Buluc, “Integrated model, batch and domain parallelism in training neuralnetworks,” ACM Symposium on Parallelism in Algorithms and Architectures(SPAA’18), 2018. [PDF].

[2] S. Ma, R. Bassily, and M. Belkin, “The power of interpolation: Understanding the effectiveness of SGD in modernover-parametrized learning,” arXiv preprint arXiv:1712.06559, 2017.

[3] C. H. Martin and M. W. Mahoney, “Implicit self-regularization in deep neural networks: Evidence from random matrixtheory and implications for learning,” Tech. Rep. Preprint: arXiv:1810.01075, 2018.

[4] C. H. Martin and M. W. Mahoney, “Traditional and heavy-tailed self regularization in neural network models,” in Proceedingsof the 36th International Conference on Machine Learning, pp. 4284–4293, 2019.

[5] Z. Yao, A. Gholami, Q. Lei, K. Keutzer, and M. W. Mahoney, “Hessian-based analysis of large batch training and robustnessto adversaries,” arXiv preprint arXiv:1802.08241, 2018.

[6] N. Golmant, N. Vemuri, Z. Yao, V. Feinberg, A. Gholami, K. Rothauge, M. W. Mahoney, and J. Gonzalez, “On thecomputational inefficiency of large batch sizes for stochastic gradient descent,” CoRR, vol. abs/1811.12941, 2018.

[7] S. McCandlish, J. Kaplan, D. Amodei, and O. D. Team, “An empirical model of large-batch training,” CoRR,vol. abs/1812.06162, 2018.

[8] C. J. Shallue, J. Lee, J. M. Antognini, J. Sohl-Dickstein, R. Frostig, and G. E. Dahl, “Measuring the effects of dataparallelism on neural network training,” CoRR, vol. abs/1811.03600, 2018.

[9] N. S. Keskar, D. Mudigere, J. Nocedal, M. Smelyanskiy, and P. T. P. Tang, “On large-batch training for deep learning:Generalization gap and sharp minima,” arXiv preprint arXiv:1609.04836, 2016.

[10] P. Goyal, P. Dollár, R. B. Girshick, P. Noordhuis, L. Wesolowski, A. Kyrola, A. Tulloch, Y. Jia, and K. He, “Accurate, largeminibatch SGD: Training imagenet in 1 hour,” CoRR, vol. abs/1706.02677, 2017.

[11] S. L. Smith, P.-J. Kindermans, C. Ying, and Q. V. Le, “Don’t decay the learning rate, increase the batch size,” arXivpreprint arXiv:1711.00489, 2017.

[12] Z. Yao, A. Gholami, K. Keutzer, and M. W. Mahoney, “Large batch size training of neural networks with adversarialtraining and second-order information,” arXiv preprint arXiv:1810.01021, 2018.

[13] N. Mu, Z. Yao, A. Gholami, K. Keutzer, and M. W. Mahoney, “Parameter re-initialization through cyclical batch sizeschedules,” arXiv preprint arXiv:1812.01216, 2018.

[14] K. Osawa, Y. Tsuji, Y. Ueno, A. Naruse, R. Yokota, and S. Matsuoka, “Second-order optimization method for largemini-batch: Training resnet-50 on imagenet in 35 epochs,” arXiv preprint arXiv:1811.12019, 2018.

[15] B. Ginsburg, I. Gitman, and Y. You, “Large batch training of convolutional networks with layer-wise adaptive rate scaling,”2018.

[16] Y. You, Z. Zhang, C. Hsieh, and J. Demmel, “100-epoch imagenet training with alexnet in 24 minutes,” CoRR,vol. abs/1709.05011, 2017.

[17] X. Jia, S. Song, W. He, Y. Wang, H. Rong, F. Zhou, L. Xie, Z. Guo, Y. Yang, L. Yu, et al., “Highly scalable deep learningtraining system with mixed-precision: Training imagenet in four minutes,” arXiv preprint arXiv:1807.11205, 2018.

[18] J. Ba, R. Grosse, and J. Martens, “Distributed second-order optimization using Kronecker-factored Approximations,” inInternational Conference on Learning Representations, 2017.

[19] J. Martens and R. B. Grosse, “Optimizing neural networks with Kronecker-factored Approximate Curvature,” CoRR,vol. abs/1503.05671, 2015.

[20] A. Krizhevsky and G. Hinton, “Learning multiple layers of features from tiny images,” tech. rep., Citeseer, 2009.[21] Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, and A. Y. Ng, “Reading digits in natural images with unsupervised

feature learning,” 2011.[22] K. He, X. Zhang, S. Ren, and J. Sun, “Deep residual learning for image recognition,” in Proceedings of the IEEE conference

on computer vision and pattern recognition, pp. 770–778, 2016.[23] A. Krizhevsky, I. Sutskever, and G. E. Hinton, “Imagenet classification with deep convolutional neural networks,” in Advances

in neural information processing systems, pp. 1097–1105, 2012.[24] S.-I. Amari, “Natural gradient works efficiently in learning,” Neural computation, vol. 10, no. 2, pp. 251–276, 1998.[25] A. Devarakonda, M. Naumov, and M. Garland, “Adabatch: Adaptive batch sizes for training deep neural networks,” arXiv

preprint arXiv:1712.02029, 2017.[26] G. Zhang, C. Wang, B. Xu, and R. Grosse, “Three mechanisms of weight decay regularization,” in International Conference

on Learning Representations, 2019.[27] E. Hoffer, I. Hubara, and D. Soudry, “Train longer, generalize better: Closing the generalization gap in large batch training

of neural networks,” in Advances in Neural Information Processing Systems, pp. 1731–1741, 2017.[28] R. Grosse and J. Martens, “A Kronecker-factored Approximate Fisher matrix for convolution layers,” in International

Conference on Machine Learning, pp. 573–582, 2016.[29] F. Roosta-Khorasani and M. W. Mahoney, “Sub-sampled Newton methods I: Globally convergent algorithms,” tech. rep.,

2016. Preprint: arXiv:1601.04737.[30] F. Roosta-Khorasani and M. W. Mahoney, “Sub-sampled Newton methods II: Local convergence rates,” tech. rep., 2016.

Preprint: arXiv:1601.04738.

Page 10: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 10

Appendix A. Appendix.In this appendix, we provide additional details about our main results.

A.1. Scaled Learning Rate Decay Schedule. Here, we compare training trajectories under fixedand scaled learning rate schedules to provide validation of our experimental setup. We evaluate the scaledlearning rate schedule discussed in Section 3 against a fixed learning rate schedule using ResNet20 on theCIFAR-10 dataset. Extensive grid search is applied for both schedules, giving rise to best-performing runswhich are illustrated in Fig. 5.

In Fig. 5a and 5c, we plot the highest-accuracy and lowest-loss training runs (chosen across the allhyperparameter configurations) for fixed learning rate schedule versus scaled learning rate schedule for SGD.Across all batch sizes, scaled learning rate schedule training gives rise to a higher final test accuracy and lowerfinal training loss. For large batch sizes (e.g., 16K), the difference is more pronounced, with fixed learning rateschedule SGD making much slower progress than the scaled learning rate schedule counterpart. In particular,for the 16K batch size with fixed learning rate schedule, after 80 epochs it can be observed that the earlydecay severely slows the previously-rapid progress since the learning rate becomes too small. As a result thefixed-scaling large batch SGD slows its climb in accuracy, while the scaled learning rate schedule variantsurges ahead.

In Fig. 5b and 5d, we plot the analogous runs for K-FAC, with decays following a fixed versus scaled learningrate schedule. Similarly, scaled learning rate schedule K-FAC demonstrates higher end-of-training accuracyand lower loss across all batch sizes. Based on these observations, we argue that our experimental setup isgeared towards boosting large-batch scalability through its use of a scaled schedule for learning rate decay.

A.2. Normal versus Approximated Damping. Here, we explain the differences between two pre-conditioning mechanisms for K-FAC: normal damping, as written in Eqn. (2.4); and approximated damping,as recommended by [19]. We also provide an empirical evaluation of the two methods, and we observe thatapproximate damping exhibits hyperparameter sensitivity that is comparable to normal damping. First,we briefly discuss how to calculate directly the inverse of the preconditioned Fisher Information Matrix(preconditioned FIM), denoted (F + λI)−1. The preconditioned FIM block can be calculated as

(Fi + λI)−1 = [(QADAQ>A)⊗ (QGDGQ

>G) + λI]−1

= [(QA ⊗QG)(DA ⊗DG + λI)(Q>A ⊗Q>G)]−1

= (Q>A ⊗Q>G)(DA ⊗DG + λI)−1(QA ⊗QG),

where QADAQ>A is the eigendecomposition of E[Ai−1A

>i−1] and QGDGQ

>G is the eigendecomposition for

E[GiG>i ]. A similar derivation can be found in Appendix A.2 in [28]. In our paper, we call the use of this

damping normal damping.For efficiency reasons, [19] proposes an alternative calculation to alleviate the burden of eigendecomposition

by preconditioning on the Kronecker-factored Fisher blocks first:

(Fi + λI)−1 ≈ (E[Ai−1A>i−1] +

√λI)−1 ⊗ (E[GiG

>i ] +

√λI)−1,

where λ is a compound term that allows for complicated maneuvers of adaptive regularization (detailed inSection 6.2 [19]). In our paper, we call this damping formulation approximated damping.

We compare normal damping and approximated damping for K-FAC with ResNet20 on CIFAR-10, examiningtheir effects for batch sizes 128 and 8,192. In Fig. 6 we compare the accuracy and training loss distributionsof normal damping and approximated damping for batch sizes 128 and 8,192. For both damping methods,training appears to be more sensitive to hyperparameter tuning under large batch (batch size 8,192). Moreinterestingly, we observe that the distributions generated with normal damping tend to be more robust. We cansee, for instance, that for batch size 8,192, approximated damping has a high concentration of hyperparameterconfigurations yielding favorable values of loss and accuracy.

A.3. Training Loss/Test Accuracy Trajectories. Here, we show the detailed training curves forK-FAC and SGD on CIFAR-10 with ResNet20. We plot the training loss and test accuracy over time for theoptimal runs of each batch size for both SGD and K-FAC to support Fig. 1 and 2 in the main text. Time ismeasured in terms of both epochs (Fig. 7) and iterations (Fig. 8).

The best performance of each curve in Fig. 7 helps to explain the shape of Fig. 1, in which we plot the besttest accuracy and training loss for each batch size for SGD and K-FAC. The best achieved accuracy can beseen to drop for both K-FAC and SGD beyond a certain batch size. The lowest loss is achieved at higher

Page 11: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 11

0 100 200 300 400 500 600 700 800Epoch

0.55

0.60

0.65

0.70

0.75

0.80

0.85

0.90

0.95

Test

Acc

urac

y

SGD: Test Accuracy vs. Epoch

Fixed Schedule2561,0244,09616,384Scaling Schedule2561,0244,09616,384

(a)

0 100 200 300 400 500 600 700 800Epoch

0.2

0.3

0.4

0.5

0.6

0.7

0.8

0.9

Test

Acc

urac

y

K-FAC: Test Accuracy vs. Epoch

Fixed Schedule2561,0244,09616,384Scaling Schedule2561,0244,09616,384

(b)

0 100 200 300 400 500 600 700 800Epoch

0.0

0.2

0.4

0.6

0.8

1.0

1.2

Trai

ning

Los

s

SGD: Training Loss vs. EpochFixed Schedule2561,0244,09616,384Scaling Schedule2561,0244,09616,384

(c)

0 100 200 300 400 500 600 700 800Epoch

0.0

0.5

1.0

1.5

2.0

2.5

Trai

ning

Los

s

K-FAC: Training Loss vs. EpochFixed Schedule2561,0244,09616,384Scaling Schedule2561,0244,09616,384

(d)

Fig. 5: Learning rate schedule comparison for ResNet20 on CIFAR-10. (a)(b): Test accuracy versus epochfor SGD and K-FAC on a fixed learning rate schedule (dashed) and scaled learning rate schedule (solid).(c)(d): Training loss versus epoch for SGD and K-FAC on a fixed learning rate schedule (dashed) and scaledlearning rate schedule (solid). For each batch size and schedule, other hyperparameters are chosen separatelyto maximize test accuracy or minimize loss. The scaled learning rate schedule significantly favors large batchsizes.

batch sizes for SGD than for K-FAC. Fig. 8 gives clues regarding the large-batch training speedup of bothK-FAC and SGD. We can see from the figure that as batch size grows, fewer iterations are necessary to reachspecific performance targets, represented by horizontal dotted lines. In the ideal scaling case, the intersectionsof training loss curves for each batch size would fall along the dotted target lines in the pattern of a geometricsequence with common ratio 1/2.

A.4. Large Batch Scalabilities: Iterations to Target. Here, we present the iterations-batch sizerelations to support our results in Section 4.2. Fig. 8 gives clues regarding the large-batch training speedup ofboth K-FAC and SGD. We can see from the figure that as batch size grows, fewer iterations are necessaryto reach specific performance targets, represented by horizontal dotted lines. In the ideal scaling case, theintersections of training loss curves for each batch size would fall along the dotted target lines in the patternof a geometric sequence with common ratio 1/2.

Iterations-to-target for each batch size are plotted directly in Fig. 9. Dotted lines denote the ideal scalingrelationship between batch size and iterations. The large-batch scalability behavior is captured in the slope of

Page 12: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 12

128 8192Batch Size

0.0

0.2

0.4

0.6

0.8

1.0

Accu

racy

K-FAC Damping Comparison: Accuracy Distribution vs. Batch Size

MethodNormal DampingApproximated Damping

(a)

128 8192Batch Size

0.0

0.5

1.0

1.5

2.0

2.5

3.0

Trai

ning

Los

s

K-FAC Damping Comparison: Training Loss Distribution vs. Batch SizeMethod

Normal DampingApproximated Damping

(b)

Fig. 6: (a): Accuracy distribution versus batch size for both normal damping and approximated damping.(b): Training loss distribution versus batch size for both normal damping and approximated damping. Alldistributions are formed over the 64 runs of the hyperparameter grid search over damping and learning rate.

each line in Fig. 9. For both K-FAC and SGD, diminishing return effects are present. In all examined cases,K-FAC deviates from ideal scaling (dotted lines) to a greater extent than SGD as batch size increases. Thisdifference explains why, in Fig. 1, SGD is increasingly able to outperform K-FAC for large batches given afixed budget.

A.5. K-FAC Heatmaps. Here, we show a complete list of heatmaps. Fig. 10 displays the completeaccuracy heatmaps for experiments on CIFAR-10 with ResNet20. For each hyperparameter configuration,training was run until terminated according to the adjusted epoch budget. We can observe a shrinkage ofthe high-accuracy region when the batch size exceeds 4,096. In addition, we extend the hyperparametertuning spaces for both ResNet20 and ResNet32 experiments under batch size 16,384, and the shrinkage of thehigh-accuracy region still exists, as is shown in Fig. 11. This strengthens the argument that K-FAC is moresensitive to hyperparameters under large batch sizes.

A.6. Hyperparameter Sensitivity of K-FAC on Other Models and Datasets. Here, we showthe hyperparameter sensitivity behavior of K-FAC on CIFAR-10 with ResNet32 and on SVHN with AlexNetto strengthen the argument that K-FAC is more sensitive to hyperparameter tuning under large batch sizes.Fig. 12 exhibits training loss and test accuracy distributions at the end of training by box plots, where eachbox contains 64 hyperparameter configurations. Fig. 12a shows the results over 3 batch sizes (27, 210, 214) onCIFAR-10 with ResNet 32. Fig. 12b shows the same range of batch sizes on SVHN with AlexNet. We can seean increase of hyperparameter sensitivity as the batch size increases.

Page 13: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 13

0 100 200 300 400 500 600 700 800Epoch

0.5

0.6

0.7

0.8

0.9

1.0

Test

Acc

urac

y

K-FAC: Test Accuracy vs. Epoch

Batch Size1282565121,0242,0484,0968,19216,384

(a)

0 100 200 300 400 500 600 700 800Epoch

0.0

0.2

0.4

0.6

0.8

1.0

1.2

Trai

ning

Los

s

K-FAC: Training Loss vs. Epoch

Batch Size1282565121,0242,0484,0968,19216,384

(b)

0 100 200 300 400 500 600 700 800Epoch

0.5

0.6

0.7

0.8

0.9

1.0

Test

Acc

urac

y

SGD: Test Accuracy vs. Epoch

Batch Size1282565121,0242,0484,0968,19216,384

(c)

0 100 200 300 400 500 600 700 800Epoch

0.0

0.2

0.4

0.6

0.8

1.0

1.2

Trai

ning

Los

s

SGD: Training Loss vs. Epoch

Batch Size1282565121,0242,0484,0968,19216,384

(d)

Fig. 7: (a)(c): Test accuracy versus epoch for the accuracy-maximizing runs of each batch-size, with K-FACabove and SGD below. (b)(d): Training loss versus epoch for the loss-minimizing runs of each batch-size, withK-FAC above and SGD below. For each plot, the star denotes the best (maximal accuracy or minimal loss)performance achieved. Horizontal dotted lines show the target values for accuracy and loss used in Fig. 9and Fig. 2.

Page 14: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 14

0k 5k 10k 15k 20k 25k 30k 35k 40kIterations

0.70

0.75

0.80

0.85

0.90

0.95

1.00

Test

Acc

urac

y

K-FAC: Test Accuracy vs. Iterations

Batch Size1282565121,0242,0484,0968,19216,384

(a)

0k 5k 10k 15k 20k 25k 30k 35k 40kIterations

0.0

0.2

0.4

0.6

0.8

1.0

1.2

1.4

1.6

Trai

ning

Los

s

K-FAC: Training Loss vs. IterationsBatch Size1282565121,0242,0484,0968,19216,384

(b)

0k 5k 10k 15k 20k 25k 30k 35k 40kIterations

0.70

0.75

0.80

0.85

0.90

0.95

1.00

Test

Acc

urac

y

SGD: Test Accuracy vs. Iterations

Batch Size1282565121,0242,0484,0968,19216,384

(c)

0k 5k 10k 15k 20k 25k 30k 35k 40kIterations

0.0

0.2

0.4

0.6

0.8

1.0

1.2

1.4

1.6

Trai

ning

loss

SGD: Training Loss vs. IterationsBatch Size1282565121,0242,0484,0968,19216,384

(d)

Fig. 8: (a)(c): Test accuracy versus iteration for the accuracy-maximizing runs of each batch-size, withK-FAC above and SGD below. (b)(d): Training loss versus iteration for the accuracy-maximizing runs of eachbatch-size, with K-FAC above and SGD below. For each plot, the star denotes the best performance achieved.Horizontal dotted lines show the target values for accuracy and loss used in Fig. 9 and Fig. 2.

Page 15: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 15

27 28 29 210 211 212 213 214

Batch Size

102

103

104

105

Itera

tions

to T

arge

t

ResNet20: Iterations to Testing AccuracySGD Acc. 0.60SGD Acc. 0.84SGD Acc. 0.89K-FAC Acc. 0.60K-FAC Acc. 0.84K-FAC Acc. 0.89

27 28 29 210 211 212 213 214

Batch Size

102

103

104

105

Itera

tions

to T

arge

t

ResNet20: Iterations to Training LossSGD Loss 0.92SGD Loss 0.28SGD Loss 0.12K-FAC Loss 0.92K-FAC Loss 0.28K-FAC Loss 0.12

27 28 29 210 211 212 213 214

Batch Size

102

103

104

105

Itera

tions

to T

arge

t

ResNet32: Iterations to Testing AccuracySGD Acc. 0.64SGD Acc. 0.87SGD Acc. 0.89K-FAC Acc. 0.64K-FAC Acc. 0.87K-FAC Acc. 0.89

27 28 29 210 211 212 213 214

Batch Size

102

103

104

105Ite

ratio

ns to

Tar

get

ResNet32: Iterations to Training LossSGD Loss 0.94SGD Loss 0.23SGD Loss 0.08K-FAC Loss 0.94K-FAC Loss 0.23K-FAC Loss 0.08

27 28 29 210 211 212 213 214

Batch Size101

102

103

104

Itera

tions

to T

arge

t

AlexNet: Iterations to Testing AccuracySGD Acc. 0.71SGD Acc. 0.91K-FAC Acc. 0.71K-FAC Acc. 0.91

27 28 29 210 211 212 213 214

Batch Size

102

103

104

Itera

tions

to T

arge

t

AlexNet: Iterations to Training LossSGD Loss 0.61SGD Loss 0.12K-FAC Loss 0.61K-FAC Loss 0.12

Fig. 9: From top to bottom: Iterations to a target training loss / test accuracy versus batch size for bothSGD and K-FAC with ResNet20 on CIFAR-10, ResNet32 on CIFAR-10 and AlexNet on SVHN, respectively.The ideal scaling result that relates batch-size to convergence iterations is plotted for each method / target asa dotted line. Note that K-FAC has no better scalability than SGD.

Page 16: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 16

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.89 0.9 0.9 0.88 0.83 0.61 0.35 0.3

0.9 0.9 0.89 0.88 0.85 0.58 0.33 0.23

0.9 0.91 0.9 0.9 0.85 0.66 0.53 0.31

0.89 0.9 0.91 0.9 0.88 0.82 0.75 0.66

0.88 0.9 0.91 0.91 0.9 0.89 0.84 0.8

0.84 0.89 0.9 0.91 0.92 0.9 0.89 0.85

0.77 0.85 0.9 0.91 0.92 0.92 0.91 0.88

0.65 0.78 0.87 0.89 0.91 0.92 0.92 0.89

ResNet20: batch size 128

0.0

0.2

0.4

0.6

0.8

1.0

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.91 0.91 0.9 0.9 0.87 0.43 0.24 0.1

0.9 0.92 0.91 0.9 0.88 0.48 0.23 0.2

0.91 0.91 0.91 0.91 0.89 0.72 0.32 0.1

0.9 0.92 0.92 0.91 0.9 0.87 0.83 0.73

0.88 0.91 0.92 0.92 0.91 0.9 0.87 0.83

0.85 0.89 0.91 0.92 0.92 0.92 0.9 0.88

0.78 0.86 0.89 0.92 0.92 0.93 0.92 0.9

0.66 0.8 0.86 0.9 0.92 0.93 0.93 0.9

ResNet20: batch size 256

0.0

0.2

0.4

0.6

0.8

1.0

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.92 0.92 0.91 0.91 0.88 0.79 0.1 0.1

0.91 0.92 0.91 0.9 0.89 0.79 0.1 0.1

0.91 0.92 0.92 0.92 0.9 0.69 0.28 0.1

0.9 0.91 0.92 0.92 0.9 0.89 0.84 0.8

0.88 0.91 0.92 0.92 0.91 0.91 0.89 0.84

0.86 0.89 0.91 0.92 0.92 0.91 0.91 0.89

0.77 0.85 0.89 0.91 0.92 0.93 0.92 0.91

0.64 0.77 0.85 0.89 0.91 0.92 0.93 0.92

ResNet20: batch size 512

0.0

0.2

0.4

0.6

0.8

1.0

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.91 0.92 0.91 0.91 0.89 0.82 0.1 0.1

0.91 0.92 0.91 0.91 0.89 0.83 0.1 0.1

0.91 0.92 0.91 0.91 0.9 0.86 0.41 0.1

0.9 0.91 0.92 0.92 0.9 0.89 0.85 0.77

0.87 0.9 0.91 0.91 0.91 0.91 0.89 0.87

0.83 0.87 0.9 0.91 0.92 0.92 0.91 0.9

0.72 0.83 0.87 0.9 0.91 0.92 0.93 0.91

0.58 0.72 0.83 0.88 0.9 0.92 0.92 0.92

ResNet20: batch size 1024

0.0

0.2

0.4

0.6

0.8

1.0

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.91 0.91 0.9 0.91 0.89 0.78 0.42 0.17

0.91 0.91 0.91 0.91 0.9 0.81 0.71 0.17

0.9 0.91 0.91 0.9 0.9 0.88 0.77 0.25

0.89 0.91 0.91 0.91 0.91 0.9 0.85 0.79

0.86 0.89 0.91 0.91 0.91 0.9 0.89 0.88

0.79 0.86 0.89 0.9 0.92 0.92 0.92 0.89

0.66 0.8 0.86 0.89 0.91 0.92 0.92 0.92

0.52 0.67 0.79 0.86 0.89 0.91 0.92 0.92

ResNet20: batch size 2048

0.0

0.2

0.4

0.6

0.8

1.0

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.9 0.9 0.9 0.9 0.87 0.8 0.7 0.16

0.9 0.9 0.9 0.9 0.88 0.79 0.72 0.21

0.9 0.9 0.9 0.9 0.91 0.85 0.73 0.41

0.87 0.9 0.91 0.9 0.9 0.9 0.83 0.79

0.83 0.88 0.89 0.91 0.9 0.91 0.9 0.89

0.74 0.83 0.88 0.9 0.91 0.91 0.92 0.91

0.59 0.73 0.83 0.88 0.9 0.91 0.92 0.92

0.43 0.6 0.74 0.84 0.88 0.9 0.92 0.92

ResNet20: batch size 4096

0.0

0.2

0.4

0.6

0.8

1.0

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.89 0.87 0.88 0.88 0.8 0.61 0.63 0.13

0.88 0.87 0.86 0.86 0.82 0.61 0.65 0.16

0.87 0.88 0.86 0.85 0.87 0.72 0.62 0.28

0.85 0.88 0.88 0.88 0.87 0.86 0.72 0.74

0.79 0.85 0.88 0.88 0.89 0.89 0.9 0.87

0.66 0.78 0.85 0.87 0.89 0.9 0.91 0.9

0.51 0.67 0.79 0.85 0.89 0.9 0.9 0.91

0.36 0.51 0.66 0.79 0.85 0.89 0.9 0.91

ResNet20: batch size 8192

0.0

0.2

0.4

0.6

0.8

1.0

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

dam

ping

0.82 0.8 0.7 0.8 0.65 0.45 0.36 0.14

0.82 0.81 0.8 0.69 0.68 0.43 0.36 0.19

0.82 0.82 0.78 0.69 0.76 0.59 0.53 0.25

0.79 0.82 0.8 0.81 0.75 0.75 0.51 0.55

0.71 0.8 0.83 0.82 0.83 0.83 0.83 0.84

0.57 0.72 0.8 0.82 0.83 0.86 0.89 0.89

0.42 0.58 0.72 0.8 0.84 0.86 0.88 0.9

0.32 0.42 0.58 0.72 0.8 0.85 0.88 0.9

ResNet20: batch size 16384

0.0

0.2

0.4

0.6

0.8

1.0

Fig. 10: Best accuracy achieved under adjusted epoch budget versus damping and learning rate for batchsizes from 128 to 16,384 with normal damping on CIFAR-10 with ResNet20. A positive correlation betweendamping and learning rate is exhibited. When the batch size exceeds 4,096, we observe a shrinking of thehigh-accuracy region for large batch sizes.

0.001 0.003 0.009 0.027 0.081 0.243 0.729 2.187 6.534 19.602 58.806 176.418

learning rate

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

0.6534

1.9602

5.8806

17.6418

dam

ping

0.82 0.8 0.7 0.8 0.65 0.45 0.36 0.14

0.82 0.81 0.8 0.69 0.68 0.43 0.36 0.19

0.82 0.82 0.78 0.69 0.76 0.59 0.53 0.25

0.79 0.82 0.8 0.81 0.75 0.75 0.51 0.55

0.71 0.8 0.83 0.82 0.83 0.83 0.83 0.84 0.83 0.61 0.21 0.1

0.57 0.72 0.8 0.82 0.83 0.86 0.89 0.89 0.88 0.8 0.27 0.18

0.42 0.58 0.72 0.8 0.84 0.86 0.88 0.9 0.89 0.84 0.53 0.28

0.32 0.42 0.58 0.72 0.8 0.85 0.88 0.9 0.91 0.86 0.52 0.23

0.73 0.81 0.86 0.89 0.9 0.85 0.58 0.34

0.59 0.73 0.82 0.88 0.9 0.83 0.48 0.33

0.44 0.61 0.76 0.85 0.89 0.79 0.51 0.21

0.33 0.47 0.65 0.76 0.87 0.76 0.44 0.22

ResNet20: batch size 16384

0.0

0.2

0.4

0.6

0.8

1.0

(a)

1.37

e-06

4.11

e-06

1.23

e-05

3.7e

-05

0.00

0111

0.00

0333

0.00

1

0.00

3

0.00

9

0.02

7

0.08

1

0.24

3

0.72

9

2.18

7

6.56

1

19.6

83

learning rate

1.1e-05

3.3e-05

0.0001

0.0003

0.0009

0.0027

0.0081

0.0243

0.0729

0.2187

0.6561

1.9683

dam

ping

0.64 0.77 0.84 0.83 0.75 0.51 0.33 0.27

0.53 0.67 0.79 0.85 0.84 0.75 0.42 0.27

0.41 0.55 0.69 0.81 0.85 0.83 0.71 0.4 0.27 0.37 0.32 0.32 0.28 0.25

0.31 0.42 0.57 0.7 0.81 0.84 0.79 0.64 0.31 0.43 0.39 0.43 0.26 0.24

0.23 0.32 0.43 0.58 0.71 0.81 0.82 0.74 0.69 0.32 0.52 0.48 0.45 0.37

0.19 0.24 0.32 0.44 0.59 0.71 0.81 0.8 0.72 0.67 0.52 0.65 0.56 0.64

0.44 0.58 0.71 0.78 0.77 0.78 0.72 0.68 0.74 0.78 0.86 0.74

0.33 0.44 0.59 0.71 0.77 0.79 0.77 0.74 0.83 0.89 0.89 0.85

0.25 0.33 0.44 0.59 0.7 0.77 0.82 0.82 0.85 0.89 0.91 0.86

0.19 0.25 0.33 0.44 0.59 0.7 0.78 0.83 0.87 0.9 0.91 0.86

0.84 0.89 0.91 0.87

0.82 0.88 0.91 0.86

ResNet32: batch size 16384

0.0

0.2

0.4

0.6

0.8

1.0

(b)

Fig. 11: Best accuracy achieved versus damping and learning rate under a larger hyperparameter tuning spacefor batch size under 16,384 on CIFAR-10 with ResNet20 (left) and ResNet32 (right). For both cases, theareas of the high-accuracy region are still smaller than that of small batch training results.

Page 17: New arXiv:1903.06237v3 [cs.LG] 31 Jul 2019 · 2019. 8. 2. · InefficiencyofK-FACforLargeBatchSizeTraining 5 27 28 29 210 2 122 2 13 24 Batch Size 0.90 0.91 0.92 0.93 0.94 Best Test

Inefficiency of K-FAC for Large Batch Size Training 17

27 210 214

Batch Size

0.2

0.3

0.4

0.5

0.6

0.7

0.8

0.9

Test

Acc

urac

y

ResNet32, K-FAC: Accuracy vs. BS

27 210 214

Batch Size

0.0

0.5

1.0

1.5

2.0

2.5

3.0

3.5

Trai

ning

Los

s

ResNet32, K-FAC: Training Loss vs. BS

(a)

27 210 214

Batch Size

0.1

0.2

0.3

0.4

0.5

0.6

0.7

0.8

0.9

Test

Acc

urac

y

AlexNet, K-FAC: Accuracy vs. BS

27 210 214

Batch Size0

2

4

6

8

10

Trai

ning

Los

s

AlexNet, K-FAC: Training Loss vs. BS

(b)

Fig. 12: (a): Test accuracy/training loss distribution versus batch size for K-FAC on CIFAR-10 with ResNet32at the end of training under an adjusted epoch budget. (b): Test accuracy/training loss distribution versusbatch size for K-FAC on SVHN with AlexNet at the end of training under an adjusted epoch budget. Largerbatch sizes result in lower accuracy and higher training losses that are more sensitive to hyperparameter choice.