stiffness: a new perspective on generalization in neural ... · stiffness: a new perspective on...

12
Stiffness: A New Perspective on Generalization in Neural Networks Stanislav Fort 123 Pawel Krzysztof Nowak 1 Stanislaw Jastrzebski 145 Srini Narayanan 1 Abstract In this paper we develop a new perspective on generalization of neural networks by proposing and investigating the concept of a neural network stiffness. We measure how stiff a network is by looking at how a small gradient step in the net- works parameters on one example affects the loss on another example. Higher stiffness suggests that a network is learning features that generalize. In particular, we study how stiffness depends on 1) class membership, 2) distance between data points in the input space, 3) training iteration, and 4) learning rate. We present experiments on MNIST, FASHION MNIST, and CIFAR-10/100 using fully-connected and convolutional neural networks, as well as on a transformer-based NLP model. We demonstrate the connection between stiffness and generalization, and observe its depen- dence on learning rate. When training on CIFAR- 100, the stiffness matrix exhibits a coarse-grained behavior indicative of the model’s awareness of super-class membership. In addition, we measure how stiffness between two data points depends on their mutual input-space distance, and establish the concept of a dynamical critical length – a dis- tance below which a parameter update based on a data point influences its neighbors. 1. Introduction Neural networks are a class of highly expressive function approximators that proved to be successful in many domains such as vision, natural language understanding, and game- play. The specific details that lead to their expressive power have recently been studied in Mont ´ ufar et al. (2014); Raghu et al. (2017); Poole et al. (2016). Empirically, neural net- 1 Google Research, Zurich, Switzerland 2 Stanford University, Stanford, CA, USA 3 This work has been done as a part of the Google AI Residency program 4 New York University, New York, USA 5 A part of this work was done while an intern in Google Research, Zurich, Switzlerland. Correspondence to: Stanislav Fort <[email protected]>. Preliminary work. Under review by the International Conference on Machine Learning (ICML). Do not distribute. Figure 1. A diagram illustrating a) the concept of stiffness and b) dynamical critical length ξ. A small gradient update to the network’s weights based on example X1 decreases loss on some examples (stiff), and increases it on others (anti-stiff). We call the characteristic distance over which datapoints are stiff the dynami- cal critical length ξ. works have been extremely successful at generalizing to new data despite their over-parametrization for the task at hand, as well as their ability to fit arbitrary random data perfectly Zhang et al. (2016); Arpit et al. (2017). The fact that gradient descent is able to find good solutions given the highly over-parametrized family of functions has been studied theoretically in Arora et al. (2018) and explored empirically in Li et al. (2018), where the effective low- dimensional nature of many common learning problems is shown. Fort and Scherlis (2019) extends the analysis in Li et al. (2018) to demonstrate the role of initialization on the effective dimensionality, and Fort and Jastrzebski (2019) use the result to build a phenomenological model of the loss landscape. Du et al. (2018a) and Du et al. (2018b) use a Gram ma- trix to study convergence in neural network empirical loss. Pennington and Worah (2017) study the concentration prop- erties of a similar covariance matrix formed from the output of the network. Ghorbani et al. (2019); Jastrzbski et al. (2019) investigate the Hessian eigenspectrum and Papyan (2019) show how it is related to the gradient covariance. The clustering of logit gradients is studied e.g. in Fort and Ganguli (2019). Both concepts are closely related to our definition of stiffness. To explain the remarkable generalization properties of neu- ral networks, it has been proposed (Rahaman et al., 2018) that the function family is biased towards low-frequency arXiv:1901.09491v3 [cs.LG] 13 Mar 2020

Upload: others

Post on 17-Oct-2020

5 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: Stiffness: A New Perspective on Generalization in Neural ... · Stiffness: A New Perspective on Generalization in Neural Networks functions. The role of similarity between the neural

Stiffness: A New Perspective on Generalization in Neural Networks

Stanislav Fort 1 2 3 Paweł Krzysztof Nowak 1 Stanisław Jastrzebski 1 4 5 Srini Narayanan 1

AbstractIn this paper we develop a new perspective ongeneralization of neural networks by proposingand investigating the concept of a neural networkstiffness. We measure how stiff a network is bylooking at how a small gradient step in the net-works parameters on one example affects the losson another example. Higher stiffness suggeststhat a network is learning features that generalize.In particular, we study how stiffness depends on1) class membership, 2) distance between datapoints in the input space, 3) training iteration,and 4) learning rate. We present experiments onMNIST, FASHION MNIST, and CIFAR-10/100using fully-connected and convolutional neuralnetworks, as well as on a transformer-based NLPmodel. We demonstrate the connection betweenstiffness and generalization, and observe its depen-dence on learning rate. When training on CIFAR-100, the stiffness matrix exhibits a coarse-grainedbehavior indicative of the model’s awareness ofsuper-class membership. In addition, we measurehow stiffness between two data points depends ontheir mutual input-space distance, and establishthe concept of a dynamical critical length – a dis-tance below which a parameter update based on adata point influences its neighbors.

1. IntroductionNeural networks are a class of highly expressive functionapproximators that proved to be successful in many domainssuch as vision, natural language understanding, and game-play. The specific details that lead to their expressive powerhave recently been studied in Montufar et al. (2014); Raghuet al. (2017); Poole et al. (2016). Empirically, neural net-

1Google Research, Zurich, Switzerland 2Stanford University,Stanford, CA, USA 3This work has been done as a part of theGoogle AI Residency program 4New York University, New York,USA 5A part of this work was done while an intern in GoogleResearch, Zurich, Switzlerland. Correspondence to: Stanislav Fort<[email protected]>.

Preliminary work. Under review by the International Conferenceon Machine Learning (ICML). Do not distribute.

Figure 1. A diagram illustrating a) the concept of stiffness andb) dynamical critical length ξ. A small gradient update to thenetwork’s weights based on example X1 decreases loss on someexamples (stiff), and increases it on others (anti-stiff). We call thecharacteristic distance over which datapoints are stiff the dynami-cal critical length ξ.

works have been extremely successful at generalizing tonew data despite their over-parametrization for the task athand, as well as their ability to fit arbitrary random dataperfectly Zhang et al. (2016); Arpit et al. (2017).

The fact that gradient descent is able to find good solutionsgiven the highly over-parametrized family of functions hasbeen studied theoretically in Arora et al. (2018) and exploredempirically in Li et al. (2018), where the effective low-dimensional nature of many common learning problems isshown. Fort and Scherlis (2019) extends the analysis in Liet al. (2018) to demonstrate the role of initialization on theeffective dimensionality, and Fort and Jastrzebski (2019)use the result to build a phenomenological model of the losslandscape.

Du et al. (2018a) and Du et al. (2018b) use a Gram ma-trix to study convergence in neural network empirical loss.Pennington and Worah (2017) study the concentration prop-erties of a similar covariance matrix formed from the outputof the network. Ghorbani et al. (2019); Jastrzbski et al.(2019) investigate the Hessian eigenspectrum and Papyan(2019) show how it is related to the gradient covariance.The clustering of logit gradients is studied e.g. in Fort andGanguli (2019). Both concepts are closely related to ourdefinition of stiffness.

To explain the remarkable generalization properties of neu-ral networks, it has been proposed (Rahaman et al., 2018)that the function family is biased towards low-frequency

arX

iv:1

901.

0949

1v3

[cs

.LG

] 1

3 M

ar 2

020

Page 2: Stiffness: A New Perspective on Generalization in Neural ... · Stiffness: A New Perspective on Generalization in Neural Networks functions. The role of similarity between the neural

Stiffness: A New Perspective on Generalization in Neural Networks

functions. The role of similarity between the neural networkoutputs to similar inputs has been studied in Schoenholzet al. (2016) for random initializations and explored empiri-cally in Novak et al. (2018), and it plays a crucial role in theNeural Tangent Kernel work started by Jacot et al. (2018).

1.1. Our contribution

In this paper, we study generalization through the lens ofstiffness. We measure how stiff a neural network is by ana-lyzing how a small gradient step based on one input exampleaffects the loss on another input example. Stiffness capturesthe resistance of the functional approximation learned todeformation by gradient steps.

We find the concept of stiffness useful 1) in diagnosing andcharacterizing generalization, 2) in uncovering semanticallymeaningful groups of datapoints beyond classes, 3) explor-ing the effect of learning rate on the function learned, and 4)defining and measuring the dynamical critical length ξ. Weshow that higher learning rates bias the functions learnedtowards lower, more local stiffness, making them easier tobend on gradient updates. We demonstrate these on visiontasks (MNIST, FASHION MNIST, CIFAR-10/100) withfully-connected and convolution architectures (includingResNets), and on a natural language task using a fine-tunedBERT model.

We find stiffness to be a useful concept to study becauseit a) relates directly to generalization, b) it is sensitive tosemantically meaningful content of the inputs, and c) at thesame time relates to the loss landscape Hessian and thereforeto the research on its curvature and optimization. As such, itbridges the gap between multiple directions simultaneously.

This paper is structured as follows: we introduce the conceptof stiffness and the relevant theory in Section 2. We describeour experimental setup in Section 3, present their results inSection 4, and discuss their implications in Section 5. Weconclude wit Section 6.

2. Theoretical background2.1. Stiffness – definitions

Let a functional approximation (e.g. a neural network) fbe parametrized by tunable parameters W . Let us assumea classification task and let a data point X have the groundtruth label y. A loss L(fW (X), y) gives us the amount ofmismatch between the function’s output at input X and theground truth y. The gradient of the loss with respect to theparameters

~g = ∇WL(fW (X), y) (1)

is the direction in which, if we were to change the param-eters W , the loss would change the most rapidly (for in-finitesimal step sizes). Gradient descent uses this step (the

negative of it) to update the weights and gradually tune thefunctional approximation to better correspond to the de-sired outputs on the training dataset inputs. Let us consider

Figure 2. A diagram illustrating the concept of stiffness. It canbe viewed in two equivalent ways: a) as the change in loss at adatapoint induced by the application of a gradient update based onanother datapoint, and b) the alignment of loss gradients computedat the two datapoints. These two descriptions are mathematicallyequivalent.

two data points with their ground truth labels (X1, y1) and(X2, y2). We construct a gradient with respect to example 1as ~g1 = ∇WL(fW (X1), y1) and ask how do the losses ondata points 1 and 2 change as the result of a small change ofW in the direction −~g1, i.e. what is

∆L1 = L(fW−ε~g1(X1), y1)− L(fW (X1), y1) , (2)

which is equivalent to

∆L1 = −ε∇εL(fW−ε~g1(X1), y1) = −ε~g1 · ~g1 +O(ε2) .(3)

The change in loss on input 2 due to the same gra-dient step from input 1 becomes equivalently ∆L2 =−ε∇εL(fW−ε~g1(X2), y2) = −ε~g1 · ~g2 + O(ε2). We areinterested in the correlation in loss changes ∆L1 and ∆L2,and will work in the limit of ε→ 0, i.e. an infinitesimal stepsize. We know that ∆L1 < 0 since we constructed the gradi-ent update accordingly. We define positive stiffness to mean∆L2 < 0 as well, i.e. that losses at both inputs went down.We assign the stiffness of 0 for ∆L2 = 0. If ∆L2 > 0, thetwo inputs would be anti-stiff (negative stiffness). The equa-tions above show that this can equivalently be thought of asthe overlap between the two gradients ~g1 · ~g2 being positivefor positive stiffness, and negative for negative stiffness. Weillustrate this in Figure 1 and Figure 2.

The above indicate that what we initially conceived of asa change in loss due to the application of a small gradientupdate from one input to another is in fact equivalent toanalyzing gradient alignment between different datapoints.

Page 3: Stiffness: A New Perspective on Generalization in Neural ... · Stiffness: A New Perspective on Generalization in Neural Networks functions. The role of similarity between the neural

Stiffness: A New Perspective on Generalization in Neural Networks

We will be using 2 different definitions of stiffness: thesign stiffness and the cosine stiffness. We define the signstiffness to be the expected sign of ~g1 · ~g2 (or equivalentlythe expected sign of ∆L1∆L2) as

Ssign((X1, y1), (X2, y2); f) = E [sign (~g1 · ~g2)] , (4)

where stiffness depends on the dataset from which X1 andX2 are drawn. The cosine stiffness is

Scos((X1, y1), (X2, y2); f) = E [cos (~g1, ~g2)] , (5)

where cos (~g1, ~g2) = (~g1/|~g1|) · (~g2/|~g2|). We use bothversions of stiffness as they are suitable to highlight differentphenomena – the sign stiffness shows the stiffness betweenclasses clearer, while the cosines stiffness is more useful forwithin-class stiffness.

Based on our definitions, the stiffness between a datapointX and itself is 1.

2.2. Train-train, train-val, and val-val

When measuring stiffness between two datapoints, we have3 options: 1) choosing both datapoints from the training set(we call this train-train), 2) choosing one from the trainingset and the other from the validation set (train-val), and3) choosing both from the validation set (val-val). Thetrain-val stiffness is directly related to generalization, as itcorresponds to the amount of improvement on the trainingset transferring to the improvement of the validation set. Weempirical observe that all 3 options are remarkably simi-lar, which gives us confidence that they all are related togeneralization. This is further supported by observing theirbehavior is a function of epoch in Figures 3, 9, and 10.

2.3. Stiffness based on class membership

A natural question to ask is whether a gradient taken withrespect to an input X1 in class c1 will also decrease the lossfor exampleX2 in class c2. In particular, we define the classstiffness matrix

C(ca, cb) = EX1∈ca,X2∈cb,X1 6=X2

[S((X1, y1), (X2, y2))] .

(6)The on-diagonal elements of this matrix correspond to thesuitability of the current gradient update to the membersof the class itself. In particular, they correspond to withinclass generalizability (we avoid counting the stiffness ofan example to itself, as that is by definition 1). The off-diagonal elements, on the other hand, express the amountof improvement transferred from one class to another. Theytherefore directly diagnose the amount of generality thecurrently improved features have.

A consistent summary of generalization between classes is

the off-diagonal sum of the class stiffness matrix

Sbetween classes =1

Nc(Nc − 1)

∑c1

∑c2 6=c1

C(c1, c2) . (7)

In our experiments, we track this value as a function oflearning rate once we reached a fixed loss. The quantity isrelated to how generally applicable the learned features are,i.e. how well they transfer from one class to another. Forexample, for CNNs learning good edge detectors in initiallayers typically benefits all downstream tasks, regardless ofthe particular class in question. We do the equivalent for thewithin-class stiffness (= on diagonal elements). When thewithin-class stiffness starts going < 1, the generality of thefeatures improved does not extend to even the class itself,suggesting an onset of overfitting (see Figure 3).

2.4. Higher order classes

For some datasets (in our case CIFAR-100), there existlarge, semantically meaningful super-classes to which dat-apoints belong. We observe how stiffness depends on thedatapoints’ membership in those super-classes, and goinga step further, define groups of super-classes we call super-super-classes. We observe that the stiffness within thosesemantically meaningful groups is higher than would beexpected on average, suggesting that stiffness is a measureis aware of the semantic content of the datapoints (Figure 8).

2.5. Stiffness as a function of distance

We investigate how stiff two datapoints are, on average, asa function of how far away from each other they are in theinput space. We can think of neural networks as a form ofkernel learning (in some limits, this is exact, e.g. Jacot et al.(2018)) and here we are investigating the particular formof the kernel learned, in particular its direction-averagedsensitivity. This links our results to the work on spectralbias (towards slowly-varying, low frequency functions) inRahaman et al. (2018). We are able to directly measure thecharacteristic size of the stiff regions in neural networkstrained on real tasks which we call dynamical critical lengthξ. The concept is illustrated in Figure 1. We work with datanormalized to the unit sphere | ~X| = 1 and use their mutualcosine to define their distance as

distance( ~X1, ~X2) = 1−~X1 · ~X2

| ~X1|| ~X2|. (8)

This has the advantage of being bounded between 0 and 2.For X1 = X2, i.e. distance = 0, stiffness is 1 by definition.We track the threshold distance ξ at which stiffness, onaverage, goes to 0. We observe ξ as a function of trainingand learning rate to estimate the characteristic size of thestiff regions of a neural net.

Page 4: Stiffness: A New Perspective on Generalization in Neural ... · Stiffness: A New Perspective on Generalization in Neural Networks functions. The role of similarity between the neural

Stiffness: A New Perspective on Generalization in Neural Networks

3. Methods

Figure 3. The evolution of training and validation loss (top panel),within-class stiffness (central panel) and between-classes stiffness(bottom panel) during training. The onset of over-fitting is markedin orange. After that, both within-class and between-classes stiff-ness regress to 0. The same effect is visible in stiffness measuredbetween two training set datapoints, one training and one validationdatapoint, and two validation set datapoints.

3.1. Experimental setup

We ran a large number of experiments with fully-connected(FC) and convolutional neural networks (CNN) (both small,custom-made networks, as well as ResNet He et al. (2015))on 4 classification datasets: MNIST (LeCun and Cortes,2010), FASHION MNIST Xiao et al. (2017), CIFAR-10Krizhevsky (2009), and CIFAR-100 Krizhevsky (2009). Inaddition, we study the behavior of a BERT Devlin et al.(2018) fine-tuned on the MNLI dataset Williams et al.(2017).

Using those experiments, we investigated the behavior ofstiffness as a function of 1) training epoch, 2) the choice oflearning rate, 3) class membership, and 4) the input spacedistance between images.

For experiments with fully-connected neural networks, weused a 3 layer ReLU network of the form X → 500 →300 → 100 → y. For experiments with convolutionalneural networks, we used a 3 layer network with filter size 3and the numbers of channels being 16, 32, and 32 after therespective convolutional layer, each followed by 2× 2 maxpooling. The final layer was fully-connected. For ResNets,we used the ResNet20v1 implementation from Chollet et al.(2015). No batch normalization was used.

We pre-processed the network inputs to have zero mean andunit variance, and normalized all data to the unit sphere as| ~X| = 1. We used Adam with different (constant) learningrates as our optimizer and the default batch size of 32. Allexperiments took at most a few hours on a single GPU.

3.2. Training and stiffness evaluation

We trained a network on a classification task first, to thestage of training at which we wanted to measure stiffness.We then froze the weights and evaluated loss gradients on athe relevant set of images.

We evaluated stiffness properties between pairs of datapoints from the training set, one datapoint from the trainingand one from the validation set, and pairs of points fromthe validation set. We used the full training set to trainour model. The procedure was as follows: 1) Train for anumber of steps on the training set and update the networkweights accordingly. 2) For each of the modes { train-train,train-val, and val-val}, go through tuples of images comingfrom the respective datasets. 3) For each tuple calculate theloss gradients g1 and g2, and compute check sign(~g1 · ~g2)and cos(~g1, ~g2). 4) Log the input space distance betweenthe images as well as other relevant features. In our experi-ments, we used a fixed subset (typically of≈ 500 images forexperiments with 10 classes, and ≈ 3, 000 images for 100classes) of the training and validation sets to evaluate thestiffness properties on. We convinced ourselves that sucha subset is sufficiently large to provide measurements withsmall enough statistical uncertainties, which we overlay inour figures.

3.3. Learning rate dependence

We investigated how stiffness properties depend on the learn-ing rate used in training. To be able to do that, we firstlooked at the dynamical critical scale ξ for different stagesof training of our network, and then compared those basedon the epoch and training loss at the time in order to be ableto compare training runs with different learning rates fairly.The results are shown in Figures 5 and 12, concluding thathigher learning rates learn functions with more localizedstiffness response, i.e. lower ξ.

4. ResultsWe will summarize our key results here and will provideadditional figures in the Supplementary material.

4.1. Stiffness properties based on class membership

We explored the stiffness properties based on class mem-bership as a function of training epoch at different stagesof training. Figure 4 shows results for a CNN on FASH-

Page 5: Stiffness: A New Perspective on Generalization in Neural ... · Stiffness: A New Perspective on Generalization in Neural Networks functions. The role of similarity between the neural

Stiffness: A New Perspective on Generalization in Neural Networks

Figure 4. Class membership dependence of stiffness for a CNNon FASHION MNIST at 4 different stages of training. The figureshows stiffness between train-train, train-val and val-val pairs ofimages, as well as the sign and cosine metrics.

ION MNIST at initialization, very early in training, aroundepoch 1, and at the late stage. We provide additional re-sults in the Supplementary material. The within-class (ondiagonal) and between-classes results are summarized inFigure 5. Initially, an improvement based on an input from

Figure 5. Stiffness as a function of epoch. The plots summarize theevolution of within-class and between-classes stiffness measuresas a function of epoch of training for a CNN on FASHION MNIST.

a particular class benefits only members of the same class.Intuitively, this could be due to some crude features sharedwithin a class (such as the typical overall intensity, or theaverage color) being learned. There is no consistent stiff-ness between different classes at initialization. As trainingprogresses, within-class stiffness stays high. In addition,stiffness between classes increases as well, given the modelis powerful enough to learn the dataset. With the onset ofoverfitting, as shown in Figure 3, the model becomes in-creasingly less stiff until even stiffness for inputs within thesame class is lost. At that stage, an improvement on one

example does not lead to a consistent improvement on otherexamples, as ever more detailed, specific features have to belearned. Figure 7 shows the class-class stiffness matrix at 4stages of training for a ResNet20v1 on CIFAR-100.

4.2. Class, super-class, and super-super-class structureof stiffness on CIFAR-100

We studied the class-class stiffness between images fromthe CIFAR-100 dataset for a ResNet20v1 trained on thefine labels (100 classes) to approximately 60% (top-1) testaccuracy. As expected, the class-class stiffness was high forimages from the same class (excluding the stiffness betweenthe image and itself to prevent artificially increasing thevalue), as shown in Figure 6.

Figure 6. Stiffness between images of all pairs of classes for aResNet20v1 trained on CIFAR-100. The within-class stiffness (thediagonal) is noticeable higher than the rest.

We sorted the classes according to their super-classes(groups of 5 semantically connected classes) and visual-ized them in Figure 8. It seems that the groups of classesbelonging to the same super-class are mutually stiff, with-out explicitly training the network to classify according tothe super-class. Stiffness within the images of the samesuper-class is noticeably higher than the typical value, andwhile it is not as high as within-class stiffness itself, this rev-elas that the network is aware of higher-order semanticallymeaningful categories to which the images belong. Visuallyinspecting Figure 8, we see that there is an additional struc-ture beyond the super-classes, which roughly correspondsto the living / non-living distinction.

To quantify this hierarchy beyond a visual inspection ofFigures 6 and 8, we calculated the following 4 measures: 1)the average stiffness between two different images from thesame class, 2) the average stiffness between two differentimages from two different classes from the same super-class, 3) the average stiffness between two different imagesfrom two different classes from two different super-classes,

Page 6: Stiffness: A New Perspective on Generalization in Neural ... · Stiffness: A New Perspective on Generalization in Neural Networks functions. The role of similarity between the neural

Stiffness: A New Perspective on Generalization in Neural Networks

Figure 7. Stiffness between images of all pairs of classes for aResNet20v1 trained on CIFAR-100. The within-class stiffness (thediagonal) is noticeable higher than the rest. The four panels showthe evolution of the matrix through training, starting at upper leftcorner, and ending at the lower right corner.

Figure 8. Stiffness between images of all pairs of classes for aResNet20v1 trained on CIFAR-100, with classes sorted based ontheir super-class membership. The higher stiffness between classesfrom the same super-class is visible, indicating the sensitivity ofstiffness to the semantically meaningful content of images. Ahigher order groups of super-classes are also showing.

belonging to the same super-super-class that we definedas either living or non-living, and 4) the average stiffnessbetween examples from different classes, which served as

a baseline. Each of these averages measures the amountof semantically meaningful similarity between reaction ofimages to gradient updates on gradually more abstract levels.

Figure 9. Stiffness between examples of the same class, super-class, and super-super-class for ResNet20v1 trained on the 100fine classes of CIFAR-100. With an increasing epoch, all measuresgenerically regress to 0. The yellow and green lines, correspondingto super-, and super-super-classes exceed the baseline (green),suggesting that the network is aware of higher-order semanticcategories of the images throughout training.

4.3. Natural language tasks

To go beyond visual classification tasks, we studied thebehavior of a BERT model fine tuned on the MNLI dataset.The evolution of within-class and between-classes stiffnessis shown in Figure 10. It is very similar to our results on

Figure 10. Stiffness as as function of epoch for a natural languagetask MNLI using a fine-tuned BERT model. Stiffness within thesame class starts high and regresses towards 0, while stiffnessbetween classes starts negative and approaches 0 from below. Thisis similar to our observations on vision tasks.

vision tasks such as Figure 4 and Figure 9.

4.4. Stiffness as a function of distance betweendatapoints

We investigated stiffness between two datapoints as a func-tion of their distance in the input space in order to measurehow large the patches of the learned function that move to-gether under gradient updates are (as illustrated in Figure 1).We focused on datapoints from the same class. Examples ofour results are shown in Figure 11 and in the Supplementary

Page 7: Stiffness: A New Perspective on Generalization in Neural ... · Stiffness: A New Perspective on Generalization in Neural Networks functions. The role of similarity between the neural

Stiffness: A New Perspective on Generalization in Neural Networks

Figure 11. Stiffness between images of the same class as a functionof their input space distance for 4 different stages of training of aCNN on FASHION MNIST. We add a linear fit to the data and mea-sure its intersection with the y = 0 line of 0 stiffness to estimatethe dynamical critical length ξ. ξ decreases with training time,suggesting that gradient updates influence distance datapoint lessand less as training progresses, leading to eventually overfitting.

material. Fitting a linear function to the stiffness vs distancedata, we estimate the typical distance at which stiffness goesto 0 for the first time, and called it the dynamical criticallength ξ. For a pair of identical datapoints, the stiffness is bydefinition 1. Equivalent plots were generated for each epochof training in each of our experiments in order to measure ξused in Figure 12.

4.5. The dynamical critical length ξ, and the role oflearning rate

At each epoch of training for each of our experiments, weanalyzed the distribution of within-class stiffness betweenimages based on their distance, and extracted the zero cross-ing which we call the critical dynamical length ξ. In Fig-ures 12 and in the Supplementary material we summarizethe dependence of ξ on the epoch of training as well as thetraining loss for 5 different learning rates. We use the train-

Figure 12. The effect of learning rate on the stiffness dynamicalcritical length ξ – the input space distance between images overwhich stiffness disappears (see Figure 1 for a cartoon illustration).The upper part shows ξ for 5 different learning rates as a functionof epoch, while the lower part shows it as a function of training lossin order to be able to compare different learning rates fairly. Thelarger the learning rate, the smaller the stiff domains, even giventhe same training loss. This means that the learning rate affects thecharacter the function learned beyond the speed of training.

ing loss to make sure we are comparing runs with differentlearning rates at the equivalent stages of training. We seethat the bigger the learning rate, the smaller the domainsize ξ. This is a surprising result – higher learning ratesdoe not only lead to a faster training, they in fact influenceproperties of the learning function beyond that. We observethat higher learning rates lead to functions of lower ξ, i.e.functions that seem easier to bend using gradient updates.

5. DiscussionWe explored the concept of neural network stiffness andused it to diagnose and characterize generalization. We stud-ied stiffness for models trained on real datasets, and mea-sured its variation with training iteration, class membership,distance between data points, and the choice of learning rate.We explored stiffness between pairs of data points comingboth from the training set, one from the training and onefrom the validation set, and both from the validation set.Training-validation stiffness is directly related to the trans-

Page 8: Stiffness: A New Perspective on Generalization in Neural ... · Stiffness: A New Perspective on Generalization in Neural Networks functions. The role of similarity between the neural

Stiffness: A New Perspective on Generalization in Neural Networks

fer of improvements on the training set to the validation set.We used two different stiffness metrics – the sign stiffnessand the cosine stiffness – to highlight different phenomena.

We explored models trained vision tasks (MNIST, FASH-ION MNIST, and CIFAR-10/100) and an natural languagetask (MNLI) through the lens of stiffness. In essence, stiff-ness measures the alignment of gradients taken at differentinput data points, which we show is equivalent to askingwhether a weight update based on one input will benefit theloss on another. We demonstrate the connection betweenstiffness and generalization and show that with the onset ofoverfitting to the training data, stiffness decreases and even-tually reaches 0, where even gradient updates taken withrespect images of a particular class stop benefiting othermembers of the same class. This happens even within thetraining set itself, and therefore could potentially be used asa diagnostic for early stopping to prevent overfitting.

Having established the usefulness of stiffness as a diagnostictool for generalization, we explored its dependence on classmembership. We find that in general gradient updates withrespect to a member of a class help to improve loss on datapoints in the same class, i.e. that members of the same classhave high stiffness with respect to each other. This holdsat initialization as well as throughout most of the training.The pattern breaks when the model starts overfitting to thetraining set, after which within-class stiffness eventuallyreaches 0. We observe this behavior with fully-connectedand convolutional neural networks on MNIST, FASHIONMNIST, and CIFAR-10/100.

Stiffness between inputs from different classes relates tothe generality of the features being learned and within-tasktransfer of improvement from class to class. With the on-set of overfitting, the stiffness between different classesregresses to 0, as does within-class stiffness. We extendedthis analysis to groups of classes (=super-classes) on CIFAR-100, and went a step above that to super-super-classes(groups of super-classes). Using these semantically mean-ingful ways of grouping images together, we verified that thenetwork is aware of more coarse-grained semantic groupsof images, going beyond the 100 fine-grained classes it wastrained on, and verified that stiffness as a useful tool todiagnose that.

We also investigated the characteristic size of stiff regionsin our trained networks at different stages of training (seeFigure 1 for an illustration). By studying stiffness betweentwo datapoints and measuring their distance in the inputspace, we observed that the farther the datapoints and thehigher the epoch of training, the less stiffness there existsbetween them on average. This allowed us to define thedynamical critical length ξ – an input space distance overwhich stiffness between input points decays to 0. ξ corre-sponds to the size of stiff regions – patches of the data space

that can move together when a gradient update is applied,provided it were an infinitesimal gradient step.

We investigated the effect of learning rate on stiffness byobserving how ξ changes as a function of epoch and thetraining loss for different learning rates. We show that thehigher the learning rate, the smaller the ξ, i.e. for highlearning rates the patches of input space that are improvedtogether are smaller. This holds both as a function of epochand training loss, which we used in order to compare runswith different learning rates fairly. This points towards theregularization role of learning rate on the kind of functionwe learn. We observe significant differences in the charac-teristic size of regions of input space that react jointly togradient updates based on the learning rate used to trainthem.

6. ConclusionWhy stiffness? We find stiffness to be a concept worth-while of study since it bridges several different, seeminglyunrelated directions of research. 1) It directly relates togeneralization, which is one of the key questions of the field,2) it is (as we have verified) sensitive to semantic content ofthe input datapoints beyond the level of classes on which theneural net was trained explicitly, 3) it relates to the Hessianof the loss with respect to the weights, and therefore to theloss landscape and optimization literature. Understandingits properties therefore holds a potential for linking theseareas together.

Empirical observations. In this paper, we have observedthe dependence of stiffness on a) class membership, b) dis-tance between datapoints, c) learning rate, and d) stage oftraining. We did this on a range of vision tasks and a naturallanguage task. We observed that stiffness is sensitive to thesemantic content of images going beyond class, to super-class, and even super-super-class level on CIFAR-100. Bydefining the concept of the dynamical critical length ξ, weshowed that higher learning rates lead to learned functionsthat are easier to bend (i.e. they are less stiff) using gradientupdates even though they perform equally well in terms ofaccuracy.

Future directions. One immediate extension to the conceptof stiffness would be to ascertain the role stiffness mightplay in architecture search. For instance, we expect locality(as in CNN) to reflect in higher stiffness properties. It isquite possible that stiffness could be a guiding parameterfor meta-learning and explorations in the space of architec-tures, however, this is beyond the scope of this paper and apotential avenue for future work.

Page 9: Stiffness: A New Perspective on Generalization in Neural ... · Stiffness: A New Perspective on Generalization in Neural Networks functions. The role of similarity between the neural

Stiffness: A New Perspective on Generalization in Neural Networks

ReferencesGuido Montufar, Razvan Pascanu, Kyunghyun Cho, and

Yoshua Bengio. On the number of linear regions of deepneural networks. In NIPS, 2014.

Maithra Raghu, Ben Poole, Jon Kleinberg, Surya Ganguli,and Jascha Sohl-Dickstein. On the expressive power ofdeep neural networks. In Doina Precup and Yee WhyeTeh, editors, Proceedings of the 34th International Con-ference on Machine Learning, volume 70 of Proceed-ings of Machine Learning Research, pages 2847–2854,International Convention Centre, Sydney, Australia, 06–11 Aug 2017. PMLR. URL http://proceedings.mlr.press/v70/raghu17a.html.

Ben Poole, Subhaneil Lahiri, Maithreyi Raghu, Jascha Sohl-Dickstein, and Surya Ganguli. Exponential expressivityin deep neural networks through transient chaos. In Ad-vances in Neural Information Processing Systems 29:Annual Conference on Neural Information ProcessingSystems 2016, December 5-10, 2016, Barcelona, Spain,pages 3360–3368, 2016.

Chiyuan Zhang, Samy Bengio, Moritz Hardt, BenjaminRecht, and Oriol Vinyals. Understanding deep learning re-quires rethinking generalization. CoRR, abs/1611.03530,2016.

Devansh Arpit, Stanislaw K. Jastrzebski, Nicolas Ballas,David Krueger, Emmanuel Bengio, Maxinder S. Kanwal,Tegan Maharaj, Asja Fischer, Aaron C. Courville, YoshuaBengio, and Simon Lacoste-Julien. A closer look at mem-orization in deep networks. In Proceedings of the 34thInternational Conference on Machine Learning, ICML2017, Sydney, NSW, Australia, 6-11 August 2017, pages233–242, 2017. URL http://proceedings.mlr.press/v70/arpit17a.html.

Sanjeev Arora, Nadav Cohen, and Elad Hazan. On the opti-mization of deep networks: Implicit acceleration by over-parameterization. CoRR, abs/1802.06509, 2018. URLhttp://arxiv.org/abs/1802.06509.

Chunyuan Li, Heerad Farkhoor, Rosanne Liu, and JasonYosinski. Measuring the intrinsic dimension of objectivelandscapes. CoRR, abs/1804.08838, 2018. URL http://arxiv.org/abs/1804.08838.

Stanislav Fort and Adam Scherlis. The goldilocks zone:Towards better understanding of neural network loss land-scapes. Proceedings of the AAAI Conference on ArtificialIntelligence, 33:35743581, Jul 2019. ISSN 2159-5399.doi: 10.1609/aaai.v33i01.33013574. URL http://dx.doi.org/10.1609/aaai.v33i01.33013574.

Stanislav Fort and Stanislaw Jastrzebski. Large scale struc-ture of neural network loss landscapes, 2019.

Simon S. Du, Jason D. Lee, Haochuan Li, Liwei Wang,and Xiyu Zhai. Gradient Descent Finds Global Minimaof Deep Neural Networks. arXiv:1811.03804 [cs, math,stat], November 2018a. URL http://arxiv.org/abs/1811.03804. arXiv: 1811.03804.

Simon S. Du, Xiyu Zhai, Barnabas Poczos, and AartiSingh. Gradient Descent Provably Optimizes Over-parameterized Neural Networks. arXiv:1810.02054 [cs,math, stat], October 2018b. URL http://arxiv.org/abs/1810.02054. arXiv: 1810.02054.

Jeffrey Pennington and Pratik Worah. Nonlinear random ma-trix theory for deep learning. In I. Guyon, U. V. Luxburg,S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, andR. Garnett, editors, Advances in Neural Information Pro-cessing Systems 30, pages 2637–2646. Curran Associates,Inc., 2017.

Behrooz Ghorbani, Shankar Krishnan, and Ying Xiao. Aninvestigation into neural net optimization via hessianeigenvalue density, 2019.

Stanisaw Jastrzbski, Zachary Kenton, Nicolas Ballas, AsjaFischer, Yoshua Bengio, and Amost Storkey. On therelation between the sharpest directions of DNN lossand the SGD step length. In International Conferenceon Learning Representations, 2019. URL https://openreview.net/forum?id=SkgEaj05t7.

Vardan Papyan. Measurements of three-level hierarchicalstructure in the outliers in the spectrum of deepnet hes-sians. In ICML, 2019.

Stanislav Fort and Surya Ganguli. Emergent properties ofthe local geometry of neural loss landscapes, 2019.

Nasim Rahaman, Aristide Baratin, Devansh Arpit, FelixDraxler, Min Lin, Fred A. Hamprecht, Yoshua Bengio,and Aaron Courville. On the Spectral Bias of NeuralNetworks. arXiv e-prints, art. arXiv:1806.08734, June2018.

Samuel S. Schoenholz, Justin Gilmer, Surya Ganguli, andJascha Sohl-Dickstein. Deep information propagation.CoRR, abs/1611.01232, 2016. URL http://arxiv.org/abs/1611.01232.

Roman Novak, Yasaman Bahri, Daniel A. Abolafia, JeffreyPennington, and Jascha Sohl-Dickstein. Sensitivity andgeneralization in neural networks: an empirical study.CoRR, abs/1802.08760, 2018.

Arthur Jacot, Franck Gabriel, and Clment Hongler. Neuraltangent kernel: Convergence and generalization in neuralnetworks, 2018.

Page 10: Stiffness: A New Perspective on Generalization in Neural ... · Stiffness: A New Perspective on Generalization in Neural Networks functions. The role of similarity between the neural

Stiffness: A New Perspective on Generalization in Neural Networks

Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.Deep residual learning for image recognition, 2015.

Yann LeCun and Corinna Cortes. MNIST handwritten digitdatabase. 2010. URL http://yann.lecun.com/exdb/mnist/.

Han Xiao, Kashif Rasul, and Roland Vollgraf. Fashion-mnist: a novel image dataset for benchmarking machinelearning algorithms. CoRR, abs/1708.07747, 2017.

Alex Krizhevsky. Learning multiple layers of features fromtiny images. 2009.

Jacob Devlin, Ming-Wei Chang, Kenton Lee, and KristinaToutanova. BERT: pre-training of deep bidirectionaltransformers for language understanding. CoRR,abs/1810.04805, 2018. URL http://arxiv.org/abs/1810.04805.

Adina Williams, Nikita Nangia, and Samuel R. Bowman.A broad-coverage challenge corpus for sentence under-standing through inference. CoRR, abs/1704.05426, 2017.URL http://arxiv.org/abs/1704.05426.

Francois Chollet et al. Keras. https://github.com/fchollet/keras, 2015.

Page 11: Stiffness: A New Perspective on Generalization in Neural ... · Stiffness: A New Perspective on Generalization in Neural Networks functions. The role of similarity between the neural

Stiffness: A New Perspective on Generalization in Neural Networks

Supplementary Material

Figure S1. Class-membership dependence of stiffness for a CNNon MNIST at 4 different stages of training.

Figure S2. Class-membership dependence of stiffness for a CNNon CIFAR-10 at 4 different stages of training.

Figure S3. Stiffness for ResNet20v1 trained on CIFAR-100 at testaccuracy 0.02.

Figure S4. Stiffness for ResNet20v1 trained on CIFAR-100 at testaccuracy 0.05.

Page 12: Stiffness: A New Perspective on Generalization in Neural ... · Stiffness: A New Perspective on Generalization in Neural Networks functions. The role of similarity between the neural

Stiffness: A New Perspective on Generalization in Neural Networks

Figure S5. Stiffness for ResNet20v1 trained on CIFAR-100 at testaccuracy 0.35.

Figure S6. Stiffness for ResNet20v1 trained on CIFAR-100 at testaccuracy 0.60.

Figure S7. Stiffness between images of the same class as a functionof their input space distance for 4 different stages of training of afully-connected network (FC) on FASHION MNIST.

Figure S8. Stiffness between images of the same class as a functionof their input space distance for 4 different stages of training of aCNN on CIFAR-10.

Figure S9. The effect of learning rate on the stiffness dynamicalcritical length ξ – the input space distance between images overwhich stiffness disappears. The upper part shows ξ for 5 differentlearning rates as a function of epoch, while the lower part showsit as a function of training loss in order to be able to comparedifferent learning rates fairly. The larger the learning rate, thesmaller the stiff domains.