Skip to main content

Batch inference: transcribe audio

This guide shows you how to transcribe audio using OpenAI Whisper and Ray Data. If you want to transcribe terabytes of audio files in a distribution fashion, this guide is for you.

You'll read audio files, preprocess amplitudes, run an automatic speech recognition system, and save transcriptions. ​

tip

Want faster and cheaper offline batch inference? Fill out this form.

Prerequisites

​ Before you begin, complete the following steps:

  1. Onboard onto Anyscale.
  2. Configure write access to a S3 bucket.
  3. Create a workspace with the ML image. ​

1. Install soundfile and OpenAI Whisper

The tutorial depends on soundfile and OpenAI Whisper to load audio and transcribe audio. To install them, run the following commands:

pip install --user soundfile openai-whisper

To learn more about installing dependencies into your environment, see Anyscale environment.

2. Read audio files

AudioDatasource is the primary API for loading audio files. It loads amplitude data into a Dataset, where each row contains an array of amplitudes representing one audio file.

Call ray.data.read_datasource() and pass in a AudioDatasource.

import ray
from ray.anyscale.data import AudioDatasource
from ray.data.datasource import FileExtensionFilter

uri = (
"s3://air-example-data-2/6G-audio-data-LibriSpeech-train-clean-100-flac/"
"train-clean-100/103/1241/"
)

dataset = ray.data.read_datasource(
AudioDatasource(),
paths=uri,
# Filter out files than don't have the .flac extension.
partition_filter=FileExtensionFilter("flac"),
# To determine which file a row corresponds to, ​set `include_paths=True`.
include_paths=True
)

​ Next, call Dataset.take() to inspect rows.

rows = dataset.take(1)

Each row should look like this: ​

{'amplitude': array([[...]], dtype=float32), 'path': 'air-example-data-2/.../103-1241-0010.flac'}

3. Preprocess amplitude

​ Call Dataset.map() to preprocess your dataset.

This function converts the amplitude to a representation that the model understands. To learn more, see the OpenAI Whisper README

from typing import Dict, Any
import numpy as np
import whisper

def preprocess(row: Dict[str, Any]) -> Dict[str, Any]:
# The audio is mono sound, so remove the channels axis.
# (channels, amplitude) -> (amplitude)
amplitude = np.squeeze(row["amplitude"], axis=0)
processed_amplitude = whisper.pad_or_trim(amplitude)
row["mel"] = whisper.log_mel_spectrogram(processed_amplitude)
return row

dataset = dataset.map(preprocess)

4. Transcribe audio

​ Implement a callable that performs inference. Set up your model in __init__ and invoke the model in __call__. ​

import torch

class TranscribeText:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = whisper.load_model("base", self.device)
self.options = whisper.DecodingOptions()

def __call__(self, batch: Dict[str, np.ndarray]):
inputs = torch.as_tensor(batch["mel"], device=self.device)
outputs = whisper.decode(self.model, inputs, self.options)
return {
"path": batch["path"],
"transcription": [result.text for result in outputs],
}

Then, call Dataset.map_batches(). You should configure num_gpus_in_cluster appropriately. ​

num_gpus_in_cluster = 1
results = dataset.map_batches(
TranscribeText,
compute=ray.data.ActorPoolStrategy(size=num_gpus_in_cluster),
batch_size=110, # Choose the largest batch size that fits in GPU memory
num_gpus=1, # Number of GPUs per worker
)

5. Inspect results

​ Call Dataset.take() to inspect the inference results.

rows = results.take(1)

Each result row should look like this: ​

{'path': '...', 'transcription': 'might have seen that the chin was very pointed and pronounced, that the big eyes were full of spirit and vivacity, that the mouth was sweet-lipped and expressive, that the forehead was broad and full. In short, our discerning extraordinary observer might have concluded'}

6. Write results to S3

​ Call Dataset.write_parquet() and pass in a URI pointing to a folder in S3. Your nodes must have write access to the folder. To write results to other formats, see Input/Output. ​

results.write_parquet("s3://sample-bucket/my-inference-results")