A lightweight engine for slicing a PyTorch tensor into parallel shards

TorchShard

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and GPT) or huge classes (millions). It has the same API design as PyTorch.

Installation

pip install torchshard

More options in INSTALL.md.

Usage

import torchshard as ts

ts.init_process_group(group_size=2)                       # init parallel groups

m = torch.nn.Sequential(
    torch.nn.Linear(20, 30, bias=True),               
    ts.nn.ParallelLinear(30, 30, bias=True, dim=None),    # equal to nn.Linear()
    ts.nn.ParallelLinear(30, 30, bias=True, dim=0),       # parallel in row dimension
    ts.nn.ParallelLinear(30, 30, bias=True, dim=1),       # parallel in column dimension
).cuda()

x = m(x)                                                  # forward
loss = ts.nn.functional.parallel_cross_entropy(x, y)      # parallel loss function
loss.backward()                                           # backward

torch.save(
  ts.collect_state_dict(m, m.state_dict()), 'm.pt')       # save model state

 

To finish reading, please visit source site