Analyzing Learned Molecular Representations for Property Prediction
Proposal
Main paper can be found here.
- compares neural-nets with handcrafted molecular fingerprints against GNN's across a wide variety of public and private datasets
- introduces a new GNN architecture called D-MPNN
- this model learns representations by message passing across bonds, instead of the usual approach of atoms
Summary
- Authors observe that if the dataset has small number of molecules, then traditional NN-based models outperform GNNs
- They also find that a hybrid representation based on fixed fingerprints and the learned representations via GNNs, along with HPO and ensemble learning, consistently yield the best result
- D-MPNN = Directed Message Passing Neural Network
- D-MPNN uses messages associated with directed edges (aka bonds), instead of vertices (aka atoms).
- This setup helps avoid sending the messages back to the originating nodes, thereby creating some sort of a loop
- propagation equations at 't'th layer are:
- $$m_{vw}^{t+1} = \Sigma_{k \epsilon N(v) \ w} M_t(x_v, x_k, h_{kv}^t)$$
- $$h_{vw}^t = U_t(h_{vw}^t, m_{vw}^{t+1})$$
- $$m_{vw}^{t+1}$$ = message passed from vertex 'w' to 'v' at 't'th layer of the network
- $$h_{vw}^t$$ = hidden representation of the edge between 'w' to 'v' at 't'th layer
- $$M_t$$ = message passing function at 't'th layer
- $$U_t$$ = update/transformation function at 't'th layer
- $$x_v$$ = node embeddings
- Note that $$h_{vw}^t != h_{wv}^t$$! They also don't end up passing message between each other
- edge hidden states are initialized with: $$h_{vw}^0 = \tau (W_i cat(x_v, e_{vw}))$$
- $$e_{vw}$$ = edge embeddings between 'w' to 'v' edge
- $$\tau$$ = ReLU activation function
- $$W_i$$ = learnable weights matrix
- $$U_t = \tau (h_{vw}^0 + W_m m_{vw}^{t+1})$$
- $$W_m$$ = learnable weights matrix
- sort of provides a skip connection
- $$M_t = h_{vw}^t$$. They are keeping this simple.
- finally, to generate atom representation
- $$m_v = \Sigma_{k \epsilon N(v)} h_{kv}^T$$
- $$T$$ = total number of layers in the network
- $$h_v = \tau (W_a cat(x_v, m_v))$$
- $$W_a$$ = learnable weights matrix
- readout phase to generate the final molecular representation is
- $$h = \Sigma_{v \epsilon G} h_v$$
- final prediction is: $$y' = f(cat(h, h_f))$$
- $$h_f$$ = other fingerprints/representation of the molecule
- $$f$$ = feedfoward NN