Fine-tuning Llama 2 70B using PyTorch FSDP
In this blog post, we will look at how to fine-tune Llama 2 70B using PyTorch FSDP and related best practices. We will be leveraging Hugging Face Transformers, Accelerate and TRL. We will also learn how to use Accelerate with SLURM.
Fully Sharded Data Parallelism (FSDP) is a paradigm in which the optimizer states, gradients and parameters are sharded across devices. During the forward pass, each FSDP unit performs an all-gather operation to get the complete weights, computation is performed followed by discarding the shards from other devices. After the forward pass, the loss is computed followed by