transformers are graph neural networks

24
Transformers are Graph Neural Networks Chaitanya K. Joshi Graph Deep Learning Reading Group Full Blogpost: https ://graphdeeplearning.github.io/post/transformers-are-gnns/

Upload: others

Post on 22-Jun-2022

17 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: Transformers are Graph Neural Networks

Transformers are Graph Neural Networks

Chaitanya K. Joshi

Graph Deep Learning Reading Group

Full Blogpost:

https://graphdeeplearning.github.io/post/transformers-are-gnns/

Page 2: Transformers are Graph Neural Networks

Some success stories: GNNs for RecSys

Page 3: Transformers are Graph Neural Networks

Some success stories: GNNs for RecSys

Page 4: Transformers are Graph Neural Networks

Another success story? The Transformer architecture for NLP

Page 5: Transformers are Graph Neural Networks
Page 6: Transformers are Graph Neural Networks

Representation Learning for NLP

Page 7: Transformers are Graph Neural Networks

Breaking down the Transformer

We update the hidden feature h of the i'thword in a sentence S from layer ℓ to layer ℓ+1 as follows:

where j∈ S denotes the set of words in the sentence and Q, K, V are learnable linear weights.

Page 8: Transformers are Graph Neural Networks

Breaking down the Transformer

We update the hidden feature h of the i'thword in a sentence S from layer ℓ to layer ℓ+1 as follows:

where j∈ S denotes the set of words in the sentence and Q, K, V are learnable linear weights.

Page 9: Transformers are Graph Neural Networks

Breaking down the Transformer

We update the hidden feature h of the i'thword in a sentence S from layer ℓ to layer ℓ+1 as follows:

where j∈ S denotes the set of words in the sentence and Q, K, V are learnable linear weights.

Page 10: Transformers are Graph Neural Networks

Multi-head Attention

Bad random initializations can de-stabilize the learning process of this dot-product attention mechanism. We can ‘hedge our bets’ through concatenating multiple attention ‘heads’:

Page 11: Transformers are Graph Neural Networks

Multi-head Attention

Page 12: Transformers are Graph Neural Networks

The Final Picture

Update each word’s features through

Multi-head Attention mechanismas a weighted sum of features of other words in the sentence.

+ Scaling dot product attention

+ Normalization layers

+ Residual links

Page 13: Transformers are Graph Neural Networks

The Final Picture

Update each word’s features through

Multi-head Attention mechanismas a weighted sum of features of other words in the sentence.

+ Scaling dot product attention

+ Normalization layers

+ Residual links

Page 14: Transformers are Graph Neural Networks

The Final Picture

Update each word’s features through

Multi-head Attention mechanismas a weighted sum of features of other words in the sentence.

+ Scaling dot product attention

+ Normalization layers

+ Residual links

Page 15: Transformers are Graph Neural Networks

Representation Learning on Graphs

Page 16: Transformers are Graph Neural Networks

Graph Neural Networks

GNNs update the hidden features h of node i at layer ℓ via a non-linear transformation of the node’s own features added to the aggregation of features from each neighbouring node j∈ N(i):

where U, V are learnable weight matrices of the GNN layer and σ is a non-linearity.

Page 17: Transformers are Graph Neural Networks

Connections b/w GNNs and Transformers

Consider a sentence as a fully connected graph of words…

Page 18: Transformers are Graph Neural Networks

Connections b/w GNNs and Transformers

Page 19: Transformers are Graph Neural Networks
Page 20: Transformers are Graph Neural Networks

Sum over local neighborhood

Sum over all words

in S

Page 21: Transformers are Graph Neural Networks
Page 22: Transformers are Graph Neural Networks

Simple linear transform and sum

Weighted sum via attention

Page 23: Transformers are Graph Neural Networks

Standard Graph NN

Graph Attention Network

Page 24: Transformers are Graph Neural Networks

Standard GNN GAT

Transformer

+ Multi-head mechanism

+ Normalization layers

+ Residual links

+ Weighted sum aggregation