sachin
commited on
Commit
·
e5a6062
1
Parent(s):
e0b5384
updat
Browse files- src/server/main.py +63 -0
src/server/main.py
CHANGED
@@ -733,6 +733,69 @@ async def chat_v2(
|
|
733 |
logger.error(f"Error processing request: {str(e)}")
|
734 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
735 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
736 |
if __name__ == "__main__":
|
737 |
parser = argparse.ArgumentParser(description="Run the FastAPI server.")
|
738 |
parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
|
|
|
733 |
logger.error(f"Error processing request: {str(e)}")
|
734 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
735 |
|
736 |
+
class TranscriptionResponse(BaseModel):
|
737 |
+
text: str
|
738 |
+
|
739 |
+
|
740 |
+
class ASRModelManager:
|
741 |
+
def __init__(self, device_type="cuda"):
|
742 |
+
self.device_type = device_type
|
743 |
+
self.model_language = {
|
744 |
+
"kannada": "kn", "hindi": "hi", "malayalam": "ml", "assamese": "as", "bengali": "bn",
|
745 |
+
"bodo": "brx", "dogri": "doi", "gujarati": "gu", "kashmiri": "ks", "konkani": "kok",
|
746 |
+
"maithili": "mai", "manipuri": "mni", "marathi": "mr", "nepali": "ne", "odia": "or",
|
747 |
+
"punjabi": "pa", "sanskrit": "sa", "santali": "sat", "sindhi": "sd", "tamil": "ta",
|
748 |
+
"telugu": "te", "urdu": "ur"
|
749 |
+
}
|
750 |
+
|
751 |
+
|
752 |
+
from fastapi import FastAPI, UploadFile
|
753 |
+
import torch
|
754 |
+
import torchaudio
|
755 |
+
from transformers import AutoModel
|
756 |
+
import argparse
|
757 |
+
import uvicorn
|
758 |
+
from pydantic import BaseModel
|
759 |
+
from pydub import AudioSegment
|
760 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
|
761 |
+
from fastapi.responses import RedirectResponse, JSONResponse
|
762 |
+
from typing import List
|
763 |
+
|
764 |
+
# Load the model
|
765 |
+
model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
|
766 |
+
|
767 |
+
asr_manager = ASRModelManager() # Load Kannada, Hindi, Tamil, Telugu, Malayalam
|
768 |
+
|
769 |
+
|
770 |
+
#asr_manager = ASRModelManager(device_type="")
|
771 |
+
|
772 |
+
@app.post("/transcribe/", response_model=TranscriptionResponse)
|
773 |
+
async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
|
774 |
+
# Load the uploaded audio file
|
775 |
+
wav, sr = torchaudio.load(file.file)
|
776 |
+
wav = torch.mean(wav, dim=0, keepdim=True)
|
777 |
+
|
778 |
+
# Resample if necessary
|
779 |
+
target_sample_rate = 16000 # Expected sample rate
|
780 |
+
if sr != target_sample_rate:
|
781 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
|
782 |
+
wav = resampler(wav)
|
783 |
+
|
784 |
+
# Perform ASR with CTC decoding
|
785 |
+
#transcription_ctc = model(wav, "kn", "ctc")
|
786 |
+
|
787 |
+
# Perform ASR with RNNT decoding
|
788 |
+
transcription_rnnt = model(wav, "kn", "rnnt")
|
789 |
+
|
790 |
+
return JSONResponse(content={"text": transcription_rnnt})
|
791 |
+
|
792 |
+
|
793 |
+
|
794 |
+
class BatchTranscriptionResponse(BaseModel):
|
795 |
+
transcriptions: List[str]
|
796 |
+
|
797 |
+
|
798 |
+
|
799 |
if __name__ == "__main__":
|
800 |
parser = argparse.ArgumentParser(description="Run the FastAPI server.")
|
801 |
parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
|