lstm shannonlab
TRANSCRIPT
LSTMとはRNNにおいて長い時系列データを処理する仕組み
LSTMとは?
Long Short-Term Memory「長・短期記憶」のこと。
そもそも単純なRNNでは長期にわたる記憶ができない。LSTMはこの問題点を改善する1手法。
それってどういうこと?そもそもRNNってなんだっけ?
RNNってなんだったっけ・・・?
Recurrent Neural Network(再帰型ニューラルネットワーク)の略。
通常のDNNには時系列データを処理する仕組みがない。
RNNは時系列データを処理できる。
RNNの順伝播における入力のしかた(1)
画像 および を入力する場合
t=1, 2, 3
1 2 1
7 9 8
2 3 4
é
ë
êêê
ù
û
úúú
10 20 30
40 50 60
70 80 90
é
ë
êêê
ù
û
úúú
出力層 入力層中間層
1 2 1
7 9 8
2 3 4
t=1, 2, 3
0.1 0.2 0.2
0.6 0.2 0.7
0.3 0.6 0.1
1つ目の画像(データセット1)
画像1つが訓練データ1つに、縦のピクセルが入力層のノード番号に、横のピクセルが入力順のt番目に、それぞれあたります。
入力順→
ノード番号1
2
3
RNNの順伝播における入力のしかた(2)
画像 および を入力する場合
t=1, 2, 3
1 2 1
7 9 8
2 3 4
é
ë
êêê
ù
û
úúú
10 20 30
40 50 60
70 80 90
é
ë
êêê
ù
û
úúú
出力層 入力層中間層
10 20 30
40 50 60
70 80 90
t=1, 2, 3
0.8 0.4 0.1
0.1 0.2 0.9
0.1 0.4 0.0
2つ目の画像(データセット2) 入力順→
ノード番号1
2
3
1つのデータセットにおける誤差は
n番目のデータセット における誤差 は
t=1
t=2
0
1
0
1
入力層 出力層中間層
an0
2
an0
1
an11
bn12
an12
bn0
2
bn0
1
bn11
yn0
1
yn11
yn12
yn0
1
0
1
0
1
0
1
0
1
目標値出力値
誤差
入力値
bn0
1 bn0
2
bn11 bn1
2
é
ë
êê
ù
û
úú En
En = f ynkt ,ank
t( )k
åt
å f:誤差関数
ミニバッチの場合の誤差は・・・
全N個のデータセットからm個のデータセットでミニバッチをつくる。
このミニバッチを用いて確率的勾配降下法で学習する。
まずn番目のデータセットの誤差 は
これをm個データセットにわたり足す。
例えば、誤差関数が交差エントロピー であれば
En
En = f ynkt ,ank
t( )k
åt=1,2
å である。
E = Enn=1
m
å = f ynkt ,ank
t( )k
åt=1,2
ån=1
m
å
f ynkt ,ank
t( ) = ankt log ynk
t
E = dnkt log ynk
t
k
åt=1,2
ån=1
m
å となる。
重みを更新するには何を求めればいい?
RNNでは入力層 i 番目ノードから中間層 j 番目ノードへの t 時刻目における重み は
ここで は と変形できます。
そうすると、このδは などと出力側のδの関数として表されます。
よって、δは出力側から再帰的に次々と決まっていきます。
そこでまず、各δを再帰的に求め、その後δから重みwの更新量を求めます。
w jt( )¢ = w j
t -e¶E
¶w jt
w jt
として更新します。
¶E
¶w jt
=¶E
¶u jt
¶u jt
¶w jt
= d jt ¶u j
t
¶w jt
¶E
¶w jt
d jt = f dk
t( )
δを再帰的に求めれば、重みは求まる。これが誤差逆伝播法。
誤差逆伝播でδを求めてみる
t=1における隠れ層ノード 0番目の はt=1の出力層からのδとt=2の隠れ層からのδからなる
d0
1 = wk0
out,1dkout,1 +
k=0,1
å w j '0d j '2
j '=0,1
åæ
èçç
ö
ø÷÷ f ' u0
1( )
d0
1
t=1
t=2
0
入力層 i 出力層 k中間層 j
∵ d jt =
¶E
¶u jt
æ
èçç
ö
ø÷÷
d0
1d0
out,1
d1
out,1
d1
2
d0
211
1
0
0
u0
1
RNNの勾配消失問題とは
tが多くなると、隠れ層ノードのδは何度も掛け算される。
t=1
t=2
j=0
j=1
k=0
k=1
d0
1 = wk0
out,1dkout,1 +
k=0,1
å w j '0d j '2
j '=0,1
åæ
èçç
ö
ø÷÷ f ' u0
1( )
t=3
d0
2 = wk0
out,2dkout,2 +
k=0,1
å w j '0d j '3
j '=0,1
åæ
èçç
ö
ø÷÷ f ' u0
2( )
δが伝播するにつれてwやf’(u)が何度もかけ算される
δは指数関数的に増加、もしくは減少
RNNの勾配消失問題とは
tが多くなると、隠れ層ノードのδにはwやf’(u)が何度も掛け算される
δは指数関数的に増加、もしくは減少する
勾配 も指数関数的に増加、もしくは減少する¶E
¶w= d
¶u
¶w
勾配消失問題を解決する
Constant Error Carousel(CEC)で問題を解決する
入力層 出力層
u jt = w jixi
i
å
s jt = f u j
t( )+ s jt-1
f u jt( ) z j
t = f s jt( )s j
t
CECの順伝播は
CECの順伝播は中央のメモリセルユニットから次の時刻のメモリセルユニットへ伝播する
入力層 出力層
CECの誤差逆伝播はδが消失しない
CECの逆伝播は中央のメモリセルユニットから前の時刻へ伝播する
d0
1 = wk0
out,1dkout,1 +
k=0,1
å w0 j 'd j '2
æ
èçç
ö
ø÷÷ f ' u0
1( ) d0
2 = wk0
out,2dkout,2 +
k=0,1
å w0 j 'd j '3
æ
èçç
ö
ø÷÷ f ' u0
2( )
, なので、δは増加や減少が少ない
勾配消失が起きにくい
f ' u( ) =1
しかし全てのデータを記憶し続けるため、外れ値やノイズも溜め込んでしまう
w0 j =1
CECへの入力を制御する
Imput gateを設けてCECへの入力を制御する
Imput gateは0から1の連続的な値をとり、0で入力を遮断し、1で入力を全て通過させる
入力層 出力層
f
出力層へ流れる値も制御したい
f
0から1の連続的な値をとる
f
CECからの出力を制御する
Output gateを設けてCECからの出力を制御する
Output gateは0から1の連続的な値をとり、0で出力を遮断し、1で出力を全て通過させる
入力層 出力層
f
CECの値がなかなか更新されない
f f
f
CECの値をすぐ更新できるようにする
Forget gateを設けてCECからの再帰出力を制御する
Forget gateは0から1の連続的な値をとり、0で出力を遮断し、1で出力を全て通過させる
入力層 出力層
f
各gateへの入力は、現在の入力系列の値と1時刻前の隠れ層からの出力である。
f ff
f
後者はoutput gateで制御されるため、隠れ状態が反映されない
長期依存を可能とするLSTM
CECの記憶そのものをgateの制御に使う
入力層 出力層
f
ff
f
f
LSTMの順伝播の計算(1)メモリセル
メモリセルからの出力は入力と同じく、忘却ゲートからの出力と入力ゲートからの出力からなる。
ここで は同時刻の入力層からの出力と、1時刻前の中間層からの出力の和である。
入力層出力層
xit, z j '
t-1
s jt = gj
I ,t f ujt( ) + gj
F,ts jt-1
u jt
f
f u jt( )
g jI ,t f uj
t( )s jt
g jF,ts j
t-1
s jt-1
u jt
u jt = w ji
inxit +
i
å w jj 'z j 't-1
j '
å
メモリセル
LSTMの順伝播の計算(2)入力ゲート
入力ゲートへの入力は、同時刻の入力層からの入力、1時刻前の中間層からの出力、1時刻前のメモリユニットからの出力である。
これらの入力をシグモイド関数などの活性化関数で処理したものを出力する。
出力値は0から1の間となり、0はゲートが完全に閉じた状態、1はゲートが完全に開いた状態となる。
入力層出力層
xit, z j '
t-1
u jI ,t
f u jI ,t( )
s jt-1
g jI ,t = f uj
I ,t( ) = f w jiI ,inxi
t +i
å w jj 'I z j '
t-1
j '
å + s jt-1
æ
èçç
ö
ø÷÷
g jI ,t
入力ゲート
LSTMの順伝播の計算(3)忘却ゲート
忘却ゲートへの入力は、同時刻の入力層からの入力、1時刻前の中間層からの出力、1時刻前のメモリユニットからの出力である。
これらの入力をシグモイド関数などの活性化関数で処理したものを出力する。
出力値は0から1の間となり、0はゲートが完全に閉じた状態、1はゲートが完全に開いた状態となる。
入力層出力層
xit, z j '
t-1
f u jF,t( )
s jt-1
g jF,t
忘却ゲート
u jF,t
g jF,t = f u j
I ,t( ) = f w jiF,inxi
t +i
å w jj 'F z j '
t-1
j '
å + s jt-1
æ
èçç
ö
ø÷÷
LSTMの順伝播の計算(4)出力ゲート
出力ゲートへの入力は、同時刻の入力層からの入力、1時刻前の中間層からの出力、同時刻のメモリユニットからの出力である。
これらの入力をシグモイド関数などの活性化関数で処理したものを出力する。
出力値は0から1の間となり、0はゲートが完全に閉じた状態、1はゲートが完全に開いた状態となる。
入力層出力層
f u jO,t( )
s jt
g jO,t
出力ゲート
g jO,t = f u j
O,t( ) = f w jiO,inxi
t +i
å w jj 'O z j '
t-1
j '
å + s jt-1
æ
èçç
ö
ø÷÷
u jO,t
xit, z j '
t-1
LSTMの順伝播の計算(5)全体の出力
メモリユニット全体からの出力は、メモリセルからの出力 を活性化関数で処理したものに対し、出力ゲートを乗じたものとなる。
入力層出力層
z jt = g j
O,t f s jt( )
f s jt( )s j
t
s jt
z jt
LSTMの逆伝播の計算(1)出力ユニット
下図出力ユニットからの出力は同時刻の出力層方向と1時刻後の中間層方向へ順伝播する。よって、出力ユニットの は以下のように展開できる。
入力層t時刻の出力層へ
d jO,t =
¶E
¶ujO,t
=¶E
¶ukout,t
¶ukout,t
¶ujO,t
k
å +¶E
¶uj 't+1
¶u j 't+1
¶ujO,t
j '
å
g jO,t = f u j
O,t( )
出力ユニットu jO,t
f s jt( ) t+1時刻の中間層へ
d jO,t
t時刻の出力層分 t+1時刻の中間層分
¶ukout,t
¶u jO,t
=¶ wkj
outg jO,t f s j
t( )( )¶u j
O,t= wkj
out f s jt( )
¶g jO,t
¶u jO,t
= wkjout f s j
t( ) f ' u jO,t( )
LSTMの逆伝播の計算(1)出力ユニット
よって、 は、t時出力層の 、t+1時中間層の を用いて
d jO,t =
¶E
¶ukout,t
¶ukout,t
¶u jO,t
k
å +¶E
¶u j 't+1
¶u j 't+1
¶u jO,t
j '
å = dkout,t ¶uk
out,t
¶u jO,t
k
å + d j 't+1 ¶uj '
t+1
¶ujO,t
j '
å
d jO,t dk
out,t d j 't+1
である。同様に も計算できるので、結局 は¶u j '
t+1
¶u jO,t
d jO,t = dk
out,twkjout
k
å + d j 't+1w j ' j
j '
åæ
èçç
ö
ø÷÷ f s j
t( ) f ' ujO,t( )
d jO,t
となる。ここで
となる。
LSTM逆伝播の計算(2)活性化ユニット
下図活性化ユニットの出力は同時刻の出力層と1時刻後の中間層へ順伝播する。よって、活性化ユニットの は と同様に以下のようにできる。
入力層t時刻の出力層へ
d jA,t =
¶E
¶s jt
= dkout,twkj
out
k
å + d j 't+1w j ' j
j '
åæ
èçç
ö
ø÷÷g j
O,t f ' s jt( )
活性化ユニットf s j
t( ) t+1時刻の中間層へ
d jA,t
s jt
g jO,t = f u j
O,t( )
d jO,t
LSTM逆伝播の計算(3)メモリセル
メモリセルの は を入力し、 を出力する恒等写像である。
メモリセルの出力先は、外部出力向け、メモリセル自身への帰還、入力ゲート、出力ゲート、忘却ゲートの5つである。
よって、それぞれの を 、 、 、 、 とおくと
入力層
d jcell,t =d j
A,t +gjF,t+1d j
cell,t+1 +d jI ,t+1 +d j
O,t +d jF,t+1
メモリセルf s j
t( )
d jcell,t
s jt
g jO,t = f u j
O,t( )
d jO,t
s jt
s jt
s jt+1
s jt+1
s jt+1
s jt s j
t
d jcell,t+1d j
A,td d j
I ,t+1 d jF,t+1
LSTM逆伝播の計算(4)入力側ユニット
下図入力側ユニットの は を入力し、 を出力する。
その後、入力ゲートを通過してメモリセルへと渡される。よっては
入力層
d jt =
¶E
¶u jt
=¶E
¶s jt
¶s jt
¶u jt
=d jcell,t ¶s j
t
¶u jt
=d jcell,t
¶ g jI ,t f u j
t( )( )¶u j
t= d j
cell,tg jI ,t f ' u j
t( )
入力側ユニット
d jt
u jt
u jt
f u jt( )
gtI ,t f uj
t( )
f u jt( )
d jt
となる。
LSTM逆伝播の計算(5)忘却ユニット
下図忘却ユニットの出力は忘却ゲートで と掛け合わさり、メモリセルへ順伝播する。よって、忘却ユニットの は以下のようにできる。
入力層
d jF,t =
¶E
¶u jF,t
=¶E
¶s jt
¶s jt
¶ujF,t
= d jcell,t ¶s j
t
¶u jF,t
=d jcell,t
¶ s jt-1 f u j
F,t( )( )¶u j
F,t=d j
cell,ts jt-1 f ' u j
F,t( )
f s jt( )
d jF,t
s jt-1
g jF,t = f u j
F,t( )
忘却ユニット
u jF,t
g jF,t
LSTM逆伝播の計算(6)入力ユニット
下図入力ユニットの出力は入力ゲートで と掛け合わさり、メモリセルへ順伝播する。よって、入力ユニットの は以下のようにできる。
入力層
d jI ,t =
¶E
¶ujI,t
=¶E
¶s jt
¶s jt
¶ujI,t
=d jcell,t ¶s j
t
¶u jI,t
= d jcell,t
¶ f ujt( ) f ujI ,t( )( )¶uj
I ,t= d j
cell,t f ujt( ) f ' u jI ,t( )
d jI ,t
s jt
入力ユニット
u jI ,t
f u jt( )
g jI ,t = f uj
I ,t( )s jt-1
f u jt( )
コンサルティング&開発業務に関するお問い合わせ
アイフォーコム東京株式会社 (電気需要予測に関してこちら)〒221-0835 横浜市神奈川区鶴屋町3-29-11 アイフォーコム横浜ビル電話:045-412-3010(代表)
Shannon Lab株式会社 (機械学習に関してはこちら)〒165-0026 東京都中野区新井5丁目29-1西武信用金庫新井薬師ビル503電話: 042-644-0013(代表)
本事業はアイフォーコム東京株式会社様の電気需要予測プロジェクトの一環です。