sachin commited on
Commit
b192c58
·
1 Parent(s): 4eda5de
Files changed (2) hide show
  1. src/server/main-v2.py +192 -148
  2. src/server/main.py +463 -493
src/server/main-v2.py CHANGED
@@ -5,7 +5,7 @@ from time import time
5
  from typing import List, Dict
6
  import tempfile
7
  import uvicorn
8
- from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form, APIRouter
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
  from PIL import Image
@@ -22,17 +22,18 @@ from contextlib import asynccontextmanager
22
  import soundfile as sf
23
  import numpy as np
24
  import requests
 
25
  from starlette.responses import StreamingResponse
26
- from logging_config import logger
27
- from tts_config import SPEED, ResponseFormat, config as tts_config
28
  import torchaudio
29
  from tenacity import retry, stop_after_attempt, wait_exponential
30
  from torch.cuda.amp import autocast
31
 
32
  # Device setup
33
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
34
- torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
35
- logger.info(f"{'GPU' if device != 'cpu' else 'CPU'} will be used for inference")
36
 
37
  # Check CUDA availability and version
38
  cuda_available = torch.cuda.is_available()
@@ -77,12 +78,15 @@ quantization_config = BitsAndBytesConfig(
77
  # Request queue for concurrency control
78
  request_queue = asyncio.Queue(maxsize=10)
79
 
 
 
 
80
  # LLM Manager with batching
81
  class LLMManager:
82
  def __init__(self, model_name: str, device: str = device):
83
  self.model_name = model_name
84
  self.device = torch.device(device)
85
- self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
  self.model = None
87
  self.processor = None
88
  self.is_loaded = False
@@ -93,6 +97,9 @@ class LLMManager:
93
  def load(self):
94
  if not self.is_loaded:
95
  try:
 
 
 
96
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
97
  self.model_name,
98
  device_map="auto",
@@ -107,7 +114,7 @@ class LLMManager:
107
  logger.info(f"LLM {self.model_name} loaded and warmed up on {self.device}")
108
  except Exception as e:
109
  logger.error(f"Failed to load LLM: {str(e)}")
110
- raise
111
 
112
  def unload(self):
113
  if self.is_loaded:
@@ -115,14 +122,17 @@ class LLMManager:
115
  del self.processor
116
  if self.device.type == "cuda":
117
  torch.cuda.empty_cache()
118
- logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
119
  self.is_loaded = False
120
  self.token_cache.clear()
121
- logger.info(f"LLM {self.model_name} unloaded from {self.device}")
122
 
123
  async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
124
  if not self.is_loaded:
 
125
  self.load()
 
 
126
 
127
  cache_key = f"{prompt}:{max_tokens}:{temperature}"
128
  if cache_key in self.token_cache:
@@ -153,7 +163,6 @@ class LLMManager:
153
  return_tensors="pt",
154
  padding=True
155
  ).to(self.device, dtype=torch.bfloat16)
156
-
157
  with autocast(), torch.no_grad():
158
  outputs = self.model.generate(
159
  **inputs_vlm,
@@ -175,12 +184,10 @@ class LLMManager:
175
  async def vision_query(self, image: Image.Image, query: str) -> str:
176
  if not self.is_loaded:
177
  self.load()
178
-
179
  messages_vlm = [
180
  {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]},
181
  {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
182
  ]
183
-
184
  try:
185
  inputs_vlm = self.processor.apply_chat_template(
186
  messages_vlm,
@@ -192,7 +199,6 @@ class LLMManager:
192
  except Exception as e:
193
  logger.error(f"Error in apply_chat_template: {str(e)}")
194
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
195
-
196
  input_len = inputs_vlm["input_ids"].shape[-1]
197
  with torch.inference_mode():
198
  generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
@@ -204,12 +210,10 @@ class LLMManager:
204
  async def chat_v2(self, image: Image.Image, query: str) -> str:
205
  if not self.is_loaded:
206
  self.load()
207
-
208
  messages_vlm = [
209
  {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]},
210
  {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
211
  ]
212
-
213
  try:
214
  inputs_vlm = self.processor.apply_chat_template(
215
  messages_vlm,
@@ -221,7 +225,6 @@ class LLMManager:
221
  except Exception as e:
222
  logger.error(f"Error in apply_chat_template: {str(e)}")
223
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
224
-
225
  input_len = inputs_vlm["input_ids"].shape[-1]
226
  with torch.inference_mode():
227
  generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
@@ -240,9 +243,18 @@ class TTSManager:
240
 
241
  def load(self):
242
  if not self.model:
243
- logger.info("Loading TTS model IndicF5...")
244
  self.model = AutoModel.from_pretrained(self.repo_id, trust_remote_code=True).to(self.device_type)
245
- logger.info("TTS model IndicF5 loaded")
 
 
 
 
 
 
 
 
 
246
 
247
  def synthesize(self, text, ref_audio_path, ref_text):
248
  if not self.model:
@@ -320,13 +332,22 @@ class ASRModelManager:
320
 
321
  def load(self):
322
  if not self.model:
323
- logger.info("Loading ASR model...")
324
  self.model = AutoModel.from_pretrained(
325
  "ai4bharat/indic-conformer-600m-multilingual",
326
  trust_remote_code=True
327
  ).to(self.device_type)
328
  logger.info("ASR model loaded")
329
 
 
 
 
 
 
 
 
 
 
330
  # Global Managers
331
  llm_manager = LLMManager(settings.llm_model_name)
332
  model_manager = ModelManager()
@@ -353,6 +374,12 @@ class SynthesizeRequest(BaseModel):
353
  class KannadaSynthesizeRequest(BaseModel):
354
  text: str
355
 
 
 
 
 
 
 
356
  class ChatRequest(BaseModel):
357
  prompt: str
358
  src_lang: str = "kan_Knda"
@@ -441,9 +468,6 @@ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_
441
  model_manager.load_model(src_lang, tgt_lang, key)
442
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
443
 
444
- if not translate_manager.model:
445
- translate_manager.load()
446
-
447
  batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
448
  inputs = translate_manager.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(translate_manager.device_type)
449
  with torch.no_grad(), autocast():
@@ -479,13 +503,17 @@ async def lifespan(app: FastAPI):
479
  model_manager.load_model(src_lang, tgt_lang, key)
480
  logger.info("All models loaded successfully")
481
 
482
- logger.info("Starting sequential model loading...")
483
  load_all_models()
484
  batch_task = asyncio.create_task(batch_worker())
485
  yield
486
  batch_task.cancel()
487
  llm_manager.unload()
488
- logger.info("Server shutdown complete")
 
 
 
 
489
 
490
  # Batch Worker
491
  async def batch_worker():
@@ -519,8 +547,8 @@ async def batch_worker():
519
 
520
  # FastAPI App
521
  app = FastAPI(
522
- title="Dhwani API",
523
- description="AI Chat API supporting Indian languages",
524
  version="1.0.0",
525
  redirect_slashes=False,
526
  lifespan=lifespan
@@ -577,6 +605,8 @@ async def translate(request: TranslationRequest, translate_manager: TranslateMan
577
  @app.get("/v1/health")
578
  async def health_check():
579
  memory_usage = torch.cuda.memory_allocated() / (24 * 1024**3) if cuda_available else 0
 
 
580
  llm_status = "unhealthy"
581
  llm_latency = None
582
  if llm_manager.is_loaded:
@@ -587,7 +617,6 @@ async def health_check():
587
  llm_status = "healthy" if llm_test else "unhealthy"
588
  except Exception as e:
589
  logger.error(f"LLM health check failed: {str(e)}")
590
-
591
  tts_status = "unhealthy"
592
  tts_latency = None
593
  if tts_manager.model:
@@ -598,7 +627,6 @@ async def health_check():
598
  tts_status = "healthy" if audio_buffer else "unhealthy"
599
  except Exception as e:
600
  logger.error(f"TTS health check failed: {str(e)}")
601
-
602
  asr_status = "unhealthy"
603
  asr_latency = None
604
  if asr_manager.model:
@@ -612,7 +640,6 @@ async def health_check():
612
  asr_status = "healthy" if asr_test else "unhealthy"
613
  except Exception as e:
614
  logger.error(f"ASR health check failed: {str(e)}")
615
-
616
  status = {
617
  "status": "healthy" if llm_status == "healthy" and tts_status == "healthy" and asr_status == "healthy" else "degraded",
618
  "model": settings.llm_model_name,
@@ -622,6 +649,7 @@ async def health_check():
622
  "tts_latency": f"{tts_latency:.3f}s" if tts_latency else "N/A",
623
  "asr_status": asr_status,
624
  "asr_latency": f"{asr_latency:.3f}s" if asr_latency else "N/A",
 
625
  "gpu_memory_usage": f"{memory_usage:.2%}"
626
  }
627
  logger.info("Health check completed")
@@ -636,6 +664,10 @@ async def unload_all_models():
636
  try:
637
  logger.info("Starting to unload all models...")
638
  llm_manager.unload()
 
 
 
 
639
  logger.info("All models unloaded successfully")
640
  return {"status": "success", "message": "All models unloaded"}
641
  except Exception as e:
@@ -647,6 +679,15 @@ async def load_all_models():
647
  try:
648
  logger.info("Starting to load all models...")
649
  llm_manager.load()
 
 
 
 
 
 
 
 
 
650
  logger.info("All models loaded successfully")
651
  return {"status": "success", "message": "All models loaded"}
652
  except Exception as e:
@@ -667,33 +708,32 @@ async def translate_endpoint(request: TranslationRequest):
667
  @app.post("/v1/chat", response_model=ChatResponse)
668
  @limiter.limit(settings.chat_rate_limit)
669
  async def chat(request: Request, chat_request: ChatRequest):
670
- if not chat_request.prompt:
671
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
672
- logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
673
- EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
674
- try:
675
- if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
676
- translated_prompt = await perform_internal_translation([chat_request.prompt], chat_request.src_lang, "eng_Latn")
677
- prompt_to_process = translated_prompt[0]
678
- logger.info(f"Translated prompt to English: {prompt_to_process}")
679
- else:
680
- prompt_to_process = chat_request.prompt
681
- logger.info("Prompt in English or European language, no translation needed")
682
-
683
- response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
684
- logger.info(f"Generated English response: {response}")
685
-
686
- if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
687
- translated_response = await perform_internal_translation([response], "eng_Latn", chat_request.tgt_lang)
688
- final_response = translated_response[0]
689
- logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
690
- else:
691
- final_response = response
692
- logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
693
- return ChatResponse(response=final_response)
694
- except Exception as e:
695
- logger.error(f"Error processing request: {str(e)}")
696
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
697
 
698
  @app.post("/v1/visual_query/")
699
  async def visual_query(
@@ -702,30 +742,31 @@ async def visual_query(
702
  src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
703
  tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
704
  ):
705
- try:
706
- image = Image.open(file.file)
707
- if image.size == (0, 0):
708
- raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
709
- if src_lang != "eng_Latn":
710
- translated_query = await perform_internal_translation([query], src_lang, "eng_Latn")
711
- query_to_process = translated_query[0]
712
- logger.info(f"Translated query to English: {query_to_process}")
713
- else:
714
- query_to_process = query
715
- logger.info("Query already in English, no translation needed")
716
- answer = await llm_manager.vision_query(image, query_to_process)
717
- logger.info(f"Generated English answer: {answer}")
718
- if tgt_lang != "eng_Latn":
719
- translated_answer = await perform_internal_translation([answer], "eng_Latn", tgt_lang)
720
- final_answer = translated_answer[0]
721
- logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
722
- else:
723
- final_answer = answer
724
- logger.info("Answer kept in English, no translation needed")
725
- return {"answer": final_answer}
726
- except Exception as e:
727
- logger.error(f"Error processing request: {str(e)}")
728
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
729
 
730
  @app.post("/v1/chat_v2", response_model=ChatResponse)
731
  @limiter.limit(settings.chat_rate_limit)
@@ -736,68 +777,70 @@ async def chat_v2(
736
  src_lang: str = Form("kan_Knda"),
737
  tgt_lang: str = Form("kan_Knda"),
738
  ):
739
- if not prompt:
740
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
741
- if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
742
- raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
743
- logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
744
- try:
745
- if image:
746
- image_data = await image.read()
747
- if not image_data:
748
- raise HTTPException(status_code=400, detail="Uploaded image is empty")
749
- img = Image.open(io.BytesIO(image_data))
750
- if src_lang != "eng_Latn":
751
- translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
752
- prompt_to_process = translated_prompt[0]
753
- logger.info(f"Translated prompt to English: {prompt_to_process}")
754
- else:
755
- prompt_to_process = prompt
756
- decoded = await llm_manager.chat_v2(img, prompt_to_process)
757
- logger.info(f"Generated English response: {decoded}")
758
- if tgt_lang != "eng_Latn":
759
- translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
760
- final_response = translated_response[0]
761
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
762
- else:
763
- final_response = decoded
764
- else:
765
- if src_lang != "eng_Latn":
766
- translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
767
- prompt_to_process = translated_prompt[0]
768
- logger.info(f"Translated prompt to English: {prompt_to_process}")
769
- else:
770
- prompt_to_process = prompt
771
- decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
772
- logger.info(f"Generated English response: {decoded}")
773
- if tgt_lang != "eng_Latn":
774
- translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
775
- final_response = translated_response[0]
776
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
777
  else:
778
- final_response = decoded
779
- return ChatResponse(response=final_response)
780
- except Exception as e:
781
- logger.error(f"Error processing request: {str(e)}")
782
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
783
 
784
  @app.post("/transcribe/", response_model=TranscriptionResponse)
785
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
786
- if not asr_manager.model:
787
- raise HTTPException(status_code=503, detail="ASR model not loaded")
788
- try:
789
- wav, sr = torchaudio.load(file.file, backend="cuda" if cuda_available else "cpu")
790
- wav = torch.mean(wav, dim=0, keepdim=True).to(device)
791
- target_sample_rate = 16000
792
- if sr != target_sample_rate:
793
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device)
794
- wav = resampler(wav)
795
- with autocast(), torch.no_grad():
796
- transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
797
- return TranscriptionResponse(text=transcription_rnnt)
798
- except Exception as e:
799
- logger.error(f"Error in transcription: {str(e)}")
800
- raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
 
801
 
802
  @app.post("/v1/speech_to_speech")
803
  async def speech_to_speech(
@@ -805,16 +848,17 @@ async def speech_to_speech(
805
  file: UploadFile = File(...),
806
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
807
  ) -> StreamingResponse:
808
- if not tts_manager.model:
809
- raise HTTPException(status_code=503, detail="TTS model not loaded")
810
- transcription = await transcribe_audio(file, language)
811
- logger.info(f"Transcribed text: {transcription.text}")
812
- chat_request = ChatRequest(prompt=transcription.text, src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"), tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"))
813
- processed_text = await chat(request, chat_request)
814
- logger.info(f"Processed text: {processed_text.response}")
815
- voice_request = KannadaSynthesizeRequest(text=processed_text.response)
816
- audio_response = await synthesize_kannada(voice_request)
817
- return audio_response
 
818
 
819
  LANGUAGE_TO_SCRIPT = {"kannada": "kan_Knda"}
820
 
 
5
  from typing import List, Dict
6
  import tempfile
7
  import uvicorn
8
+ from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
  from PIL import Image
 
22
  import soundfile as sf
23
  import numpy as np
24
  import requests
25
+ import logging
26
  from starlette.responses import StreamingResponse
27
+ from logging_config import logger # Assumed external logging config
28
+ from tts_config import SPEED, ResponseFormat, config as tts_config # Assumed external TTS config
29
  import torchaudio
30
  from tenacity import retry, stop_after_attempt, wait_exponential
31
  from torch.cuda.amp import autocast
32
 
33
  # Device setup
34
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
35
+ torch_dtype = torch.float16 if device != "cpu" else torch.float32
36
+ logger.info(f"Using device: {device} with dtype: {torch_dtype}")
37
 
38
  # Check CUDA availability and version
39
  cuda_available = torch.cuda.is_available()
 
78
  # Request queue for concurrency control
79
  request_queue = asyncio.Queue(maxsize=10)
80
 
81
+ # Logging optimization
82
+ logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
83
+
84
  # LLM Manager with batching
85
  class LLMManager:
86
  def __init__(self, model_name: str, device: str = device):
87
  self.model_name = model_name
88
  self.device = torch.device(device)
89
+ self.torch_dtype = torch.float16 if self.device.type != "cpu" else torch.float32
90
  self.model = None
91
  self.processor = None
92
  self.is_loaded = False
 
97
  def load(self):
98
  if not self.is_loaded:
99
  try:
100
+ if self.device.type == "cuda":
101
+ torch.set_float32_matmul_precision('high')
102
+ logger.info("Enabled TF32 matrix multiplication for improved GPU performance")
103
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
104
  self.model_name,
105
  device_map="auto",
 
114
  logger.info(f"LLM {self.model_name} loaded and warmed up on {self.device}")
115
  except Exception as e:
116
  logger.error(f"Failed to load LLM: {str(e)}")
117
+ self.is_loaded = False
118
 
119
  def unload(self):
120
  if self.is_loaded:
 
122
  del self.processor
123
  if self.device.type == "cuda":
124
  torch.cuda.empty_cache()
125
+ logger.info(f"GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
126
  self.is_loaded = False
127
  self.token_cache.clear()
128
+ logger.info(f"LLM {self.model_name} unloaded")
129
 
130
  async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
131
  if not self.is_loaded:
132
+ logger.warning("LLM not loaded; attempting reload")
133
  self.load()
134
+ if not self.is_loaded:
135
+ raise HTTPException(status_code=503, detail="LLM model unavailable")
136
 
137
  cache_key = f"{prompt}:{max_tokens}:{temperature}"
138
  if cache_key in self.token_cache:
 
163
  return_tensors="pt",
164
  padding=True
165
  ).to(self.device, dtype=torch.bfloat16)
 
166
  with autocast(), torch.no_grad():
167
  outputs = self.model.generate(
168
  **inputs_vlm,
 
184
  async def vision_query(self, image: Image.Image, query: str) -> str:
185
  if not self.is_loaded:
186
  self.load()
 
187
  messages_vlm = [
188
  {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]},
189
  {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
190
  ]
 
191
  try:
192
  inputs_vlm = self.processor.apply_chat_template(
193
  messages_vlm,
 
199
  except Exception as e:
200
  logger.error(f"Error in apply_chat_template: {str(e)}")
201
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
202
  input_len = inputs_vlm["input_ids"].shape[-1]
203
  with torch.inference_mode():
204
  generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
 
210
  async def chat_v2(self, image: Image.Image, query: str) -> str:
211
  if not self.is_loaded:
212
  self.load()
 
213
  messages_vlm = [
214
  {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]},
215
  {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
216
  ]
 
217
  try:
218
  inputs_vlm = self.processor.apply_chat_template(
219
  messages_vlm,
 
225
  except Exception as e:
226
  logger.error(f"Error in apply_chat_template: {str(e)}")
227
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
228
  input_len = inputs_vlm["input_ids"].shape[-1]
229
  with torch.inference_mode():
230
  generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
 
243
 
244
  def load(self):
245
  if not self.model:
246
+ logger.info(f"Loading TTS model {self.repo_id} on {self.device_type}...")
247
  self.model = AutoModel.from_pretrained(self.repo_id, trust_remote_code=True).to(self.device_type)
248
+ logger.info("TTS model loaded")
249
+
250
+ def unload(self):
251
+ if self.model:
252
+ del self.model
253
+ if self.device_type.type == "cuda":
254
+ torch.cuda.empty_cache()
255
+ logger.info(f"TTS GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
256
+ self.model = None
257
+ logger.info("TTS model unloaded")
258
 
259
  def synthesize(self, text, ref_audio_path, ref_text):
260
  if not self.model:
 
332
 
333
  def load(self):
334
  if not self.model:
335
+ logger.info(f"Loading ASR model on {self.device_type}...")
336
  self.model = AutoModel.from_pretrained(
337
  "ai4bharat/indic-conformer-600m-multilingual",
338
  trust_remote_code=True
339
  ).to(self.device_type)
340
  logger.info("ASR model loaded")
341
 
342
+ def unload(self):
343
+ if self.model:
344
+ del self.model
345
+ if self.device_type.type == "cuda":
346
+ torch.cuda.empty_cache()
347
+ logger.info(f"ASR GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
348
+ self.model = None
349
+ logger.info("ASR model unloaded")
350
+
351
  # Global Managers
352
  llm_manager = LLMManager(settings.llm_model_name)
353
  model_manager = ModelManager()
 
374
  class KannadaSynthesizeRequest(BaseModel):
375
  text: str
376
 
377
+ @field_validator("text")
378
+ def text_must_be_valid(cls, v):
379
+ if len(v) > 500:
380
+ raise ValueError("Text cannot exceed 500 characters")
381
+ return v.strip()
382
+
383
  class ChatRequest(BaseModel):
384
  prompt: str
385
  src_lang: str = "kan_Knda"
 
468
  model_manager.load_model(src_lang, tgt_lang, key)
469
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
470
 
 
 
 
471
  batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
472
  inputs = translate_manager.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(translate_manager.device_type)
473
  with torch.no_grad(), autocast():
 
503
  model_manager.load_model(src_lang, tgt_lang, key)
504
  logger.info("All models loaded successfully")
505
 
506
+ logger.info("Starting server with preloaded models...")
507
  load_all_models()
508
  batch_task = asyncio.create_task(batch_worker())
509
  yield
510
  batch_task.cancel()
511
  llm_manager.unload()
512
+ tts_manager.unload()
513
+ asr_manager.unload()
514
+ for model in model_manager.models.values():
515
+ model.unload()
516
+ logger.info("Server shutdown complete; all models unloaded")
517
 
518
  # Batch Worker
519
  async def batch_worker():
 
547
 
548
  # FastAPI App
549
  app = FastAPI(
550
+ title="Optimized Dhwani API",
551
+ description="AI Chat API supporting Indian languages with performance enhancements",
552
  version="1.0.0",
553
  redirect_slashes=False,
554
  lifespan=lifespan
 
605
  @app.get("/v1/health")
606
  async def health_check():
607
  memory_usage = torch.cuda.memory_allocated() / (24 * 1024**3) if cuda_available else 0
608
+ if memory_usage > 0.9:
609
+ logger.warning("GPU memory usage exceeds 90%; consider unloading models")
610
  llm_status = "unhealthy"
611
  llm_latency = None
612
  if llm_manager.is_loaded:
 
617
  llm_status = "healthy" if llm_test else "unhealthy"
618
  except Exception as e:
619
  logger.error(f"LLM health check failed: {str(e)}")
 
620
  tts_status = "unhealthy"
621
  tts_latency = None
622
  if tts_manager.model:
 
627
  tts_status = "healthy" if audio_buffer else "unhealthy"
628
  except Exception as e:
629
  logger.error(f"TTS health check failed: {str(e)}")
 
630
  asr_status = "unhealthy"
631
  asr_latency = None
632
  if asr_manager.model:
 
640
  asr_status = "healthy" if asr_test else "unhealthy"
641
  except Exception as e:
642
  logger.error(f"ASR health check failed: {str(e)}")
 
643
  status = {
644
  "status": "healthy" if llm_status == "healthy" and tts_status == "healthy" and asr_status == "healthy" else "degraded",
645
  "model": settings.llm_model_name,
 
649
  "tts_latency": f"{tts_latency:.3f}s" if tts_latency else "N/A",
650
  "asr_status": asr_status,
651
  "asr_latency": f"{asr_latency:.3f}s" if asr_latency else "N/A",
652
+ "translation_models": list(model_manager.models.keys()),
653
  "gpu_memory_usage": f"{memory_usage:.2%}"
654
  }
655
  logger.info("Health check completed")
 
664
  try:
665
  logger.info("Starting to unload all models...")
666
  llm_manager.unload()
667
+ tts_manager.unload()
668
+ asr_manager.unload()
669
+ for model in model_manager.models.values():
670
+ model.unload()
671
  logger.info("All models unloaded successfully")
672
  return {"status": "success", "message": "All models unloaded"}
673
  except Exception as e:
 
679
  try:
680
  logger.info("Starting to load all models...")
681
  llm_manager.load()
682
+ tts_manager.load()
683
+ asr_manager.load()
684
+ for src_lang, tgt_lang, key in [
685
+ ('eng_Latn', 'kan_Knda', 'eng_indic'),
686
+ ('kan_Knda', 'eng_Latn', 'indic_eng'),
687
+ ('kan_Knda', 'hin_Deva', 'indic_indic'),
688
+ ]:
689
+ if key not in model_manager.models:
690
+ model_manager.load_model(src_lang, tgt_lang, key)
691
  logger.info("All models loaded successfully")
692
  return {"status": "success", "message": "All models loaded"}
693
  except Exception as e:
 
708
  @app.post("/v1/chat", response_model=ChatResponse)
709
  @limiter.limit(settings.chat_rate_limit)
710
  async def chat(request: Request, chat_request: ChatRequest):
711
+ async with request_queue:
712
+ if not chat_request.prompt:
713
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
714
+ logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
715
+ EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
716
+ try:
717
+ if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
718
+ translated_prompt = await perform_internal_translation([chat_request.prompt], chat_request.src_lang, "eng_Latn")
719
+ prompt_to_process = translated_prompt[0]
720
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
721
+ else:
722
+ prompt_to_process = chat_request.prompt
723
+ logger.info("Prompt in English or European language, no translation needed")
724
+ response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
725
+ logger.info(f"Generated English response: {response}")
726
+ if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
727
+ translated_response = await perform_internal_translation([response], "eng_Latn", chat_request.tgt_lang)
728
+ final_response = translated_response[0]
729
+ logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
730
+ else:
731
+ final_response = response
732
+ logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
733
+ return ChatResponse(response=final_response)
734
+ except Exception as e:
735
+ logger.error(f"Error processing request: {str(e)}")
736
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
737
 
738
  @app.post("/v1/visual_query/")
739
  async def visual_query(
 
742
  src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
743
  tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
744
  ):
745
+ async with request_queue:
746
+ try:
747
+ image = Image.open(file.file)
748
+ if image.size == (0, 0):
749
+ raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
750
+ if src_lang != "eng_Latn":
751
+ translated_query = await perform_internal_translation([query], src_lang, "eng_Latn")
752
+ query_to_process = translated_query[0]
753
+ logger.info(f"Translated query to English: {query_to_process}")
754
+ else:
755
+ query_to_process = query
756
+ logger.info("Query already in English, no translation needed")
757
+ answer = await llm_manager.vision_query(image, query_to_process)
758
+ logger.info(f"Generated English answer: {answer}")
759
+ if tgt_lang != "eng_Latn":
760
+ translated_answer = await perform_internal_translation([answer], "eng_Latn", tgt_lang)
761
+ final_answer = translated_answer[0]
762
+ logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
763
+ else:
764
+ final_answer = answer
765
+ logger.info("Answer kept in English, no translation needed")
766
+ return {"answer": final_answer}
767
+ except Exception as e:
768
+ logger.error(f"Error processing request: {str(e)}")
769
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
770
 
771
  @app.post("/v1/chat_v2", response_model=ChatResponse)
772
  @limiter.limit(settings.chat_rate_limit)
 
777
  src_lang: str = Form("kan_Knda"),
778
  tgt_lang: str = Form("kan_Knda"),
779
  ):
780
+ async with request_queue:
781
+ if not prompt:
782
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
783
+ if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
784
+ raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
785
+ logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
786
+ try:
787
+ if image:
788
+ image_data = await image.read()
789
+ if not image_data:
790
+ raise HTTPException(status_code=400, detail="Uploaded image is empty")
791
+ img = Image.open(io.BytesIO(image_data))
792
+ if src_lang != "eng_Latn":
793
+ translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
794
+ prompt_to_process = translated_prompt[0]
795
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
796
+ else:
797
+ prompt_to_process = prompt
798
+ decoded = await llm_manager.chat_v2(img, prompt_to_process)
799
+ logger.info(f"Generated English response: {decoded}")
800
+ if tgt_lang != "eng_Latn":
801
+ translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
802
+ final_response = translated_response[0]
803
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
804
+ else:
805
+ final_response = decoded
 
 
 
 
 
 
 
 
 
 
 
 
806
  else:
807
+ if src_lang != "eng_Latn":
808
+ translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
809
+ prompt_to_process = translated_prompt[0]
810
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
811
+ else:
812
+ prompt_to_process = prompt
813
+ decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
814
+ logger.info(f"Generated English response: {decoded}")
815
+ if tgt_lang != "eng_Latn":
816
+ translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
817
+ final_response = translated_response[0]
818
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
819
+ else:
820
+ final_response = decoded
821
+ return ChatResponse(response=final_response)
822
+ except Exception as e:
823
+ logger.error(f"Error processing request: {str(e)}")
824
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
825
 
826
  @app.post("/transcribe/", response_model=TranscriptionResponse)
827
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
828
+ async with request_queue:
829
+ if not asr_manager.model:
830
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
831
+ try:
832
+ wav, sr = torchaudio.load(file.file, backend="cuda" if cuda_available else "cpu")
833
+ wav = torch.mean(wav, dim=0, keepdim=True).to(device)
834
+ target_sample_rate = 16000
835
+ if sr != target_sample_rate:
836
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device)
837
+ wav = resampler(wav)
838
+ with autocast(), torch.no_grad():
839
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
840
+ return TranscriptionResponse(text=transcription_rnnt)
841
+ except Exception as e:
842
+ logger.error(f"Error in transcription: {str(e)}")
843
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
844
 
845
  @app.post("/v1/speech_to_speech")
846
  async def speech_to_speech(
 
848
  file: UploadFile = File(...),
849
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
850
  ) -> StreamingResponse:
851
+ async with request_queue:
852
+ if not tts_manager.model:
853
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
854
+ transcription = await transcribe_audio(file, language)
855
+ logger.info(f"Transcribed text: {transcription.text}")
856
+ chat_request = ChatRequest(prompt=transcription.text, src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"), tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"))
857
+ processed_text = await chat(request, chat_request)
858
+ logger.info(f"Processed text: {processed_text.response}")
859
+ voice_request = KannadaSynthesizeRequest(text=processed_text.response)
860
+ audio_response = await synthesize_kannada(voice_request)
861
+ return audio_response
862
 
863
  LANGUAGE_TO_SCRIPT = {"kannada": "kan_Knda"}
864
 
src/server/main.py CHANGED
@@ -2,7 +2,7 @@ import argparse
2
  import io
3
  import os
4
  from time import time
5
- from typing import List
6
  import tempfile
7
  import uvicorn
8
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
@@ -22,25 +22,23 @@ from contextlib import asynccontextmanager
22
  import soundfile as sf
23
  import numpy as np
24
  import requests
 
25
  from starlette.responses import StreamingResponse
26
- from logging_config import logger
27
- from tts_config import SPEED, ResponseFormat, config as tts_config
28
  import torchaudio
 
 
29
 
30
  # Device setup
31
- if torch.cuda.is_available():
32
- device = "cuda:0"
33
- logger.info("GPU will be used for inference")
34
- else:
35
- device = "cpu"
36
- logger.info("CPU will be used for inference")
37
- torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
38
 
39
  # Check CUDA availability and version
40
  cuda_available = torch.cuda.is_available()
41
  cuda_version = torch.version.cuda if cuda_available else None
42
-
43
- if torch.cuda.is_available():
44
  device_idx = torch.cuda.current_device()
45
  capability = torch.cuda.get_device_capability(device_idx)
46
  compute_capability_float = float(f"{capability[0]}.{capability[1]}")
@@ -77,33 +75,46 @@ quantization_config = BitsAndBytesConfig(
77
  bnb_4bit_compute_dtype=torch.bfloat16
78
  )
79
 
80
- # LLM Manager
 
 
 
 
 
 
81
  class LLMManager:
82
- def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
83
  self.model_name = model_name
84
  self.device = torch.device(device)
85
- self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
  self.model = None
87
  self.processor = None
88
  self.is_loaded = False
 
 
89
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
 
91
  def load(self):
92
  if not self.is_loaded:
93
  try:
 
 
 
94
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
95
  self.model_name,
96
  device_map="auto",
97
  quantization_config=quantization_config,
98
  torch_dtype=self.torch_dtype
99
- )
100
- self.model.eval()
101
  self.processor = AutoProcessor.from_pretrained(self.model_name)
 
 
 
102
  self.is_loaded = True
103
- logger.info(f"LLM {self.model_name} loaded on {self.device}")
104
  except Exception as e:
105
  logger.error(f"Failed to load LLM: {str(e)}")
106
- raise
107
 
108
  def unload(self):
109
  if self.is_loaded:
@@ -111,74 +122,72 @@ class LLMManager:
111
  del self.processor
112
  if self.device.type == "cuda":
113
  torch.cuda.empty_cache()
114
- logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
115
  self.is_loaded = False
116
- logger.info(f"LLM {self.model_name} unloaded from {self.device}")
 
117
 
118
- async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
119
  if not self.is_loaded:
 
120
  self.load()
 
 
121
 
122
- messages_vlm = [
123
- {
124
- "role": "system",
125
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
126
- },
127
- {
128
- "role": "user",
129
- "content": [{"type": "text", "text": prompt}]
130
- }
131
- ]
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  try:
134
  inputs_vlm = self.processor.apply_chat_template(
135
- messages_vlm,
136
  add_generation_prompt=True,
137
  tokenize=True,
138
  return_dict=True,
139
- return_tensors="pt"
 
140
  ).to(self.device, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  except Exception as e:
142
- logger.error(f"Error in tokenization: {str(e)}")
143
- raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
144
-
145
- input_len = inputs_vlm["input_ids"].shape[-1]
146
-
147
- with torch.inference_mode():
148
- generation = self.model.generate(
149
- **inputs_vlm,
150
- max_new_tokens=max_tokens,
151
- do_sample=True,
152
- temperature=temperature
153
- )
154
- generation = generation[0][input_len:]
155
-
156
- response = self.processor.decode(generation, skip_special_tokens=True)
157
- logger.info(f"Generated response: {response}")
158
- return response
159
 
160
  async def vision_query(self, image: Image.Image, query: str) -> str:
161
  if not self.is_loaded:
162
  self.load()
163
-
164
  messages_vlm = [
165
- {
166
- "role": "system",
167
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
168
- },
169
- {
170
- "role": "user",
171
- "content": []
172
- }
173
  ]
174
-
175
- messages_vlm[1]["content"].append({"type": "text", "text": query})
176
- if image and image.size[0] > 0 and image.size[1] > 0:
177
- messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
178
- logger.info(f"Received valid image for processing")
179
- else:
180
- logger.info("No valid image provided, processing text only")
181
-
182
  try:
183
  inputs_vlm = self.processor.apply_chat_template(
184
  messages_vlm,
@@ -190,18 +199,10 @@ class LLMManager:
190
  except Exception as e:
191
  logger.error(f"Error in apply_chat_template: {str(e)}")
192
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
193
-
194
  input_len = inputs_vlm["input_ids"].shape[-1]
195
-
196
  with torch.inference_mode():
197
- generation = self.model.generate(
198
- **inputs_vlm,
199
- max_new_tokens=512,
200
- do_sample=True,
201
- temperature=0.7
202
- )
203
  generation = generation[0][input_len:]
204
-
205
  decoded = self.processor.decode(generation, skip_special_tokens=True)
206
  logger.info(f"Vision query response: {decoded}")
207
  return decoded
@@ -209,25 +210,10 @@ class LLMManager:
209
  async def chat_v2(self, image: Image.Image, query: str) -> str:
210
  if not self.is_loaded:
211
  self.load()
212
-
213
  messages_vlm = [
214
- {
215
- "role": "system",
216
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
217
- },
218
- {
219
- "role": "user",
220
- "content": []
221
- }
222
  ]
223
-
224
- messages_vlm[1]["content"].append({"type": "text", "text": query})
225
- if image and image.size[0] > 0 and image.size[1] > 0:
226
- messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
227
- logger.info(f"Received valid image for processing")
228
- else:
229
- logger.info("No valid image provided, processing text only")
230
-
231
  try:
232
  inputs_vlm = self.processor.apply_chat_template(
233
  messages_vlm,
@@ -239,18 +225,10 @@ class LLMManager:
239
  except Exception as e:
240
  logger.error(f"Error in apply_chat_template: {str(e)}")
241
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
242
-
243
  input_len = inputs_vlm["input_ids"].shape[-1]
244
-
245
  with torch.inference_mode():
246
- generation = self.model.generate(
247
- **inputs_vlm,
248
- max_new_tokens=512,
249
- do_sample=True,
250
- temperature=0.7
251
- )
252
  generation = generation[0][input_len:]
253
-
254
  decoded = self.processor.decode(generation, skip_special_tokens=True)
255
  logger.info(f"Chat_v2 response: {decoded}")
256
  return decoded
@@ -258,101 +236,42 @@ class LLMManager:
258
  # TTS Manager
259
  class TTSManager:
260
  def __init__(self, device_type=device):
261
- self.device_type = device_type
262
  self.model = None
263
  self.repo_id = "ai4bharat/IndicF5"
 
264
 
265
  def load(self):
266
  if not self.model:
267
- logger.info("Loading TTS model IndicF5...")
268
- self.model = AutoModel.from_pretrained(
269
- self.repo_id,
270
- trust_remote_code=True
271
- )
272
- self.model = self.model.to(self.device_type)
273
- logger.info("TTS model IndicF5 loaded")
 
 
 
 
 
274
 
275
  def synthesize(self, text, ref_audio_path, ref_text):
276
  if not self.model:
277
  raise ValueError("TTS model not loaded")
278
- return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
279
-
280
- # TTS Constants
281
- EXAMPLES = [
282
- {
283
- "audio_name": "KAN_F (Happy)",
284
- "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
285
- "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
286
- "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿ��ುವ ವಿಷಯ."
287
- },
288
- ]
289
-
290
- # Pydantic models for TTS
291
- class SynthesizeRequest(BaseModel):
292
- text: str
293
- ref_audio_name: str
294
- ref_text: str = None
295
-
296
- class KannadaSynthesizeRequest(BaseModel):
297
- text: str
298
-
299
- # TTS Functions
300
- def load_audio_from_url(url: str):
301
- response = requests.get(url)
302
- if response.status_code == 200:
303
- audio_data, sample_rate = sf.read(io.BytesIO(response.content))
304
- return sample_rate, audio_data
305
- raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
306
-
307
- def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
308
- ref_audio_url = None
309
- for example in EXAMPLES:
310
- if example["audio_name"] == ref_audio_name:
311
- ref_audio_url = example["audio_url"]
312
- if not ref_text:
313
- ref_text = example["ref_text"]
314
- break
315
-
316
- if not ref_audio_url:
317
- raise HTTPException(status_code=400, detail="Invalid reference audio name.")
318
- if not text.strip():
319
- raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
320
- if not ref_text or not ref_text.strip():
321
- raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
322
-
323
- sample_rate, audio_data = load_audio_from_url(ref_audio_url)
324
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
325
- sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
326
- temp_audio.flush()
327
- audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
328
-
329
- if audio.dtype == np.int16:
330
- audio = audio.astype(np.float32) / 32768.0
331
- buffer = io.BytesIO()
332
- sf.write(buffer, audio, 24000, format='WAV')
333
- buffer.seek(0)
334
- return buffer
335
-
336
- # Supported languages
337
- SUPPORTED_LANGUAGES = {
338
- "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
339
- "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
340
- "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
341
- "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
342
- "kan_Knda", "ory_Orya",
343
- "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
344
- "por_Latn", "rus_Cyrl", "pol_Latn"
345
- }
346
 
347
  # Translation Manager
348
  class TranslateManager:
349
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
350
- self.device_type = device_type
351
  self.tokenizer = None
352
  self.model = None
353
  self.src_lang = src_lang
354
  self.tgt_lang = tgt_lang
355
  self.use_distilled = use_distilled
 
356
 
357
  def load(self):
358
  if not self.tokenizer or not self.model:
@@ -364,21 +283,17 @@ class TranslateManager:
364
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
365
  else:
366
  raise ValueError("Invalid language combination")
367
-
368
- self.tokenizer = AutoTokenizer.from_pretrained(
369
- model_name,
370
- trust_remote_code=True
371
- )
372
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
373
  model_name,
374
  trust_remote_code=True,
375
  torch_dtype=torch.float16,
376
  attn_implementation="flash_attention_2"
377
- )
378
- self.model = self.model.to(self.device_type)
379
  self.model = torch.compile(self.model, mode="reduce-overhead")
380
  logger.info(f"Translation model {model_name} loaded")
381
 
 
382
  class ModelManager:
383
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
384
  self.models = {}
@@ -389,18 +304,14 @@ class ModelManager:
389
  def load_model(self, src_lang, tgt_lang, key):
390
  logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
391
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
392
- translate_manager.load()
393
  self.models[key] = translate_manager
394
  logger.info(f"Loaded translation model for {key}")
395
 
396
  def get_model(self, src_lang, tgt_lang):
397
  key = self._get_model_key(src_lang, tgt_lang)
398
- if key not in self.models:
399
- if self.is_lazy_loading:
400
- self.load_model(src_lang, tgt_lang, key)
401
- else:
402
- raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
403
- return self.models.get(key)
404
 
405
  def _get_model_key(self, src_lang, tgt_lang):
406
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
@@ -413,21 +324,30 @@ class ModelManager:
413
 
414
  # ASR Manager
415
  class ASRModelManager:
416
- def __init__(self, device_type="cuda"):
417
- self.device_type = device_type
418
  self.model = None
419
  self.model_language = {"kannada": "kn"}
 
420
 
421
  def load(self):
422
  if not self.model:
423
- logger.info("Loading ASR model...")
424
  self.model = AutoModel.from_pretrained(
425
  "ai4bharat/indic-conformer-600m-multilingual",
426
  trust_remote_code=True
427
- )
428
- self.model = self.model.to(self.device_type)
429
  logger.info("ASR model loaded")
430
 
 
 
 
 
 
 
 
 
 
431
  # Global Managers
432
  llm_manager = LLMManager(settings.llm_model_name)
433
  model_manager = ModelManager()
@@ -435,7 +355,31 @@ asr_manager = ASRModelManager()
435
  tts_manager = TTSManager()
436
  ip = IndicProcessor(inference=True)
437
 
 
 
 
 
 
 
 
 
 
 
438
  # Pydantic Models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  class ChatRequest(BaseModel):
440
  prompt: str
441
  src_lang: str = "kan_Knda"
@@ -453,7 +397,6 @@ class ChatRequest(BaseModel):
453
  raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
454
  return v
455
 
456
-
457
  class ChatResponse(BaseModel):
458
  response: str
459
 
@@ -468,71 +411,149 @@ class TranscriptionResponse(BaseModel):
468
  class TranslationResponse(BaseModel):
469
  translations: List[str]
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  # Dependency
472
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
473
  return model_manager.get_model(src_lang, tgt_lang)
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  # Lifespan Event Handler
476
  translation_configs = []
477
 
478
  @asynccontextmanager
479
  async def lifespan(app: FastAPI):
480
  def load_all_models():
481
- try:
482
- # Load LLM model
483
- logger.info("Loading LLM model...")
484
- llm_manager.load()
485
- logger.info("LLM model loaded successfully")
486
-
487
- # Load TTS model
488
- logger.info("Loading TTS model...")
489
- tts_manager.load()
490
- logger.info("TTS model loaded successfully")
491
-
492
- # Load ASR model
493
- logger.info("Loading ASR model...")
494
- asr_manager.load()
495
- logger.info("ASR model loaded successfully")
496
-
497
- # Load translation models
498
- translation_tasks = [
499
- ('eng_Latn', 'kan_Knda', 'eng_indic'),
500
- ('kan_Knda', 'eng_Latn', 'indic_eng'),
501
- ('kan_Knda', 'hin_Deva', 'indic_indic'),
502
- ]
503
-
504
- for config in translation_configs:
505
- src_lang = config["src_lang"]
506
- tgt_lang = config["tgt_lang"]
507
- key = model_manager._get_model_key(src_lang, tgt_lang)
508
- translation_tasks.append((src_lang, tgt_lang, key))
509
-
510
- for src_lang, tgt_lang, key in translation_tasks:
511
- logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
512
- model_manager.load_model(src_lang, tgt_lang, key)
513
- logger.info(f"Translation model for {key} loaded successfully")
514
-
515
- logger.info("All models loaded successfully")
516
- except Exception as e:
517
- logger.error(f"Error loading models: {str(e)}")
518
- raise
519
 
520
- logger.info("Starting sequential model loading...")
521
  load_all_models()
 
522
  yield
 
523
  llm_manager.unload()
524
- logger.info("Server shutdown complete")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
  # FastAPI App
527
  app = FastAPI(
528
- title="Dhwani API",
529
- description="AI Chat API supporting Indian languages",
530
  version="1.0.0",
531
  redirect_slashes=False,
532
  lifespan=lifespan
533
  )
534
 
535
- # Add CORS Middleware
536
  app.add_middleware(
537
  CORSMiddleware,
538
  allow_origins=["*"],
@@ -541,13 +562,11 @@ app.add_middleware(
541
  allow_headers=["*"],
542
  )
543
 
544
- # Add Timing Middleware
545
  @app.middleware("http")
546
  async def add_request_timing(request: Request, call_next):
547
  start_time = time()
548
  response = await call_next(request)
549
- end_time = time()
550
- duration = end_time - start_time
551
  logger.info(f"Request to {request.url.path} took {duration:.3f} seconds")
552
  response.headers["X-Response-Time"] = f"{duration:.3f}"
553
  return response
@@ -555,7 +574,7 @@ async def add_request_timing(request: Request, call_next):
555
  limiter = Limiter(key_func=get_remote_address)
556
  app.state.limiter = limiter
557
 
558
- # API Endpoints
559
  @app.post("/audio/speech", response_class=StreamingResponse)
560
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
561
  if not tts_manager.model:
@@ -563,14 +582,7 @@ async def synthesize_kannada(request: KannadaSynthesizeRequest):
563
  kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
564
  if not request.text.strip():
565
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
566
-
567
- audio_buffer = synthesize_speech(
568
- tts_manager,
569
- text=request.text,
570
- ref_audio_name="KAN_F (Happy)",
571
- ref_text=kannada_example["ref_text"]
572
- )
573
-
574
  return StreamingResponse(
575
  audio_buffer,
576
  media_type="audio/wav",
@@ -579,61 +591,69 @@ async def synthesize_kannada(request: KannadaSynthesizeRequest):
579
 
580
  @app.post("/translate", response_model=TranslationResponse)
581
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
582
- input_sentences = request.sentences
583
- src_lang = request.src_lang
584
- tgt_lang = request.tgt_lang
585
-
586
- if not input_sentences:
587
  raise HTTPException(status_code=400, detail="Input sentences are required")
588
-
589
- batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
590
- inputs = translate_manager.tokenizer(
591
- batch,
592
- truncation=True,
593
- padding="longest",
594
- return_tensors="pt",
595
- return_attention_mask=True,
596
- ).to(translate_manager.device_type)
597
-
598
- with torch.no_grad():
599
- generated_tokens = translate_manager.model.generate(
600
- **inputs,
601
- use_cache=True,
602
- min_length=0,
603
- max_length=256,
604
- num_beams=5,
605
- num_return_sequences=1,
606
- )
607
-
608
  with translate_manager.tokenizer.as_target_tokenizer():
609
- generated_tokens = translate_manager.tokenizer.batch_decode(
610
- generated_tokens.detach().cpu().tolist(),
611
- skip_special_tokens=True,
612
- clean_up_tokenization_spaces=True,
613
- )
614
-
615
- translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
616
  return TranslationResponse(translations=translations)
617
 
618
- async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
619
- try:
620
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
621
- except ValueError as e:
622
- logger.info(f"Model not preloaded: {str(e)}, loading now...")
623
- key = model_manager._get_model_key(src_lang, tgt_lang)
624
- model_manager.load_model(src_lang, tgt_lang, key)
625
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
626
-
627
- if not translate_manager.model:
628
- translate_manager.load()
629
-
630
- request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
631
- response = await translate(request, translate_manager)
632
- return response.translations
633
-
634
  @app.get("/v1/health")
635
  async def health_check():
636
- return {"status": "healthy", "model": settings.llm_model_name}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
 
638
  @app.get("/")
639
  async def home():
@@ -644,6 +664,10 @@ async def unload_all_models():
644
  try:
645
  logger.info("Starting to unload all models...")
646
  llm_manager.unload()
 
 
 
 
647
  logger.info("All models unloaded successfully")
648
  return {"status": "success", "message": "All models unloaded"}
649
  except Exception as e:
@@ -655,6 +679,15 @@ async def load_all_models():
655
  try:
656
  logger.info("Starting to load all models...")
657
  llm_manager.load()
 
 
 
 
 
 
 
 
 
658
  logger.info("All models loaded successfully")
659
  return {"status": "success", "message": "All models loaded"}
660
  except Exception as e:
@@ -665,11 +698,7 @@ async def load_all_models():
665
  async def translate_endpoint(request: TranslationRequest):
666
  logger.info(f"Received translation request: {request.dict()}")
667
  try:
668
- translations = await perform_internal_translation(
669
- sentences=request.sentences,
670
- src_lang=request.src_lang,
671
- tgt_lang=request.tgt_lang
672
- )
673
  logger.info(f"Translation successful: {translations}")
674
  return TranslationResponse(translations=translations)
675
  except Exception as e:
@@ -679,44 +708,32 @@ async def translate_endpoint(request: TranslationRequest):
679
  @app.post("/v1/chat", response_model=ChatResponse)
680
  @limiter.limit(settings.chat_rate_limit)
681
  async def chat(request: Request, chat_request: ChatRequest):
682
- if not chat_request.prompt:
683
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
684
- logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
685
-
686
- EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
687
-
688
- try:
689
- if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
690
- translated_prompt = await perform_internal_translation(
691
- sentences=[chat_request.prompt],
692
- src_lang=chat_request.src_lang,
693
- tgt_lang="eng_Latn"
694
- )
695
- prompt_to_process = translated_prompt[0]
696
- logger.info(f"Translated prompt to English: {prompt_to_process}")
697
- else:
698
- prompt_to_process = chat_request.prompt
699
- logger.info("Prompt in English or European language, no translation needed")
700
-
701
- response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
702
- logger.info(f"Generated response: {response}")
703
-
704
- if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
705
- translated_response = await perform_internal_translation(
706
- sentences=[response],
707
- src_lang="eng_Latn",
708
- tgt_lang=chat_request.tgt_lang
709
- )
710
- final_response = translated_response[0]
711
- logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
712
- else:
713
- final_response = response
714
- logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
715
-
716
- return ChatResponse(response=final_response)
717
- except Exception as e:
718
- logger.error(f"Error processing request: {str(e)}")
719
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
720
 
721
  @app.post("/v1/visual_query/")
722
  async def visual_query(
@@ -725,42 +742,31 @@ async def visual_query(
725
  src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
726
  tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
727
  ):
728
- try:
729
- image = Image.open(file.file)
730
- if image.size == (0, 0):
731
- raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
732
-
733
- if src_lang != "eng_Latn":
734
- translated_query = await perform_internal_translation(
735
- sentences=[query],
736
- src_lang=src_lang,
737
- tgt_lang="eng_Latn"
738
- )
739
- query_to_process = translated_query[0]
740
- logger.info(f"Translated query to English: {query_to_process}")
741
- else:
742
- query_to_process = query
743
- logger.info("Query already in English, no translation needed")
744
-
745
- answer = await llm_manager.vision_query(image, query_to_process)
746
- logger.info(f"Generated English answer: {answer}")
747
-
748
- if tgt_lang != "eng_Latn":
749
- translated_answer = await perform_internal_translation(
750
- sentences=[answer],
751
- src_lang="eng_Latn",
752
- tgt_lang=tgt_lang
753
- )
754
- final_answer = translated_answer[0]
755
- logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
756
- else:
757
- final_answer = answer
758
- logger.info("Answer kept in English, no translation needed")
759
-
760
- return {"answer": final_answer}
761
- except Exception as e:
762
- logger.error(f"Error processing request: {str(e)}")
763
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
764
 
765
  @app.post("/v1/chat_v2", response_model=ChatResponse)
766
  @limiter.limit(settings.chat_rate_limit)
@@ -771,95 +777,70 @@ async def chat_v2(
771
  src_lang: str = Form("kan_Knda"),
772
  tgt_lang: str = Form("kan_Knda"),
773
  ):
774
- if not prompt:
775
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
776
- if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
777
- raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
778
-
779
- logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
780
-
781
- try:
782
- if image:
783
- image_data = await image.read()
784
- if not image_data:
785
- raise HTTPException(status_code=400, detail="Uploaded image is empty")
786
- img = Image.open(io.BytesIO(image_data))
787
-
788
- if src_lang != "eng_Latn":
789
- translated_prompt = await perform_internal_translation(
790
- sentences=[prompt],
791
- src_lang=src_lang,
792
- tgt_lang="eng_Latn"
793
- )
794
- prompt_to_process = translated_prompt[0]
795
- logger.info(f"Translated prompt to English: {prompt_to_process}")
796
- else:
797
- prompt_to_process = prompt
798
- logger.info("Prompt already in English, no translation needed")
799
-
800
- decoded = await llm_manager.chat_v2(img, prompt_to_process)
801
- logger.info(f"Generated English response: {decoded}")
802
-
803
- if tgt_lang != "eng_Latn":
804
- translated_response = await perform_internal_translation(
805
- sentences=[decoded],
806
- src_lang="eng_Latn",
807
- tgt_lang=tgt_lang
808
- )
809
- final_response = translated_response[0]
810
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
811
- else:
812
- final_response = decoded
813
- logger.info("Response kept in English, no translation needed")
814
- else:
815
- if src_lang != "eng_Latn":
816
- translated_prompt = await perform_internal_translation(
817
- sentences=[prompt],
818
- src_lang=src_lang,
819
- tgt_lang="eng_Latn"
820
- )
821
- prompt_to_process = translated_prompt[0]
822
- logger.info(f"Translated prompt to English: {prompt_to_process}")
823
- else:
824
- prompt_to_process = prompt
825
- logger.info("Prompt already in English, no translation needed")
826
-
827
- decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
828
- logger.info(f"Generated English response: {decoded}")
829
-
830
- if tgt_lang != "eng_Latn":
831
- translated_response = await perform_internal_translation(
832
- sentences=[decoded],
833
- src_lang="eng_Latn",
834
- tgt_lang=tgt_lang
835
- )
836
- final_response = translated_response[0]
837
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
838
  else:
839
- final_response = decoded
840
- logger.info("Response kept in English, no translation needed")
841
-
842
- return ChatResponse(response=final_response)
843
- except Exception as e:
844
- logger.error(f"Error processing request: {str(e)}")
845
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
846
 
847
  @app.post("/transcribe/", response_model=TranscriptionResponse)
848
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
849
- if not asr_manager.model:
850
- raise HTTPException(status_code=503, detail="ASR model not loaded")
851
- try:
852
- wav, sr = torchaudio.load(file.file)
853
- wav = torch.mean(wav, dim=0, keepdim=True)
854
- target_sample_rate = 16000
855
- if sr != target_sample_rate:
856
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
857
- wav = resampler(wav)
858
- transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
859
- return TranscriptionResponse(text=transcription_rnnt)
860
- except Exception as e:
861
- logger.error(f"Error in transcription: {str(e)}")
862
- raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
 
 
863
 
864
  @app.post("/v1/speech_to_speech")
865
  async def speech_to_speech(
@@ -867,28 +848,20 @@ async def speech_to_speech(
867
  file: UploadFile = File(...),
868
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
869
  ) -> StreamingResponse:
870
- if not tts_manager.model:
871
- raise HTTPException(status_code=503, detail="TTS model not loaded")
872
- transcription = await transcribe_audio(file, language)
873
- logger.info(f"Transcribed text: {transcription.text}")
 
 
 
 
 
 
 
 
 
874
 
875
- chat_request = ChatRequest(
876
- prompt=transcription.text,
877
- src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
878
- tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
879
- )
880
- processed_text = await chat(request, chat_request)
881
- logger.info(f"Processed text: {processed_text.response}")
882
-
883
- voice_request = KannadaSynthesizeRequest(text=processed_text.response)
884
- audio_response = await synthesize_kannada(voice_request)
885
- return audio_response
886
-
887
- LANGUAGE_TO_SCRIPT = {
888
- "kannada": "kan_Knda"
889
- }
890
-
891
- # Main Execution
892
  if __name__ == "__main__":
893
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
894
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
@@ -915,15 +888,12 @@ if __name__ == "__main__":
915
  settings.speech_rate_limit = global_settings["speech_rate_limit"]
916
 
917
  llm_manager = LLMManager(settings.llm_model_name)
918
-
919
  if selected_config["components"]["ASR"]:
920
- asr_model_name = selected_config["components"]["ASR"]["model"]
921
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
922
-
923
  if selected_config["components"]["Translation"]:
924
  translation_configs.extend(selected_config["components"]["Translation"])
925
 
926
  host = args.host if args.host != settings.host else settings.host
927
  port = args.port if args.port != settings.port else settings.port
928
 
929
- uvicorn.run(app, host=host, port=port)
 
2
  import io
3
  import os
4
  from time import time
5
+ from typing import List, Dict
6
  import tempfile
7
  import uvicorn
8
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
 
22
  import soundfile as sf
23
  import numpy as np
24
  import requests
25
+ import logging
26
  from starlette.responses import StreamingResponse
27
+ from logging_config import logger # Assumed external logging config
28
+ from tts_config import SPEED, ResponseFormat, config as tts_config # Assumed external TTS config
29
  import torchaudio
30
+ from tenacity import retry, stop_after_attempt, wait_exponential
31
+ from torch.cuda.amp import autocast
32
 
33
  # Device setup
34
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
35
+ torch_dtype = torch.float16 if device != "cpu" else torch.float32
36
+ logger.info(f"Using device: {device} with dtype: {torch_dtype}")
 
 
 
 
37
 
38
  # Check CUDA availability and version
39
  cuda_available = torch.cuda.is_available()
40
  cuda_version = torch.version.cuda if cuda_available else None
41
+ if cuda_available:
 
42
  device_idx = torch.cuda.current_device()
43
  capability = torch.cuda.get_device_capability(device_idx)
44
  compute_capability_float = float(f"{capability[0]}.{capability[1]}")
 
75
  bnb_4bit_compute_dtype=torch.bfloat16
76
  )
77
 
78
+ # Request queue for concurrency control
79
+ request_queue = asyncio.Queue(maxsize=10)
80
+
81
+ # Logging optimization
82
+ logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
83
+
84
+ # LLM Manager with batching
85
  class LLMManager:
86
+ def __init__(self, model_name: str, device: str = device):
87
  self.model_name = model_name
88
  self.device = torch.device(device)
89
+ self.torch_dtype = torch.float16 if self.device.type != "cpu" else torch.float32
90
  self.model = None
91
  self.processor = None
92
  self.is_loaded = False
93
+ self.token_cache = {}
94
+ self.load()
95
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
96
 
97
  def load(self):
98
  if not self.is_loaded:
99
  try:
100
+ if self.device.type == "cuda":
101
+ torch.set_float32_matmul_precision('high')
102
+ logger.info("Enabled TF32 matrix multiplication for improved GPU performance")
103
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
104
  self.model_name,
105
  device_map="auto",
106
  quantization_config=quantization_config,
107
  torch_dtype=self.torch_dtype
108
+ ).eval()
 
109
  self.processor = AutoProcessor.from_pretrained(self.model_name)
110
+ dummy_input = self.processor("test", return_tensors="pt").to(self.device)
111
+ with torch.no_grad():
112
+ self.model.generate(**dummy_input, max_new_tokens=10)
113
  self.is_loaded = True
114
+ logger.info(f"LLM {self.model_name} loaded and warmed up on {self.device}")
115
  except Exception as e:
116
  logger.error(f"Failed to load LLM: {str(e)}")
117
+ self.is_loaded = False
118
 
119
  def unload(self):
120
  if self.is_loaded:
 
122
  del self.processor
123
  if self.device.type == "cuda":
124
  torch.cuda.empty_cache()
125
+ logger.info(f"GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
126
  self.is_loaded = False
127
+ self.token_cache.clear()
128
+ logger.info(f"LLM {self.model_name} unloaded")
129
 
130
+ async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
131
  if not self.is_loaded:
132
+ logger.warning("LLM not loaded; attempting reload")
133
  self.load()
134
+ if not self.is_loaded:
135
+ raise HTTPException(status_code=503, detail="LLM model unavailable")
136
 
137
+ cache_key = f"{prompt}:{max_tokens}:{temperature}"
138
+ if cache_key in self.token_cache:
139
+ logger.info("Using cached response")
140
+ return self.token_cache[cache_key]
 
 
 
 
 
 
141
 
142
+ future = asyncio.Future()
143
+ await request_queue.put({"prompt": prompt, "max_tokens": max_tokens, "temperature": temperature, "future": future})
144
+ response = await future
145
+ self.token_cache[cache_key] = response
146
+ logger.info(f"Generated response: {response}")
147
+ return response
148
+
149
+ async def batch_generate(self, prompts: List[Dict]) -> List[str]:
150
+ messages_batch = [
151
+ [
152
+ {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]},
153
+ {"role": "user", "content": [{"type": "text", "text": prompt["prompt"]}]}
154
+ ]
155
+ for prompt in prompts
156
+ ]
157
  try:
158
  inputs_vlm = self.processor.apply_chat_template(
159
+ messages_batch,
160
  add_generation_prompt=True,
161
  tokenize=True,
162
  return_dict=True,
163
+ return_tensors="pt",
164
+ padding=True
165
  ).to(self.device, dtype=torch.bfloat16)
166
+ with autocast(), torch.no_grad():
167
+ outputs = self.model.generate(
168
+ **inputs_vlm,
169
+ max_new_tokens=max(prompt["max_tokens"] for prompt in prompts),
170
+ do_sample=True,
171
+ top_p=0.9,
172
+ temperature=max(prompt["temperature"] for prompt in prompts)
173
+ )
174
+ responses = [
175
+ self.processor.decode(output[input_len:], skip_special_tokens=True)
176
+ for output, input_len in zip(outputs, inputs_vlm["input_ids"].shape[1])
177
+ ]
178
+ logger.info(f"Batch generated {len(responses)} responses")
179
+ return responses
180
  except Exception as e:
181
+ logger.error(f"Error in batch generation: {str(e)}")
182
+ raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  async def vision_query(self, image: Image.Image, query: str) -> str:
185
  if not self.is_loaded:
186
  self.load()
 
187
  messages_vlm = [
188
+ {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]},
189
+ {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
 
 
 
 
 
 
190
  ]
 
 
 
 
 
 
 
 
191
  try:
192
  inputs_vlm = self.processor.apply_chat_template(
193
  messages_vlm,
 
199
  except Exception as e:
200
  logger.error(f"Error in apply_chat_template: {str(e)}")
201
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
202
  input_len = inputs_vlm["input_ids"].shape[-1]
 
203
  with torch.inference_mode():
204
+ generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
 
 
 
 
 
205
  generation = generation[0][input_len:]
 
206
  decoded = self.processor.decode(generation, skip_special_tokens=True)
207
  logger.info(f"Vision query response: {decoded}")
208
  return decoded
 
210
  async def chat_v2(self, image: Image.Image, query: str) -> str:
211
  if not self.is_loaded:
212
  self.load()
 
213
  messages_vlm = [
214
+ {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]},
215
+ {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
 
 
 
 
 
 
216
  ]
 
 
 
 
 
 
 
 
217
  try:
218
  inputs_vlm = self.processor.apply_chat_template(
219
  messages_vlm,
 
225
  except Exception as e:
226
  logger.error(f"Error in apply_chat_template: {str(e)}")
227
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
228
  input_len = inputs_vlm["input_ids"].shape[-1]
 
229
  with torch.inference_mode():
230
+ generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
 
 
 
 
 
231
  generation = generation[0][input_len:]
 
232
  decoded = self.processor.decode(generation, skip_special_tokens=True)
233
  logger.info(f"Chat_v2 response: {decoded}")
234
  return decoded
 
236
  # TTS Manager
237
  class TTSManager:
238
  def __init__(self, device_type=device):
239
+ self.device_type = torch.device(device_type)
240
  self.model = None
241
  self.repo_id = "ai4bharat/IndicF5"
242
+ self.load()
243
 
244
  def load(self):
245
  if not self.model:
246
+ logger.info(f"Loading TTS model {self.repo_id} on {self.device_type}...")
247
+ self.model = AutoModel.from_pretrained(self.repo_id, trust_remote_code=True).to(self.device_type)
248
+ logger.info("TTS model loaded")
249
+
250
+ def unload(self):
251
+ if self.model:
252
+ del self.model
253
+ if self.device_type.type == "cuda":
254
+ torch.cuda.empty_cache()
255
+ logger.info(f"TTS GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
256
+ self.model = None
257
+ logger.info("TTS model unloaded")
258
 
259
  def synthesize(self, text, ref_audio_path, ref_text):
260
  if not self.model:
261
  raise ValueError("TTS model not loaded")
262
+ with autocast():
263
+ return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  # Translation Manager
266
  class TranslateManager:
267
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
268
+ self.device_type = torch.device(device_type)
269
  self.tokenizer = None
270
  self.model = None
271
  self.src_lang = src_lang
272
  self.tgt_lang = tgt_lang
273
  self.use_distilled = use_distilled
274
+ self.load()
275
 
276
  def load(self):
277
  if not self.tokenizer or not self.model:
 
283
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
284
  else:
285
  raise ValueError("Invalid language combination")
286
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
287
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
288
  model_name,
289
  trust_remote_code=True,
290
  torch_dtype=torch.float16,
291
  attn_implementation="flash_attention_2"
292
+ ).to(self.device_type)
 
293
  self.model = torch.compile(self.model, mode="reduce-overhead")
294
  logger.info(f"Translation model {model_name} loaded")
295
 
296
+ # Model Manager
297
  class ModelManager:
298
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
299
  self.models = {}
 
304
  def load_model(self, src_lang, tgt_lang, key):
305
  logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
306
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
 
307
  self.models[key] = translate_manager
308
  logger.info(f"Loaded translation model for {key}")
309
 
310
  def get_model(self, src_lang, tgt_lang):
311
  key = self._get_model_key(src_lang, tgt_lang)
312
+ if key not in self.models and self.is_lazy_loading:
313
+ self.load_model(src_lang, tgt_lang, key)
314
+ return self.models.get(key) or (self.load_model(src_lang, tgt_lang, key) or self.models[key])
 
 
 
315
 
316
  def _get_model_key(self, src_lang, tgt_lang):
317
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
 
324
 
325
  # ASR Manager
326
  class ASRModelManager:
327
+ def __init__(self, device_type=device):
328
+ self.device_type = torch.device(device_type)
329
  self.model = None
330
  self.model_language = {"kannada": "kn"}
331
+ self.load()
332
 
333
  def load(self):
334
  if not self.model:
335
+ logger.info(f"Loading ASR model on {self.device_type}...")
336
  self.model = AutoModel.from_pretrained(
337
  "ai4bharat/indic-conformer-600m-multilingual",
338
  trust_remote_code=True
339
+ ).to(self.device_type)
 
340
  logger.info("ASR model loaded")
341
 
342
+ def unload(self):
343
+ if self.model:
344
+ del self.model
345
+ if self.device_type.type == "cuda":
346
+ torch.cuda.empty_cache()
347
+ logger.info(f"ASR GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
348
+ self.model = None
349
+ logger.info("ASR model unloaded")
350
+
351
  # Global Managers
352
  llm_manager = LLMManager(settings.llm_model_name)
353
  model_manager = ModelManager()
 
355
  tts_manager = TTSManager()
356
  ip = IndicProcessor(inference=True)
357
 
358
+ # TTS Constants
359
+ EXAMPLES = [
360
+ {
361
+ "audio_name": "KAN_F (Happy)",
362
+ "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
363
+ "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ।",
364
+ "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
365
+ },
366
+ ]
367
+
368
  # Pydantic Models
369
+ class SynthesizeRequest(BaseModel):
370
+ text: str
371
+ ref_audio_name: str
372
+ ref_text: str = None
373
+
374
+ class KannadaSynthesizeRequest(BaseModel):
375
+ text: str
376
+
377
+ @field_validator("text")
378
+ def text_must_be_valid(cls, v):
379
+ if len(v) > 500:
380
+ raise ValueError("Text cannot exceed 500 characters")
381
+ return v.strip()
382
+
383
  class ChatRequest(BaseModel):
384
  prompt: str
385
  src_lang: str = "kan_Knda"
 
397
  raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
398
  return v
399
 
 
400
  class ChatResponse(BaseModel):
401
  response: str
402
 
 
411
  class TranslationResponse(BaseModel):
412
  translations: List[str]
413
 
414
+ # TTS Functions
415
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
416
+ def load_audio_from_url(url: str):
417
+ response = requests.get(url)
418
+ if response.status_code == 200:
419
+ audio_data, sample_rate = sf.read(io.BytesIO(response.content))
420
+ return sample_rate, audio_data
421
+ raise HTTPException(status_code=500, detail="Failed to load reference audio from URL after retries")
422
+
423
+ async def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str) -> io.BytesIO:
424
+ async with request_queue:
425
+ ref_audio_url = next((ex["audio_url"] for ex in EXAMPLES if ex["audio_name"] == ref_audio_name), None)
426
+ if not ref_audio_url:
427
+ raise HTTPException(status_code=400, detail="Invalid reference audio name.")
428
+ if not text.strip() or not ref_text.strip():
429
+ raise HTTPException(status_code=400, detail="Text or reference text cannot be empty.")
430
+
431
+ logger.info(f"Synthesizing speech for text: {text[:50]}... with ref_audio: {ref_audio_name}")
432
+ loop = asyncio.get_running_loop()
433
+ sample_rate, audio_data = await loop.run_in_executor(None, load_audio_from_url, ref_audio_url)
434
+
435
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio:
436
+ await loop.run_in_executor(None, sf.write, temp_audio.name, audio_data, sample_rate, "WAV")
437
+ temp_audio.flush()
438
+ audio = tts_manager.synthesize(text, temp_audio.name, ref_text)
439
+
440
+ buffer = io.BytesIO()
441
+ await loop.run_in_executor(None, sf.write, buffer, audio.astype(np.float32) / 32768.0 if audio.dtype == np.int16 else audio, 24000, "WAV")
442
+ buffer.seek(0)
443
+ logger.info("Speech synthesis completed")
444
+ return buffer
445
+
446
+ # Supported Languages
447
+ SUPPORTED_LANGUAGES = {
448
+ "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
449
+ "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
450
+ "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
451
+ "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
452
+ "kan_Knda", "ory_Orya",
453
+ "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
454
+ "por_Latn", "rus_Cyrl", "pol_Latn"
455
+ }
456
+
457
  # Dependency
458
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
459
  return model_manager.get_model(src_lang, tgt_lang)
460
 
461
+ # Translation Function
462
+ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
463
+ try:
464
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
465
+ except ValueError as e:
466
+ logger.info(f"Model not preloaded: {str(e)}, loading now...")
467
+ key = model_manager._get_model_key(src_lang, tgt_lang)
468
+ model_manager.load_model(src_lang, tgt_lang, key)
469
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
470
+
471
+ batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
472
+ inputs = translate_manager.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(translate_manager.device_type)
473
+ with torch.no_grad(), autocast():
474
+ generated_tokens = translate_manager.model.generate(**inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1)
475
+ with translate_manager.tokenizer.as_target_tokenizer():
476
+ generated_tokens = translate_manager.tokenizer.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
477
+ return ip.postprocess_batch(generated_tokens, lang=tgt_lang)
478
+
479
  # Lifespan Event Handler
480
  translation_configs = []
481
 
482
  @asynccontextmanager
483
  async def lifespan(app: FastAPI):
484
  def load_all_models():
485
+ logger.info("Loading LLM model...")
486
+ llm_manager.load()
487
+ logger.info("Loading TTS model...")
488
+ tts_manager.load()
489
+ logger.info("Loading ASR model...")
490
+ asr_manager.load()
491
+ translation_tasks = [
492
+ ('eng_Latn', 'kan_Knda', 'eng_indic'),
493
+ ('kan_Knda', 'eng_Latn', 'indic_eng'),
494
+ ('kan_Knda', 'hin_Deva', 'indic_indic'),
495
+ ]
496
+ for config in translation_configs:
497
+ src_lang = config["src_lang"]
498
+ tgt_lang = config["tgt_lang"]
499
+ key = model_manager._get_model_key(src_lang, tgt_lang)
500
+ translation_tasks.append((src_lang, tgt_lang, key))
501
+ for src_lang, tgt_lang, key in translation_tasks:
502
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
503
+ model_manager.load_model(src_lang, tgt_lang, key)
504
+ logger.info("All models loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
+ logger.info("Starting server with preloaded models...")
507
  load_all_models()
508
+ batch_task = asyncio.create_task(batch_worker())
509
  yield
510
+ batch_task.cancel()
511
  llm_manager.unload()
512
+ tts_manager.unload()
513
+ asr_manager.unload()
514
+ for model in model_manager.models.values():
515
+ model.unload()
516
+ logger.info("Server shutdown complete; all models unloaded")
517
+
518
+ # Batch Worker
519
+ async def batch_worker():
520
+ while True:
521
+ batch = []
522
+ last_request_time = time()
523
+ try:
524
+ while len(batch) < 4:
525
+ try:
526
+ request = await asyncio.wait_for(request_queue.get(), timeout=1.0)
527
+ batch.append(request)
528
+ current_time = time()
529
+ if current_time - last_request_time > 1.0 and batch:
530
+ break
531
+ last_request_time = current_time
532
+ except asyncio.TimeoutError:
533
+ if batch:
534
+ break
535
+ continue
536
+ if batch:
537
+ start_time = time()
538
+ responses = await llm_manager.batch_generate(batch)
539
+ duration = time() - start_time
540
+ logger.info(f"Batch of {len(batch)} requests processed in {duration:.3f} seconds")
541
+ for request, response in zip(batch, responses):
542
+ request["future"].set_result(response)
543
+ except Exception as e:
544
+ logger.error(f"Batch worker error: {str(e)}")
545
+ for request in batch:
546
+ request["future"].set_exception(e)
547
 
548
  # FastAPI App
549
  app = FastAPI(
550
+ title="Optimized Dhwani API",
551
+ description="AI Chat API supporting Indian languages with performance enhancements",
552
  version="1.0.0",
553
  redirect_slashes=False,
554
  lifespan=lifespan
555
  )
556
 
 
557
  app.add_middleware(
558
  CORSMiddleware,
559
  allow_origins=["*"],
 
562
  allow_headers=["*"],
563
  )
564
 
 
565
  @app.middleware("http")
566
  async def add_request_timing(request: Request, call_next):
567
  start_time = time()
568
  response = await call_next(request)
569
+ duration = time() - start_time
 
570
  logger.info(f"Request to {request.url.path} took {duration:.3f} seconds")
571
  response.headers["X-Response-Time"] = f"{duration:.3f}"
572
  return response
 
574
  limiter = Limiter(key_func=get_remote_address)
575
  app.state.limiter = limiter
576
 
577
+ # Endpoints
578
  @app.post("/audio/speech", response_class=StreamingResponse)
579
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
580
  if not tts_manager.model:
 
582
  kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
583
  if not request.text.strip():
584
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
585
+ audio_buffer = await synthesize_speech(tts_manager, request.text, "KAN_F (Happy)", kannada_example["ref_text"])
 
 
 
 
 
 
 
586
  return StreamingResponse(
587
  audio_buffer,
588
  media_type="audio/wav",
 
591
 
592
  @app.post("/translate", response_model=TranslationResponse)
593
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
594
+ if not request.sentences:
 
 
 
 
595
  raise HTTPException(status_code=400, detail="Input sentences are required")
596
+ batch = ip.preprocess_batch(request.sentences, src_lang=request.src_lang, tgt_lang=request.tgt_lang)
597
+ inputs = translate_manager.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(translate_manager.device_type)
598
+ with torch.no_grad(), autocast():
599
+ generated_tokens = translate_manager.model.generate(**inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  with translate_manager.tokenizer.as_target_tokenizer():
601
+ generated_tokens = translate_manager.tokenizer.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
602
+ translations = ip.postprocess_batch(generated_tokens, lang=request.tgt_lang)
 
 
 
 
 
603
  return TranslationResponse(translations=translations)
604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  @app.get("/v1/health")
606
  async def health_check():
607
+ memory_usage = torch.cuda.memory_allocated() / (24 * 1024**3) if cuda_available else 0
608
+ if memory_usage > 0.9:
609
+ logger.warning("GPU memory usage exceeds 90%; consider unloading models")
610
+ llm_status = "unhealthy"
611
+ llm_latency = None
612
+ if llm_manager.is_loaded:
613
+ start = time()
614
+ try:
615
+ llm_test = await llm_manager.generate("What is the capital of Karnataka?", max_tokens=10)
616
+ llm_latency = time() - start
617
+ llm_status = "healthy" if llm_test else "unhealthy"
618
+ except Exception as e:
619
+ logger.error(f"LLM health check failed: {str(e)}")
620
+ tts_status = "unhealthy"
621
+ tts_latency = None
622
+ if tts_manager.model:
623
+ start = time()
624
+ try:
625
+ audio_buffer = await synthesize_speech(tts_manager, "Test", "KAN_F (Happy)", EXAMPLES[0]["ref_text"])
626
+ tts_latency = time() - start
627
+ tts_status = "healthy" if audio_buffer else "unhealthy"
628
+ except Exception as e:
629
+ logger.error(f"TTS health check failed: {str(e)}")
630
+ asr_status = "unhealthy"
631
+ asr_latency = None
632
+ if asr_manager.model:
633
+ start = time()
634
+ try:
635
+ dummy_audio = np.zeros(16000, dtype=np.float32)
636
+ wav = torch.tensor(dummy_audio).unsqueeze(0).to(device)
637
+ with autocast(), torch.no_grad():
638
+ asr_test = asr_manager.model(wav, asr_manager.model_language["kannada"], "rnnt")
639
+ asr_latency = time() - start
640
+ asr_status = "healthy" if asr_test else "unhealthy"
641
+ except Exception as e:
642
+ logger.error(f"ASR health check failed: {str(e)}")
643
+ status = {
644
+ "status": "healthy" if llm_status == "healthy" and tts_status == "healthy" and asr_status == "healthy" else "degraded",
645
+ "model": settings.llm_model_name,
646
+ "llm_status": llm_status,
647
+ "llm_latency": f"{llm_latency:.3f}s" if llm_latency else "N/A",
648
+ "tts_status": tts_status,
649
+ "tts_latency": f"{tts_latency:.3f}s" if tts_latency else "N/A",
650
+ "asr_status": asr_status,
651
+ "asr_latency": f"{asr_latency:.3f}s" if asr_latency else "N/A",
652
+ "translation_models": list(model_manager.models.keys()),
653
+ "gpu_memory_usage": f"{memory_usage:.2%}"
654
+ }
655
+ logger.info("Health check completed")
656
+ return status
657
 
658
  @app.get("/")
659
  async def home():
 
664
  try:
665
  logger.info("Starting to unload all models...")
666
  llm_manager.unload()
667
+ tts_manager.unload()
668
+ asr_manager.unload()
669
+ for model in model_manager.models.values():
670
+ model.unload()
671
  logger.info("All models unloaded successfully")
672
  return {"status": "success", "message": "All models unloaded"}
673
  except Exception as e:
 
679
  try:
680
  logger.info("Starting to load all models...")
681
  llm_manager.load()
682
+ tts_manager.load()
683
+ asr_manager.load()
684
+ for src_lang, tgt_lang, key in [
685
+ ('eng_Latn', 'kan_Knda', 'eng_indic'),
686
+ ('kan_Knda', 'eng_Latn', 'indic_eng'),
687
+ ('kan_Knda', 'hin_Deva', 'indic_indic'),
688
+ ]:
689
+ if key not in model_manager.models:
690
+ model_manager.load_model(src_lang, tgt_lang, key)
691
  logger.info("All models loaded successfully")
692
  return {"status": "success", "message": "All models loaded"}
693
  except Exception as e:
 
698
  async def translate_endpoint(request: TranslationRequest):
699
  logger.info(f"Received translation request: {request.dict()}")
700
  try:
701
+ translations = await perform_internal_translation(request.sentences, request.src_lang, request.tgt_lang)
 
 
 
 
702
  logger.info(f"Translation successful: {translations}")
703
  return TranslationResponse(translations=translations)
704
  except Exception as e:
 
708
  @app.post("/v1/chat", response_model=ChatResponse)
709
  @limiter.limit(settings.chat_rate_limit)
710
  async def chat(request: Request, chat_request: ChatRequest):
711
+ async with request_queue:
712
+ if not chat_request.prompt:
713
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
714
+ logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
715
+ EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
716
+ try:
717
+ if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
718
+ translated_prompt = await perform_internal_translation([chat_request.prompt], chat_request.src_lang, "eng_Latn")
719
+ prompt_to_process = translated_prompt[0]
720
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
721
+ else:
722
+ prompt_to_process = chat_request.prompt
723
+ logger.info("Prompt in English or European language, no translation needed")
724
+ response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
725
+ logger.info(f"Generated English response: {response}")
726
+ if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
727
+ translated_response = await perform_internal_translation([response], "eng_Latn", chat_request.tgt_lang)
728
+ final_response = translated_response[0]
729
+ logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
730
+ else:
731
+ final_response = response
732
+ logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
733
+ return ChatResponse(response=final_response)
734
+ except Exception as e:
735
+ logger.error(f"Error processing request: {str(e)}")
736
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
737
 
738
  @app.post("/v1/visual_query/")
739
  async def visual_query(
 
742
  src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
743
  tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
744
  ):
745
+ async with request_queue:
746
+ try:
747
+ image = Image.open(file.file)
748
+ if image.size == (0, 0):
749
+ raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
750
+ if src_lang != "eng_Latn":
751
+ translated_query = await perform_internal_translation([query], src_lang, "eng_Latn")
752
+ query_to_process = translated_query[0]
753
+ logger.info(f"Translated query to English: {query_to_process}")
754
+ else:
755
+ query_to_process = query
756
+ logger.info("Query already in English, no translation needed")
757
+ answer = await llm_manager.vision_query(image, query_to_process)
758
+ logger.info(f"Generated English answer: {answer}")
759
+ if tgt_lang != "eng_Latn":
760
+ translated_answer = await perform_internal_translation([answer], "eng_Latn", tgt_lang)
761
+ final_answer = translated_answer[0]
762
+ logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
763
+ else:
764
+ final_answer = answer
765
+ logger.info("Answer kept in English, no translation needed")
766
+ return {"answer": final_answer}
767
+ except Exception as e:
768
+ logger.error(f"Error processing request: {str(e)}")
769
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
770
 
771
  @app.post("/v1/chat_v2", response_model=ChatResponse)
772
  @limiter.limit(settings.chat_rate_limit)
 
777
  src_lang: str = Form("kan_Knda"),
778
  tgt_lang: str = Form("kan_Knda"),
779
  ):
780
+ async with request_queue:
781
+ if not prompt:
782
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
783
+ if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
784
+ raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
785
+ logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
786
+ try:
787
+ if image:
788
+ image_data = await image.read()
789
+ if not image_data:
790
+ raise HTTPException(status_code=400, detail="Uploaded image is empty")
791
+ img = Image.open(io.BytesIO(image_data))
792
+ if src_lang != "eng_Latn":
793
+ translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
794
+ prompt_to_process = translated_prompt[0]
795
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
796
+ else:
797
+ prompt_to_process = prompt
798
+ decoded = await llm_manager.chat_v2(img, prompt_to_process)
799
+ logger.info(f"Generated English response: {decoded}")
800
+ if tgt_lang != "eng_Latn":
801
+ translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
802
+ final_response = translated_response[0]
803
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
804
+ else:
805
+ final_response = decoded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
806
  else:
807
+ if src_lang != "eng_Latn":
808
+ translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
809
+ prompt_to_process = translated_prompt[0]
810
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
811
+ else:
812
+ prompt_to_process = prompt
813
+ decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
814
+ logger.info(f"Generated English response: {decoded}")
815
+ if tgt_lang != "eng_Latn":
816
+ translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
817
+ final_response = translated_response[0]
818
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
819
+ else:
820
+ final_response = decoded
821
+ return ChatResponse(response=final_response)
822
+ except Exception as e:
823
+ logger.error(f"Error processing request: {str(e)}")
824
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
825
 
826
  @app.post("/transcribe/", response_model=TranscriptionResponse)
827
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
828
+ async with request_queue:
829
+ if not asr_manager.model:
830
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
831
+ try:
832
+ wav, sr = torchaudio.load(file.file, backend="cuda" if cuda_available else "cpu")
833
+ wav = torch.mean(wav, dim=0, keepdim=True).to(device)
834
+ target_sample_rate = 16000
835
+ if sr != target_sample_rate:
836
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device)
837
+ wav = resampler(wav)
838
+ with autocast(), torch.no_grad():
839
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
840
+ return TranscriptionResponse(text=transcription_rnnt)
841
+ except Exception as e:
842
+ logger.error(f"Error in transcription: {str(e)}")
843
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
844
 
845
  @app.post("/v1/speech_to_speech")
846
  async def speech_to_speech(
 
848
  file: UploadFile = File(...),
849
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
850
  ) -> StreamingResponse:
851
+ async with request_queue:
852
+ if not tts_manager.model:
853
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
854
+ transcription = await transcribe_audio(file, language)
855
+ logger.info(f"Transcribed text: {transcription.text}")
856
+ chat_request = ChatRequest(prompt=transcription.text, src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"), tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"))
857
+ processed_text = await chat(request, chat_request)
858
+ logger.info(f"Processed text: {processed_text.response}")
859
+ voice_request = KannadaSynthesizeRequest(text=processed_text.response)
860
+ audio_response = await synthesize_kannada(voice_request)
861
+ return audio_response
862
+
863
+ LANGUAGE_TO_SCRIPT = {"kannada": "kan_Knda"}
864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
  if __name__ == "__main__":
866
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
867
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
 
888
  settings.speech_rate_limit = global_settings["speech_rate_limit"]
889
 
890
  llm_manager = LLMManager(settings.llm_model_name)
 
891
  if selected_config["components"]["ASR"]:
 
892
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
 
893
  if selected_config["components"]["Translation"]:
894
  translation_configs.extend(selected_config["components"]["Translation"])
895
 
896
  host = args.host if args.host != settings.host else settings.host
897
  port = args.port if args.port != settings.port else settings.port
898
 
899
+ uvicorn.run(app, host=host, port=port, workers=2)