A Transformer that Ponders, using the scheme from the PonderNet paper

Ponder(ing) Transformer

Implementation of a Transformer that learns to adapt the number of computational steps it takes depending on the difficulty of the input sequence, using the scheme from the PonderNet paper. Will also try to abstract out a pondering module that can be used with any block that returns an output with the halting probability.

This repository would not have been possible without repeated viewings of Yannic’s educational video

Install

$ pip install ponder-transformer

Usage

import torch
from ponder_transformer import PonderTransformer

model = PonderTransformer(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 512
)

mask = torch.ones(1, 512).bool()

x = torch.randint(0, 20000, (1, 512))
y = torch.randint(0, 20000, (1, 512))

loss = model(x, labels = y, mask = mask)
loss.backward()

 

 

 

To finish reading, please visit source site