CLOOB training (JAX) and inference (JAX and PyTorch)

Pretrained models PyTorch from cloob_training import model_pt, pretrained pretrained.list_configs() returns: [‘cloob_laion_400m_vit_b_16_16_epochs’, ‘cloob_laion_400m_vit_b_16_32_epochs’] The models can be used by: config = pretrained.get_config(‘cloob_laion_400m_vit_b_16_16_epochs’) model = model_pt.get_pt_model(config) checkpoint = pretrained.download_checkpoint(config) model.load_state_dict(model_pt.get_pt_params(config, checkpoint)) model.eval().requires_grad_(False).to(‘cuda’) Model class attributes: model.config: the model config dict. model.image_encoder: the image encoder, which expects NCHW batches of normalized images (preprocessed by model.normalize), where C    

Read more

Implementation of different GAN in JAX/Haiku

This project aims to bring the power of JAX, a Python framework developped by Google and DeepMind to train Generative Adversarial Networks for images generation. JAX JAX is a framework developed by Deep-Mind (Google) that allows to build machine learning models in a more powerful (XLA compilation) and flexible way than its counterpart Tensorflow, using a framework almost entirely based on the nd.array of numpy (but stored on the GPU, or TPU if available). It also provides new utilities for […]

Read more

Turning SymPy expressions into JAX functions

Turn SymPy expressions into parametrized, differentiable, vectorizable, JAX functions. All SymPy floats become trainable input parameters. SymPy symbols become columns of a passed matrix. Installation pip install git+https://github.com/MilesCranmer/sympy2jax.git Example import sympy from sympy import symbols import jax import jax.numpy as jnp from jax import random from sympy2jax import sympy2jax Let’s create an expression in SymPy: x, y = symbols(‘x y’) expression = 1.0 * sympy.cos(x) +

Read more

A Gaussian process (GP) library built in JAX (with objax)

Newt Newt is a Gaussian process (GP) library built in JAX (with objax), built and actively maintained by Will Wilkinson. Newt provides a unifying view of approximate Bayesian inference for GPs, and allows for the combination of many models (e.g. GPs, sparse GPs, Markov GPs, sparse Markov GPs) with the inference method of your choice (VI, EP, Laplace, Linearisation). For a full list of the methods implemented scroll down to the bottom of this page. Installation In the top directory […]

Read more

A Mixed Precision library for JAX in python

Mixed precision training in JAX Mixed precision training [0] is a technique that mixes the use of full andhalf precision floating point numbers during training to reduce the memorybandwidth requirements and improve the computational efficiency of a givenmodel. This library implements support for mixed precision training in JAX by providingtwo key abstractions (mixed precision “policies” and loss scaling). Neuralnetwork libraries (such as Haiku) can integrate with jmp and provide“Automatic Mixed Precision (AMP)” support (automating or simplifying applyingpolicies to modules). All […]

Read more