|
import torch |
|
import torchaudio |
|
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, BitsAndBytesConfig |
|
import gradio as gr |
|
import os |
|
import time |
|
import numpy as np |
|
|
|
|
|
model_name = "ibm-granite/granite-speech-3.2-8b" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
print("Loading processor...") |
|
speech_granite_processor = AutoProcessor.from_pretrained( |
|
model_name, trust_remote_code=True) |
|
tokenizer = speech_granite_processor.tokenizer |
|
|
|
print("Loading model with 4-bit quantization...") |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True |
|
) |
|
|
|
speech_granite = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_name, |
|
quantization_config=quantization_config, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
print("Model loaded successfully") |
|
|
|
def transcribe_audio(audio_input): |
|
"""Process audio input and return transcription""" |
|
start_time = time.time() |
|
|
|
logs = [f"Audio input received: {type(audio_input)}"] |
|
|
|
if audio_input is None: |
|
return "Error: No audio provided.", 0.0 |
|
|
|
try: |
|
|
|
if isinstance(audio_input, tuple) and len(audio_input) == 2: |
|
|
|
logs.append("Processing microphone input") |
|
sr, wav_np = audio_input |
|
wav = torch.from_numpy(wav_np).float().unsqueeze(0) |
|
else: |
|
|
|
logs.append(f"Processing file input: {audio_input}") |
|
wav, sr = torchaudio.load(audio_input) |
|
logs.append(f"Loaded audio file with sample rate {sr}Hz and shape {wav.shape}") |
|
|
|
|
|
if wav.shape[0] > 1: |
|
wav = torch.mean(wav, dim=0, keepdim=True) |
|
logs.append("Converted stereo to mono") |
|
|
|
|
|
if sr != 16000: |
|
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) |
|
wav = resampler(wav) |
|
sr = 16000 |
|
logs.append(f"Resampled to {sr}Hz") |
|
|
|
logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}") |
|
|
|
|
|
chat = [ |
|
{ |
|
"role": "system", |
|
"content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant", |
|
}, |
|
{ |
|
"role": "user", |
|
"content": "<|audio|>can you transcribe the speech into a written format?", |
|
} |
|
] |
|
|
|
text = tokenizer.apply_chat_template( |
|
chat, tokenize=False, add_generation_prompt=True |
|
) |
|
|
|
|
|
logs.append("Preparing model inputs") |
|
model_inputs = speech_granite_processor( |
|
text=text, |
|
audio=wav.numpy().squeeze(), |
|
sampling_rate=sr, |
|
return_tensors="pt", |
|
).to(device) |
|
|
|
|
|
logs.append("Generating transcription") |
|
model_outputs = speech_granite.generate( |
|
**model_inputs, |
|
max_new_tokens=1000, |
|
num_beams=4, |
|
do_sample=False, |
|
min_length=1, |
|
top_p=1.0, |
|
repetition_penalty=3.0, |
|
length_penalty=1.0, |
|
temperature=1.0, |
|
bos_token_id=tokenizer.bos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
) |
|
|
|
|
|
logs.append("Processing output") |
|
num_input_tokens = model_inputs["input_ids"].shape[-1] |
|
new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0) |
|
|
|
output_text = tokenizer.batch_decode( |
|
new_tokens, add_special_tokens=False, skip_special_tokens=True |
|
) |
|
|
|
transcription = output_text[0].strip().upper() |
|
logs.append(f"Transcription complete: {transcription[:50]}...") |
|
|
|
except Exception as e: |
|
import traceback |
|
error_trace = traceback.format_exc() |
|
print(error_trace) |
|
print("\n".join(logs)) |
|
return f"Error: {str(e)}\n\nLogs:\n" + "\n".join(logs), 0.0 |
|
|
|
processing_time = round(time.time() - start_time, 2) |
|
return transcription, processing_time |
|
|
|
|
|
title = "IBM Granite Speech-to-Text (8B Quantized)" |
|
description = """ |
|
Transcribe speech using IBM's Granite Speech 3.2 8B model (loaded in 4-bit). |
|
Upload an audio file or use your microphone to record speech. |
|
""" |
|
|
|
iface = gr.Interface( |
|
fn=transcribe_audio, |
|
inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"), |
|
outputs=[ |
|
gr.Textbox(label="Transcription", lines=5), |
|
gr.Number(label="Processing Time (seconds)") |
|
], |
|
title=title, |
|
description=description, |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |