Clockwork Variational Autoencoders using JAX and Flax
Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported from the official TensorFlow implementation. Running on a single TPU v3, training is 10x faster than reported in the paper (60h -> 6h on minerl). Method Clockwork VAEs are deep generative model that learn long-term dependencies in video by leveraging hierarchies of representations that progress at different clock speeds. In contrast to prior video prediction methods that typically focus on […]
Read more