Proposal

Main paper can be found here.

  • compares neural-nets with handcrafted molecular fingerprints against GNN's across a wide variety of public and private datasets
  • introduces a new GNN architecture called D-MPNN
  • this model learns representations by message passing across bonds, instead of the usual approach of atoms

Summary

  • Authors observe that if the dataset has small number of molecules, then traditional NN-based models outperform GNNs
  • They also find that a hybrid representation based on fixed fingerprints and the learned representations via GNNs, along with HPO and ensemble learning, consistently yield the best result
  • D-MPNN = Directed Message Passing Neural Network
  • D-MPNN uses messages associated with directed edges (aka bonds), instead of vertices (aka atoms).
  • This setup helps avoid sending the messages back to the originating nodes, thereby creating some sort of a loop
  • propagation equations at 't'th layer are:
    • $$m_{vw}^{t+1} = \Sigma_{k \epsilon N(v) \ w} M_t(x_v, x_k, h_{kv}^t)$$
    • $$h_{vw}^t = U_t(h_{vw}^t, m_{vw}^{t+1})$$
  • $$m_{vw}^{t+1}$$ = message passed from vertex 'w' to 'v' at 't'th layer of the network
  • $$h_{vw}^t$$ = hidden representation of the edge between 'w' to 'v' at 't'th layer
  • $$M_t$$ = message passing function at 't'th layer
  • $$U_t$$ = update/transformation function at 't'th layer
  • $$x_v$$ = node embeddings
  • Note that $$h_{vw}^t != h_{wv}^t$$! They also don't end up passing message between each other
  • edge hidden states are initialized with: $$h_{vw}^0 = \tau (W_i cat(x_v, e_{vw}))$$
    • $$e_{vw}$$ = edge embeddings between 'w' to 'v' edge
    • $$\tau$$ = ReLU activation function
    • $$W_i$$ = learnable weights matrix
  • $$U_t = \tau (h_{vw}^0 + W_m m_{vw}^{t+1})$$
    • $$W_m$$ = learnable weights matrix
    • sort of provides a skip connection
  • $$M_t = h_{vw}^t$$. They are keeping this simple.
  • finally, to generate atom representation
    • $$m_v = \Sigma_{k \epsilon N(v)} h_{kv}^T$$
    • $$T$$ = total number of layers in the network
    • $$h_v = \tau (W_a cat(x_v, m_v))$$
    • $$W_a$$ = learnable weights matrix
  • readout phase to generate the final molecular representation is
    • $$h = \Sigma_{v \epsilon G} h_v$$
  • final prediction is: $$y' = f(cat(h, h_f))$$
    • $$h_f$$ = other fingerprints/representation of the molecule
    • $$f$$ = feedfoward NN