Implementation of Graph Transformer in Pytorch
 
				Graph Transformer – Pytorch
Implementation of Graph Transformer in Pytorch, for potential use in replicating Alphafold2. This was recently used by both Costa et al and Bakers lab for transforming MSA and pair-wise embedding into 3d coordinates.
Todo
- add rotary embeddings for injecting adjacency information
Install
$ pip install graph-transformer-pytorch
Usage
import torch
from graph_transformer_pytorch import GraphTransformer
model = GraphTransformer(
    dim = 256,
    depth = 6,
    edge_dim = 512,             # optional - if left out, edge dimensions is assumed to be the same as the node dimensions above
    with_feedforwards = True,   # whether to add a feedforward after each attention layer, suggested by literature to be needed
    gated_residual = True       # to use the gated residual to prevent over-smoothing
)
nodes =