Skip to main content

Fast model loading

note

This is an alpha feature that is subject to change. Contact Anyscale support with any feedback.

Anyscale provides a library to download model weights saved in safetensors format in remote storage directly to a GPU. The model weights are streamed chunk-by-chunk, avoiding a synchronous download to disk, speeding up the end-to-end download time for large models. For more technical details and benchmarks, see this blog post.

This guide covers how to use Anyscale's optimized streaming safetensors client to load model weights, including:

  • Requirements to use the safetensors client.
  • Download API reference and usage.
  • How to load a custom PyTorch model in a Serve app.

Requirements

To use the Anyscale safetensors client:

  • Use an optimized Ray image on the Anyscale platform with version >= 2.36.
  • Save your model weights in a single file in safetensors format. Note that safetensors files only include tensors and not model code, which must be initialized separately.
  • Save your model weights in remote storage accessible to the cluster using HTTP(S). This includes AWS S3 or GCP GS buckets, which are supported natively.

API and Usage

Download API

def load_file(
file_uri: str,
device: str = "cpu",
region: Optional[str] = None,
) -> Dict[str, torch.Tensor]:
"""Load a safetensors file from the given URI and return a state_dict.

Arguments:
file_uri: a remote URI specifying a safetensors file. You can use the 'anyscale://'
prefix when running in an Anyscale cluster to access files in Anyscale-managed
artifact storage.
device: device to load the tensors to ("cpu" or "cuda").
region: region of the bucket when using an 's3://' URI.

Returns:
A PyTorch state_dict.
"""

Supported Storage URIs

  • gs:// - load from a Google Cloud Storage bucket.
  • s3:// - load from an AWS S3 bucket. You must provide the region of the bucket.
  • anyscale:// - load from the artifact storage for an Anyscale cloud.
  • http:// or https:// - load from any HTTP file server that supports the Range header. The client doesn't support authentication.

Usage

Use ray.anyscale.safetensors.torch.load_file to download model weights from remote storage in a streaming fashion. This API returns a state_dict that is used to initialize a PyTorch model.

import torch
from typing import Dict

from ray.anyscale.safetensors.torch import load_file

# Stream model weights from remote storage onto the GPU.
state_dict: Dict[str, torch.Tensor] = load_file(
"s3://my_bucket/model.safetensors", device="cuda",
)

# Initialize PyTorch model using the downloaded model weights.
model: torch.nn.Module = torch.nn.utils.skip_init(MyTorchModel)
model.load_state_dict(state_dict, assign=True)

Local File Caching

By default, the client streams model weights directly to the target device without writing the contents of the safetensors file to disk. However, in cases where the same model might be loaded multiple times on the same machine, you can enable local file caching to write the file to disk and speed up subsequent loads. The file is written asynchronously, so waiting for the file write to complete does not block usage of the downloaded tensors.

from ray.anyscale.safetensors import set_local_cache_dir

# Enable local caching to a provided directory.
set_local_cache_dir("/mnt/local_storage/cache/safetensors/")

# Disable local caching.
set_local_cache_dir(None)

# The first download will stream the model weights to the target device *and* the local cache directory.
# Subsequent downloads to the same URI will use the file saved in the local cache directory.
state_dict: Dict[str, torch.Tensor] = load_file("s3://my_bucket/model.safetensors")
note

Call set_local_cache_dir inside of your task or actor code (not in global scope) when running inside a Ray application.

Known limitations

Contact Anyscale support if you encounter any issues or are hindered by any of the below limitations.

The safetensors client:

  • Only supports PyTorch tensors. Support for other frameworks can be added if needed.
  • Only supports "cpu" and "cuda" as target devices. Loading to a specific CUDA device is not yet supported.
  • Allocates CPU memory equal to the full size of the safetensors file during the download. Ensure that the instance type you're using has enough memory to accommodate this.
  • Does not work for PyTorch models that use shared tensors. See the safetensors documentation for an explanation.

Example: loading a custom PyTorch model

This example loads weights for a generic torch model. For simplicity, it uses a model and pretrained weights from Hugging Face, but you can replace these with your own custom model and weights.

Setup: Start a workspace and install dependencies

Start an Anyscale workspace, making sure to use the following options:

  • select the "auto select worker nodes" option in the compute config
  • use a CUDA image with a Ray version >= 2.36, such as anyscale/ray:2.36.0-slim-py312-cu123

Then install required dependencies:

pip install -U accelerate safetensors torch transformers

The example uses Mistral-7B-Instruct-v0.1 from Hugging Face. To use the model, you need to accept the terms in the model repository and retrieve your access token. Then, export it in the workspace:

export HUGGING_FACE_HUB_TOKEN=<YOUR_TOKEN_HERE>

NOTE: also add the HUGGING_FACE_HUB_TOKEN to the "Dependencies" tab of your workspace so Anyscale picks it up when you run the Serve app below.

Step 1: Save model weights in safetensors format

First, download the weights from Hugging Face and save them to a single safetensors file in cluster local storage. For a custom model, this would likely be the output of the training step.

import torch
from safetensors.torch import save_file
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", torch_dtype=torch.float16)
save_file(model.state_dict(), "/mnt/local_storage/Mistral-7B-Instruct-v0.1.safetensors")

Save this code in a local file named download_model.py and run it:

$ python download_model.py
Loading checkpoint shards: 100%|██████████████████████████| 2/2 [00:24<00:00, 12.35s/it]

Step 2: Upload model weights to remote storage

Now upload the safetensors file to the artifact storage bucket for the cloud your workspace is running in.

$ aws s3 cp /mnt/local_storage/Mistral-7B-Instruct-v0.1.safetensors $ANYSCALE_ARTIFACT_STORAGE/
upload: ../../../mnt/local_storage/Mistral-7B-Instruct-v0.1.safetensors to s3://anyscale-test-data-cld-i2w99rzq8b6lbjkke9y94vi5/org_7c1Kalm9WcX2bNIjW53GUT/cld_kvedZWag2qA8i5BjxUevf5i7/artifact_storage/Mistral-7B-Instruct-v0.1.safetensors

Step 3: Construct the model and load its weights in a Serve app

import torch
from accelerate import init_empty_weights
from fastapi import FastAPI
from transformers import AutoTokenizer, MistralConfig, MistralForCausalLM
from typing import Dict

from ray import serve
from ray.anyscale.safetensors.torch import load_file

fastapi_app = FastAPI()

@serve.deployment(
# Configure the replica to use an Nvidia T4 GPU.
ray_actor_options={"num_gpus": 1, "accelerator_type": "T4"},
)
@serve.ingress(fastapi_app)
class Mistral7BApp:
def __init__(self, model_weights_uri: str):
# IMPORTANT: Initialize the model with *empty weights*.
# When using your own `torch.nn.Module`, you can use torch.nn.utils.skip_init, see:
# https://pytorch.org/tutorials/prototype/skip_param_init.html
with init_empty_weights():
self._model = MistralForCausalLM(
MistralConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", torch_dtype=torch.float16)
)

# Download the model weights directly to the GPU.
state_dict: Dict[str, torch.Tensor] = load_file(model_weights_uri, device="cuda")

# Populate the weights in the model class.
self._model.load_state_dict(state_dict, assign=True)
self._model.to("cuda")

# Load the tokenizer for the model.
self._tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

@fastapi_app.get("/")
def predict(self, prompt: str) -> str:
messages = [{"role": "user", "content": prompt}]
encodeds = self._tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
generated_ids = self._model.generate(encodeds, max_new_tokens=1000, do_sample=True)
decoded = self._tokenizer.batch_decode(generated_ids.to("cpu"))
return decoded[0]


# Pass the URI to the uploaded model weights.
# The 'anyscale://' URI resolves to Anyscale-managed artifact storage when running inside an Anyscale cluster.
app = Mistral7BApp.bind("anyscale://Mistral-7B-Instruct-v0.1.safetensors")

To run the serve app, save the Python file locally as main.py and use serve run:

(base) ray@ip-10-0-58-159:~/default$ serve run main:app
2024-09-20 14:50:16,217 INFO scripts.py:499 -- Running import path: 'main:app'.
2024-09-20 14:50:18,804 INFO worker.py:1601 -- Connecting to existing Ray cluster at address: 10.0.58.159:6379...
2024-09-20 14:50:18,810 INFO worker.py:1777 -- Connected to Ray cluster. View the dashboard at https://session-nf13dmw1fmld6bei4v7czihhhh.i.anyscaleuserdata.com
2024-09-20 14:50:18,812 INFO packaging.py:359 -- Pushing file package 'gcs://_ray_pkg_5cbc1f5ea201895259da328911c3241cb56311e9.zip' (0.02MiB) to Ray cluster...
2024-09-20 14:50:18,812 INFO packaging.py:372 -- Successfully pushed file package 'gcs://_ray_pkg_5cbc1f5ea201895259da328911c3241cb56311e9.zip'.
(ProxyActor pid=49865) INFO 2024-09-20 14:50:22,805 proxy 10.0.58.159 proxy.py:1235 - Proxy starting on node 4c64508eea53f54920fdf1be0286974e95a3945af03178e0160b6d39 (HTTP port: 8000).
INFO 2024-09-20 14:50:22,828 serve 49656 api.py:277 - Started Serve in namespace "serve".
INFO 2024-09-20 14:50:22,828 serve 49656 api.py:259 - Connecting to existing Serve app in namespace "serve". New http options will not be applied.
(ServeController pid=49775) INFO 2024-09-20 14:50:22,920 controller 49775 deployment_state.py:1598 - Deploying new version of Deployment(name='Mistral7BApp', app='default') (initial target replicas: 1).
(ServeController pid=49775) INFO 2024-09-20 14:50:23,023 controller 49775 deployment_state.py:1844 - Adding 1 replica to Deployment(name='Mistral7BApp', app='default').
(ServeReplica:default:Mistral7BApp pid=49960) 2024-09-20 14:50:27,987 anytensor INFO - Got 1 file to download (13.49 GB total)
(ServeReplica:default:Mistral7BApp pid=49960) 2024-09-20 14:50:33,464 anytensor INFO - Downloaded 7.16/13.49 GB (53.1%, 1.43 GB/s)
(ServeReplica:default:Mistral7BApp pid=49960) 2024-09-20 14:50:37,716 anytensor INFO - Downloaded 13.49/13.49 GB (100.0%, 1.45 GB/s)
(ServeReplica:default:Mistral7BApp pid=49960) 2024-09-20 14:50:37,788 anytensor INFO - Finished download in 9.41s (1.43 GB/s)
INFO 2024-09-20 14:50:38,879 serve 49656 client.py:492 - Deployment 'Mistral7BApp:tbd5s25i' is ready at `http://127.0.0.1:8000/`. component=serve deployment=Mistral7BApp
INFO 2024-09-20 14:50:38,882 serve 49656 api.py:549 - Deployed app 'default' successfully.

Open another terminal and test the endpoint:

$ curl -X GET http://localhost:8000?prompt=hello
"<s> [INST] hello [/INST] Hello! It's great to have you here. Is there anything you would like to ask or talk about?</s>"