Spaces:
Sleeping
Sleeping
import whisper | |
import gradio as gr | |
from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
# Initialize the device map for ZeRO | |
from accelerate.utils import set_module_tensor_to_device | |
import torch | |
device_map = "auto" # Automatically allocate layers across available GPUs/CPUs | |
print(f"Using ZeRO-powered device map: {device_map}") | |
# Load the model using ZeRO | |
model_name = "openai/whisper-tiny" | |
# Load the Whisper model into ZeRO's memory-efficient mode | |
with init_empty_weights(): | |
whisper_model = whisper.load_model(model_name) | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load model with Accelerate/ZeRO | |
whisper_model = load_checkpoint_and_dispatch( | |
whisper_model, | |
device_map=device_map, | |
dtype=torch.float16 # Optional: Use mixed precision for further optimization | |
) | |
# Define the transcription function | |
def transcribe(audio): | |
# Perform transcription using the Whisper model | |
result = whisper_model.transcribe(audio) | |
return result['text'] | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=transcribe, # The function to be called for transcription | |
inputs=gr.Audio(source="microphone", type="filepath", label="Speak into the microphone"), # Input audio | |
outputs=gr.Textbox(label="Transcription"), # Output transcription | |
title="Whisper Speech-to-Text with ZeRO", # Title of the interface | |
description="Record audio using your microphone and get a transcription using the Whisper model optimized by ZeRO." | |
) | |
# Launch the Gradio interface | |
demo.launch() | |