import torch import nemo.collections.asr as nemo_asr from fastapi import FastAPI, File, UploadFile, HTTPException, Query from fastapi.responses import RedirectResponse, JSONResponse from pydantic import BaseModel from pydub import AudioSegment import os import tempfile import subprocess import asyncio import io import logging from logging.handlers import RotatingFileHandler from time import time from typing import List import argparse import uvicorn import shutil # Configure logging with log rotation logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ RotatingFileHandler("transcription_api.log", maxBytes=10*1024*1024, backupCount=5), logging.StreamHandler() ] ) class ASRModelManager: def __init__(self, languages_to_load=["kn", "hi", "ta", "te", "ml"], device_type="cuda"): self.device_type = device_type self.model_language = { "kannada": "kn", "hindi": "hi", "malayalam": "ml", "assamese": "as", "bengali": "bn", "bodo": "brx", "dogri": "doi", "gujarati": "gu", "kashmiri": "ks", "konkani": "kok", "maithili": "mai", "manipuri": "mni", "marathi": "mr", "nepali": "ne", "odia": "or", "punjabi": "pa", "sanskrit": "sa", "santali": "sat", "sindhi": "sd", "tamil": "ta", "telugu": "te", "urdu": "ur" } self.config_models = { "as": "ai4bharat/indicconformer_stt_as_hybrid_rnnt_large", "bn": "ai4bharat/indicconformer_stt_bn_hybrid_rnnt_large", "brx": "ai4bharat/indicconformer_stt_brx_hybrid_rnnt_large", "doi": "ai4bharat/indicconformer_stt_doi_hybrid_rnnt_large", "gu": "ai4bharat/indicconformer_stt_gu_hybrid_rnnt_large", "hi": "ai4bharat/indicconformer_stt_hi_hybrid_rnnt_large", "kn": "ai4bharat/indicconformer_stt_kn_hybrid_rnnt_large", "ks": "ai4bharat/indicconformer_stt_ks_hybrid_rnnt_large", "kok": "ai4bharat/indicconformer_stt_kok_hybrid_rnnt_large", "mai": "ai4bharat/indicconformer_stt_mai_hybrid_rnnt_large", "ml": "ai4bharat/indicconformer_stt_ml_hybrid_rnnt_large", "mni": "ai4bharat/indicconformer_stt_mni_hybrid_rnnt_large", "mr": "ai4bharat/indicconformer_stt_mr_hybrid_rnnt_large", "ne": "ai4bharat/indicconformer_stt_ne_hybrid_rnnt_large", "or": "ai4bharat/indicconformer_stt_or_hybrid_rnnt_large", "pa": "ai4bharat/indicconformer_stt_pa_hybrid_rnnt_large", "sa": "ai4bharat/indicconformer_stt_sa_hybrid_rnnt_large", "sat": "ai4bharat/indicconformer_stt_sat_hybrid_rnnt_large", "sd": "ai4bharat/indicconformer_stt_sd_hybrid_rnnt_large", "ta": "ai4bharat/indicconformer_stt_ta_hybrid_rnnt_large", "te": "ai4bharat/indicconformer_stt_te_hybrid_rnnt_large", "ur": "ai4bharat/indicconformer_stt_ur_hybrid_rnnt_large" } # Load models for specified languages on startup self.models = {} self.load_initial_models(languages_to_load) def load_initial_models(self, languages): device = torch.device(self.device_type if torch.cuda.is_available() and self.device_type == "cuda" else "cpu") logging.info(f"Loading models on device: {device}") for lang_id in languages: if lang_id not in self.config_models: logging.warning(f"No model available for language ID: {lang_id}. Skipping.") continue try: model_name = self.config_models[lang_id] logging.info(f"Loading model for {lang_id}: {model_name}") model = nemo_asr.models.ASRModel.from_pretrained(model_name) model.freeze() # Set to inference mode model = model.to(device) self.models[lang_id] = model logging.info(f"Successfully loaded model for {lang_id}") except Exception as e: logging.error(f"Failed to load model for {lang_id}: {str(e)}") def get_model(self, language_id): if language_id not in self.models: logging.warning(f"Model for {language_id} not pre-loaded. Loading now...") model = self.load_model(language_id) self.models[language_id] = model return self.models[language_id] def load_model(self, language_id): model_name = self.config_models.get(language_id, self.config_models["kn"]) model = nemo_asr.models.ASRModel.from_pretrained(model_name) device = torch.device(self.device_type if torch.cuda.is_available() and self.device_type == "cuda" else "cpu") model.freeze() model = model.to(device) return model def split_audio(self, file_path, chunk_duration_ms=15000): audio = AudioSegment.from_file(file_path) duration_ms = len(audio) if duration_ms > chunk_duration_ms: num_chunks = (duration_ms + chunk_duration_ms - 1) // chunk_duration_ms chunks = [audio[i*chunk_duration_ms:min((i+1)*chunk_duration_ms, duration_ms)] for i in range(num_chunks)] output_dir = "audio_chunks" os.makedirs(output_dir, exist_ok=True) chunk_file_paths = [] for i, chunk in enumerate(chunks): chunk_file_path = os.path.join(output_dir, f"chunk_{i}.wav") chunk.export(chunk_file_path, format="wav") chunk_file_paths.append(chunk_file_path) return chunk_file_paths else: return [file_path] def cleanup(self): output_dir = "audio_chunks" if os.path.exists(output_dir): shutil.rmtree(output_dir) app = FastAPI() asr_manager = ASRModelManager(languages_to_load=["kn", "hi", "ta", "te", "ml"]) # Load Kannada, Hindi, Tamil, Telugu, Malayalam class TranscriptionResponse(BaseModel): text: str class BatchTranscriptionResponse(BaseModel): transcriptions: List[str] @app.post("/transcribe/", response_model=TranscriptionResponse) async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))): start_time = time() try: file_extension = file.filename.split(".")[-1].lower() if file_extension not in ["wav", "mp3"]: logging.warning(f"Unsupported file format: {file_extension}") raise HTTPException(status_code=400, detail="Unsupported file format. Please upload a WAV or MP3 file.") file_content = await file.read() if file_extension == "mp3": audio = AudioSegment.from_mp3(io.BytesIO(file_content)) else: audio = AudioSegment.from_wav(io.BytesIO(file_content)) if audio.frame_rate != 16000: audio = audio.set_frame_rate(16000).set_channels(1) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: audio.export(tmp_file.name, format="wav") tmp_file_path = tmp_file.name chunk_file_paths = asr_manager.split_audio(tmp_file_path) try: language_id = asr_manager.model_language.get(language, "kn") model = asr_manager.get_model(language_id) model.cur_decoder = "rnnt" transcriptions = [] for chunk_file_path in chunk_file_paths: rnnt_texts = model.transcribe([chunk_file_path], batch_size=1, language_id=language_id)[0] if isinstance(rnnt_texts, list) and len(rnnt_texts) > 0: transcriptions.append(rnnt_texts[0]) else: transcriptions.append(rnnt_texts) joined_transcriptions = ' '.join(transcriptions) end_time = time() logging.info(f"Transcription completed in {end_time - start_time:.2f} seconds") return JSONResponse(content={"text": joined_transcriptions}) finally: for chunk_file_path in chunk_file_paths: if os.path.exists(chunk_file_path): os.remove(chunk_file_path) if os.path.exists(tmp_file_path): os.remove(tmp_file_path) asr_manager.cleanup() except HTTPException as e: logging.error(f"HTTPException: {str(e)}") raise e except Exception as e: logging.error(f"An unexpected error occurred: {str(e)}") raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") @app.get("/") async def home(): return RedirectResponse(url="/docs") @app.post("/transcribe_batch/", response_model=BatchTranscriptionResponse) async def transcribe_audio_batch(files: List[UploadFile] = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))): start_time = time() all_transcriptions = [] try: for file in files: file_extension = file.filename.split(".")[-1].lower() if file_extension not in ["wav", "mp3"]: logging.warning(f"Unsupported file format: {file_extension}") raise HTTPException(status_code=400, detail="Unsupported file format. Please upload WAV or MP3 files.") file_content = await file.read() if file_extension == "mp3": audio = AudioSegment.from_mp3(io.BytesIO(file_content)) else: audio = AudioSegment.from_wav(io.BytesIO(file_content)) if audio.frame_rate != 16000: audio = audio.set_frame_rate(16000).set_channels(1) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: audio.export(tmp_file.name, format="wav") tmp_file_path = tmp_file.name chunk_file_paths = asr_manager.split_audio(tmp_file_path) try: language_id = asr_manager.model_language.get(language, "kn") model = asr_manager.get_model(language_id) model.cur_decoder = "rnnt" transcriptions = [] for chunk_file_path in chunk_file_paths: rnnt_texts = model.transcribe([chunk_file_path], batch_size=1, language_id=language_id)[0] if isinstance(rnnt_texts, list) and len(rnnt_texts) > 0: transcriptions.append(rnnt_texts[0]) else: transcriptions.append(rnnt_texts) joined_transcriptions = ' '.join(transcriptions) all_transcriptions.append(joined_transcriptions) finally: for chunk_file_path in chunk_file_paths: if os.path.exists(chunk_file_path): os.remove(chunk_file_path) if os.path.exists(tmp_file_path): os.remove(tmp_file_path) asr_manager.cleanup() end_time = time() logging.info(f"Batch transcription completed in {end_time - start_time:.2f} seconds") return JSONResponse(content={"transcriptions": all_transcriptions}) except HTTPException as e: logging.error(f"HTTPException: {str(e)}") raise e except Exception as e: logging.error(f"An unexpected error occurred: {str(e)}") raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the FastAPI server for ASR.") parser.add_argument("--port", type=int, default=8888, help="Port to run the server on.") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on.") parser.add_argument("--device", type=str, default="cuda", help="Device type to run the model on (cuda or cpu).") args = parser.parse_args() asr_manager = ASRModelManager(languages_to_load=["kn", "hi", "ta", "te", "ml"], device_type=args.device) uvicorn.run(app, host=args.host, port=args.port)