lstm shannonlab

32

Upload: shannon-lab

Post on 15-Apr-2017

1.455 views

Category:

Data & Analytics


2 download

TRANSCRIPT

Page 1: Lstm shannonlab
Page 2: Lstm shannonlab

LSTMとはRNNにおいて長い時系列データを処理する仕組み

Page 3: Lstm shannonlab

LSTMとは?

Long Short-Term Memory「長・短期記憶」のこと。

そもそも単純なRNNでは長期にわたる記憶ができない。LSTMはこの問題点を改善する1手法。

それってどういうこと?そもそもRNNってなんだっけ?

Page 4: Lstm shannonlab

RNNってなんだったっけ・・・?

Recurrent Neural Network(再帰型ニューラルネットワーク)の略。

通常のDNNには時系列データを処理する仕組みがない。

RNNは時系列データを処理できる。

Page 5: Lstm shannonlab

RNNの順伝播における入力のしかた(1)

画像 および を入力する場合

t=1, 2, 3

1 2 1

7 9 8

2 3 4

é

ë

êêê

ù

û

úúú

10 20 30

40 50 60

70 80 90

é

ë

êêê

ù

û

úúú

出力層 入力層中間層

1 2 1

7 9 8

2 3 4

t=1, 2, 3

0.1 0.2 0.2

0.6 0.2 0.7

0.3 0.6 0.1

1つ目の画像(データセット1)

画像1つが訓練データ1つに、縦のピクセルが入力層のノード番号に、横のピクセルが入力順のt番目に、それぞれあたります。

入力順→

ノード番号1

Page 6: Lstm shannonlab

RNNの順伝播における入力のしかた(2)

画像 および を入力する場合

t=1, 2, 3

1 2 1

7 9 8

2 3 4

é

ë

êêê

ù

û

úúú

10 20 30

40 50 60

70 80 90

é

ë

êêê

ù

û

úúú

出力層 入力層中間層

10 20 30

40 50 60

70 80 90

t=1, 2, 3

0.8 0.4 0.1

0.1 0.2 0.9

0.1 0.4 0.0

2つ目の画像(データセット2) 入力順→

ノード番号1

Page 7: Lstm shannonlab

1つのデータセットにおける誤差は

n番目のデータセット における誤差 は

t=1

t=2

0

1

0

1

入力層 出力層中間層

an0

2

an0

1

an11

bn12

an12

bn0

2

bn0

1

bn11

yn0

1

yn11

yn12

yn0

1

0

1

0

1

0

1

0

1

目標値出力値

誤差

入力値

bn0

1 bn0

2

bn11 bn1

2

é

ë

êê

ù

û

úú En

En = f ynkt ,ank

t( )k

åt

å f:誤差関数

Page 8: Lstm shannonlab

ミニバッチの場合の誤差は・・・

全N個のデータセットからm個のデータセットでミニバッチをつくる。

このミニバッチを用いて確率的勾配降下法で学習する。

まずn番目のデータセットの誤差 は

これをm個データセットにわたり足す。

例えば、誤差関数が交差エントロピー であれば

En

En = f ynkt ,ank

t( )k

åt=1,2

å である。

E = Enn=1

m

å = f ynkt ,ank

t( )k

åt=1,2

ån=1

m

å

f ynkt ,ank

t( ) = ankt log ynk

t

E = dnkt log ynk

t

k

åt=1,2

ån=1

m

å となる。

Page 9: Lstm shannonlab

重みを更新するには何を求めればいい?

RNNでは入力層 i 番目ノードから中間層 j 番目ノードへの t 時刻目における重み は

ここで は と変形できます。

そうすると、このδは などと出力側のδの関数として表されます。

よって、δは出力側から再帰的に次々と決まっていきます。

そこでまず、各δを再帰的に求め、その後δから重みwの更新量を求めます。

w jt( )¢ = w j

t -e¶E

¶w jt

w jt

として更新します。

¶E

¶w jt

=¶E

¶u jt

¶u jt

¶w jt

= d jt ¶u j

t

¶w jt

¶E

¶w jt

d jt = f dk

t( )

δを再帰的に求めれば、重みは求まる。これが誤差逆伝播法。

Page 10: Lstm shannonlab

誤差逆伝播でδを求めてみる

t=1における隠れ層ノード 0番目の はt=1の出力層からのδとt=2の隠れ層からのδからなる

d0

1 = wk0

out,1dkout,1 +

k=0,1

å w j '0d j '2

j '=0,1

åæ

èçç

ö

ø÷÷ f ' u0

1( )

d0

1

t=1

t=2

0

入力層 i 出力層 k中間層 j

∵ d jt =

¶E

¶u jt

æ

èçç

ö

ø÷÷

d0

1d0

out,1

d1

out,1

d1

2

d0

211

1

0

0

u0

1

Page 11: Lstm shannonlab

RNNの勾配消失問題とは

tが多くなると、隠れ層ノードのδは何度も掛け算される。

t=1

t=2

j=0

j=1

k=0

k=1

d0

1 = wk0

out,1dkout,1 +

k=0,1

å w j '0d j '2

j '=0,1

åæ

èçç

ö

ø÷÷ f ' u0

1( )

t=3

d0

2 = wk0

out,2dkout,2 +

k=0,1

å w j '0d j '3

j '=0,1

åæ

èçç

ö

ø÷÷ f ' u0

2( )

δが伝播するにつれてwやf’(u)が何度もかけ算される

δは指数関数的に増加、もしくは減少

Page 12: Lstm shannonlab

RNNの勾配消失問題とは

tが多くなると、隠れ層ノードのδにはwやf’(u)が何度も掛け算される

δは指数関数的に増加、もしくは減少する

勾配 も指数関数的に増加、もしくは減少する¶E

¶w= d

¶u

¶w

Page 13: Lstm shannonlab

勾配消失問題を解決する

Constant Error Carousel(CEC)で問題を解決する

入力層 出力層

u jt = w jixi

i

å

s jt = f u j

t( )+ s jt-1

f u jt( ) z j

t = f s jt( )s j

t

Page 14: Lstm shannonlab

CECの順伝播は

CECの順伝播は中央のメモリセルユニットから次の時刻のメモリセルユニットへ伝播する

入力層 出力層

Page 15: Lstm shannonlab

CECの誤差逆伝播はδが消失しない

CECの逆伝播は中央のメモリセルユニットから前の時刻へ伝播する

d0

1 = wk0

out,1dkout,1 +

k=0,1

å w0 j 'd j '2

æ

èçç

ö

ø÷÷ f ' u0

1( ) d0

2 = wk0

out,2dkout,2 +

k=0,1

å w0 j 'd j '3

æ

èçç

ö

ø÷÷ f ' u0

2( )

, なので、δは増加や減少が少ない

勾配消失が起きにくい

f ' u( ) =1

しかし全てのデータを記憶し続けるため、外れ値やノイズも溜め込んでしまう

w0 j =1

Page 16: Lstm shannonlab

CECへの入力を制御する

Imput gateを設けてCECへの入力を制御する

Imput gateは0から1の連続的な値をとり、0で入力を遮断し、1で入力を全て通過させる

入力層 出力層

f

出力層へ流れる値も制御したい

f

0から1の連続的な値をとる

f

Page 17: Lstm shannonlab

CECからの出力を制御する

Output gateを設けてCECからの出力を制御する

Output gateは0から1の連続的な値をとり、0で出力を遮断し、1で出力を全て通過させる

入力層 出力層

f

CECの値がなかなか更新されない

f f

f

Page 18: Lstm shannonlab

CECの値をすぐ更新できるようにする

Forget gateを設けてCECからの再帰出力を制御する

Forget gateは0から1の連続的な値をとり、0で出力を遮断し、1で出力を全て通過させる

入力層 出力層

f

各gateへの入力は、現在の入力系列の値と1時刻前の隠れ層からの出力である。

f ff

f

後者はoutput gateで制御されるため、隠れ状態が反映されない

Page 19: Lstm shannonlab

長期依存を可能とするLSTM

CECの記憶そのものをgateの制御に使う

入力層 出力層

f

ff

f

f

Page 20: Lstm shannonlab

LSTMの順伝播の計算(1)メモリセル

メモリセルからの出力は入力と同じく、忘却ゲートからの出力と入力ゲートからの出力からなる。

ここで は同時刻の入力層からの出力と、1時刻前の中間層からの出力の和である。

入力層出力層

xit, z j '

t-1

s jt = gj

I ,t f ujt( ) + gj

F,ts jt-1

u jt

f

f u jt( )

g jI ,t f uj

t( )s jt

g jF,ts j

t-1

s jt-1

u jt

u jt = w ji

inxit +

i

å w jj 'z j 't-1

j '

å

メモリセル

Page 21: Lstm shannonlab

LSTMの順伝播の計算(2)入力ゲート

入力ゲートへの入力は、同時刻の入力層からの入力、1時刻前の中間層からの出力、1時刻前のメモリユニットからの出力である。

これらの入力をシグモイド関数などの活性化関数で処理したものを出力する。

出力値は0から1の間となり、0はゲートが完全に閉じた状態、1はゲートが完全に開いた状態となる。

入力層出力層

xit, z j '

t-1

u jI ,t

f u jI ,t( )

s jt-1

g jI ,t = f uj

I ,t( ) = f w jiI ,inxi

t +i

å w jj 'I z j '

t-1

j '

å + s jt-1

æ

èçç

ö

ø÷÷

g jI ,t

入力ゲート

Page 22: Lstm shannonlab

LSTMの順伝播の計算(3)忘却ゲート

忘却ゲートへの入力は、同時刻の入力層からの入力、1時刻前の中間層からの出力、1時刻前のメモリユニットからの出力である。

これらの入力をシグモイド関数などの活性化関数で処理したものを出力する。

出力値は0から1の間となり、0はゲートが完全に閉じた状態、1はゲートが完全に開いた状態となる。

入力層出力層

xit, z j '

t-1

f u jF,t( )

s jt-1

g jF,t

忘却ゲート

u jF,t

g jF,t = f u j

I ,t( ) = f w jiF,inxi

t +i

å w jj 'F z j '

t-1

j '

å + s jt-1

æ

èçç

ö

ø÷÷

Page 23: Lstm shannonlab

LSTMの順伝播の計算(4)出力ゲート

出力ゲートへの入力は、同時刻の入力層からの入力、1時刻前の中間層からの出力、同時刻のメモリユニットからの出力である。

これらの入力をシグモイド関数などの活性化関数で処理したものを出力する。

出力値は0から1の間となり、0はゲートが完全に閉じた状態、1はゲートが完全に開いた状態となる。

入力層出力層

f u jO,t( )

s jt

g jO,t

出力ゲート

g jO,t = f u j

O,t( ) = f w jiO,inxi

t +i

å w jj 'O z j '

t-1

j '

å + s jt-1

æ

èçç

ö

ø÷÷

u jO,t

xit, z j '

t-1

Page 24: Lstm shannonlab

LSTMの順伝播の計算(5)全体の出力

メモリユニット全体からの出力は、メモリセルからの出力 を活性化関数で処理したものに対し、出力ゲートを乗じたものとなる。

入力層出力層

z jt = g j

O,t f s jt( )

f s jt( )s j

t

s jt

z jt

Page 25: Lstm shannonlab

LSTMの逆伝播の計算(1)出力ユニット

下図出力ユニットからの出力は同時刻の出力層方向と1時刻後の中間層方向へ順伝播する。よって、出力ユニットの は以下のように展開できる。

入力層t時刻の出力層へ

d jO,t =

¶E

¶ujO,t

=¶E

¶ukout,t

¶ukout,t

¶ujO,t

k

å +¶E

¶uj 't+1

¶u j 't+1

¶ujO,t

j '

å

g jO,t = f u j

O,t( )

出力ユニットu jO,t

f s jt( ) t+1時刻の中間層へ

d jO,t

t時刻の出力層分 t+1時刻の中間層分

Page 26: Lstm shannonlab

¶ukout,t

¶u jO,t

=¶ wkj

outg jO,t f s j

t( )( )¶u j

O,t= wkj

out f s jt( )

¶g jO,t

¶u jO,t

= wkjout f s j

t( ) f ' u jO,t( )

LSTMの逆伝播の計算(1)出力ユニット

よって、 は、t時出力層の 、t+1時中間層の を用いて

d jO,t =

¶E

¶ukout,t

¶ukout,t

¶u jO,t

k

å +¶E

¶u j 't+1

¶u j 't+1

¶u jO,t

j '

å = dkout,t ¶uk

out,t

¶u jO,t

k

å + d j 't+1 ¶uj '

t+1

¶ujO,t

j '

å

d jO,t dk

out,t d j 't+1

である。同様に も計算できるので、結局 は¶u j '

t+1

¶u jO,t

d jO,t = dk

out,twkjout

k

å + d j 't+1w j ' j

j '

åæ

èçç

ö

ø÷÷ f s j

t( ) f ' ujO,t( )

d jO,t

となる。ここで

となる。

Page 27: Lstm shannonlab

LSTM逆伝播の計算(2)活性化ユニット

下図活性化ユニットの出力は同時刻の出力層と1時刻後の中間層へ順伝播する。よって、活性化ユニットの は と同様に以下のようにできる。

入力層t時刻の出力層へ

d jA,t =

¶E

¶s jt

= dkout,twkj

out

k

å + d j 't+1w j ' j

j '

åæ

èçç

ö

ø÷÷g j

O,t f ' s jt( )

活性化ユニットf s j

t( ) t+1時刻の中間層へ

d jA,t

s jt

g jO,t = f u j

O,t( )

d jO,t

Page 28: Lstm shannonlab

LSTM逆伝播の計算(3)メモリセル

メモリセルの は を入力し、 を出力する恒等写像である。

メモリセルの出力先は、外部出力向け、メモリセル自身への帰還、入力ゲート、出力ゲート、忘却ゲートの5つである。

よって、それぞれの を 、 、 、 、 とおくと

入力層

d jcell,t =d j

A,t +gjF,t+1d j

cell,t+1 +d jI ,t+1 +d j

O,t +d jF,t+1

メモリセルf s j

t( )

d jcell,t

s jt

g jO,t = f u j

O,t( )

d jO,t

s jt

s jt

s jt+1

s jt+1

s jt+1

s jt s j

t

d jcell,t+1d j

A,td d j

I ,t+1 d jF,t+1

Page 29: Lstm shannonlab

LSTM逆伝播の計算(4)入力側ユニット

下図入力側ユニットの は を入力し、 を出力する。

その後、入力ゲートを通過してメモリセルへと渡される。よっては

入力層

d jt =

¶E

¶u jt

=¶E

¶s jt

¶s jt

¶u jt

=d jcell,t ¶s j

t

¶u jt

=d jcell,t

¶ g jI ,t f u j

t( )( )¶u j

t= d j

cell,tg jI ,t f ' u j

t( )

入力側ユニット

d jt

u jt

u jt

f u jt( )

gtI ,t f uj

t( )

f u jt( )

d jt

となる。

Page 30: Lstm shannonlab

LSTM逆伝播の計算(5)忘却ユニット

下図忘却ユニットの出力は忘却ゲートで と掛け合わさり、メモリセルへ順伝播する。よって、忘却ユニットの は以下のようにできる。

入力層

d jF,t =

¶E

¶u jF,t

=¶E

¶s jt

¶s jt

¶ujF,t

= d jcell,t ¶s j

t

¶u jF,t

=d jcell,t

¶ s jt-1 f u j

F,t( )( )¶u j

F,t=d j

cell,ts jt-1 f ' u j

F,t( )

f s jt( )

d jF,t

s jt-1

g jF,t = f u j

F,t( )

忘却ユニット

u jF,t

g jF,t

Page 31: Lstm shannonlab

LSTM逆伝播の計算(6)入力ユニット

下図入力ユニットの出力は入力ゲートで と掛け合わさり、メモリセルへ順伝播する。よって、入力ユニットの は以下のようにできる。

入力層

d jI ,t =

¶E

¶ujI,t

=¶E

¶s jt

¶s jt

¶ujI,t

=d jcell,t ¶s j

t

¶u jI,t

= d jcell,t

¶ f ujt( ) f ujI ,t( )( )¶uj

I ,t= d j

cell,t f ujt( ) f ' u jI ,t( )

d jI ,t

s jt

入力ユニット

u jI ,t

f u jt( )

g jI ,t = f uj

I ,t( )s jt-1

f u jt( )

Page 32: Lstm shannonlab

コンサルティング&開発業務に関するお問い合わせ

アイフォーコム東京株式会社 (電気需要予測に関してこちら)〒221-0835 横浜市神奈川区鶴屋町3-29-11 アイフォーコム横浜ビル電話:045-412-3010(代表)

Shannon Lab株式会社 (機械学習に関してはこちら)〒165-0026 東京都中野区新井5丁目29-1西武信用金庫新井薬師ビル503電話: 042-644-0013(代表)

本事業はアイフォーコム東京株式会社様の電気需要予測プロジェクトの一環です。