-nn and decision trees - university of...
TRANSCRIPT
k-nn and decision trees
CS 446
Overview
Today we’ll cover two standard machine learning methods.
1.0 0.5 0.0 0.5 1.0
1.0
0.5
0.0
0.5
1.0
x1
x2
Nearest neighbors. Decision trees.
1 / 30
1. Nearest neighbor rule
Example: OCR for digits
1. Classify images of handwritten digits by the actual digits they represent.
2. Classification problem: Y = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} (a discrete set).
Digits from standard MNIST dataset(Lecun, Cortes, Burges).
2 / 30
Nearest neighbor (NN) classifier
Given: labeled examples D := {(xi, yi)}ni=1
Predictor: fD : X → Y
On input x,
1. Find the point xi among {xi}ni=1 that is “closest” to x(the nearest neighbor).
2. Return yi.
Remark. For nearest neighbor, there is no “fitting” procedure,other than memorizing the training data!
3 / 30
How to measure distance?
A default choice for distance between points in Rd is the Euclidean distance(also called `2 distance):
‖u− v‖2 :=
√√√√ d∑i=1
(ui − vi)2
(where u = (u1, u2, . . . , ud) and v = (v1, v2, . . . , vd)).
Grayscale 28×28 pixel images.
Treat as vectors (of 784 real-valued features)that live in [0, 1]784.
Note. Spatial information lost!
4 / 30
2. Evaluation
Example: OCR for digits with NN classifier
I Classify images of handwritten digits by the digits they depict.
I X = R784, Y = {0, 1, . . . , 9}.
I Given: labeled examples D := {(xi, yi)}ni=1 ⊂ X × Y.
I Construct NN classifier fD using D.
I Question: Is this classifier any good?
5 / 30
Example: OCR for digits with NN classifier
I Classify images of handwritten digits by the digits they depict.
I X = R784, Y = {0, 1, . . . , 9}.
I Given: labeled examples D := {(xi, yi)}ni=1 ⊂ X × Y.
I Construct NN classifier fD using D.
I Question: Is this classifier any good?
5 / 30
Example: OCR for digits with NN classifier
I Classify images of handwritten digits by the digits they depict.
I X = R784, Y = {0, 1, . . . , 9}.
I Given: labeled examples D := {(xi, yi)}ni=1 ⊂ X × Y.
I Construct NN classifier fD using D.
I Question: Is this classifier any good?
5 / 30
Example: OCR for digits with NN classifier
I Classify images of handwritten digits by the digits they depict.
I X = R784, Y = {0, 1, . . . , 9}.
I Given: labeled examples D := {(xi, yi)}ni=1 ⊂ X × Y.
I Construct NN classifier fD using D.
I Question: Is this classifier any good?
5 / 30
Example: OCR for digits with NN classifier
I Classify images of handwritten digits by the digits they depict.
I X = R784, Y = {0, 1, . . . , 9}.
I Given: labeled examples D := {(xi, yi)}ni=1 ⊂ X × Y.
I Construct NN classifier fD using D.
I Question: Is this classifier any good?
5 / 30
Error rate
I Error rate of classifier f on a set of labeled examples D:
err(f ;D) :=# of (x, y) ∈ D such that f(x) 6= y
|D|
(i.e., the fraction of D on which f disagrees with paired label).
I Question: What is err(fD;D)?
I Note. We can also write
err(f ;D) =1
|D|∑
(x,y)∈D
1[f(x) 6= y
]=
1
|D|∑
(x,y)∈D
`zo(f(x), y),
where `zo(y, y) = 1[y 6= y] is the zero-one (classification) loss,
6 / 30
Error rate
I Error rate of classifier f on a set of labeled examples D:
err(f ;D) :=# of (x, y) ∈ D such that f(x) 6= y
|D|
(i.e., the fraction of D on which f disagrees with paired label).
I Question: What is err(fD;D)?
I Note. We can also write
err(f ;D) =1
|D|∑
(x,y)∈D
1[f(x) 6= y
]=
1
|D|∑
(x,y)∈D
`zo(f(x), y),
where `zo(y, y) = 1[y 6= y] is the zero-one (classification) loss,
6 / 30
Error rate
I Error rate of classifier f on a set of labeled examples D:
err(f ;D) :=# of (x, y) ∈ D such that f(x) 6= y
|D|
(i.e., the fraction of D on which f disagrees with paired label).
I Question: What is err(fD;D)?
I Note. We can also write
err(f ;D) =1
|D|∑
(x,y)∈D
1[f(x) 6= y
]=
1
|D|∑
(x,y)∈D
`zo(f(x), y),
where `zo(y, y) = 1[y 6= y] is the zero-one (classification) loss,
6 / 30
A better way to evaluate the classifier
I Split the labeled examples {(xi, yi)}ni=1 into two sets (randomly).
I Training data S.I Test data T .
I Only use training data S to construct NN classifier fS .
I Training error rate of fS : err(fS ;S) = 0%.
I Use test data T to evaluate accuracy of fS .
I Test error rate of fS : err(fS ;T ) = 3.09%.
Is this good?
This gap err(fS ;T )− err(fS ;S) = 3.09% means we have overfit.
7 / 30
A better way to evaluate the classifier
I Split the labeled examples {(xi, yi)}ni=1 into two sets (randomly).
I Training data S.I Test data T .
I Only use training data S to construct NN classifier fS .
I Training error rate of fS : err(fS ;S) = 0%.
I Use test data T to evaluate accuracy of fS .
I Test error rate of fS : err(fS ;T ) = 3.09%.
Is this good?
This gap err(fS ;T )− err(fS ;S) = 3.09% means we have overfit.
7 / 30
A better way to evaluate the classifier
I Split the labeled examples {(xi, yi)}ni=1 into two sets (randomly).
I Training data S.I Test data T .
I Only use training data S to construct NN classifier fS .
I Training error rate of fS : err(fS ;S) = 0%.
I Use test data T to evaluate accuracy of fS .
I Test error rate of fS : err(fS ;T ) = 3.09%.
Is this good?
This gap err(fS ;T )− err(fS ;S) = 3.09% means we have overfit.
7 / 30
A better way to evaluate the classifier
I Split the labeled examples {(xi, yi)}ni=1 into two sets (randomly).
I Training data S.I Test data T .
I Only use training data S to construct NN classifier fS .
I Training error rate of fS : err(fS ;S) = 0%.
I Use test data T to evaluate accuracy of fS .
I Test error rate of fS : err(fS ;T ) = 3.09%.
Is this good?
This gap err(fS ;T )− err(fS ;S) = 3.09% means we have overfit.
7 / 30
A better way to evaluate the classifier
I Split the labeled examples {(xi, yi)}ni=1 into two sets (randomly).
I Training data S.I Test data T .
I Only use training data S to construct NN classifier fS .
I Training error rate of fS : err(fS ;S) = 0%.
I Use test data T to evaluate accuracy of fS .
I Test error rate of fS : err(fS ;T ) = 3.09%.
Is this good?
This gap err(fS ;T )− err(fS ;S) = 3.09% means we have overfit.
7 / 30
A better way to evaluate the classifier
I Split the labeled examples {(xi, yi)}ni=1 into two sets (randomly).
I Training data S.I Test data T .
I Only use training data S to construct NN classifier fS .
I Training error rate of fS : err(fS ;S) = 0%.
I Use test data T to evaluate accuracy of fS .
I Test error rate of fS : err(fS ;T ) = 3.09%.
Is this good?
This gap err(fS ;T )− err(fS ;S) = 3.09% means we have overfit.
7 / 30
A better way to evaluate the classifier
I Split the labeled examples {(xi, yi)}ni=1 into two sets (randomly).
I Training data S.I Test data T .
I Only use training data S to construct NN classifier fS .
I Training error rate of fS : err(fS ;S) = 0%.
I Use test data T to evaluate accuracy of fS .
I Test error rate of fS : err(fS ;T ) = 3.09%.
Is this good?
This gap err(fS ;T )− err(fS ;S) = 3.09% means we have overfit.
7 / 30
Why does NN work on new data?
Consider any new point x.As training set size increases,kth nearest neighbor becomes closer and closer.
1.0 0.5 0.0 0.5 1.0
1.0
0.5
0.0
0.5
1.0
Note. Can prove this so long as k/n→ 0!
8 / 30
Why does NN work on new data?
Consider any new point x.As training set size increases,kth nearest neighbor becomes closer and closer.
1.0 0.5 0.0 0.5 1.0
1.0
0.5
0.0
0.5
1.0
Note. Can prove this so long as k/n→ 0!
8 / 30
Why does NN work on new data?
Consider any new point x.As training set size increases,kth nearest neighbor becomes closer and closer.
1.0 0.5 0.0 0.5 1.0
1.0
0.5
0.0
0.5
1.0
Note. Can prove this so long as k/n→ 0!
8 / 30
Why does NN work on new data?
Consider any new point x.As training set size increases,kth nearest neighbor becomes closer and closer.
1.0 0.5 0.0 0.5 1.0
1.0
0.5
0.0
0.5
1.0
Note. Can prove this so long as k/n→ 0!
8 / 30
Why does NN work on new data?
Consider any new point x.As training set size increases,kth nearest neighbor becomes closer and closer.
1.0 0.5 0.0 0.5 1.0
1.0
0.5
0.0
0.5
1.0
Note. Can prove this so long as k/n→ 0!
8 / 30
Why does NN work on new data?
Consider any new point x.As training set size increases,kth nearest neighbor becomes closer and closer.
1.0 0.5 0.0 0.5 1.0
1.0
0.5
0.0
0.5
1.0
Note. Can prove this so long as k/n→ 0!
8 / 30
Why does NN work on new data?
Consider any new point x.As training set size increases,kth nearest neighbor becomes closer and closer.
1.0 0.5 0.0 0.5 1.0
1.0
0.5
0.0
0.5
1.0
Note. Can prove this so long as k/n→ 0!
8 / 30
Why does NN work on new data?
Consider any new point x.As training set size increases,kth nearest neighbor becomes closer and closer.
1.0 0.5 0.0 0.5 1.0
1.0
0.5
0.0
0.5
1.0
Note. Can prove this so long as k/n→ 0!
8 / 30
Why does NN work on new data?
Consider any new point x.As training set size increases,kth nearest neighbor becomes closer and closer.
1.0 0.5 0.0 0.5 1.0
1.0
0.5
0.0
0.5
1.0
Note. Can prove this so long as k/n→ 0!
8 / 30
Why does NN work on new data?
Consider any new point x.As training set size increases,kth nearest neighbor becomes closer and closer.
1.0 0.5 0.0 0.5 1.0
1.0
0.5
0.0
0.5
1.0
Note. Can prove this so long as k/n→ 0!
8 / 30
3. Upgrading the nearest neighbor rule
Diagnostics
I Some mistakes made by the NN classifier
(test point in T , nearest neighbor in S):
I First mistake (correct label is “2”) could’ve been avoided by looking atthe three nearest neighbors (whose labels are “8”, “2”, and “2”).
test point three nearest neighbors
9 / 30
Diagnostics
I Some mistakes made by the NN classifier
(test point in T , nearest neighbor in S):
I First mistake (correct label is “2”) could’ve been avoided by looking atthe three nearest neighbors (whose labels are “8”, “2”, and “2”).
test point three nearest neighbors
9 / 30
k-nearest neighbors classifier
Given: labeled examples D := {(xi, yi)}ni=1
Predictor: fD,k : X → Y:
On input x,
1. Find the k points xi1 ,xi2 , . . . ,xik among {xi}ni=1 “closest” to x(the k nearest neighbors).
2. Return the plurality of yi1 , yi2 , . . . , yik .
(Break ties in both steps arbitrarily.)
10 / 30
Effect of k
I Smaller k: smaller training error rate.
I Larger k: higher training error rate, but predictions are more “stable” dueto voting.
OCR digits classificationk 1 3 5 7 9
Test error rate 0.0309 0.0295 0.0312 0.0306 0.0341
11 / 30
Choosing k
The hold-out set approach
1. Pick a subset V ⊂ S (hold-out set, a.k.a. validation set).2. For each k ∈ {1, 3, 5, . . . }:
I Construct k-NN classifier fS\V,k using S \ V .
I Compute error rate of fS\V,k on V(“hold-out error rate”).
3. Pick the k that gives the smallest hold-out error rate.
(There are many other approaches.)
12 / 30
Choosing k
The hold-out set approach
1. Pick a subset V ⊂ S (hold-out set, a.k.a. validation set).2. For each k ∈ {1, 3, 5, . . . }:
I Construct k-NN classifier fS\V,k using S \ V .
I Compute error rate of fS\V,k on V(“hold-out error rate”).
3. Pick the k that gives the smallest hold-out error rate.
(There are many other approaches.)
12 / 30
4. Other issues with nearest neighbor prediction
Better distance functions
I Strings: edit distance
dist(u, v) = # insertions/deletions/mutations needed to change u to v.
I Images: shape context distance
dist(u, v) = how much “warping” is required to change u to v.
I Audio waveforms: dynamic time warping
I Etc.
OCR digits classificationDistance `2 `3 Tangent Shape
Test error rate 3.09% 2.83% 1.10% 0.63%
13 / 30
Better distance functions
I Strings: edit distance
dist(u, v) = # insertions/deletions/mutations needed to change u to v.
I Images: shape context distance
dist(u, v) = how much “warping” is required to change u to v.
I Audio waveforms: dynamic time warping
I Etc.
OCR digits classificationDistance `2 `3 Tangent Shape
Test error rate 3.09% 2.83% 1.10% 0.63%
13 / 30
Better distance functions
I Strings: edit distance
dist(u, v) = # insertions/deletions/mutations needed to change u to v.
I Images: shape context distance
dist(u, v) = how much “warping” is required to change u to v.
I Audio waveforms: dynamic time warping
I Etc.
OCR digits classificationDistance `2 `3 Tangent Shape
Test error rate 3.09% 2.83% 1.10% 0.63%
13 / 30
Better distance functions
I Strings: edit distance
dist(u, v) = # insertions/deletions/mutations needed to change u to v.
I Images: shape context distance
dist(u, v) = how much “warping” is required to change u to v.
I Audio waveforms: dynamic time warping
I Etc.
OCR digits classificationDistance `2 `3 Tangent Shape
Test error rate 3.09% 2.83% 1.10% 0.63%
13 / 30
Bad features
Caution: nearest neighbor classifier can be broken by bad/noisy features!
Feature 1 𝑦 = 0 𝑦 = 1
Feature 1
Feature 2
𝑦 = 0 𝑦 = 1
Curse of dimension. Given poly(d) random unit norm points in Rd,with probability > 99%, each is squared distance 2±O
(1/
√d)
from all others.
Looking ahead: if we train a deep network on other data,we can chop it in the middle, and those (k � d) outputs become NN features!
14 / 30
Computation
I Naıve method for computing NN predictions:O(n) distance computations.
I Better: organize training data in a data structure to improve look-up time.
I Space: O(nd) for n points in Rd.I Query time: O(2d logn) time in worst-case.
I Finding an “approximate” NN can be more efficient.
E.g., how to quickly find a point among the top-1% closest points?
I Popular technique: Locality sensitive hashing
15 / 30
Computation
I Naıve method for computing NN predictions:O(n) distance computations.
I Better: organize training data in a data structure to improve look-up time.
I Space: O(nd) for n points in Rd.I Query time: O(2d logn) time in worst-case.
I Finding an “approximate” NN can be more efficient.
E.g., how to quickly find a point among the top-1% closest points?
I Popular technique: Locality sensitive hashing
15 / 30
Computation
I Naıve method for computing NN predictions:O(n) distance computations.
I Better: organize training data in a data structure to improve look-up time.
I Space: O(nd) for n points in Rd.I Query time: O(2d logn) time in worst-case.
I Finding an “approximate” NN can be more efficient.
E.g., how to quickly find a point among the top-1% closest points?
I Popular technique: Locality sensitive hashing
15 / 30
Computation
I Naıve method for computing NN predictions:O(n) distance computations.
I Better: organize training data in a data structure to improve look-up time.
I Space: O(nd) for n points in Rd.I Query time: O(2d logn) time in worst-case.
I Finding an “approximate” NN can be more efficient.
E.g., how to quickly find a point among the top-1% closest points?
I Popular technique: Locality sensitive hashing
15 / 30
Computation
I Naıve method for computing NN predictions:O(n) distance computations.
I Better: organize training data in a data structure to improve look-up time.
I Space: O(nd) for n points in Rd.I Query time: O(2d logn) time in worst-case.
I Finding an “approximate” NN can be more efficient.
E.g., how to quickly find a point among the top-1% closest points?
I Popular technique: Locality sensitive hashing
15 / 30
Computation
I Naıve method for computing NN predictions:O(n) distance computations.
I Better: organize training data in a data structure to improve look-up time.
I Space: O(nd) for n points in Rd.I Query time: O(2d logn) time in worst-case.
I Finding an “approximate” NN can be more efficient.
E.g., how to quickly find a point among the top-1% closest points?
I Popular technique: Locality sensitive hashing
15 / 30
5. Decision trees
Decision trees
Directly optimize tree structure for good classification.
A decision tree is a function f : X → Y, represented by a binary tree in which:
I Each tree node is associated with a splitting rule g : X → {0, 1}.I Each leaf node is associated with a label y ∈ Y.
When X = Rd, typically only considersplitting rules of the form
g(x) = 1{xi > t}
for some i ∈ [d] and t ∈ R.Called axis-aligned or coordinate splits.
(Notation: [d] := {1, 2, . . . , d}.)
x1 > 1.7
x2 > 2.8y = 1
y = 2 y = 3
16 / 30
Decision trees
Directly optimize tree structure for good classification.
A decision tree is a function f : X → Y, represented by a binary tree in which:
I Each tree node is associated with a splitting rule g : X → {0, 1}.I Each leaf node is associated with a label y ∈ Y.
When X = Rd, typically only considersplitting rules of the form
g(x) = 1{xi > t}
for some i ∈ [d] and t ∈ R.Called axis-aligned or coordinate splits.
(Notation: [d] := {1, 2, . . . , d}.)
x1 > 1.7
x2 > 2.8y = 1
y = 2 y = 3
16 / 30
Decision tree example
sepal length/width1.5 2 2.5 3
peta
l le
ngth
/wid
th
2
2.5
3
3.5
4
4.5
5
5.5
6 Classifying irises by sepal and petalmeasurements
I X = R2, Y = {1, 2, 3}I x1 = ratio of sepal length to width
I x2 = ratio of petal length to width
17 / 30
Decision tree example
sepal length/width1.5 2 2.5 3
peta
l le
ngth
/wid
th
2
2.5
3
3.5
4
4.5
5
5.5
6 Classifying irises by sepal and petalmeasurements
I X = R2, Y = {1, 2, 3}I x1 = ratio of sepal length to width
I x2 = ratio of petal length to width
y = 2
17 / 30
Decision tree example
sepal length/width1.5 2 2.5 3
peta
l le
ngth
/wid
th
2
2.5
3
3.5
4
4.5
5
5.5
6 Classifying irises by sepal and petalmeasurements
I X = R2, Y = {1, 2, 3}I x1 = ratio of sepal length to width
I x2 = ratio of petal length to width
x1 > 1.7
17 / 30
Decision tree example
sepal length/width1.5 2 2.5 3
peta
l le
ngth
/wid
th
2
2.5
3
3.5
4
4.5
5
5.5
6 Classifying irises by sepal and petalmeasurements
I X = R2, Y = {1, 2, 3}I x1 = ratio of sepal length to width
I x2 = ratio of petal length to width
x1 > 1.7
y = 1 y = 3
17 / 30
Decision tree example
sepal length/width1.5 2 2.5 3
peta
l le
ngth
/wid
th
2
2.5
3
3.5
4
4.5
5
5.5
6 Classifying irises by sepal and petalmeasurements
I X = R2, Y = {1, 2, 3}I x1 = ratio of sepal length to width
I x2 = ratio of petal length to width
x1 > 1.7
x2 > 2.8y = 1
17 / 30
Decision tree example
sepal length/width1.5 2 2.5 3
peta
l le
ngth
/wid
th
2
2.5
3
3.5
4
4.5
5
5.5
6 Classifying irises by sepal and petalmeasurements
I X = R2, Y = {1, 2, 3}I x1 = ratio of sepal length to width
I x2 = ratio of petal length to width
x1 > 1.7
x2 > 2.8y = 1
y = 2 y = 3
17 / 30
Basic decision tree learning algorithm
y = 2
−→
x1 > 1.7
y = 1 y = 3 −→
x1 > 1.7
x2 > 2.8y = 1
y = 2 y = 3
Basic “top-down” greedy algorithm
I Initially, tree is a single leaf node containing all (training) data.
I Loop:
I Pick the leaf ` and rule h that maximally reduces uncertainty.I Split data in ` using h, and grow tree accordingly.
. . . until some stopping criterion is satisfied.
[Label of a leaf is the plurality label among the data contained in the leaf.]
18 / 30
Notions of uncertainty
Notions of uncertainty: binary case (Y = {0, 1})Suppose in a set of examples S ⊆ X × {0, 1}, a p fraction are labeled as 1.
1. Classification error:
u(S) := min{p, 1− p}
2. Gini index:
u(S) := 2p(1− p)
3. Entropy:
u(S) := p log1
p+(1−p) log 1
1− pp
0 0.2 0.4 0.6 0.8 10
0.05
0.1
0.15
0.2
0.25
0.3
0.35
0.4
0.45
0.5
Gini index and entropy (after somerescaling) are concave upper-boundson classification error.
19 / 30
Notions of uncertainty
Notions of uncertainty: general case (Y = {1, 2, . . . , K})Suppose in S ⊆ X × Y, a py fraction are labeled as y (for each y ∈ Y).
1. Classification error:u(S) := 1−max
y∈Ypy
2. Gini index:u(S) := 1−
∑y∈Y
p2y
3. Entropy:
u(S) :=∑y∈Y
py log1
py
Each is maximized when py = 1/K for all y ∈ Y(i.e., equal numbers of each label in S).
Each is minimized when py = 1 for a single label y ∈ Y(so S is pure in label).
20 / 30
Notions of uncertainty
Notions of uncertainty: general case (Y = {1, 2, . . . , K})Suppose in S ⊆ X × Y, a py fraction are labeled as y (for each y ∈ Y).
1. Classification error:u(S) := 1−max
y∈Ypy
2. Gini index:u(S) := 1−
∑y∈Y
p2y
3. Entropy:
u(S) :=∑y∈Y
py log1
py
Each is maximized when py = 1/K for all y ∈ Y(i.e., equal numbers of each label in S).
Each is minimized when py = 1 for a single label y ∈ Y(so S is pure in label).
20 / 30
Limitations of uncertainty notions
Suppose X = R2 and Y = {red, blue}, and the data is as follows:
x1
x2
Every split of the form 1{xi > t} provides no reduction in uncertainty (whetherbased on classification error, Gini index, or entropy).
Upshot:Zero reduction in uncertainty may not be a good stopping condition.
21 / 30
Limitations of uncertainty notions
Suppose X = R2 and Y = {red, blue}, and the data is as follows:
x1
x2
Every split of the form 1{xi > t} provides no reduction in uncertainty (whetherbased on classification error, Gini index, or entropy).
Upshot:Zero reduction in uncertainty may not be a good stopping condition.
21 / 30
Basic decision tree learning algorithm
y = 2
−→
x1 > 1.7
y = 1 y = 3 −→
x1 > 1.7
x2 > 2.8y = 1
y = 2 y = 3
Basic “top-down” greedy algorithm
I Initially, tree is a single leaf node containing all (training) data.
I Loop:
I Pick the leaf ` and rule h that maximally reduces uncertainty.I Split data in ` using h, and grow tree accordingly.
. . . until some stopping criterion is satisfied.
[Label of a leaf is the plurality label among the data contained in the leaf.]
22 / 30
Stopping criterion
Many alternatives; two common choices are:
1. Stop when the tree reaches a pre-specified size.
Involves setting additional “tuning parameters” (similar to k in k-NN).
2. Stop when every leaf is pure. (More common.)
Serious danger of overfitting spurious structure due to sampling.
23 / 30
Stopping criterion
Many alternatives; two common choices are:
1. Stop when the tree reaches a pre-specified size.
Involves setting additional “tuning parameters” (similar to k in k-NN).
2. Stop when every leaf is pure. (More common.)
Serious danger of overfitting spurious structure due to sampling.
23 / 30
Stopping criterion
Many alternatives; two common choices are:
1. Stop when the tree reaches a pre-specified size.
Involves setting additional “tuning parameters” (similar to k in k-NN).
2. Stop when every leaf is pure. (More common.)
Serious danger of overfitting spurious structure due to sampling.
23 / 30
Stopping criterion
Many alternatives; two common choices are:
1. Stop when the tree reaches a pre-specified size.
Involves setting additional “tuning parameters” (similar to k in k-NN).
2. Stop when every leaf is pure. (More common.)
Serious danger of overfitting spurious structure due to sampling.
23 / 30
Stopping criterion
Many alternatives; two common choices are:
1. Stop when the tree reaches a pre-specified size.
Involves setting additional “tuning parameters” (similar to k in k-NN).
2. Stop when every leaf is pure. (More common.)
Serious danger of overfitting spurious structure due to sampling.
23 / 30
Stopping criterion
Many alternatives; two common choices are:
1. Stop when the tree reaches a pre-specified size.
Involves setting additional “tuning parameters” (similar to k in k-NN).
2. Stop when every leaf is pure. (More common.)
Serious danger of overfitting spurious structure due to sampling.
x1
x2
23 / 30
Stopping criterion
Many alternatives; two common choices are:
1. Stop when the tree reaches a pre-specified size.
Involves setting additional “tuning parameters” (similar to k in k-NN).
2. Stop when every leaf is pure. (More common.)
Serious danger of overfitting spurious structure due to sampling.
x1
x2
23 / 30
Stopping criterion
Many alternatives; two common choices are:
1. Stop when the tree reaches a pre-specified size.
Involves setting additional “tuning parameters” (similar to k in k-NN).
2. Stop when every leaf is pure. (More common.)
Serious danger of overfitting spurious structure due to sampling.
x1
x2
23 / 30
Overfitting
number of nodes in tree
error
training error
true error
I Training error goes to zero as the number of nodes in the tree increases.
I True error decreases initially, but eventually increases due to overfitting.
24 / 30
What can be done about overfitting?
Preventing overfittingSplit training data S into two parts, S′ and S′′:
I Use first part S′ to grow the tree until all leaves are pure.
I Use second part S′′ to choose a good pruning of the tree.
Pruning algorithmLoop:
I Replace any tree node by a leaf node if it improves the error on S′′.
. . . until no more such improvements possible.
This can be done efficiently using dynamic programming (bottom-up traversalof the tree).
Independence of S′ and S′′ make it unlikely for spurious structures in eachto perfectly align.
25 / 30
What can be done about overfitting?
Preventing overfittingSplit training data S into two parts, S′ and S′′:
I Use first part S′ to grow the tree until all leaves are pure.
I Use second part S′′ to choose a good pruning of the tree.
Pruning algorithmLoop:
I Replace any tree node by a leaf node if it improves the error on S′′.
. . . until no more such improvements possible.
This can be done efficiently using dynamic programming (bottom-up traversalof the tree).
Independence of S′ and S′′ make it unlikely for spurious structures in eachto perfectly align.
25 / 30
What can be done about overfitting?
Preventing overfittingSplit training data S into two parts, S′ and S′′:
I Use first part S′ to grow the tree until all leaves are pure.
I Use second part S′′ to choose a good pruning of the tree.
Pruning algorithmLoop:
I Replace any tree node by a leaf node if it improves the error on S′′.
. . . until no more such improvements possible.
This can be done efficiently using dynamic programming (bottom-up traversalof the tree).
Independence of S′ and S′′ make it unlikely for spurious structures in eachto perfectly align.
25 / 30
What can be done about overfitting?
Preventing overfittingSplit training data S into two parts, S′ and S′′:
I Use first part S′ to grow the tree until all leaves are pure.
I Use second part S′′ to choose a good pruning of the tree.
Pruning algorithmLoop:
I Replace any tree node by a leaf node if it improves the error on S′′.
. . . until no more such improvements possible.
This can be done efficiently using dynamic programming (bottom-up traversalof the tree).
Independence of S′ and S′′ make it unlikely for spurious structures in eachto perfectly align.
25 / 30
Overfitting and (model complexity)
With both k-nn and DT, overfitting (large test-vs-train gap) when:
I k in k-nn is too large, or
I DT’s number of leaves (or a surrogate like depth) are too large.
More generally, high model complexity leads to overfitting.Looking forward. In neural networks, the complexity notion is less obvious.
26 / 30
Overfitting and (model complexity)
With both k-nn and DT, overfitting (large test-vs-train gap) when:
I k in k-nn is too large, or
I DT’s number of leaves (or a surrogate like depth) are too large.
More generally, high model complexity leads to overfitting.Looking forward. In neural networks, the complexity notion is less obvious.
26 / 30
Overfitting and (model complexity)
With both k-nn and DT, overfitting (large test-vs-train gap) when:
I k in k-nn is too large, or
I DT’s number of leaves (or a surrogate like depth) are too large.
More generally, high model complexity leads to overfitting.
Looking forward. In neural networks, the complexity notion is less obvious.
26 / 30
Overfitting and (model complexity)
With both k-nn and DT, overfitting (large test-vs-train gap) when:
I k in k-nn is too large, or
I DT’s number of leaves (or a surrogate like depth) are too large.
More generally, high model complexity leads to overfitting.Looking forward. In neural networks, the complexity notion is less obvious.
26 / 30
Example: Spam filtering
Data
I 4601 e-mail messages, 39.4% are spam.
I Y = {spam, not spam}I E-mails represented by 57 features:
I 48: percentange of e-mail words that is specific word(e.g., “free”, “business”)
I 6: percentage of e-mail characters that is specificcharacter (e.g., “!”).
I 3: other features (e.g., average length of ALL-CAPS words).
ResultsUsing variant of greedy algorithm to grow tree; prune tree using validation set.
Chosen tree has just 17 leaves. Test error is 9.3%.
y = not spam y = spam
y = not spam 57.3% 4.0%y = spam 5.3% 33.4%
27 / 30
Example: Spam filteringElements of Statistical Learning (2nd Ed.) c!Hastie, Tibshirani & Friedman 2009 Chap 9
600/1536
280/1177
180/1065
80/861
80/652
77/423
20/238
19/236 1/2
57/185
48/113
37/101 1/12
9/72
3/229
0/209
100/204
36/123
16/94
14/89 3/5
9/29
16/81
9/112
6/109 0/3
48/359
26/337
19/110
18/109 0/1
7/227
0/22
spam
spam
spam
spam
spam
spam
spam
spam
spam
spam
spam
spam
ch$<0.0555
remove<0.06
ch!<0.191
george<0.005
hp<0.03
CAPMAX<10.5
receive<0.125 edu<0.045
our<1.2
CAPAVE<2.7505
free<0.065
business<0.145
george<0.15
hp<0.405
CAPAVE<2.907
1999<0.58
ch$>0.0555
remove>0.06
ch!>0.191
george>0.005
hp>0.03
CAPMAX>10.5
receive>0.125 edu>0.045
our>1.2
CAPAVE>2.7505
free>0.065
business>0.145
george>0.15
hp>0.405
CAPAVE>2.907
1999>0.58
FIGURE 9.5. The pruned tree for the spam example.The split variables are shown in blue on the branches,and the classification is shown in every node.The num-bers under the terminal nodes indicate misclassificationrates on the test data.
28 / 30
6. Final remarks
Nearest neighbors and decision trees
Today we covered two standard machine learning methods.
1.0 0.5 0.0 0.5 1.0
1.0
0.5
0.0
0.5
1.0
x1
x2
Nearest neighbors.Training/fitting: memorize data.Testing/predicting: find k closestmemorized points, return plurity la-bel.Overfitting? Vary k.
Decision trees.Training/fitting: greedily partitionspace, reducing “uncertainty”.Testing/predicting: traverse tree,output leaf label.Overfitting? Limit or prune tree.
Note: both methods can ouput real numbers (regression, not classification);return median/mean of { neighbors, points reaching leaf }. 29 / 30
Key takeaways
1. Two methods: k-nn and decision trees.
I For k-nn, we have specified the complete method for fixed k.I To choose k, we’ve given the heuristic of using a validation set.I For decision tree, we have not fully specified the training procedure,
only how to make decisions.
2. Training error, test error, overfitting.
3. Features.
30 / 30