Recording of the talk can be found here.

Overview of Graph Neural Networks

  • Tasks in graph learning
    • node classification (fraud detection)
    • link prediction (eg: recsys)
    • graph classification (eg: drug discovery)
  • graph learning has 2 steps:
    • generate low-dim embedding of node
    • use standard classifiers from there onwards
  • GNNs can learn node, edge, graph embeddings in an end-to-end fashion and are based on message-passing between neighbors
    • aggregation operation needs to be permutation invariant
    • thereby these nets integrate node/edge/graph features as well as topology in a non-linear fashion
  • for graph classification, there'll be a final "readout" layer to compute the overall graph embedding based on embeddings of each node
  • training of large graphs is done via mini-batch training
    • with pruning of neighborhood via sampling to reduce computational complexity

Deep Graph Library - an update

DGL stack

  • already having active customers using DGL via AWS Sagemaker
  • started out with multiple backends

flexible message propagation

Flexible message handling

  • full propagation
  • propagation by graph traversal: sampling on ego-network
  • propagation by random walk

DGL programming interface

DGL API

  • DGLGraph is the core abstraction
    • DGLGraph.ndata["h"] - the node representation matrix
  • simple and flexible message passing APIs
    • active set - set of nodes/edges to trigger computation on
  • three user defined functions
    • $$\phi^v$$ - transformation function on vertices
    • $$\bigoplus$$ - reduction or aggregation function
    • $$\phi^e$$ - transformation function on edges
  • update_all
    • shortcut for send(G.edges()); recv(G.nodes());
    • in other words, do a full propagation
  • now heterogeneous graph is supported
  • new sampling API is introduced in v0.43 release
  • next plan is to look at distributed training

Using GNNs for basic graph tasks

  • using Zachary's karate class network to demo APIs of DGL
  • DGL expects node id's to be consecutive integers starting from 0
  • dgl.graph is the main graph structure which provides IO and query methods
  • dgl.graph.ndata member is a dict that holds node features as tensor
  • dgl.graph.edata member is a dict that holds edge features as tensor
  • definition of models (and their training) in dgl is similar to pytorch

How to customize graph-conv using message passing APIs

  • dgl.graph.ndata (and edata) can be locally updated using dgl.graph.local_scope()
  • Message passing APIs in DGL are a generalization as found in Message Passing Neural Networks. The relevant equations are as follows:
    • $$m_{uv}^{(l)} = Message^{(l)}(h_u^{(l-1)}, h_v^{(l-1)}, e_{uv}^{(l)})$$
    • $$m_v^{(l)} = Aggregation_{u \epsilon N(v)}(m_{uv}^{(l)})$$
    • $$h_v^{(l)} = Update^{(l)}(h_v^{(l-1)}, m_v^{(l)})$$

Scale GNN to giant graphs using DGL

  • for large batches it is recommended to use mini-batch training procedure
  • minibatch generation on graphs
    • sample the target nodes
      • not done inside DGL
      • using numpy.random.choice or torch.utils.data.DataLoader
    • randomly sample the neighbors (multi-hop)
      • dgl.sampling.sample_neighbors for one layer
      • dgl.in_subgraph similar to the above, but will copy all edges
    • construct the minibatch
      • dgl.to_block
      • renames nodes to be consecutive (for memory efficiency as well as perf)
      • constructs a bipartite graph for message passing (COO format)
  • to_block() has an option include_dst_in_src to help with self-loops during aggregation
  • inference
    • we need to infer for all nodes in each layer
    • thus, inference is typically costlier than training!

DGL on real world applications

  • eg: recommender systems using GCMC
  • introduced traditional collaborative filtering based approaches
  • such a user-item matrix can be converted into bipartite graphs
  • apply_edges() computes edge features
  • heterogenous graphs:
    • graphs with different types of nodes and/or edges. (eg: user-item graphs)
    • dgl.heterograph for creating such graphs
    • accessing node features is via: g.nodes["ntype"].data["name"]
    • accessing edge features is via: g.edges["etype"].data["name"]