File size: 5,323 Bytes
3003400
 
dd3ef50
 
 
3003400
dd3ef50
3003400
dd3ef50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3003400
dd3ef50
 
 
 
 
 
 
3003400
dd3ef50
 
3003400
dd3ef50
 
 
 
 
 
3003400
dd3ef50
 
 
 
 
 
3003400
dd3ef50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3003400
dd3ef50
 
 
 
 
 
 
 
 
 
 
 
 
 
3003400
 
 
dd3ef50
 
3003400
dd3ef50
 
3003400
dd3ef50
 
3003400
 
 
dd3ef50
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import torchaudio
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, BitsAndBytesConfig
import gradio as gr
import os
import time
import numpy as np

# Load model and processor (runs once on startup)
model_name = "ibm-granite/granite-speech-3.2-8b"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

print("Loading processor...")
speech_granite_processor = AutoProcessor.from_pretrained(
    model_name, trust_remote_code=True)
tokenizer = speech_granite_processor.tokenizer

print("Loading model with 4-bit quantization...")
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)

speech_granite = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True
)
print("Model loaded successfully")

def transcribe_audio(audio_input):
    """Process audio input and return transcription"""
    start_time = time.time()
    
    logs = [f"Audio input received: {type(audio_input)}"]
    
    if audio_input is None:
        return "Error: No audio provided.", 0.0
    
    try:
        # Handle different audio input formats
        if isinstance(audio_input, tuple) and len(audio_input) == 2:  
            # Microphone input: (sample_rate, numpy_array)
            logs.append("Processing microphone input")
            sr, wav_np = audio_input
            wav = torch.from_numpy(wav_np).float().unsqueeze(0)
        else:
            # File input: filepath string
            logs.append(f"Processing file input: {audio_input}")
            wav, sr = torchaudio.load(audio_input)
            logs.append(f"Loaded audio file with sample rate {sr}Hz and shape {wav.shape}")
        
        # Convert to mono if stereo
        if wav.shape[0] > 1:
            wav = torch.mean(wav, dim=0, keepdim=True)
            logs.append("Converted stereo to mono")
        
        # Resample to 16kHz if needed
        if sr != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
            wav = resampler(wav)
            sr = 16000
            logs.append(f"Resampled to {sr}Hz")
        
        logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}")
        
        # Create text prompt
        chat = [
            {
                "role": "system",
                "content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant",
            },
            {
                "role": "user",
                "content": "<|audio|>can you transcribe the speech into a written format?",
            }
        ]
        
        text = tokenizer.apply_chat_template(
            chat, tokenize=False, add_generation_prompt=True
        )
        
        # Compute audio embeddings
        logs.append("Preparing model inputs")
        model_inputs = speech_granite_processor(
            text=text,
            audio=wav.numpy().squeeze(),  # Convert to numpy and squeeze
            sampling_rate=sr,
            return_tensors="pt",
        ).to(device)
        
        # Generate transcription
        logs.append("Generating transcription")
        model_outputs = speech_granite.generate(
            **model_inputs,
            max_new_tokens=1000,
            num_beams=4,
            do_sample=False,
            min_length=1,
            top_p=1.0,
            repetition_penalty=3.0,
            length_penalty=1.0,
            temperature=1.0,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
        
        # Extract the generated text (skipping input tokens)
        logs.append("Processing output")
        num_input_tokens = model_inputs["input_ids"].shape[-1]
        new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0)
        
        output_text = tokenizer.batch_decode(
            new_tokens, add_special_tokens=False, skip_special_tokens=True
        )
        
        transcription = output_text[0].strip().upper()
        logs.append(f"Transcription complete: {transcription[:50]}...")
        
    except Exception as e:
        import traceback
        error_trace = traceback.format_exc()
        print(error_trace)
        print("\n".join(logs))
        return f"Error: {str(e)}\n\nLogs:\n" + "\n".join(logs), 0.0
    
    processing_time = round(time.time() - start_time, 2)
    return transcription, processing_time

# Create Gradio interface
title = "IBM Granite Speech-to-Text (8B Quantized)"
description = """
Transcribe speech using IBM's Granite Speech 3.2 8B model (loaded in 4-bit).
Upload an audio file or use your microphone to record speech.
"""

iface = gr.Interface(
    fn=transcribe_audio,
    inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"),
    outputs=[
        gr.Textbox(label="Transcription", lines=5),
        gr.Number(label="Processing Time (seconds)")
    ],
    title=title,
    description=description,
)

if __name__ == "__main__":
    iface.launch()