(Preview) Fine-tuning and serving with the Anyscale Models
SDK/CLI
This guide walks through a new preview feature: The Anyscale LLM Models SDK/CLI. This enables programmatically fine-tuning and serving custom models. Review the basic fine-tuning and serving examples for this guide.
This example requires llmforge>=0.5.4
and anyscale>= 0.24.61
Example: Serverless fine-tuning and serving a custom model on Anyscale
In this example, we fine-tune a Llama 3 8B model on a math word problem dataset using an Anyscale Job. Then, we serve the custom model on Anyscale using rayllm
.
Step 1: Fine-tuning
Assume the following directory structure:
├── configs
│ ├── llama-3-8b.yaml
│ └── zero_3.json
Here's an example fine-tuning config llama-3-8b.yaml
:
model_id: meta-llama/Meta-Llama-3-8B-Instruct
train_path: s3://air-example-data/gms8k/train.jsonl
valid_path: s3://air-example-data/gms8k/valid.jsonl
num_devices: 4
num_epochs: 2
context_length: 512
worker_resources:
accelerator_type:A10G: 0.001
deepspeed:
config_path: configs/zero_3.json
generation_config:
prompt_format:
system: "{instruction}"
user: "{instruction}"
assistant: "{instruction} </s>"
trailing_assistant: ""
bos: ""
stopping_sequences: ["</s>"]
lora_config:
r: 8
lora_alpha: 16
lora_dropout: 0.05
target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
- gate_proj
- up_proj
- down_proj
- embed_tokens
- lm_head
modules_to_save: []
llmforge
supports any HuggingFace model, so you can use their smaller test models to quickly iterate and assess new configurations or datasets.
You can launch the fine-tuning run as an Anyscale Job and wait until the job is done:
import time
import anyscale
from anyscale.job.models import JobConfig, JobState
job_id: str = anyscale.job.submit(
JobConfig(
name="llmforge-fine-tuning-job",
entrypoint="llmforge anyscale finetune config.yaml",
working_dir=".",
image_uri="localhost:5555/anyscale/llm-forge:0.5.4"
),
)
# Wait until job succeeds, with a 5 hour timeout for the wait operation. See the API reference for more details: https://docs.anyscale.com/reference/job-api
anyscale.job.wait(id=job_id, timeout_s=18000)
print(f"Job {job_id} succeeded!")
The above job runs in the default cloud and the default project. For the full set of config parameters, see the Anyscale Job API reference.
Once the job is complete, we can retrieve the model info with anyscale.llm.models.get
:
model_info = anyscale.llm.models.get(job_id=job_id).to_dict()
print(model_info)
This is what the model metadata looks like:
{'base_model_id': 'meta-llama/Meta-Llama-3-8B-Instruct',
'cloud_id': 'cld_123',
'created_at': datetime.datetime(2024, 8, 26, 21, 21, 54, 213160, tzinfo=tzlocal()),
'creator_id': 'usr_123',
'ft_type': 'LORA',
'generation_config': {'prompt_format': {'add_system_tags_even_if_message_is_empty': False,
'assistant': '{instruction} </s>',
'bos': '<s>',
'default_system_message': '',
'strip_whitespace': True,
'system': '{instruction}',
'system_in_last_user': False,
'system_in_user': False,
'trailing_assistant': '',
'user': '{instruction}'},
'stopping_sequences': ['</s>']},
'id': 'meta-llama/Meta-Llama-3-8B-Instruct:usern:deyoq',
'job_id': 'prodjob_123',
'project_id': 'prj_123',
'storage_uri': 's3://org_123/cld_123/artifact_storage/username/llmforge-finetuning/meta-llama/Meta-Llama-3-8B-Instruct/TorchTrainer_2024-08-26_14-18-46/epoch-2',
'workspace_id': None}
Some of the important fields are id
(model tag), base_model_id
(base model ID used for fine-tuning), ft_type
(fine-tuning type), storage_uri
(storage path for the best checkpoint) and generation_config
(includes chat-templating parameters and stopping sequences for inference).
With LoRA training, Anyscale forwards all LoRA weights to a shared location for convenience. $ANYSCALE_ARTIFACT_STORAGE/lora_fine_tuning
is the common storage path used for all LoRA checkpoints (corresponds to the dynamic_lora_loading_path
for serving) . The models SDK is still useful here as you can retrieve other parameters like id
just from the job ID.
If you already have the model id
(either through the "Models" page on the platform or the fine-tuning logs) and wish to know more about the model, you can use the llm.models.get
method again but now specify the id
:
model_info = anyscale.llm.models.get(model_id="meta-llama/Meta-Llama-3-8B-Instruct:usern:deyoq")
The artifact storage path is specific to your Anyscale cloud and organization. This is available in a workspace or a job environment as the $ANYSCALE_ARTIFACT_STORAGE
environment variable. For more on the same, see the storage guide.
To use the Anyscale CLI, you can use anyscale llm models get --job-id JOB_ID
or anyscale llm models --model-id MODEL_ID
.
Step 2: Serving
We can now serve the fine-tuned model on the Anyscale Platform using rayllm.
To get started quickly, you can auto-generate the serve config and the model config using this template. Make sure to update the model_loading_config
, generation_config
, max_request_content_length
(and optionally lora_config
) using the model_info
data.
We can now launch a service through Anyscale service SDK or CLI:
service = anyscale.service.deploy(config_file="./serve_TIMESTAMP.yaml")
or
anyscale service deploy -f ./serve_TIMESTAMP.yaml
It's good to use the workspace template once to generate the RayLLM configs for your model. For full-parameter fine-tuning, the same config is applicable for different models (you can change the generation_config
as needed based on model_info
), with a similar story for LoRA.
If you've had a previous LoRA deployment for the base model (say meta-llama/Meta-Llama-3-8B-Instruct
), then all you need is the id
to query the new LoRA checkpoint.
model_info = anyscale.llm.models.get(job_id=job_id).to_dict()
finetuned_model_id = model_info["id"]
# Use the the new model ID in your existing client code.
# Make sure to use the ENDPOINT_URL and ENDPOINT_API_KEY for your Anyscale Service.
client = openai.OpenAI(base_url=ENDPOINT_URL, api_key=ENDPOINT_API_KEY)
client.chat.completions.create(
model = finetuned_model_id,
messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"}],
stream = True
)
Stay tuned for updates and API references for this preview feature!