sachin commited on
Commit
ff84a71
·
1 Parent(s): b192c58
Files changed (1) hide show
  1. src/server/main.py +493 -463
src/server/main.py CHANGED
@@ -2,7 +2,7 @@ import argparse
2
  import io
3
  import os
4
  from time import time
5
- from typing import List, Dict
6
  import tempfile
7
  import uvicorn
8
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
@@ -22,23 +22,25 @@ from contextlib import asynccontextmanager
22
  import soundfile as sf
23
  import numpy as np
24
  import requests
25
- import logging
26
  from starlette.responses import StreamingResponse
27
- from logging_config import logger # Assumed external logging config
28
- from tts_config import SPEED, ResponseFormat, config as tts_config # Assumed external TTS config
29
  import torchaudio
30
- from tenacity import retry, stop_after_attempt, wait_exponential
31
- from torch.cuda.amp import autocast
32
 
33
  # Device setup
34
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
35
- torch_dtype = torch.float16 if device != "cpu" else torch.float32
36
- logger.info(f"Using device: {device} with dtype: {torch_dtype}")
 
 
 
 
37
 
38
  # Check CUDA availability and version
39
  cuda_available = torch.cuda.is_available()
40
  cuda_version = torch.version.cuda if cuda_available else None
41
- if cuda_available:
 
42
  device_idx = torch.cuda.current_device()
43
  capability = torch.cuda.get_device_capability(device_idx)
44
  compute_capability_float = float(f"{capability[0]}.{capability[1]}")
@@ -75,46 +77,33 @@ quantization_config = BitsAndBytesConfig(
75
  bnb_4bit_compute_dtype=torch.bfloat16
76
  )
77
 
78
- # Request queue for concurrency control
79
- request_queue = asyncio.Queue(maxsize=10)
80
-
81
- # Logging optimization
82
- logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
83
-
84
- # LLM Manager with batching
85
  class LLMManager:
86
- def __init__(self, model_name: str, device: str = device):
87
  self.model_name = model_name
88
  self.device = torch.device(device)
89
- self.torch_dtype = torch.float16 if self.device.type != "cpu" else torch.float32
90
  self.model = None
91
  self.processor = None
92
  self.is_loaded = False
93
- self.token_cache = {}
94
- self.load()
95
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
96
 
97
  def load(self):
98
  if not self.is_loaded:
99
  try:
100
- if self.device.type == "cuda":
101
- torch.set_float32_matmul_precision('high')
102
- logger.info("Enabled TF32 matrix multiplication for improved GPU performance")
103
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
104
  self.model_name,
105
  device_map="auto",
106
  quantization_config=quantization_config,
107
  torch_dtype=self.torch_dtype
108
- ).eval()
 
109
  self.processor = AutoProcessor.from_pretrained(self.model_name)
110
- dummy_input = self.processor("test", return_tensors="pt").to(self.device)
111
- with torch.no_grad():
112
- self.model.generate(**dummy_input, max_new_tokens=10)
113
  self.is_loaded = True
114
- logger.info(f"LLM {self.model_name} loaded and warmed up on {self.device}")
115
  except Exception as e:
116
  logger.error(f"Failed to load LLM: {str(e)}")
117
- self.is_loaded = False
118
 
119
  def unload(self):
120
  if self.is_loaded:
@@ -122,72 +111,74 @@ class LLMManager:
122
  del self.processor
123
  if self.device.type == "cuda":
124
  torch.cuda.empty_cache()
125
- logger.info(f"GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
126
  self.is_loaded = False
127
- self.token_cache.clear()
128
- logger.info(f"LLM {self.model_name} unloaded")
129
 
130
- async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
131
  if not self.is_loaded:
132
- logger.warning("LLM not loaded; attempting reload")
133
  self.load()
134
- if not self.is_loaded:
135
- raise HTTPException(status_code=503, detail="LLM model unavailable")
136
 
137
- cache_key = f"{prompt}:{max_tokens}:{temperature}"
138
- if cache_key in self.token_cache:
139
- logger.info("Using cached response")
140
- return self.token_cache[cache_key]
141
-
142
- future = asyncio.Future()
143
- await request_queue.put({"prompt": prompt, "max_tokens": max_tokens, "temperature": temperature, "future": future})
144
- response = await future
145
- self.token_cache[cache_key] = response
146
- logger.info(f"Generated response: {response}")
147
- return response
148
-
149
- async def batch_generate(self, prompts: List[Dict]) -> List[str]:
150
- messages_batch = [
151
- [
152
- {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]},
153
- {"role": "user", "content": [{"type": "text", "text": prompt["prompt"]}]}
154
- ]
155
- for prompt in prompts
156
  ]
 
157
  try:
158
  inputs_vlm = self.processor.apply_chat_template(
159
- messages_batch,
160
  add_generation_prompt=True,
161
  tokenize=True,
162
  return_dict=True,
163
- return_tensors="pt",
164
- padding=True
165
  ).to(self.device, dtype=torch.bfloat16)
166
- with autocast(), torch.no_grad():
167
- outputs = self.model.generate(
168
- **inputs_vlm,
169
- max_new_tokens=max(prompt["max_tokens"] for prompt in prompts),
170
- do_sample=True,
171
- top_p=0.9,
172
- temperature=max(prompt["temperature"] for prompt in prompts)
173
- )
174
- responses = [
175
- self.processor.decode(output[input_len:], skip_special_tokens=True)
176
- for output, input_len in zip(outputs, inputs_vlm["input_ids"].shape[1])
177
- ]
178
- logger.info(f"Batch generated {len(responses)} responses")
179
- return responses
180
  except Exception as e:
181
- logger.error(f"Error in batch generation: {str(e)}")
182
- raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  async def vision_query(self, image: Image.Image, query: str) -> str:
185
  if not self.is_loaded:
186
  self.load()
 
187
  messages_vlm = [
188
- {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]},
189
- {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
 
 
 
 
 
 
190
  ]
 
 
 
 
 
 
 
 
191
  try:
192
  inputs_vlm = self.processor.apply_chat_template(
193
  messages_vlm,
@@ -199,10 +190,18 @@ class LLMManager:
199
  except Exception as e:
200
  logger.error(f"Error in apply_chat_template: {str(e)}")
201
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
202
  input_len = inputs_vlm["input_ids"].shape[-1]
 
203
  with torch.inference_mode():
204
- generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
 
 
 
 
 
205
  generation = generation[0][input_len:]
 
206
  decoded = self.processor.decode(generation, skip_special_tokens=True)
207
  logger.info(f"Vision query response: {decoded}")
208
  return decoded
@@ -210,10 +209,25 @@ class LLMManager:
210
  async def chat_v2(self, image: Image.Image, query: str) -> str:
211
  if not self.is_loaded:
212
  self.load()
 
213
  messages_vlm = [
214
- {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]},
215
- {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
 
 
 
 
 
 
216
  ]
 
 
 
 
 
 
 
 
217
  try:
218
  inputs_vlm = self.processor.apply_chat_template(
219
  messages_vlm,
@@ -225,10 +239,18 @@ class LLMManager:
225
  except Exception as e:
226
  logger.error(f"Error in apply_chat_template: {str(e)}")
227
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
228
  input_len = inputs_vlm["input_ids"].shape[-1]
 
229
  with torch.inference_mode():
230
- generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
 
 
 
 
 
231
  generation = generation[0][input_len:]
 
232
  decoded = self.processor.decode(generation, skip_special_tokens=True)
233
  logger.info(f"Chat_v2 response: {decoded}")
234
  return decoded
@@ -236,42 +258,101 @@ class LLMManager:
236
  # TTS Manager
237
  class TTSManager:
238
  def __init__(self, device_type=device):
239
- self.device_type = torch.device(device_type)
240
  self.model = None
241
  self.repo_id = "ai4bharat/IndicF5"
242
- self.load()
243
 
244
  def load(self):
245
  if not self.model:
246
- logger.info(f"Loading TTS model {self.repo_id} on {self.device_type}...")
247
- self.model = AutoModel.from_pretrained(self.repo_id, trust_remote_code=True).to(self.device_type)
248
- logger.info("TTS model loaded")
249
-
250
- def unload(self):
251
- if self.model:
252
- del self.model
253
- if self.device_type.type == "cuda":
254
- torch.cuda.empty_cache()
255
- logger.info(f"TTS GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
256
- self.model = None
257
- logger.info("TTS model unloaded")
258
 
259
  def synthesize(self, text, ref_audio_path, ref_text):
260
  if not self.model:
261
  raise ValueError("TTS model not loaded")
262
- with autocast():
263
- return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  # Translation Manager
266
  class TranslateManager:
267
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
268
- self.device_type = torch.device(device_type)
269
  self.tokenizer = None
270
  self.model = None
271
  self.src_lang = src_lang
272
  self.tgt_lang = tgt_lang
273
  self.use_distilled = use_distilled
274
- self.load()
275
 
276
  def load(self):
277
  if not self.tokenizer or not self.model:
@@ -283,17 +364,21 @@ class TranslateManager:
283
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
284
  else:
285
  raise ValueError("Invalid language combination")
286
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
287
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
288
  model_name,
289
  trust_remote_code=True,
290
  torch_dtype=torch.float16,
291
  attn_implementation="flash_attention_2"
292
- ).to(self.device_type)
 
293
  self.model = torch.compile(self.model, mode="reduce-overhead")
294
  logger.info(f"Translation model {model_name} loaded")
295
 
296
- # Model Manager
297
  class ModelManager:
298
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
299
  self.models = {}
@@ -304,14 +389,18 @@ class ModelManager:
304
  def load_model(self, src_lang, tgt_lang, key):
305
  logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
306
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
 
307
  self.models[key] = translate_manager
308
  logger.info(f"Loaded translation model for {key}")
309
 
310
  def get_model(self, src_lang, tgt_lang):
311
  key = self._get_model_key(src_lang, tgt_lang)
312
- if key not in self.models and self.is_lazy_loading:
313
- self.load_model(src_lang, tgt_lang, key)
314
- return self.models.get(key) or (self.load_model(src_lang, tgt_lang, key) or self.models[key])
 
 
 
315
 
316
  def _get_model_key(self, src_lang, tgt_lang):
317
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
@@ -324,30 +413,21 @@ class ModelManager:
324
 
325
  # ASR Manager
326
  class ASRModelManager:
327
- def __init__(self, device_type=device):
328
- self.device_type = torch.device(device_type)
329
  self.model = None
330
  self.model_language = {"kannada": "kn"}
331
- self.load()
332
 
333
  def load(self):
334
  if not self.model:
335
- logger.info(f"Loading ASR model on {self.device_type}...")
336
  self.model = AutoModel.from_pretrained(
337
  "ai4bharat/indic-conformer-600m-multilingual",
338
  trust_remote_code=True
339
- ).to(self.device_type)
 
340
  logger.info("ASR model loaded")
341
 
342
- def unload(self):
343
- if self.model:
344
- del self.model
345
- if self.device_type.type == "cuda":
346
- torch.cuda.empty_cache()
347
- logger.info(f"ASR GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
348
- self.model = None
349
- logger.info("ASR model unloaded")
350
-
351
  # Global Managers
352
  llm_manager = LLMManager(settings.llm_model_name)
353
  model_manager = ModelManager()
@@ -355,31 +435,7 @@ asr_manager = ASRModelManager()
355
  tts_manager = TTSManager()
356
  ip = IndicProcessor(inference=True)
357
 
358
- # TTS Constants
359
- EXAMPLES = [
360
- {
361
- "audio_name": "KAN_F (Happy)",
362
- "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
363
- "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ।",
364
- "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
365
- },
366
- ]
367
-
368
  # Pydantic Models
369
- class SynthesizeRequest(BaseModel):
370
- text: str
371
- ref_audio_name: str
372
- ref_text: str = None
373
-
374
- class KannadaSynthesizeRequest(BaseModel):
375
- text: str
376
-
377
- @field_validator("text")
378
- def text_must_be_valid(cls, v):
379
- if len(v) > 500:
380
- raise ValueError("Text cannot exceed 500 characters")
381
- return v.strip()
382
-
383
  class ChatRequest(BaseModel):
384
  prompt: str
385
  src_lang: str = "kan_Knda"
@@ -397,6 +453,7 @@ class ChatRequest(BaseModel):
397
  raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
398
  return v
399
 
 
400
  class ChatResponse(BaseModel):
401
  response: str
402
 
@@ -411,149 +468,71 @@ class TranscriptionResponse(BaseModel):
411
  class TranslationResponse(BaseModel):
412
  translations: List[str]
413
 
414
- # TTS Functions
415
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
416
- def load_audio_from_url(url: str):
417
- response = requests.get(url)
418
- if response.status_code == 200:
419
- audio_data, sample_rate = sf.read(io.BytesIO(response.content))
420
- return sample_rate, audio_data
421
- raise HTTPException(status_code=500, detail="Failed to load reference audio from URL after retries")
422
-
423
- async def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str) -> io.BytesIO:
424
- async with request_queue:
425
- ref_audio_url = next((ex["audio_url"] for ex in EXAMPLES if ex["audio_name"] == ref_audio_name), None)
426
- if not ref_audio_url:
427
- raise HTTPException(status_code=400, detail="Invalid reference audio name.")
428
- if not text.strip() or not ref_text.strip():
429
- raise HTTPException(status_code=400, detail="Text or reference text cannot be empty.")
430
-
431
- logger.info(f"Synthesizing speech for text: {text[:50]}... with ref_audio: {ref_audio_name}")
432
- loop = asyncio.get_running_loop()
433
- sample_rate, audio_data = await loop.run_in_executor(None, load_audio_from_url, ref_audio_url)
434
-
435
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio:
436
- await loop.run_in_executor(None, sf.write, temp_audio.name, audio_data, sample_rate, "WAV")
437
- temp_audio.flush()
438
- audio = tts_manager.synthesize(text, temp_audio.name, ref_text)
439
-
440
- buffer = io.BytesIO()
441
- await loop.run_in_executor(None, sf.write, buffer, audio.astype(np.float32) / 32768.0 if audio.dtype == np.int16 else audio, 24000, "WAV")
442
- buffer.seek(0)
443
- logger.info("Speech synthesis completed")
444
- return buffer
445
-
446
- # Supported Languages
447
- SUPPORTED_LANGUAGES = {
448
- "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
449
- "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
450
- "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
451
- "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
452
- "kan_Knda", "ory_Orya",
453
- "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
454
- "por_Latn", "rus_Cyrl", "pol_Latn"
455
- }
456
-
457
  # Dependency
458
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
459
  return model_manager.get_model(src_lang, tgt_lang)
460
 
461
- # Translation Function
462
- async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
463
- try:
464
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
465
- except ValueError as e:
466
- logger.info(f"Model not preloaded: {str(e)}, loading now...")
467
- key = model_manager._get_model_key(src_lang, tgt_lang)
468
- model_manager.load_model(src_lang, tgt_lang, key)
469
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
470
-
471
- batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
472
- inputs = translate_manager.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(translate_manager.device_type)
473
- with torch.no_grad(), autocast():
474
- generated_tokens = translate_manager.model.generate(**inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1)
475
- with translate_manager.tokenizer.as_target_tokenizer():
476
- generated_tokens = translate_manager.tokenizer.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
477
- return ip.postprocess_batch(generated_tokens, lang=tgt_lang)
478
-
479
  # Lifespan Event Handler
480
  translation_configs = []
481
 
482
  @asynccontextmanager
483
  async def lifespan(app: FastAPI):
484
  def load_all_models():
485
- logger.info("Loading LLM model...")
486
- llm_manager.load()
487
- logger.info("Loading TTS model...")
488
- tts_manager.load()
489
- logger.info("Loading ASR model...")
490
- asr_manager.load()
491
- translation_tasks = [
492
- ('eng_Latn', 'kan_Knda', 'eng_indic'),
493
- ('kan_Knda', 'eng_Latn', 'indic_eng'),
494
- ('kan_Knda', 'hin_Deva', 'indic_indic'),
495
- ]
496
- for config in translation_configs:
497
- src_lang = config["src_lang"]
498
- tgt_lang = config["tgt_lang"]
499
- key = model_manager._get_model_key(src_lang, tgt_lang)
500
- translation_tasks.append((src_lang, tgt_lang, key))
501
- for src_lang, tgt_lang, key in translation_tasks:
502
- logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
503
- model_manager.load_model(src_lang, tgt_lang, key)
504
- logger.info("All models loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
- logger.info("Starting server with preloaded models...")
507
  load_all_models()
508
- batch_task = asyncio.create_task(batch_worker())
509
  yield
510
- batch_task.cancel()
511
  llm_manager.unload()
512
- tts_manager.unload()
513
- asr_manager.unload()
514
- for model in model_manager.models.values():
515
- model.unload()
516
- logger.info("Server shutdown complete; all models unloaded")
517
-
518
- # Batch Worker
519
- async def batch_worker():
520
- while True:
521
- batch = []
522
- last_request_time = time()
523
- try:
524
- while len(batch) < 4:
525
- try:
526
- request = await asyncio.wait_for(request_queue.get(), timeout=1.0)
527
- batch.append(request)
528
- current_time = time()
529
- if current_time - last_request_time > 1.0 and batch:
530
- break
531
- last_request_time = current_time
532
- except asyncio.TimeoutError:
533
- if batch:
534
- break
535
- continue
536
- if batch:
537
- start_time = time()
538
- responses = await llm_manager.batch_generate(batch)
539
- duration = time() - start_time
540
- logger.info(f"Batch of {len(batch)} requests processed in {duration:.3f} seconds")
541
- for request, response in zip(batch, responses):
542
- request["future"].set_result(response)
543
- except Exception as e:
544
- logger.error(f"Batch worker error: {str(e)}")
545
- for request in batch:
546
- request["future"].set_exception(e)
547
 
548
  # FastAPI App
549
  app = FastAPI(
550
- title="Optimized Dhwani API",
551
- description="AI Chat API supporting Indian languages with performance enhancements",
552
  version="1.0.0",
553
  redirect_slashes=False,
554
  lifespan=lifespan
555
  )
556
 
 
557
  app.add_middleware(
558
  CORSMiddleware,
559
  allow_origins=["*"],
@@ -562,11 +541,13 @@ app.add_middleware(
562
  allow_headers=["*"],
563
  )
564
 
 
565
  @app.middleware("http")
566
  async def add_request_timing(request: Request, call_next):
567
  start_time = time()
568
  response = await call_next(request)
569
- duration = time() - start_time
 
570
  logger.info(f"Request to {request.url.path} took {duration:.3f} seconds")
571
  response.headers["X-Response-Time"] = f"{duration:.3f}"
572
  return response
@@ -574,7 +555,7 @@ async def add_request_timing(request: Request, call_next):
574
  limiter = Limiter(key_func=get_remote_address)
575
  app.state.limiter = limiter
576
 
577
- # Endpoints
578
  @app.post("/audio/speech", response_class=StreamingResponse)
579
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
580
  if not tts_manager.model:
@@ -582,7 +563,14 @@ async def synthesize_kannada(request: KannadaSynthesizeRequest):
582
  kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
583
  if not request.text.strip():
584
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
585
- audio_buffer = await synthesize_speech(tts_manager, request.text, "KAN_F (Happy)", kannada_example["ref_text"])
 
 
 
 
 
 
 
586
  return StreamingResponse(
587
  audio_buffer,
588
  media_type="audio/wav",
@@ -591,69 +579,61 @@ async def synthesize_kannada(request: KannadaSynthesizeRequest):
591
 
592
  @app.post("/translate", response_model=TranslationResponse)
593
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
594
- if not request.sentences:
 
 
 
 
595
  raise HTTPException(status_code=400, detail="Input sentences are required")
596
- batch = ip.preprocess_batch(request.sentences, src_lang=request.src_lang, tgt_lang=request.tgt_lang)
597
- inputs = translate_manager.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(translate_manager.device_type)
598
- with torch.no_grad(), autocast():
599
- generated_tokens = translate_manager.model.generate(**inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  with translate_manager.tokenizer.as_target_tokenizer():
601
- generated_tokens = translate_manager.tokenizer.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
602
- translations = ip.postprocess_batch(generated_tokens, lang=request.tgt_lang)
 
 
 
 
 
603
  return TranslationResponse(translations=translations)
604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  @app.get("/v1/health")
606
  async def health_check():
607
- memory_usage = torch.cuda.memory_allocated() / (24 * 1024**3) if cuda_available else 0
608
- if memory_usage > 0.9:
609
- logger.warning("GPU memory usage exceeds 90%; consider unloading models")
610
- llm_status = "unhealthy"
611
- llm_latency = None
612
- if llm_manager.is_loaded:
613
- start = time()
614
- try:
615
- llm_test = await llm_manager.generate("What is the capital of Karnataka?", max_tokens=10)
616
- llm_latency = time() - start
617
- llm_status = "healthy" if llm_test else "unhealthy"
618
- except Exception as e:
619
- logger.error(f"LLM health check failed: {str(e)}")
620
- tts_status = "unhealthy"
621
- tts_latency = None
622
- if tts_manager.model:
623
- start = time()
624
- try:
625
- audio_buffer = await synthesize_speech(tts_manager, "Test", "KAN_F (Happy)", EXAMPLES[0]["ref_text"])
626
- tts_latency = time() - start
627
- tts_status = "healthy" if audio_buffer else "unhealthy"
628
- except Exception as e:
629
- logger.error(f"TTS health check failed: {str(e)}")
630
- asr_status = "unhealthy"
631
- asr_latency = None
632
- if asr_manager.model:
633
- start = time()
634
- try:
635
- dummy_audio = np.zeros(16000, dtype=np.float32)
636
- wav = torch.tensor(dummy_audio).unsqueeze(0).to(device)
637
- with autocast(), torch.no_grad():
638
- asr_test = asr_manager.model(wav, asr_manager.model_language["kannada"], "rnnt")
639
- asr_latency = time() - start
640
- asr_status = "healthy" if asr_test else "unhealthy"
641
- except Exception as e:
642
- logger.error(f"ASR health check failed: {str(e)}")
643
- status = {
644
- "status": "healthy" if llm_status == "healthy" and tts_status == "healthy" and asr_status == "healthy" else "degraded",
645
- "model": settings.llm_model_name,
646
- "llm_status": llm_status,
647
- "llm_latency": f"{llm_latency:.3f}s" if llm_latency else "N/A",
648
- "tts_status": tts_status,
649
- "tts_latency": f"{tts_latency:.3f}s" if tts_latency else "N/A",
650
- "asr_status": asr_status,
651
- "asr_latency": f"{asr_latency:.3f}s" if asr_latency else "N/A",
652
- "translation_models": list(model_manager.models.keys()),
653
- "gpu_memory_usage": f"{memory_usage:.2%}"
654
- }
655
- logger.info("Health check completed")
656
- return status
657
 
658
  @app.get("/")
659
  async def home():
@@ -664,10 +644,6 @@ async def unload_all_models():
664
  try:
665
  logger.info("Starting to unload all models...")
666
  llm_manager.unload()
667
- tts_manager.unload()
668
- asr_manager.unload()
669
- for model in model_manager.models.values():
670
- model.unload()
671
  logger.info("All models unloaded successfully")
672
  return {"status": "success", "message": "All models unloaded"}
673
  except Exception as e:
@@ -679,15 +655,6 @@ async def load_all_models():
679
  try:
680
  logger.info("Starting to load all models...")
681
  llm_manager.load()
682
- tts_manager.load()
683
- asr_manager.load()
684
- for src_lang, tgt_lang, key in [
685
- ('eng_Latn', 'kan_Knda', 'eng_indic'),
686
- ('kan_Knda', 'eng_Latn', 'indic_eng'),
687
- ('kan_Knda', 'hin_Deva', 'indic_indic'),
688
- ]:
689
- if key not in model_manager.models:
690
- model_manager.load_model(src_lang, tgt_lang, key)
691
  logger.info("All models loaded successfully")
692
  return {"status": "success", "message": "All models loaded"}
693
  except Exception as e:
@@ -698,7 +665,11 @@ async def load_all_models():
698
  async def translate_endpoint(request: TranslationRequest):
699
  logger.info(f"Received translation request: {request.dict()}")
700
  try:
701
- translations = await perform_internal_translation(request.sentences, request.src_lang, request.tgt_lang)
 
 
 
 
702
  logger.info(f"Translation successful: {translations}")
703
  return TranslationResponse(translations=translations)
704
  except Exception as e:
@@ -708,32 +679,44 @@ async def translate_endpoint(request: TranslationRequest):
708
  @app.post("/v1/chat", response_model=ChatResponse)
709
  @limiter.limit(settings.chat_rate_limit)
710
  async def chat(request: Request, chat_request: ChatRequest):
711
- async with request_queue:
712
- if not chat_request.prompt:
713
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
714
- logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
715
- EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
716
- try:
717
- if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
718
- translated_prompt = await perform_internal_translation([chat_request.prompt], chat_request.src_lang, "eng_Latn")
719
- prompt_to_process = translated_prompt[0]
720
- logger.info(f"Translated prompt to English: {prompt_to_process}")
721
- else:
722
- prompt_to_process = chat_request.prompt
723
- logger.info("Prompt in English or European language, no translation needed")
724
- response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
725
- logger.info(f"Generated English response: {response}")
726
- if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
727
- translated_response = await perform_internal_translation([response], "eng_Latn", chat_request.tgt_lang)
728
- final_response = translated_response[0]
729
- logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
730
- else:
731
- final_response = response
732
- logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
733
- return ChatResponse(response=final_response)
734
- except Exception as e:
735
- logger.error(f"Error processing request: {str(e)}")
736
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
737
 
738
  @app.post("/v1/visual_query/")
739
  async def visual_query(
@@ -742,31 +725,42 @@ async def visual_query(
742
  src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
743
  tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
744
  ):
745
- async with request_queue:
746
- try:
747
- image = Image.open(file.file)
748
- if image.size == (0, 0):
749
- raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
750
- if src_lang != "eng_Latn":
751
- translated_query = await perform_internal_translation([query], src_lang, "eng_Latn")
752
- query_to_process = translated_query[0]
753
- logger.info(f"Translated query to English: {query_to_process}")
754
- else:
755
- query_to_process = query
756
- logger.info("Query already in English, no translation needed")
757
- answer = await llm_manager.vision_query(image, query_to_process)
758
- logger.info(f"Generated English answer: {answer}")
759
- if tgt_lang != "eng_Latn":
760
- translated_answer = await perform_internal_translation([answer], "eng_Latn", tgt_lang)
761
- final_answer = translated_answer[0]
762
- logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
763
- else:
764
- final_answer = answer
765
- logger.info("Answer kept in English, no translation needed")
766
- return {"answer": final_answer}
767
- except Exception as e:
768
- logger.error(f"Error processing request: {str(e)}")
769
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
770
 
771
  @app.post("/v1/chat_v2", response_model=ChatResponse)
772
  @limiter.limit(settings.chat_rate_limit)
@@ -777,70 +771,95 @@ async def chat_v2(
777
  src_lang: str = Form("kan_Knda"),
778
  tgt_lang: str = Form("kan_Knda"),
779
  ):
780
- async with request_queue:
781
- if not prompt:
782
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
783
- if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
784
- raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
785
- logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
786
- try:
787
- if image:
788
- image_data = await image.read()
789
- if not image_data:
790
- raise HTTPException(status_code=400, detail="Uploaded image is empty")
791
- img = Image.open(io.BytesIO(image_data))
792
- if src_lang != "eng_Latn":
793
- translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
794
- prompt_to_process = translated_prompt[0]
795
- logger.info(f"Translated prompt to English: {prompt_to_process}")
796
- else:
797
- prompt_to_process = prompt
798
- decoded = await llm_manager.chat_v2(img, prompt_to_process)
799
- logger.info(f"Generated English response: {decoded}")
800
- if tgt_lang != "eng_Latn":
801
- translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
802
- final_response = translated_response[0]
803
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
804
- else:
805
- final_response = decoded
806
  else:
807
- if src_lang != "eng_Latn":
808
- translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
809
- prompt_to_process = translated_prompt[0]
810
- logger.info(f"Translated prompt to English: {prompt_to_process}")
811
- else:
812
- prompt_to_process = prompt
813
- decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
814
- logger.info(f"Generated English response: {decoded}")
815
- if tgt_lang != "eng_Latn":
816
- translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
817
- final_response = translated_response[0]
818
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
819
- else:
820
- final_response = decoded
821
- return ChatResponse(response=final_response)
822
- except Exception as e:
823
- logger.error(f"Error processing request: {str(e)}")
824
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
 
826
  @app.post("/transcribe/", response_model=TranscriptionResponse)
827
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
828
- async with request_queue:
829
- if not asr_manager.model:
830
- raise HTTPException(status_code=503, detail="ASR model not loaded")
831
- try:
832
- wav, sr = torchaudio.load(file.file, backend="cuda" if cuda_available else "cpu")
833
- wav = torch.mean(wav, dim=0, keepdim=True).to(device)
834
- target_sample_rate = 16000
835
- if sr != target_sample_rate:
836
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device)
837
- wav = resampler(wav)
838
- with autocast(), torch.no_grad():
839
- transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
840
- return TranscriptionResponse(text=transcription_rnnt)
841
- except Exception as e:
842
- logger.error(f"Error in transcription: {str(e)}")
843
- raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
844
 
845
  @app.post("/v1/speech_to_speech")
846
  async def speech_to_speech(
@@ -848,20 +867,28 @@ async def speech_to_speech(
848
  file: UploadFile = File(...),
849
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
850
  ) -> StreamingResponse:
851
- async with request_queue:
852
- if not tts_manager.model:
853
- raise HTTPException(status_code=503, detail="TTS model not loaded")
854
- transcription = await transcribe_audio(file, language)
855
- logger.info(f"Transcribed text: {transcription.text}")
856
- chat_request = ChatRequest(prompt=transcription.text, src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"), tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"))
857
- processed_text = await chat(request, chat_request)
858
- logger.info(f"Processed text: {processed_text.response}")
859
- voice_request = KannadaSynthesizeRequest(text=processed_text.response)
860
- audio_response = await synthesize_kannada(voice_request)
861
- return audio_response
862
-
863
- LANGUAGE_TO_SCRIPT = {"kannada": "kan_Knda"}
864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
  if __name__ == "__main__":
866
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
867
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
@@ -888,12 +915,15 @@ if __name__ == "__main__":
888
  settings.speech_rate_limit = global_settings["speech_rate_limit"]
889
 
890
  llm_manager = LLMManager(settings.llm_model_name)
 
891
  if selected_config["components"]["ASR"]:
 
892
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
 
893
  if selected_config["components"]["Translation"]:
894
  translation_configs.extend(selected_config["components"]["Translation"])
895
 
896
  host = args.host if args.host != settings.host else settings.host
897
  port = args.port if args.port != settings.port else settings.port
898
 
899
- uvicorn.run(app, host=host, port=port, workers=2)
 
2
  import io
3
  import os
4
  from time import time
5
+ from typing import List
6
  import tempfile
7
  import uvicorn
8
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
 
22
  import soundfile as sf
23
  import numpy as np
24
  import requests
 
25
  from starlette.responses import StreamingResponse
26
+ from logging_config import logger
27
+ from tts_config import SPEED, ResponseFormat, config as tts_config
28
  import torchaudio
 
 
29
 
30
  # Device setup
31
+ if torch.cuda.is_available():
32
+ device = "cuda:0"
33
+ logger.info("GPU will be used for inference")
34
+ else:
35
+ device = "cpu"
36
+ logger.info("CPU will be used for inference")
37
+ torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
38
 
39
  # Check CUDA availability and version
40
  cuda_available = torch.cuda.is_available()
41
  cuda_version = torch.version.cuda if cuda_available else None
42
+
43
+ if torch.cuda.is_available():
44
  device_idx = torch.cuda.current_device()
45
  capability = torch.cuda.get_device_capability(device_idx)
46
  compute_capability_float = float(f"{capability[0]}.{capability[1]}")
 
77
  bnb_4bit_compute_dtype=torch.bfloat16
78
  )
79
 
80
+ # LLM Manager
 
 
 
 
 
 
81
  class LLMManager:
82
+ def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
83
  self.model_name = model_name
84
  self.device = torch.device(device)
85
+ self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
  self.model = None
87
  self.processor = None
88
  self.is_loaded = False
 
 
89
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
 
91
  def load(self):
92
  if not self.is_loaded:
93
  try:
 
 
 
94
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
95
  self.model_name,
96
  device_map="auto",
97
  quantization_config=quantization_config,
98
  torch_dtype=self.torch_dtype
99
+ )
100
+ self.model.eval()
101
  self.processor = AutoProcessor.from_pretrained(self.model_name)
 
 
 
102
  self.is_loaded = True
103
+ logger.info(f"LLM {self.model_name} loaded on {self.device}")
104
  except Exception as e:
105
  logger.error(f"Failed to load LLM: {str(e)}")
106
+ raise
107
 
108
  def unload(self):
109
  if self.is_loaded:
 
111
  del self.processor
112
  if self.device.type == "cuda":
113
  torch.cuda.empty_cache()
114
+ logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
115
  self.is_loaded = False
116
+ logger.info(f"LLM {self.model_name} unloaded from {self.device}")
 
117
 
118
+ async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
119
  if not self.is_loaded:
 
120
  self.load()
 
 
121
 
122
+ messages_vlm = [
123
+ {
124
+ "role": "system",
125
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
126
+ },
127
+ {
128
+ "role": "user",
129
+ "content": [{"type": "text", "text": prompt}]
130
+ }
 
 
 
 
 
 
 
 
 
 
131
  ]
132
+
133
  try:
134
  inputs_vlm = self.processor.apply_chat_template(
135
+ messages_vlm,
136
  add_generation_prompt=True,
137
  tokenize=True,
138
  return_dict=True,
139
+ return_tensors="pt"
 
140
  ).to(self.device, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  except Exception as e:
142
+ logger.error(f"Error in tokenization: {str(e)}")
143
+ raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
144
+
145
+ input_len = inputs_vlm["input_ids"].shape[-1]
146
+
147
+ with torch.inference_mode():
148
+ generation = self.model.generate(
149
+ **inputs_vlm,
150
+ max_new_tokens=max_tokens,
151
+ do_sample=True,
152
+ temperature=temperature
153
+ )
154
+ generation = generation[0][input_len:]
155
+
156
+ response = self.processor.decode(generation, skip_special_tokens=True)
157
+ logger.info(f"Generated response: {response}")
158
+ return response
159
 
160
  async def vision_query(self, image: Image.Image, query: str) -> str:
161
  if not self.is_loaded:
162
  self.load()
163
+
164
  messages_vlm = [
165
+ {
166
+ "role": "system",
167
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
168
+ },
169
+ {
170
+ "role": "user",
171
+ "content": []
172
+ }
173
  ]
174
+
175
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
176
+ if image and image.size[0] > 0 and image.size[1] > 0:
177
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
178
+ logger.info(f"Received valid image for processing")
179
+ else:
180
+ logger.info("No valid image provided, processing text only")
181
+
182
  try:
183
  inputs_vlm = self.processor.apply_chat_template(
184
  messages_vlm,
 
190
  except Exception as e:
191
  logger.error(f"Error in apply_chat_template: {str(e)}")
192
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
193
+
194
  input_len = inputs_vlm["input_ids"].shape[-1]
195
+
196
  with torch.inference_mode():
197
+ generation = self.model.generate(
198
+ **inputs_vlm,
199
+ max_new_tokens=512,
200
+ do_sample=True,
201
+ temperature=0.7
202
+ )
203
  generation = generation[0][input_len:]
204
+
205
  decoded = self.processor.decode(generation, skip_special_tokens=True)
206
  logger.info(f"Vision query response: {decoded}")
207
  return decoded
 
209
  async def chat_v2(self, image: Image.Image, query: str) -> str:
210
  if not self.is_loaded:
211
  self.load()
212
+
213
  messages_vlm = [
214
+ {
215
+ "role": "system",
216
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
217
+ },
218
+ {
219
+ "role": "user",
220
+ "content": []
221
+ }
222
  ]
223
+
224
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
225
+ if image and image.size[0] > 0 and image.size[1] > 0:
226
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
227
+ logger.info(f"Received valid image for processing")
228
+ else:
229
+ logger.info("No valid image provided, processing text only")
230
+
231
  try:
232
  inputs_vlm = self.processor.apply_chat_template(
233
  messages_vlm,
 
239
  except Exception as e:
240
  logger.error(f"Error in apply_chat_template: {str(e)}")
241
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
242
+
243
  input_len = inputs_vlm["input_ids"].shape[-1]
244
+
245
  with torch.inference_mode():
246
+ generation = self.model.generate(
247
+ **inputs_vlm,
248
+ max_new_tokens=512,
249
+ do_sample=True,
250
+ temperature=0.7
251
+ )
252
  generation = generation[0][input_len:]
253
+
254
  decoded = self.processor.decode(generation, skip_special_tokens=True)
255
  logger.info(f"Chat_v2 response: {decoded}")
256
  return decoded
 
258
  # TTS Manager
259
  class TTSManager:
260
  def __init__(self, device_type=device):
261
+ self.device_type = device_type
262
  self.model = None
263
  self.repo_id = "ai4bharat/IndicF5"
 
264
 
265
  def load(self):
266
  if not self.model:
267
+ logger.info("Loading TTS model IndicF5...")
268
+ self.model = AutoModel.from_pretrained(
269
+ self.repo_id,
270
+ trust_remote_code=True
271
+ )
272
+ self.model = self.model.to(self.device_type)
273
+ logger.info("TTS model IndicF5 loaded")
 
 
 
 
 
274
 
275
  def synthesize(self, text, ref_audio_path, ref_text):
276
  if not self.model:
277
  raise ValueError("TTS model not loaded")
278
+ return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
279
+
280
+ # TTS Constants
281
+ EXAMPLES = [
282
+ {
283
+ "audio_name": "KAN_F (Happy)",
284
+ "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
285
+ "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
286
+ "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್���ುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
287
+ },
288
+ ]
289
+
290
+ # Pydantic models for TTS
291
+ class SynthesizeRequest(BaseModel):
292
+ text: str
293
+ ref_audio_name: str
294
+ ref_text: str = None
295
+
296
+ class KannadaSynthesizeRequest(BaseModel):
297
+ text: str
298
+
299
+ # TTS Functions
300
+ def load_audio_from_url(url: str):
301
+ response = requests.get(url)
302
+ if response.status_code == 200:
303
+ audio_data, sample_rate = sf.read(io.BytesIO(response.content))
304
+ return sample_rate, audio_data
305
+ raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
306
+
307
+ def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
308
+ ref_audio_url = None
309
+ for example in EXAMPLES:
310
+ if example["audio_name"] == ref_audio_name:
311
+ ref_audio_url = example["audio_url"]
312
+ if not ref_text:
313
+ ref_text = example["ref_text"]
314
+ break
315
+
316
+ if not ref_audio_url:
317
+ raise HTTPException(status_code=400, detail="Invalid reference audio name.")
318
+ if not text.strip():
319
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
320
+ if not ref_text or not ref_text.strip():
321
+ raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
322
+
323
+ sample_rate, audio_data = load_audio_from_url(ref_audio_url)
324
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
325
+ sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
326
+ temp_audio.flush()
327
+ audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
328
+
329
+ if audio.dtype == np.int16:
330
+ audio = audio.astype(np.float32) / 32768.0
331
+ buffer = io.BytesIO()
332
+ sf.write(buffer, audio, 24000, format='WAV')
333
+ buffer.seek(0)
334
+ return buffer
335
+
336
+ # Supported languages
337
+ SUPPORTED_LANGUAGES = {
338
+ "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
339
+ "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
340
+ "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
341
+ "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
342
+ "kan_Knda", "ory_Orya",
343
+ "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
344
+ "por_Latn", "rus_Cyrl", "pol_Latn"
345
+ }
346
 
347
  # Translation Manager
348
  class TranslateManager:
349
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
350
+ self.device_type = device_type
351
  self.tokenizer = None
352
  self.model = None
353
  self.src_lang = src_lang
354
  self.tgt_lang = tgt_lang
355
  self.use_distilled = use_distilled
 
356
 
357
  def load(self):
358
  if not self.tokenizer or not self.model:
 
364
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
365
  else:
366
  raise ValueError("Invalid language combination")
367
+
368
+ self.tokenizer = AutoTokenizer.from_pretrained(
369
+ model_name,
370
+ trust_remote_code=True
371
+ )
372
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
373
  model_name,
374
  trust_remote_code=True,
375
  torch_dtype=torch.float16,
376
  attn_implementation="flash_attention_2"
377
+ )
378
+ self.model = self.model.to(self.device_type)
379
  self.model = torch.compile(self.model, mode="reduce-overhead")
380
  logger.info(f"Translation model {model_name} loaded")
381
 
 
382
  class ModelManager:
383
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
384
  self.models = {}
 
389
  def load_model(self, src_lang, tgt_lang, key):
390
  logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
391
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
392
+ translate_manager.load()
393
  self.models[key] = translate_manager
394
  logger.info(f"Loaded translation model for {key}")
395
 
396
  def get_model(self, src_lang, tgt_lang):
397
  key = self._get_model_key(src_lang, tgt_lang)
398
+ if key not in self.models:
399
+ if self.is_lazy_loading:
400
+ self.load_model(src_lang, tgt_lang, key)
401
+ else:
402
+ raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
403
+ return self.models.get(key)
404
 
405
  def _get_model_key(self, src_lang, tgt_lang):
406
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
 
413
 
414
  # ASR Manager
415
  class ASRModelManager:
416
+ def __init__(self, device_type="cuda"):
417
+ self.device_type = device_type
418
  self.model = None
419
  self.model_language = {"kannada": "kn"}
 
420
 
421
  def load(self):
422
  if not self.model:
423
+ logger.info("Loading ASR model...")
424
  self.model = AutoModel.from_pretrained(
425
  "ai4bharat/indic-conformer-600m-multilingual",
426
  trust_remote_code=True
427
+ )
428
+ self.model = self.model.to(self.device_type)
429
  logger.info("ASR model loaded")
430
 
 
 
 
 
 
 
 
 
 
431
  # Global Managers
432
  llm_manager = LLMManager(settings.llm_model_name)
433
  model_manager = ModelManager()
 
435
  tts_manager = TTSManager()
436
  ip = IndicProcessor(inference=True)
437
 
 
 
 
 
 
 
 
 
 
 
438
  # Pydantic Models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  class ChatRequest(BaseModel):
440
  prompt: str
441
  src_lang: str = "kan_Knda"
 
453
  raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
454
  return v
455
 
456
+
457
  class ChatResponse(BaseModel):
458
  response: str
459
 
 
468
  class TranslationResponse(BaseModel):
469
  translations: List[str]
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  # Dependency
472
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
473
  return model_manager.get_model(src_lang, tgt_lang)
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  # Lifespan Event Handler
476
  translation_configs = []
477
 
478
  @asynccontextmanager
479
  async def lifespan(app: FastAPI):
480
  def load_all_models():
481
+ try:
482
+ # Load LLM model
483
+ logger.info("Loading LLM model...")
484
+ llm_manager.load()
485
+ logger.info("LLM model loaded successfully")
486
+
487
+ # Load TTS model
488
+ logger.info("Loading TTS model...")
489
+ tts_manager.load()
490
+ logger.info("TTS model loaded successfully")
491
+
492
+ # Load ASR model
493
+ logger.info("Loading ASR model...")
494
+ asr_manager.load()
495
+ logger.info("ASR model loaded successfully")
496
+
497
+ # Load translation models
498
+ translation_tasks = [
499
+ ('eng_Latn', 'kan_Knda', 'eng_indic'),
500
+ ('kan_Knda', 'eng_Latn', 'indic_eng'),
501
+ ('kan_Knda', 'hin_Deva', 'indic_indic'),
502
+ ]
503
+
504
+ for config in translation_configs:
505
+ src_lang = config["src_lang"]
506
+ tgt_lang = config["tgt_lang"]
507
+ key = model_manager._get_model_key(src_lang, tgt_lang)
508
+ translation_tasks.append((src_lang, tgt_lang, key))
509
+
510
+ for src_lang, tgt_lang, key in translation_tasks:
511
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
512
+ model_manager.load_model(src_lang, tgt_lang, key)
513
+ logger.info(f"Translation model for {key} loaded successfully")
514
+
515
+ logger.info("All models loaded successfully")
516
+ except Exception as e:
517
+ logger.error(f"Error loading models: {str(e)}")
518
+ raise
519
 
520
+ logger.info("Starting sequential model loading...")
521
  load_all_models()
 
522
  yield
 
523
  llm_manager.unload()
524
+ logger.info("Server shutdown complete")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
  # FastAPI App
527
  app = FastAPI(
528
+ title="Dhwani API",
529
+ description="AI Chat API supporting Indian languages",
530
  version="1.0.0",
531
  redirect_slashes=False,
532
  lifespan=lifespan
533
  )
534
 
535
+ # Add CORS Middleware
536
  app.add_middleware(
537
  CORSMiddleware,
538
  allow_origins=["*"],
 
541
  allow_headers=["*"],
542
  )
543
 
544
+ # Add Timing Middleware
545
  @app.middleware("http")
546
  async def add_request_timing(request: Request, call_next):
547
  start_time = time()
548
  response = await call_next(request)
549
+ end_time = time()
550
+ duration = end_time - start_time
551
  logger.info(f"Request to {request.url.path} took {duration:.3f} seconds")
552
  response.headers["X-Response-Time"] = f"{duration:.3f}"
553
  return response
 
555
  limiter = Limiter(key_func=get_remote_address)
556
  app.state.limiter = limiter
557
 
558
+ # API Endpoints
559
  @app.post("/audio/speech", response_class=StreamingResponse)
560
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
561
  if not tts_manager.model:
 
563
  kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
564
  if not request.text.strip():
565
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
566
+
567
+ audio_buffer = synthesize_speech(
568
+ tts_manager,
569
+ text=request.text,
570
+ ref_audio_name="KAN_F (Happy)",
571
+ ref_text=kannada_example["ref_text"]
572
+ )
573
+
574
  return StreamingResponse(
575
  audio_buffer,
576
  media_type="audio/wav",
 
579
 
580
  @app.post("/translate", response_model=TranslationResponse)
581
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
582
+ input_sentences = request.sentences
583
+ src_lang = request.src_lang
584
+ tgt_lang = request.tgt_lang
585
+
586
+ if not input_sentences:
587
  raise HTTPException(status_code=400, detail="Input sentences are required")
588
+
589
+ batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
590
+ inputs = translate_manager.tokenizer(
591
+ batch,
592
+ truncation=True,
593
+ padding="longest",
594
+ return_tensors="pt",
595
+ return_attention_mask=True,
596
+ ).to(translate_manager.device_type)
597
+
598
+ with torch.no_grad():
599
+ generated_tokens = translate_manager.model.generate(
600
+ **inputs,
601
+ use_cache=True,
602
+ min_length=0,
603
+ max_length=256,
604
+ num_beams=5,
605
+ num_return_sequences=1,
606
+ )
607
+
608
  with translate_manager.tokenizer.as_target_tokenizer():
609
+ generated_tokens = translate_manager.tokenizer.batch_decode(
610
+ generated_tokens.detach().cpu().tolist(),
611
+ skip_special_tokens=True,
612
+ clean_up_tokenization_spaces=True,
613
+ )
614
+
615
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
616
  return TranslationResponse(translations=translations)
617
 
618
+ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
619
+ try:
620
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
621
+ except ValueError as e:
622
+ logger.info(f"Model not preloaded: {str(e)}, loading now...")
623
+ key = model_manager._get_model_key(src_lang, tgt_lang)
624
+ model_manager.load_model(src_lang, tgt_lang, key)
625
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
626
+
627
+ if not translate_manager.model:
628
+ translate_manager.load()
629
+
630
+ request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
631
+ response = await translate(request, translate_manager)
632
+ return response.translations
633
+
634
  @app.get("/v1/health")
635
  async def health_check():
636
+ return {"status": "healthy", "model": settings.llm_model_name}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
 
638
  @app.get("/")
639
  async def home():
 
644
  try:
645
  logger.info("Starting to unload all models...")
646
  llm_manager.unload()
 
 
 
 
647
  logger.info("All models unloaded successfully")
648
  return {"status": "success", "message": "All models unloaded"}
649
  except Exception as e:
 
655
  try:
656
  logger.info("Starting to load all models...")
657
  llm_manager.load()
 
 
 
 
 
 
 
 
 
658
  logger.info("All models loaded successfully")
659
  return {"status": "success", "message": "All models loaded"}
660
  except Exception as e:
 
665
  async def translate_endpoint(request: TranslationRequest):
666
  logger.info(f"Received translation request: {request.dict()}")
667
  try:
668
+ translations = await perform_internal_translation(
669
+ sentences=request.sentences,
670
+ src_lang=request.src_lang,
671
+ tgt_lang=request.tgt_lang
672
+ )
673
  logger.info(f"Translation successful: {translations}")
674
  return TranslationResponse(translations=translations)
675
  except Exception as e:
 
679
  @app.post("/v1/chat", response_model=ChatResponse)
680
  @limiter.limit(settings.chat_rate_limit)
681
  async def chat(request: Request, chat_request: ChatRequest):
682
+ if not chat_request.prompt:
683
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
684
+ logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
685
+
686
+ EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
687
+
688
+ try:
689
+ if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
690
+ translated_prompt = await perform_internal_translation(
691
+ sentences=[chat_request.prompt],
692
+ src_lang=chat_request.src_lang,
693
+ tgt_lang="eng_Latn"
694
+ )
695
+ prompt_to_process = translated_prompt[0]
696
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
697
+ else:
698
+ prompt_to_process = chat_request.prompt
699
+ logger.info("Prompt in English or European language, no translation needed")
700
+
701
+ response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
702
+ logger.info(f"Generated response: {response}")
703
+
704
+ if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
705
+ translated_response = await perform_internal_translation(
706
+ sentences=[response],
707
+ src_lang="eng_Latn",
708
+ tgt_lang=chat_request.tgt_lang
709
+ )
710
+ final_response = translated_response[0]
711
+ logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
712
+ else:
713
+ final_response = response
714
+ logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
715
+
716
+ return ChatResponse(response=final_response)
717
+ except Exception as e:
718
+ logger.error(f"Error processing request: {str(e)}")
719
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
720
 
721
  @app.post("/v1/visual_query/")
722
  async def visual_query(
 
725
  src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
726
  tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
727
  ):
728
+ try:
729
+ image = Image.open(file.file)
730
+ if image.size == (0, 0):
731
+ raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
732
+
733
+ if src_lang != "eng_Latn":
734
+ translated_query = await perform_internal_translation(
735
+ sentences=[query],
736
+ src_lang=src_lang,
737
+ tgt_lang="eng_Latn"
738
+ )
739
+ query_to_process = translated_query[0]
740
+ logger.info(f"Translated query to English: {query_to_process}")
741
+ else:
742
+ query_to_process = query
743
+ logger.info("Query already in English, no translation needed")
744
+
745
+ answer = await llm_manager.vision_query(image, query_to_process)
746
+ logger.info(f"Generated English answer: {answer}")
747
+
748
+ if tgt_lang != "eng_Latn":
749
+ translated_answer = await perform_internal_translation(
750
+ sentences=[answer],
751
+ src_lang="eng_Latn",
752
+ tgt_lang=tgt_lang
753
+ )
754
+ final_answer = translated_answer[0]
755
+ logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
756
+ else:
757
+ final_answer = answer
758
+ logger.info("Answer kept in English, no translation needed")
759
+
760
+ return {"answer": final_answer}
761
+ except Exception as e:
762
+ logger.error(f"Error processing request: {str(e)}")
763
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
764
 
765
  @app.post("/v1/chat_v2", response_model=ChatResponse)
766
  @limiter.limit(settings.chat_rate_limit)
 
771
  src_lang: str = Form("kan_Knda"),
772
  tgt_lang: str = Form("kan_Knda"),
773
  ):
774
+ if not prompt:
775
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
776
+ if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
777
+ raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
778
+
779
+ logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
780
+
781
+ try:
782
+ if image:
783
+ image_data = await image.read()
784
+ if not image_data:
785
+ raise HTTPException(status_code=400, detail="Uploaded image is empty")
786
+ img = Image.open(io.BytesIO(image_data))
787
+
788
+ if src_lang != "eng_Latn":
789
+ translated_prompt = await perform_internal_translation(
790
+ sentences=[prompt],
791
+ src_lang=src_lang,
792
+ tgt_lang="eng_Latn"
793
+ )
794
+ prompt_to_process = translated_prompt[0]
795
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
 
 
 
 
796
  else:
797
+ prompt_to_process = prompt
798
+ logger.info("Prompt already in English, no translation needed")
799
+
800
+ decoded = await llm_manager.chat_v2(img, prompt_to_process)
801
+ logger.info(f"Generated English response: {decoded}")
802
+
803
+ if tgt_lang != "eng_Latn":
804
+ translated_response = await perform_internal_translation(
805
+ sentences=[decoded],
806
+ src_lang="eng_Latn",
807
+ tgt_lang=tgt_lang
808
+ )
809
+ final_response = translated_response[0]
810
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
811
+ else:
812
+ final_response = decoded
813
+ logger.info("Response kept in English, no translation needed")
814
+ else:
815
+ if src_lang != "eng_Latn":
816
+ translated_prompt = await perform_internal_translation(
817
+ sentences=[prompt],
818
+ src_lang=src_lang,
819
+ tgt_lang="eng_Latn"
820
+ )
821
+ prompt_to_process = translated_prompt[0]
822
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
823
+ else:
824
+ prompt_to_process = prompt
825
+ logger.info("Prompt already in English, no translation needed")
826
+
827
+ decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
828
+ logger.info(f"Generated English response: {decoded}")
829
+
830
+ if tgt_lang != "eng_Latn":
831
+ translated_response = await perform_internal_translation(
832
+ sentences=[decoded],
833
+ src_lang="eng_Latn",
834
+ tgt_lang=tgt_lang
835
+ )
836
+ final_response = translated_response[0]
837
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
838
+ else:
839
+ final_response = decoded
840
+ logger.info("Response kept in English, no translation needed")
841
+
842
+ return ChatResponse(response=final_response)
843
+ except Exception as e:
844
+ logger.error(f"Error processing request: {str(e)}")
845
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
846
 
847
  @app.post("/transcribe/", response_model=TranscriptionResponse)
848
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
849
+ if not asr_manager.model:
850
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
851
+ try:
852
+ wav, sr = torchaudio.load(file.file)
853
+ wav = torch.mean(wav, dim=0, keepdim=True)
854
+ target_sample_rate = 16000
855
+ if sr != target_sample_rate:
856
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
857
+ wav = resampler(wav)
858
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
859
+ return TranscriptionResponse(text=transcription_rnnt)
860
+ except Exception as e:
861
+ logger.error(f"Error in transcription: {str(e)}")
862
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
 
 
863
 
864
  @app.post("/v1/speech_to_speech")
865
  async def speech_to_speech(
 
867
  file: UploadFile = File(...),
868
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
869
  ) -> StreamingResponse:
870
+ if not tts_manager.model:
871
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
872
+ transcription = await transcribe_audio(file, language)
873
+ logger.info(f"Transcribed text: {transcription.text}")
 
 
 
 
 
 
 
 
 
874
 
875
+ chat_request = ChatRequest(
876
+ prompt=transcription.text,
877
+ src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
878
+ tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
879
+ )
880
+ processed_text = await chat(request, chat_request)
881
+ logger.info(f"Processed text: {processed_text.response}")
882
+
883
+ voice_request = KannadaSynthesizeRequest(text=processed_text.response)
884
+ audio_response = await synthesize_kannada(voice_request)
885
+ return audio_response
886
+
887
+ LANGUAGE_TO_SCRIPT = {
888
+ "kannada": "kan_Knda"
889
+ }
890
+
891
+ # Main Execution
892
  if __name__ == "__main__":
893
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
894
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
 
915
  settings.speech_rate_limit = global_settings["speech_rate_limit"]
916
 
917
  llm_manager = LLMManager(settings.llm_model_name)
918
+
919
  if selected_config["components"]["ASR"]:
920
+ asr_model_name = selected_config["components"]["ASR"]["model"]
921
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
922
+
923
  if selected_config["components"]["Translation"]:
924
  translation_configs.extend(selected_config["components"]["Translation"])
925
 
926
  host = args.host if args.host != settings.host else settings.host
927
  port = args.port if args.port != settings.port else settings.port
928
 
929
+ uvicorn.run(app, host=host, port=port)