Leverage Cloud TPUs on GKE
TPUs on Anyscale are only supported on Anyscale operator deployments on Google Kubernetes Engine.
Overview
This page describes how to use Cloud TPUs on Anyscale.
Before proceeding, we recommend becoming familiar with how Cloud TPU works with Google Kubernetes Engine and the terminology related to TPUs in GKE.
This walkthrough has been tested using GKE Autopilot. If using GKE Standard, you must ensure that your GKE cluster is properly configured to use Cloud TPUs, which will require creating one or more TPU node groups.
Using Single-Host TPUs
To use single-host TPUs on Anyscale, add an instance type to the Anyscale operator's Helm chart values.yaml
file as follows. Note that the resources
may need to be adjusted based on the target TPU topology.
additionalInstanceTypes:
# This instance type defines a single-host TPU.
#
# Note that the CPU, TPU, and memory values here
# may be adjusted based on the shape of the TPU
# and the TPU host.
8CPU-16GB-TPU-V5E-2x2-SINGLEHOST:
resources:
CPU: 8
TPU: 4
memory: 16Gi
'accelerator_type:TPU-V5E': 1
# This is a hint to Anyscale that this is a single-host deployment.
'anyscale/tpu_hosts': 1
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
cloud.google.com/gke-tpu-topology: 2x2
# Force this instance type to always request SPOT instances.
# TPUs are generally more available on SPOT instances.
cloud.google.com/gke-spot: "true"
Then, use this instance type in an Anyscale workload.
# tpu_singlehost_job.yaml
name: tpu_singlehost_job
working_dir: .
entrypoint: python tpu_singlehost.py
compute_config:
cloud: CLOUD_NAME
head_node:
instance_type: 4CPU-16GB
worker_nodes:
- instance_type: 8CPU-16GB-TPU-V5E-2x2-SINGLEHOST
min_nodes: 1
max_nodes: 1
market_type: SPOT
# tpu_singlehost.py
import os
import ray
import jax
import time
from jax.experimental import multihost_utils
ray.init()
@ray.remote(resources={"TPU": 4})
def tpu_cores():
cores = "TPU cores:" + str(jax.device_count())
return cores
result = ray.get(tpu_cores.remote())
print(result)
anyscale job submit -f tpu_singlehost_job.yaml
Using Multi-Host TPUs
To use multi-host TPUs on Anyscale, add an instance type to the Anyscale operator's Helm chart values.yaml
file as follows. Note that the resources
may need to be adjusted based on the target TPU topology.
additionalInstanceTypes:
# This instance type defines a multi-host TPU slice.
#
# Note that the CPU, TPU, and memory values here
# may be adjusted based on the shape of the TPU
# and the TPU hosts.
8CPU-16GB-TPU-V5E-4x4-MULTIHOST:
resources:
CPU: 8
TPU: 4
memory: 16Gi
'accelerator_type:TPU-V5E': 1
# Hint to Anyscale that this is a multi-host deployment,
# and so we need to set the TPU_WORKER_HOSTNAMES environment
# variable to link together all of the hosts in this TPU slice.
#
# Anyscale will automatically set this environment variable on
# all hosts.
'anyscale/tpu_hosts': 4
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
cloud.google.com/gke-tpu-topology: 4x4
# Force this instance type to always request SPOT instances.
# TPUs are generally more available on SPOT instances.
cloud.google.com/gke-spot: "true"
Then, use this instance type in an Anyscale workload.
# tpu_multihost_job.yaml
name: tpu_multihost_job
working_dir: .
entrypoint: python tpu_multihost.py
compute_config:
cloud: CLOUD_NAME
head_node:
instance_type: 4CPU-16GB
worker_nodes:
- instance_type: 8CPU-16GB-TPU-V5E-4x4-MULTIHOST
min_nodes: 4
max_nodes: 4
market_type: SPOT
# tpu_multihost.py
import os
import ray
import jax
import time
from jax.experimental import multihost_utils
ray.init()
@ray.remote(resources={"TPU": 4})
def tpu_cores():
multihost_utils.sync_global_devices("sync")
cores = "TPU cores:" + str(jax.device_count())
print("TPU Worker: " + os.environ.get("TPU_WORKER_ID"))
return cores
num_workers = int(ray.available_resources()["TPU"]) // 4
print(f"Number of TPU Workers: {num_workers}")
result = [tpu_cores.remote() for _ in range(num_workers)]
print(ray.get(result))
anyscale job submit -f tpu_multihost_job.yaml