Skip to main content

Serving a Pytorch Model

This tutorial will show you how to serve a more advance application. In particular, a Pytorch Resnet model. This tutorial is closely modeled after a similar tutorial for Ray Serve.

  1. Create a project for your application and set the ANYSCALE_PROJECT_NAME environment variable.

    $ mkdir hello_world_pytorch
    $ cd hello_world_pytorch
    $ anyscale project create
    Name [hello_world_pytorch]: hello_world_pytorch
    Created project with name hello_world_pytorch at https://console.anyscale.com/projects/prj_6SCoYQrJU4BYzpcpj0mKxk.
    Please specify the project id as prj_6SCoYQrJU4BYzpcpj0mKxk when calling Anyscale CLI or SDK commands or Ray Client.

    $ export ANYSCALE_PROJECT_NAME="hello_world_pytorch"
  2. Install necessary dependencies.

    pip install "ray[serve]"
    pip install torch torchvision
  3. Add imports and create runtime environment with the necessary dependencies.

    import ray
    from ray import serve

    from io import BytesIO
    from PIL import Image
    import requests

    import torch
    from torchvision import transforms
    from torchvision.models import resnet18

    runtime_env = {
    "pip": ["torch", "torchvision", "ray[serve]"]
    }
  4. Define your service

    @serve.deployment(route_prefix="/image_predict")
    class ImageModel:
    def __init__(self):
    self.model = resnet18(pretrained=True).eval()
    self.preprocessor = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: t[:3, ...]), # remove alpha channel
    transforms.Normalize(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    async def __call__(self, starlette_request):
    image_payload_bytes = await starlette_request.body()
    pil_image = Image.open(BytesIO(image_payload_bytes))
    print("[1/3] Parsed image data: {}".format(pil_image))

    pil_images = [pil_image] # Our current batch size is one
    input_tensor = torch.cat(
    [self.preprocessor(i).unsqueeze(0) for i in pil_images])
    print("[2/3] Images transformed, tensor shape {}".format(
    input_tensor.shape))

    with torch.no_grad():
    output_tensor = self.model(input_tensor)
    print("[3/3] Inference done!")
    return {"class_index": int(torch.argmax(output_tensor[0]))}
  5. Connect to anyscale and deployit.

    ray.init("anyscale://my_cluster", namespace="my_serve_namespace", runtime_env=runtime_env, autosuspend=-1)

    serve.start(detached=True)

    ImageModel.deploy()
  6. Get your user_service_url and user_service_token from the SDK read model for your cluster and query the endpoint!

    from anyscale import AnyscaleSDK

    sdk = AnyscaleSDK()
    cluster = sdk.search_clusters(clusters_query={"name": {"equals": "my_cluster"}}).results[0]

    ray_logo_bytes = requests.get(
    "https://github.com/ray-project/ray/raw/"
    "master/doc/source/images/ray_header_logo.png").content

    resp = requests.post(
    f"{cluster.user_service_url}/image_predict", headers={"Authorization" : f"Bearer {cluster.user_service_token}"}, data=ray_logo_bytes)
    print(resp.json())
    # Output
    # {'class_index': 919}