Train a model with Jax on GPUs
Train a model with Jax on GPUs
This example uses Jax to train a small model on 16 T4 GPUs.
Install the Anyscale CLI
pip install -U anyscale
anyscale login
Submit the job
Clone this repository
Clone the example from GitHub.
git clone https://github.com/anyscale/examples.git
cd examples/jax_training
Submit the job with
anyscale job submit -f job.yaml
Understanding the example
- This example installs a nightly version of Ray in the Dockerfile because Ray Train GPU support for Jax is very recent.