import gradio as gr import torch from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer from IndicTransToolkit import IndicProcessor import speech_recognition as sr # Constants BATCH_SIZE = 4 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" quantization = None # ---- IndicTrans2 Model Initialization ---- def initialize_model_and_tokenizer(ckpt_dir, quantization): if quantization == "4-bit": qconfig = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, ) elif quantization == "8-bit": qconfig = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_compute_dtype=torch.bfloat16, ) else: qconfig = None tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True) model = AutoModelForSeq2SeqLM.from_pretrained( ckpt_dir, trust_remote_code=True, low_cpu_mem_usage=True, quantization_config=qconfig, ) if qconfig is None: model = model.to(DEVICE) if DEVICE == "cuda": model.half() model.eval() return tokenizer, model def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip): translations = [] for i in range(0, len(input_sentences), BATCH_SIZE): batch = input_sentences[i : i + BATCH_SIZE] batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang) inputs = tokenizer( batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True, ).to(DEVICE) with torch.no_grad(): generated_tokens = model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) with tokenizer.as_target_tokenizer(): generated_tokens = tokenizer.batch_decode( generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True, ) translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang) del inputs torch.cuda.empty_cache() return translations # Initialize IndicTrans2 en_indic_ckpt_dir = "ai4bharat/indictrans2-indic-en-1B" en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, quantization) ip = IndicProcessor(inference=True) # ---- Gradio Function ---- def transcribe_and_translate(audio): recognizer = sr.Recognizer() with sr.AudioFile(audio) as source: audio_data = recognizer.record(source) try: # Malayalam transcription using Google API malayalam_text = recognizer.recognize_google(audio_data, language="ml-IN") except sr.UnknownValueError: return "Could not understand audio", "" except sr.RequestError as e: return f"Google API Error: {e}", "" # Translation en_sents = [malayalam_text] src_lang, tgt_lang = "mal_Mlym", "eng_Latn" translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip) return malayalam_text, translations[0] # ---- Gradio Interface ---- iface = gr.Interface( fn=transcribe_and_translate, inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"), outputs=[ gr.Textbox(label="Malayalam Transcription"), gr.Textbox(label="English Translation") ], title="Malayalam Speech Recognition & Translation", description="Speak in Malayalam → Transcribe using Google Speech Recognition → Translate to English using IndicTrans2." ) iface.launch(debug=True)