Learning Graph Neural Networks with Deep Graph Library
Recording of the talk can be found here.
Overview of Graph Neural Networks
- Tasks in graph learning
- node classification (fraud detection)
- link prediction (eg: recsys)
- graph classification (eg: drug discovery)
- graph learning has 2 steps:
- generate low-dim embedding of node
- use standard classifiers from there onwards
- GNNs can learn node, edge, graph embeddings in an end-to-end fashion and are
based on message-passing between neighbors
- aggregation operation needs to be permutation invariant
- thereby these nets integrate node/edge/graph features as well as topology in a non-linear fashion
- for graph classification, there'll be a final "readout" layer to compute the overall graph embedding based on embeddings of each node
- training of large graphs is done via mini-batch training
- with pruning of neighborhood via sampling to reduce computational complexity
Deep Graph Library - an update
- already having active customers using DGL via AWS Sagemaker
- started out with multiple backends
flexible message propagation
- full propagation
- propagation by graph traversal: sampling on ego-network
- propagation by random walk
DGL programming interface
DGLGraph
is the core abstractionDGLGraph.ndata["h"]
- the node representation matrix
- simple and flexible message passing APIs
- active set - set of nodes/edges to trigger computation on
- three user defined functions
- $$\phi^v$$ - transformation function on vertices
- $$\bigoplus$$ - reduction or aggregation function
- $$\phi^e$$ - transformation function on edges
update_all
- shortcut for
send(G.edges()); recv(G.nodes());
- in other words, do a full propagation
- shortcut for
- now heterogeneous graph is supported
- new sampling API is introduced in v0.43 release
- next plan is to look at distributed training
Using GNNs for basic graph tasks
- using Zachary's karate class network to demo APIs of DGL
- DGL expects node id's to be consecutive integers starting from 0
dgl.graph
is the main graph structure which provides IO and query methodsdgl.graph.ndata
member is a dict that holds node features as tensordgl.graph.edata
member is a dict that holds edge features as tensor- definition of models (and their training) in dgl is similar to pytorch
How to customize graph-conv using message passing APIs
dgl.graph.ndata
(andedata
) can be locally updated usingdgl.graph.local_scope()
- Message passing APIs in DGL are a generalization as found in
Message Passing Neural Networks. The relevant equations are as follows:
- $$m_{uv}^{(l)} = Message^{(l)}(h_u^{(l-1)}, h_v^{(l-1)}, e_{uv}^{(l)})$$
- $$m_v^{(l)} = Aggregation_{u \epsilon N(v)}(m_{uv}^{(l)})$$
- $$h_v^{(l)} = Update^{(l)}(h_v^{(l-1)}, m_v^{(l)})$$
Scale GNN to giant graphs using DGL
- for large batches it is recommended to use mini-batch training procedure
- minibatch generation on graphs
- sample the target nodes
- not done inside DGL
- using
numpy.random.choice
ortorch.utils.data.DataLoader
- randomly sample the neighbors (multi-hop)
dgl.sampling.sample_neighbors
for one layerdgl.in_subgraph
similar to the above, but will copy all edges
- construct the minibatch
dgl.to_block
- renames nodes to be consecutive (for memory efficiency as well as perf)
- constructs a bipartite graph for message passing (COO format)
- sample the target nodes
to_block()
has an optioninclude_dst_in_src
to help with self-loops during aggregation- inference
- we need to infer for all nodes in each layer
- thus, inference is typically costlier than training!
DGL on real world applications
- eg: recommender systems using GCMC
- introduced traditional collaborative filtering based approaches
- such a user-item matrix can be converted into bipartite graphs
apply_edges()
computes edge features- heterogenous graphs:
- graphs with different types of nodes and/or edges. (eg: user-item graphs)
dgl.heterograph
for creating such graphs- accessing node features is via:
g.nodes["ntype"].data["name"]
- accessing edge features is via:
g.edges["etype"].data["name"]