Skip to main content

Train a model with Jax on GPUs

Train a model with Jax on GPUs

This example uses Jax to train a small model on 16 T4 GPUs.

Install the Anyscale CLI

pip install -U anyscale
anyscale login

Submit the job

Clone this repository

Clone the example from GitHub.

git clone https://github.com/anyscale/examples.git
cd examples/jax_training

Submit the job with

anyscale job submit -f job.yaml

Understanding the example

  • This example installs a nightly version of Ray in the Dockerfile because Ray Train GPU support for Jax is very recent.