sachin commited on
Commit
120ac19
·
1 Parent(s): ff84a71
Files changed (2) hide show
  1. src/server/main-v2.py +8 -3
  2. src/server/main.py +467 -492
src/server/main-v2.py CHANGED
@@ -105,8 +105,11 @@ class LLMManager:
105
  device_map="auto",
106
  quantization_config=quantization_config,
107
  torch_dtype=self.torch_dtype
108
- ).eval()
109
- self.processor = AutoProcessor.from_pretrained(self.model_name)
 
 
 
110
  dummy_input = self.processor("test", return_tensors="pt").to(self.device)
111
  with torch.no_grad():
112
  self.model.generate(**dummy_input, max_new_tokens=10)
@@ -115,6 +118,7 @@ class LLMManager:
115
  except Exception as e:
116
  logger.error(f"Failed to load LLM: {str(e)}")
117
  self.is_loaded = False
 
118
 
119
  def unload(self):
120
  if self.is_loaded:
@@ -896,4 +900,5 @@ if __name__ == "__main__":
896
  host = args.host if args.host != settings.host else settings.host
897
  port = args.port if args.port != settings.port else settings.port
898
 
899
- uvicorn.run(app, host=host, port=port, workers=2)
 
 
105
  device_map="auto",
106
  quantization_config=quantization_config,
107
  torch_dtype=self.torch_dtype
108
+ )
109
+ if self.model is None:
110
+ raise ValueError(f"Failed to load model {self.model_name}: Model object is None")
111
+ self.model.eval()
112
+ self.processor = AutoProcessor.from_pretrained(self.model_name, use_fast=True)
113
  dummy_input = self.processor("test", return_tensors="pt").to(self.device)
114
  with torch.no_grad():
115
  self.model.generate(**dummy_input, max_new_tokens=10)
 
118
  except Exception as e:
119
  logger.error(f"Failed to load LLM: {str(e)}")
120
  self.is_loaded = False
121
+ raise # Re-raise to ensure failure is caught upstream
122
 
123
  def unload(self):
124
  if self.is_loaded:
 
900
  host = args.host if args.host != settings.host else settings.host
901
  port = args.port if args.port != settings.port else settings.port
902
 
903
+ # Run Uvicorn with import string to support workers
904
+ uvicorn.run("main:app", host=host, port=port, workers=2)
src/server/main.py CHANGED
@@ -2,7 +2,7 @@ import argparse
2
  import io
3
  import os
4
  from time import time
5
- from typing import List
6
  import tempfile
7
  import uvicorn
8
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
@@ -22,25 +22,23 @@ from contextlib import asynccontextmanager
22
  import soundfile as sf
23
  import numpy as np
24
  import requests
 
25
  from starlette.responses import StreamingResponse
26
- from logging_config import logger
27
- from tts_config import SPEED, ResponseFormat, config as tts_config
28
  import torchaudio
 
 
29
 
30
  # Device setup
31
- if torch.cuda.is_available():
32
- device = "cuda:0"
33
- logger.info("GPU will be used for inference")
34
- else:
35
- device = "cpu"
36
- logger.info("CPU will be used for inference")
37
- torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
38
 
39
  # Check CUDA availability and version
40
  cuda_available = torch.cuda.is_available()
41
  cuda_version = torch.version.cuda if cuda_available else None
42
-
43
- if torch.cuda.is_available():
44
  device_idx = torch.cuda.current_device()
45
  capability = torch.cuda.get_device_capability(device_idx)
46
  compute_capability_float = float(f"{capability[0]}.{capability[1]}")
@@ -77,33 +75,50 @@ quantization_config = BitsAndBytesConfig(
77
  bnb_4bit_compute_dtype=torch.bfloat16
78
  )
79
 
80
- # LLM Manager
 
 
 
 
 
 
81
  class LLMManager:
82
- def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
83
  self.model_name = model_name
84
  self.device = torch.device(device)
85
- self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
  self.model = None
87
  self.processor = None
88
  self.is_loaded = False
 
 
89
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
 
91
  def load(self):
92
  if not self.is_loaded:
93
  try:
 
 
 
94
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
95
  self.model_name,
96
  device_map="auto",
97
  quantization_config=quantization_config,
98
  torch_dtype=self.torch_dtype
99
  )
 
 
100
  self.model.eval()
101
- self.processor = AutoProcessor.from_pretrained(self.model_name)
 
 
 
102
  self.is_loaded = True
103
- logger.info(f"LLM {self.model_name} loaded on {self.device}")
104
  except Exception as e:
105
  logger.error(f"Failed to load LLM: {str(e)}")
106
- raise
 
107
 
108
  def unload(self):
109
  if self.is_loaded:
@@ -111,74 +126,72 @@ class LLMManager:
111
  del self.processor
112
  if self.device.type == "cuda":
113
  torch.cuda.empty_cache()
114
- logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
115
  self.is_loaded = False
116
- logger.info(f"LLM {self.model_name} unloaded from {self.device}")
 
117
 
118
- async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
119
  if not self.is_loaded:
 
120
  self.load()
 
 
121
 
122
- messages_vlm = [
123
- {
124
- "role": "system",
125
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
126
- },
127
- {
128
- "role": "user",
129
- "content": [{"type": "text", "text": prompt}]
130
- }
131
- ]
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  try:
134
  inputs_vlm = self.processor.apply_chat_template(
135
- messages_vlm,
136
  add_generation_prompt=True,
137
  tokenize=True,
138
  return_dict=True,
139
- return_tensors="pt"
 
140
  ).to(self.device, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  except Exception as e:
142
- logger.error(f"Error in tokenization: {str(e)}")
143
- raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
144
-
145
- input_len = inputs_vlm["input_ids"].shape[-1]
146
-
147
- with torch.inference_mode():
148
- generation = self.model.generate(
149
- **inputs_vlm,
150
- max_new_tokens=max_tokens,
151
- do_sample=True,
152
- temperature=temperature
153
- )
154
- generation = generation[0][input_len:]
155
-
156
- response = self.processor.decode(generation, skip_special_tokens=True)
157
- logger.info(f"Generated response: {response}")
158
- return response
159
 
160
  async def vision_query(self, image: Image.Image, query: str) -> str:
161
  if not self.is_loaded:
162
  self.load()
163
-
164
  messages_vlm = [
165
- {
166
- "role": "system",
167
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
168
- },
169
- {
170
- "role": "user",
171
- "content": []
172
- }
173
  ]
174
-
175
- messages_vlm[1]["content"].append({"type": "text", "text": query})
176
- if image and image.size[0] > 0 and image.size[1] > 0:
177
- messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
178
- logger.info(f"Received valid image for processing")
179
- else:
180
- logger.info("No valid image provided, processing text only")
181
-
182
  try:
183
  inputs_vlm = self.processor.apply_chat_template(
184
  messages_vlm,
@@ -190,18 +203,10 @@ class LLMManager:
190
  except Exception as e:
191
  logger.error(f"Error in apply_chat_template: {str(e)}")
192
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
193
-
194
  input_len = inputs_vlm["input_ids"].shape[-1]
195
-
196
  with torch.inference_mode():
197
- generation = self.model.generate(
198
- **inputs_vlm,
199
- max_new_tokens=512,
200
- do_sample=True,
201
- temperature=0.7
202
- )
203
  generation = generation[0][input_len:]
204
-
205
  decoded = self.processor.decode(generation, skip_special_tokens=True)
206
  logger.info(f"Vision query response: {decoded}")
207
  return decoded
@@ -209,25 +214,10 @@ class LLMManager:
209
  async def chat_v2(self, image: Image.Image, query: str) -> str:
210
  if not self.is_loaded:
211
  self.load()
212
-
213
  messages_vlm = [
214
- {
215
- "role": "system",
216
- "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
217
- },
218
- {
219
- "role": "user",
220
- "content": []
221
- }
222
  ]
223
-
224
- messages_vlm[1]["content"].append({"type": "text", "text": query})
225
- if image and image.size[0] > 0 and image.size[1] > 0:
226
- messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
227
- logger.info(f"Received valid image for processing")
228
- else:
229
- logger.info("No valid image provided, processing text only")
230
-
231
  try:
232
  inputs_vlm = self.processor.apply_chat_template(
233
  messages_vlm,
@@ -239,18 +229,10 @@ class LLMManager:
239
  except Exception as e:
240
  logger.error(f"Error in apply_chat_template: {str(e)}")
241
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
242
-
243
  input_len = inputs_vlm["input_ids"].shape[-1]
244
-
245
  with torch.inference_mode():
246
- generation = self.model.generate(
247
- **inputs_vlm,
248
- max_new_tokens=512,
249
- do_sample=True,
250
- temperature=0.7
251
- )
252
  generation = generation[0][input_len:]
253
-
254
  decoded = self.processor.decode(generation, skip_special_tokens=True)
255
  logger.info(f"Chat_v2 response: {decoded}")
256
  return decoded
@@ -258,101 +240,42 @@ class LLMManager:
258
  # TTS Manager
259
  class TTSManager:
260
  def __init__(self, device_type=device):
261
- self.device_type = device_type
262
  self.model = None
263
  self.repo_id = "ai4bharat/IndicF5"
 
264
 
265
  def load(self):
266
  if not self.model:
267
- logger.info("Loading TTS model IndicF5...")
268
- self.model = AutoModel.from_pretrained(
269
- self.repo_id,
270
- trust_remote_code=True
271
- )
272
- self.model = self.model.to(self.device_type)
273
- logger.info("TTS model IndicF5 loaded")
 
 
 
 
 
274
 
275
  def synthesize(self, text, ref_audio_path, ref_text):
276
  if not self.model:
277
  raise ValueError("TTS model not loaded")
278
- return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
279
-
280
- # TTS Constants
281
- EXAMPLES = [
282
- {
283
- "audio_name": "KAN_F (Happy)",
284
- "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
285
- "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
286
- "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
287
- },
288
- ]
289
-
290
- # Pydantic models for TTS
291
- class SynthesizeRequest(BaseModel):
292
- text: str
293
- ref_audio_name: str
294
- ref_text: str = None
295
-
296
- class KannadaSynthesizeRequest(BaseModel):
297
- text: str
298
-
299
- # TTS Functions
300
- def load_audio_from_url(url: str):
301
- response = requests.get(url)
302
- if response.status_code == 200:
303
- audio_data, sample_rate = sf.read(io.BytesIO(response.content))
304
- return sample_rate, audio_data
305
- raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
306
-
307
- def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
308
- ref_audio_url = None
309
- for example in EXAMPLES:
310
- if example["audio_name"] == ref_audio_name:
311
- ref_audio_url = example["audio_url"]
312
- if not ref_text:
313
- ref_text = example["ref_text"]
314
- break
315
-
316
- if not ref_audio_url:
317
- raise HTTPException(status_code=400, detail="Invalid reference audio name.")
318
- if not text.strip():
319
- raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
320
- if not ref_text or not ref_text.strip():
321
- raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
322
-
323
- sample_rate, audio_data = load_audio_from_url(ref_audio_url)
324
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
325
- sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
326
- temp_audio.flush()
327
- audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
328
-
329
- if audio.dtype == np.int16:
330
- audio = audio.astype(np.float32) / 32768.0
331
- buffer = io.BytesIO()
332
- sf.write(buffer, audio, 24000, format='WAV')
333
- buffer.seek(0)
334
- return buffer
335
-
336
- # Supported languages
337
- SUPPORTED_LANGUAGES = {
338
- "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
339
- "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
340
- "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
341
- "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
342
- "kan_Knda", "ory_Orya",
343
- "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
344
- "por_Latn", "rus_Cyrl", "pol_Latn"
345
- }
346
 
347
  # Translation Manager
348
  class TranslateManager:
349
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
350
- self.device_type = device_type
351
  self.tokenizer = None
352
  self.model = None
353
  self.src_lang = src_lang
354
  self.tgt_lang = tgt_lang
355
  self.use_distilled = use_distilled
 
356
 
357
  def load(self):
358
  if not self.tokenizer or not self.model:
@@ -364,21 +287,17 @@ class TranslateManager:
364
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
365
  else:
366
  raise ValueError("Invalid language combination")
367
-
368
- self.tokenizer = AutoTokenizer.from_pretrained(
369
- model_name,
370
- trust_remote_code=True
371
- )
372
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
373
  model_name,
374
  trust_remote_code=True,
375
  torch_dtype=torch.float16,
376
  attn_implementation="flash_attention_2"
377
- )
378
- self.model = self.model.to(self.device_type)
379
  self.model = torch.compile(self.model, mode="reduce-overhead")
380
  logger.info(f"Translation model {model_name} loaded")
381
 
 
382
  class ModelManager:
383
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
384
  self.models = {}
@@ -389,18 +308,14 @@ class ModelManager:
389
  def load_model(self, src_lang, tgt_lang, key):
390
  logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
391
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
392
- translate_manager.load()
393
  self.models[key] = translate_manager
394
  logger.info(f"Loaded translation model for {key}")
395
 
396
  def get_model(self, src_lang, tgt_lang):
397
  key = self._get_model_key(src_lang, tgt_lang)
398
- if key not in self.models:
399
- if self.is_lazy_loading:
400
- self.load_model(src_lang, tgt_lang, key)
401
- else:
402
- raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
403
- return self.models.get(key)
404
 
405
  def _get_model_key(self, src_lang, tgt_lang):
406
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
@@ -413,21 +328,30 @@ class ModelManager:
413
 
414
  # ASR Manager
415
  class ASRModelManager:
416
- def __init__(self, device_type="cuda"):
417
- self.device_type = device_type
418
  self.model = None
419
  self.model_language = {"kannada": "kn"}
 
420
 
421
  def load(self):
422
  if not self.model:
423
- logger.info("Loading ASR model...")
424
  self.model = AutoModel.from_pretrained(
425
  "ai4bharat/indic-conformer-600m-multilingual",
426
  trust_remote_code=True
427
- )
428
- self.model = self.model.to(self.device_type)
429
  logger.info("ASR model loaded")
430
 
 
 
 
 
 
 
 
 
 
431
  # Global Managers
432
  llm_manager = LLMManager(settings.llm_model_name)
433
  model_manager = ModelManager()
@@ -435,7 +359,31 @@ asr_manager = ASRModelManager()
435
  tts_manager = TTSManager()
436
  ip = IndicProcessor(inference=True)
437
 
 
 
 
 
 
 
 
 
 
 
438
  # Pydantic Models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  class ChatRequest(BaseModel):
440
  prompt: str
441
  src_lang: str = "kan_Knda"
@@ -453,7 +401,6 @@ class ChatRequest(BaseModel):
453
  raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
454
  return v
455
 
456
-
457
  class ChatResponse(BaseModel):
458
  response: str
459
 
@@ -468,71 +415,149 @@ class TranscriptionResponse(BaseModel):
468
  class TranslationResponse(BaseModel):
469
  translations: List[str]
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  # Dependency
472
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
473
  return model_manager.get_model(src_lang, tgt_lang)
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  # Lifespan Event Handler
476
  translation_configs = []
477
 
478
  @asynccontextmanager
479
  async def lifespan(app: FastAPI):
480
  def load_all_models():
481
- try:
482
- # Load LLM model
483
- logger.info("Loading LLM model...")
484
- llm_manager.load()
485
- logger.info("LLM model loaded successfully")
486
-
487
- # Load TTS model
488
- logger.info("Loading TTS model...")
489
- tts_manager.load()
490
- logger.info("TTS model loaded successfully")
491
-
492
- # Load ASR model
493
- logger.info("Loading ASR model...")
494
- asr_manager.load()
495
- logger.info("ASR model loaded successfully")
496
-
497
- # Load translation models
498
- translation_tasks = [
499
- ('eng_Latn', 'kan_Knda', 'eng_indic'),
500
- ('kan_Knda', 'eng_Latn', 'indic_eng'),
501
- ('kan_Knda', 'hin_Deva', 'indic_indic'),
502
- ]
503
-
504
- for config in translation_configs:
505
- src_lang = config["src_lang"]
506
- tgt_lang = config["tgt_lang"]
507
- key = model_manager._get_model_key(src_lang, tgt_lang)
508
- translation_tasks.append((src_lang, tgt_lang, key))
509
-
510
- for src_lang, tgt_lang, key in translation_tasks:
511
- logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
512
- model_manager.load_model(src_lang, tgt_lang, key)
513
- logger.info(f"Translation model for {key} loaded successfully")
514
-
515
- logger.info("All models loaded successfully")
516
- except Exception as e:
517
- logger.error(f"Error loading models: {str(e)}")
518
- raise
519
 
520
- logger.info("Starting sequential model loading...")
521
  load_all_models()
 
522
  yield
 
523
  llm_manager.unload()
524
- logger.info("Server shutdown complete")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
  # FastAPI App
527
  app = FastAPI(
528
- title="Dhwani API",
529
- description="AI Chat API supporting Indian languages",
530
  version="1.0.0",
531
  redirect_slashes=False,
532
  lifespan=lifespan
533
  )
534
 
535
- # Add CORS Middleware
536
  app.add_middleware(
537
  CORSMiddleware,
538
  allow_origins=["*"],
@@ -541,13 +566,11 @@ app.add_middleware(
541
  allow_headers=["*"],
542
  )
543
 
544
- # Add Timing Middleware
545
  @app.middleware("http")
546
  async def add_request_timing(request: Request, call_next):
547
  start_time = time()
548
  response = await call_next(request)
549
- end_time = time()
550
- duration = end_time - start_time
551
  logger.info(f"Request to {request.url.path} took {duration:.3f} seconds")
552
  response.headers["X-Response-Time"] = f"{duration:.3f}"
553
  return response
@@ -555,7 +578,7 @@ async def add_request_timing(request: Request, call_next):
555
  limiter = Limiter(key_func=get_remote_address)
556
  app.state.limiter = limiter
557
 
558
- # API Endpoints
559
  @app.post("/audio/speech", response_class=StreamingResponse)
560
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
561
  if not tts_manager.model:
@@ -563,14 +586,7 @@ async def synthesize_kannada(request: KannadaSynthesizeRequest):
563
  kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
564
  if not request.text.strip():
565
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
566
-
567
- audio_buffer = synthesize_speech(
568
- tts_manager,
569
- text=request.text,
570
- ref_audio_name="KAN_F (Happy)",
571
- ref_text=kannada_example["ref_text"]
572
- )
573
-
574
  return StreamingResponse(
575
  audio_buffer,
576
  media_type="audio/wav",
@@ -579,61 +595,69 @@ async def synthesize_kannada(request: KannadaSynthesizeRequest):
579
 
580
  @app.post("/translate", response_model=TranslationResponse)
581
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
582
- input_sentences = request.sentences
583
- src_lang = request.src_lang
584
- tgt_lang = request.tgt_lang
585
-
586
- if not input_sentences:
587
  raise HTTPException(status_code=400, detail="Input sentences are required")
588
-
589
- batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
590
- inputs = translate_manager.tokenizer(
591
- batch,
592
- truncation=True,
593
- padding="longest",
594
- return_tensors="pt",
595
- return_attention_mask=True,
596
- ).to(translate_manager.device_type)
597
-
598
- with torch.no_grad():
599
- generated_tokens = translate_manager.model.generate(
600
- **inputs,
601
- use_cache=True,
602
- min_length=0,
603
- max_length=256,
604
- num_beams=5,
605
- num_return_sequences=1,
606
- )
607
-
608
  with translate_manager.tokenizer.as_target_tokenizer():
609
- generated_tokens = translate_manager.tokenizer.batch_decode(
610
- generated_tokens.detach().cpu().tolist(),
611
- skip_special_tokens=True,
612
- clean_up_tokenization_spaces=True,
613
- )
614
-
615
- translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
616
  return TranslationResponse(translations=translations)
617
 
618
- async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
619
- try:
620
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
621
- except ValueError as e:
622
- logger.info(f"Model not preloaded: {str(e)}, loading now...")
623
- key = model_manager._get_model_key(src_lang, tgt_lang)
624
- model_manager.load_model(src_lang, tgt_lang, key)
625
- translate_manager = model_manager.get_model(src_lang, tgt_lang)
626
-
627
- if not translate_manager.model:
628
- translate_manager.load()
629
-
630
- request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
631
- response = await translate(request, translate_manager)
632
- return response.translations
633
-
634
  @app.get("/v1/health")
635
  async def health_check():
636
- return {"status": "healthy", "model": settings.llm_model_name}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
 
638
  @app.get("/")
639
  async def home():
@@ -644,6 +668,10 @@ async def unload_all_models():
644
  try:
645
  logger.info("Starting to unload all models...")
646
  llm_manager.unload()
 
 
 
 
647
  logger.info("All models unloaded successfully")
648
  return {"status": "success", "message": "All models unloaded"}
649
  except Exception as e:
@@ -655,6 +683,15 @@ async def load_all_models():
655
  try:
656
  logger.info("Starting to load all models...")
657
  llm_manager.load()
 
 
 
 
 
 
 
 
 
658
  logger.info("All models loaded successfully")
659
  return {"status": "success", "message": "All models loaded"}
660
  except Exception as e:
@@ -665,11 +702,7 @@ async def load_all_models():
665
  async def translate_endpoint(request: TranslationRequest):
666
  logger.info(f"Received translation request: {request.dict()}")
667
  try:
668
- translations = await perform_internal_translation(
669
- sentences=request.sentences,
670
- src_lang=request.src_lang,
671
- tgt_lang=request.tgt_lang
672
- )
673
  logger.info(f"Translation successful: {translations}")
674
  return TranslationResponse(translations=translations)
675
  except Exception as e:
@@ -679,44 +712,32 @@ async def translate_endpoint(request: TranslationRequest):
679
  @app.post("/v1/chat", response_model=ChatResponse)
680
  @limiter.limit(settings.chat_rate_limit)
681
  async def chat(request: Request, chat_request: ChatRequest):
682
- if not chat_request.prompt:
683
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
684
- logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
685
-
686
- EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
687
-
688
- try:
689
- if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
690
- translated_prompt = await perform_internal_translation(
691
- sentences=[chat_request.prompt],
692
- src_lang=chat_request.src_lang,
693
- tgt_lang="eng_Latn"
694
- )
695
- prompt_to_process = translated_prompt[0]
696
- logger.info(f"Translated prompt to English: {prompt_to_process}")
697
- else:
698
- prompt_to_process = chat_request.prompt
699
- logger.info("Prompt in English or European language, no translation needed")
700
-
701
- response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
702
- logger.info(f"Generated response: {response}")
703
-
704
- if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
705
- translated_response = await perform_internal_translation(
706
- sentences=[response],
707
- src_lang="eng_Latn",
708
- tgt_lang=chat_request.tgt_lang
709
- )
710
- final_response = translated_response[0]
711
- logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
712
- else:
713
- final_response = response
714
- logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
715
-
716
- return ChatResponse(response=final_response)
717
- except Exception as e:
718
- logger.error(f"Error processing request: {str(e)}")
719
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
720
 
721
  @app.post("/v1/visual_query/")
722
  async def visual_query(
@@ -725,42 +746,31 @@ async def visual_query(
725
  src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
726
  tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
727
  ):
728
- try:
729
- image = Image.open(file.file)
730
- if image.size == (0, 0):
731
- raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
732
-
733
- if src_lang != "eng_Latn":
734
- translated_query = await perform_internal_translation(
735
- sentences=[query],
736
- src_lang=src_lang,
737
- tgt_lang="eng_Latn"
738
- )
739
- query_to_process = translated_query[0]
740
- logger.info(f"Translated query to English: {query_to_process}")
741
- else:
742
- query_to_process = query
743
- logger.info("Query already in English, no translation needed")
744
-
745
- answer = await llm_manager.vision_query(image, query_to_process)
746
- logger.info(f"Generated English answer: {answer}")
747
-
748
- if tgt_lang != "eng_Latn":
749
- translated_answer = await perform_internal_translation(
750
- sentences=[answer],
751
- src_lang="eng_Latn",
752
- tgt_lang=tgt_lang
753
- )
754
- final_answer = translated_answer[0]
755
- logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
756
- else:
757
- final_answer = answer
758
- logger.info("Answer kept in English, no translation needed")
759
-
760
- return {"answer": final_answer}
761
- except Exception as e:
762
- logger.error(f"Error processing request: {str(e)}")
763
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
764
 
765
  @app.post("/v1/chat_v2", response_model=ChatResponse)
766
  @limiter.limit(settings.chat_rate_limit)
@@ -771,95 +781,70 @@ async def chat_v2(
771
  src_lang: str = Form("kan_Knda"),
772
  tgt_lang: str = Form("kan_Knda"),
773
  ):
774
- if not prompt:
775
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
776
- if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
777
- raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
778
-
779
- logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
780
-
781
- try:
782
- if image:
783
- image_data = await image.read()
784
- if not image_data:
785
- raise HTTPException(status_code=400, detail="Uploaded image is empty")
786
- img = Image.open(io.BytesIO(image_data))
787
-
788
- if src_lang != "eng_Latn":
789
- translated_prompt = await perform_internal_translation(
790
- sentences=[prompt],
791
- src_lang=src_lang,
792
- tgt_lang="eng_Latn"
793
- )
794
- prompt_to_process = translated_prompt[0]
795
- logger.info(f"Translated prompt to English: {prompt_to_process}")
796
- else:
797
- prompt_to_process = prompt
798
- logger.info("Prompt already in English, no translation needed")
799
-
800
- decoded = await llm_manager.chat_v2(img, prompt_to_process)
801
- logger.info(f"Generated English response: {decoded}")
802
-
803
- if tgt_lang != "eng_Latn":
804
- translated_response = await perform_internal_translation(
805
- sentences=[decoded],
806
- src_lang="eng_Latn",
807
- tgt_lang=tgt_lang
808
- )
809
- final_response = translated_response[0]
810
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
811
- else:
812
- final_response = decoded
813
- logger.info("Response kept in English, no translation needed")
814
- else:
815
- if src_lang != "eng_Latn":
816
- translated_prompt = await perform_internal_translation(
817
- sentences=[prompt],
818
- src_lang=src_lang,
819
- tgt_lang="eng_Latn"
820
- )
821
- prompt_to_process = translated_prompt[0]
822
- logger.info(f"Translated prompt to English: {prompt_to_process}")
823
- else:
824
- prompt_to_process = prompt
825
- logger.info("Prompt already in English, no translation needed")
826
-
827
- decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
828
- logger.info(f"Generated English response: {decoded}")
829
-
830
- if tgt_lang != "eng_Latn":
831
- translated_response = await perform_internal_translation(
832
- sentences=[decoded],
833
- src_lang="eng_Latn",
834
- tgt_lang=tgt_lang
835
- )
836
- final_response = translated_response[0]
837
- logger.info(f"Translated response to {tgt_lang}: {final_response}")
838
  else:
839
- final_response = decoded
840
- logger.info("Response kept in English, no translation needed")
841
-
842
- return ChatResponse(response=final_response)
843
- except Exception as e:
844
- logger.error(f"Error processing request: {str(e)}")
845
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
846
 
847
  @app.post("/transcribe/", response_model=TranscriptionResponse)
848
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
849
- if not asr_manager.model:
850
- raise HTTPException(status_code=503, detail="ASR model not loaded")
851
- try:
852
- wav, sr = torchaudio.load(file.file)
853
- wav = torch.mean(wav, dim=0, keepdim=True)
854
- target_sample_rate = 16000
855
- if sr != target_sample_rate:
856
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
857
- wav = resampler(wav)
858
- transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
859
- return TranscriptionResponse(text=transcription_rnnt)
860
- except Exception as e:
861
- logger.error(f"Error in transcription: {str(e)}")
862
- raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
 
 
863
 
864
  @app.post("/v1/speech_to_speech")
865
  async def speech_to_speech(
@@ -867,28 +852,20 @@ async def speech_to_speech(
867
  file: UploadFile = File(...),
868
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
869
  ) -> StreamingResponse:
870
- if not tts_manager.model:
871
- raise HTTPException(status_code=503, detail="TTS model not loaded")
872
- transcription = await transcribe_audio(file, language)
873
- logger.info(f"Transcribed text: {transcription.text}")
 
 
 
 
 
 
 
 
 
874
 
875
- chat_request = ChatRequest(
876
- prompt=transcription.text,
877
- src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
878
- tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
879
- )
880
- processed_text = await chat(request, chat_request)
881
- logger.info(f"Processed text: {processed_text.response}")
882
-
883
- voice_request = KannadaSynthesizeRequest(text=processed_text.response)
884
- audio_response = await synthesize_kannada(voice_request)
885
- return audio_response
886
-
887
- LANGUAGE_TO_SCRIPT = {
888
- "kannada": "kan_Knda"
889
- }
890
-
891
- # Main Execution
892
  if __name__ == "__main__":
893
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
894
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
@@ -915,15 +892,13 @@ if __name__ == "__main__":
915
  settings.speech_rate_limit = global_settings["speech_rate_limit"]
916
 
917
  llm_manager = LLMManager(settings.llm_model_name)
918
-
919
  if selected_config["components"]["ASR"]:
920
- asr_model_name = selected_config["components"]["ASR"]["model"]
921
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
922
-
923
  if selected_config["components"]["Translation"]:
924
  translation_configs.extend(selected_config["components"]["Translation"])
925
 
926
  host = args.host if args.host != settings.host else settings.host
927
  port = args.port if args.port != settings.port else settings.port
928
 
929
- uvicorn.run(app, host=host, port=port)
 
 
2
  import io
3
  import os
4
  from time import time
5
+ from typing import List, Dict
6
  import tempfile
7
  import uvicorn
8
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
 
22
  import soundfile as sf
23
  import numpy as np
24
  import requests
25
+ import logging
26
  from starlette.responses import StreamingResponse
27
+ from logging_config import logger # Assumed external logging config
28
+ from tts_config import SPEED, ResponseFormat, config as tts_config # Assumed external TTS config
29
  import torchaudio
30
+ from tenacity import retry, stop_after_attempt, wait_exponential
31
+ from torch.cuda.amp import autocast
32
 
33
  # Device setup
34
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
35
+ torch_dtype = torch.float16 if device != "cpu" else torch.float32
36
+ logger.info(f"Using device: {device} with dtype: {torch_dtype}")
 
 
 
 
37
 
38
  # Check CUDA availability and version
39
  cuda_available = torch.cuda.is_available()
40
  cuda_version = torch.version.cuda if cuda_available else None
41
+ if cuda_available:
 
42
  device_idx = torch.cuda.current_device()
43
  capability = torch.cuda.get_device_capability(device_idx)
44
  compute_capability_float = float(f"{capability[0]}.{capability[1]}")
 
75
  bnb_4bit_compute_dtype=torch.bfloat16
76
  )
77
 
78
+ # Request queue for concurrency control
79
+ request_queue = asyncio.Queue(maxsize=10)
80
+
81
+ # Logging optimization
82
+ logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
83
+
84
+ # LLM Manager with batching
85
  class LLMManager:
86
+ def __init__(self, model_name: str, device: str = device):
87
  self.model_name = model_name
88
  self.device = torch.device(device)
89
+ self.torch_dtype = torch.float16 if self.device.type != "cpu" else torch.float32
90
  self.model = None
91
  self.processor = None
92
  self.is_loaded = False
93
+ self.token_cache = {}
94
+ self.load()
95
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
96
 
97
  def load(self):
98
  if not self.is_loaded:
99
  try:
100
+ if self.device.type == "cuda":
101
+ torch.set_float32_matmul_precision('high')
102
+ logger.info("Enabled TF32 matrix multiplication for improved GPU performance")
103
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
104
  self.model_name,
105
  device_map="auto",
106
  quantization_config=quantization_config,
107
  torch_dtype=self.torch_dtype
108
  )
109
+ if self.model is None:
110
+ raise ValueError(f"Failed to load model {self.model_name}: Model object is None")
111
  self.model.eval()
112
+ self.processor = AutoProcessor.from_pretrained(self.model_name, use_fast=True)
113
+ dummy_input = self.processor("test", return_tensors="pt").to(self.device)
114
+ with torch.no_grad():
115
+ self.model.generate(**dummy_input, max_new_tokens=10)
116
  self.is_loaded = True
117
+ logger.info(f"LLM {self.model_name} loaded and warmed up on {self.device}")
118
  except Exception as e:
119
  logger.error(f"Failed to load LLM: {str(e)}")
120
+ self.is_loaded = False
121
+ raise # Re-raise to ensure failure is caught upstream
122
 
123
  def unload(self):
124
  if self.is_loaded:
 
126
  del self.processor
127
  if self.device.type == "cuda":
128
  torch.cuda.empty_cache()
129
+ logger.info(f"GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
130
  self.is_loaded = False
131
+ self.token_cache.clear()
132
+ logger.info(f"LLM {self.model_name} unloaded")
133
 
134
+ async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
135
  if not self.is_loaded:
136
+ logger.warning("LLM not loaded; attempting reload")
137
  self.load()
138
+ if not self.is_loaded:
139
+ raise HTTPException(status_code=503, detail="LLM model unavailable")
140
 
141
+ cache_key = f"{prompt}:{max_tokens}:{temperature}"
142
+ if cache_key in self.token_cache:
143
+ logger.info("Using cached response")
144
+ return self.token_cache[cache_key]
 
 
 
 
 
 
145
 
146
+ future = asyncio.Future()
147
+ await request_queue.put({"prompt": prompt, "max_tokens": max_tokens, "temperature": temperature, "future": future})
148
+ response = await future
149
+ self.token_cache[cache_key] = response
150
+ logger.info(f"Generated response: {response}")
151
+ return response
152
+
153
+ async def batch_generate(self, prompts: List[Dict]) -> List[str]:
154
+ messages_batch = [
155
+ [
156
+ {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]},
157
+ {"role": "user", "content": [{"type": "text", "text": prompt["prompt"]}]}
158
+ ]
159
+ for prompt in prompts
160
+ ]
161
  try:
162
  inputs_vlm = self.processor.apply_chat_template(
163
+ messages_batch,
164
  add_generation_prompt=True,
165
  tokenize=True,
166
  return_dict=True,
167
+ return_tensors="pt",
168
+ padding=True
169
  ).to(self.device, dtype=torch.bfloat16)
170
+ with autocast(), torch.no_grad():
171
+ outputs = self.model.generate(
172
+ **inputs_vlm,
173
+ max_new_tokens=max(prompt["max_tokens"] for prompt in prompts),
174
+ do_sample=True,
175
+ top_p=0.9,
176
+ temperature=max(prompt["temperature"] for prompt in prompts)
177
+ )
178
+ responses = [
179
+ self.processor.decode(output[input_len:], skip_special_tokens=True)
180
+ for output, input_len in zip(outputs, inputs_vlm["input_ids"].shape[1])
181
+ ]
182
+ logger.info(f"Batch generated {len(responses)} responses")
183
+ return responses
184
  except Exception as e:
185
+ logger.error(f"Error in batch generation: {str(e)}")
186
+ raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  async def vision_query(self, image: Image.Image, query: str) -> str:
189
  if not self.is_loaded:
190
  self.load()
 
191
  messages_vlm = [
192
+ {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]},
193
+ {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
 
 
 
 
 
 
194
  ]
 
 
 
 
 
 
 
 
195
  try:
196
  inputs_vlm = self.processor.apply_chat_template(
197
  messages_vlm,
 
203
  except Exception as e:
204
  logger.error(f"Error in apply_chat_template: {str(e)}")
205
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
206
  input_len = inputs_vlm["input_ids"].shape[-1]
 
207
  with torch.inference_mode():
208
+ generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
 
 
 
 
 
209
  generation = generation[0][input_len:]
 
210
  decoded = self.processor.decode(generation, skip_special_tokens=True)
211
  logger.info(f"Vision query response: {decoded}")
212
  return decoded
 
214
  async def chat_v2(self, image: Image.Image, query: str) -> str:
215
  if not self.is_loaded:
216
  self.load()
 
217
  messages_vlm = [
218
+ {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]},
219
+ {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
 
 
 
 
 
 
220
  ]
 
 
 
 
 
 
 
 
221
  try:
222
  inputs_vlm = self.processor.apply_chat_template(
223
  messages_vlm,
 
229
  except Exception as e:
230
  logger.error(f"Error in apply_chat_template: {str(e)}")
231
  raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
 
232
  input_len = inputs_vlm["input_ids"].shape[-1]
 
233
  with torch.inference_mode():
234
+ generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
 
 
 
 
 
235
  generation = generation[0][input_len:]
 
236
  decoded = self.processor.decode(generation, skip_special_tokens=True)
237
  logger.info(f"Chat_v2 response: {decoded}")
238
  return decoded
 
240
  # TTS Manager
241
  class TTSManager:
242
  def __init__(self, device_type=device):
243
+ self.device_type = torch.device(device_type)
244
  self.model = None
245
  self.repo_id = "ai4bharat/IndicF5"
246
+ self.load()
247
 
248
  def load(self):
249
  if not self.model:
250
+ logger.info(f"Loading TTS model {self.repo_id} on {self.device_type}...")
251
+ self.model = AutoModel.from_pretrained(self.repo_id, trust_remote_code=True).to(self.device_type)
252
+ logger.info("TTS model loaded")
253
+
254
+ def unload(self):
255
+ if self.model:
256
+ del self.model
257
+ if self.device_type.type == "cuda":
258
+ torch.cuda.empty_cache()
259
+ logger.info(f"TTS GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
260
+ self.model = None
261
+ logger.info("TTS model unloaded")
262
 
263
  def synthesize(self, text, ref_audio_path, ref_text):
264
  if not self.model:
265
  raise ValueError("TTS model not loaded")
266
+ with autocast():
267
+ return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  # Translation Manager
270
  class TranslateManager:
271
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
272
+ self.device_type = torch.device(device_type)
273
  self.tokenizer = None
274
  self.model = None
275
  self.src_lang = src_lang
276
  self.tgt_lang = tgt_lang
277
  self.use_distilled = use_distilled
278
+ self.load()
279
 
280
  def load(self):
281
  if not self.tokenizer or not self.model:
 
287
  model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
288
  else:
289
  raise ValueError("Invalid language combination")
290
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
291
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
292
  model_name,
293
  trust_remote_code=True,
294
  torch_dtype=torch.float16,
295
  attn_implementation="flash_attention_2"
296
+ ).to(self.device_type)
 
297
  self.model = torch.compile(self.model, mode="reduce-overhead")
298
  logger.info(f"Translation model {model_name} loaded")
299
 
300
+ # Model Manager
301
  class ModelManager:
302
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
303
  self.models = {}
 
308
  def load_model(self, src_lang, tgt_lang, key):
309
  logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
310
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
 
311
  self.models[key] = translate_manager
312
  logger.info(f"Loaded translation model for {key}")
313
 
314
  def get_model(self, src_lang, tgt_lang):
315
  key = self._get_model_key(src_lang, tgt_lang)
316
+ if key not in self.models and self.is_lazy_loading:
317
+ self.load_model(src_lang, tgt_lang, key)
318
+ return self.models.get(key) or (self.load_model(src_lang, tgt_lang, key) or self.models[key])
 
 
 
319
 
320
  def _get_model_key(self, src_lang, tgt_lang):
321
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
 
328
 
329
  # ASR Manager
330
  class ASRModelManager:
331
+ def __init__(self, device_type=device):
332
+ self.device_type = torch.device(device_type)
333
  self.model = None
334
  self.model_language = {"kannada": "kn"}
335
+ self.load()
336
 
337
  def load(self):
338
  if not self.model:
339
+ logger.info(f"Loading ASR model on {self.device_type}...")
340
  self.model = AutoModel.from_pretrained(
341
  "ai4bharat/indic-conformer-600m-multilingual",
342
  trust_remote_code=True
343
+ ).to(self.device_type)
 
344
  logger.info("ASR model loaded")
345
 
346
+ def unload(self):
347
+ if self.model:
348
+ del self.model
349
+ if self.device_type.type == "cuda":
350
+ torch.cuda.empty_cache()
351
+ logger.info(f"ASR GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
352
+ self.model = None
353
+ logger.info("ASR model unloaded")
354
+
355
  # Global Managers
356
  llm_manager = LLMManager(settings.llm_model_name)
357
  model_manager = ModelManager()
 
359
  tts_manager = TTSManager()
360
  ip = IndicProcessor(inference=True)
361
 
362
+ # TTS Constants
363
+ EXAMPLES = [
364
+ {
365
+ "audio_name": "KAN_F (Happy)",
366
+ "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
367
+ "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ।",
368
+ "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
369
+ },
370
+ ]
371
+
372
  # Pydantic Models
373
+ class SynthesizeRequest(BaseModel):
374
+ text: str
375
+ ref_audio_name: str
376
+ ref_text: str = None
377
+
378
+ class KannadaSynthesizeRequest(BaseModel):
379
+ text: str
380
+
381
+ @field_validator("text")
382
+ def text_must_be_valid(cls, v):
383
+ if len(v) > 500:
384
+ raise ValueError("Text cannot exceed 500 characters")
385
+ return v.strip()
386
+
387
  class ChatRequest(BaseModel):
388
  prompt: str
389
  src_lang: str = "kan_Knda"
 
401
  raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
402
  return v
403
 
 
404
  class ChatResponse(BaseModel):
405
  response: str
406
 
 
415
  class TranslationResponse(BaseModel):
416
  translations: List[str]
417
 
418
+ # TTS Functions
419
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
420
+ def load_audio_from_url(url: str):
421
+ response = requests.get(url)
422
+ if response.status_code == 200:
423
+ audio_data, sample_rate = sf.read(io.BytesIO(response.content))
424
+ return sample_rate, audio_data
425
+ raise HTTPException(status_code=500, detail="Failed to load reference audio from URL after retries")
426
+
427
+ async def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str) -> io.BytesIO:
428
+ async with request_queue:
429
+ ref_audio_url = next((ex["audio_url"] for ex in EXAMPLES if ex["audio_name"] == ref_audio_name), None)
430
+ if not ref_audio_url:
431
+ raise HTTPException(status_code=400, detail="Invalid reference audio name.")
432
+ if not text.strip() or not ref_text.strip():
433
+ raise HTTPException(status_code=400, detail="Text or reference text cannot be empty.")
434
+
435
+ logger.info(f"Synthesizing speech for text: {text[:50]}... with ref_audio: {ref_audio_name}")
436
+ loop = asyncio.get_running_loop()
437
+ sample_rate, audio_data = await loop.run_in_executor(None, load_audio_from_url, ref_audio_url)
438
+
439
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio:
440
+ await loop.run_in_executor(None, sf.write, temp_audio.name, audio_data, sample_rate, "WAV")
441
+ temp_audio.flush()
442
+ audio = tts_manager.synthesize(text, temp_audio.name, ref_text)
443
+
444
+ buffer = io.BytesIO()
445
+ await loop.run_in_executor(None, sf.write, buffer, audio.astype(np.float32) / 32768.0 if audio.dtype == np.int16 else audio, 24000, "WAV")
446
+ buffer.seek(0)
447
+ logger.info("Speech synthesis completed")
448
+ return buffer
449
+
450
+ # Supported Languages
451
+ SUPPORTED_LANGUAGES = {
452
+ "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
453
+ "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
454
+ "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
455
+ "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
456
+ "kan_Knda", "ory_Orya",
457
+ "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
458
+ "por_Latn", "rus_Cyrl", "pol_Latn"
459
+ }
460
+
461
  # Dependency
462
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
463
  return model_manager.get_model(src_lang, tgt_lang)
464
 
465
+ # Translation Function
466
+ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
467
+ try:
468
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
469
+ except ValueError as e:
470
+ logger.info(f"Model not preloaded: {str(e)}, loading now...")
471
+ key = model_manager._get_model_key(src_lang, tgt_lang)
472
+ model_manager.load_model(src_lang, tgt_lang, key)
473
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
474
+
475
+ batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
476
+ inputs = translate_manager.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(translate_manager.device_type)
477
+ with torch.no_grad(), autocast():
478
+ generated_tokens = translate_manager.model.generate(**inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1)
479
+ with translate_manager.tokenizer.as_target_tokenizer():
480
+ generated_tokens = translate_manager.tokenizer.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
481
+ return ip.postprocess_batch(generated_tokens, lang=tgt_lang)
482
+
483
  # Lifespan Event Handler
484
  translation_configs = []
485
 
486
  @asynccontextmanager
487
  async def lifespan(app: FastAPI):
488
  def load_all_models():
489
+ logger.info("Loading LLM model...")
490
+ llm_manager.load()
491
+ logger.info("Loading TTS model...")
492
+ tts_manager.load()
493
+ logger.info("Loading ASR model...")
494
+ asr_manager.load()
495
+ translation_tasks = [
496
+ ('eng_Latn', 'kan_Knda', 'eng_indic'),
497
+ ('kan_Knda', 'eng_Latn', 'indic_eng'),
498
+ ('kan_Knda', 'hin_Deva', 'indic_indic'),
499
+ ]
500
+ for config in translation_configs:
501
+ src_lang = config["src_lang"]
502
+ tgt_lang = config["tgt_lang"]
503
+ key = model_manager._get_model_key(src_lang, tgt_lang)
504
+ translation_tasks.append((src_lang, tgt_lang, key))
505
+ for src_lang, tgt_lang, key in translation_tasks:
506
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
507
+ model_manager.load_model(src_lang, tgt_lang, key)
508
+ logger.info("All models loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
 
510
+ logger.info("Starting server with preloaded models...")
511
  load_all_models()
512
+ batch_task = asyncio.create_task(batch_worker())
513
  yield
514
+ batch_task.cancel()
515
  llm_manager.unload()
516
+ tts_manager.unload()
517
+ asr_manager.unload()
518
+ for model in model_manager.models.values():
519
+ model.unload()
520
+ logger.info("Server shutdown complete; all models unloaded")
521
+
522
+ # Batch Worker
523
+ async def batch_worker():
524
+ while True:
525
+ batch = []
526
+ last_request_time = time()
527
+ try:
528
+ while len(batch) < 4:
529
+ try:
530
+ request = await asyncio.wait_for(request_queue.get(), timeout=1.0)
531
+ batch.append(request)
532
+ current_time = time()
533
+ if current_time - last_request_time > 1.0 and batch:
534
+ break
535
+ last_request_time = current_time
536
+ except asyncio.TimeoutError:
537
+ if batch:
538
+ break
539
+ continue
540
+ if batch:
541
+ start_time = time()
542
+ responses = await llm_manager.batch_generate(batch)
543
+ duration = time() - start_time
544
+ logger.info(f"Batch of {len(batch)} requests processed in {duration:.3f} seconds")
545
+ for request, response in zip(batch, responses):
546
+ request["future"].set_result(response)
547
+ except Exception as e:
548
+ logger.error(f"Batch worker error: {str(e)}")
549
+ for request in batch:
550
+ request["future"].set_exception(e)
551
 
552
  # FastAPI App
553
  app = FastAPI(
554
+ title="Optimized Dhwani API",
555
+ description="AI Chat API supporting Indian languages with performance enhancements",
556
  version="1.0.0",
557
  redirect_slashes=False,
558
  lifespan=lifespan
559
  )
560
 
 
561
  app.add_middleware(
562
  CORSMiddleware,
563
  allow_origins=["*"],
 
566
  allow_headers=["*"],
567
  )
568
 
 
569
  @app.middleware("http")
570
  async def add_request_timing(request: Request, call_next):
571
  start_time = time()
572
  response = await call_next(request)
573
+ duration = time() - start_time
 
574
  logger.info(f"Request to {request.url.path} took {duration:.3f} seconds")
575
  response.headers["X-Response-Time"] = f"{duration:.3f}"
576
  return response
 
578
  limiter = Limiter(key_func=get_remote_address)
579
  app.state.limiter = limiter
580
 
581
+ # Endpoints
582
  @app.post("/audio/speech", response_class=StreamingResponse)
583
  async def synthesize_kannada(request: KannadaSynthesizeRequest):
584
  if not tts_manager.model:
 
586
  kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
587
  if not request.text.strip():
588
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
589
+ audio_buffer = await synthesize_speech(tts_manager, request.text, "KAN_F (Happy)", kannada_example["ref_text"])
 
 
 
 
 
 
 
590
  return StreamingResponse(
591
  audio_buffer,
592
  media_type="audio/wav",
 
595
 
596
  @app.post("/translate", response_model=TranslationResponse)
597
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
598
+ if not request.sentences:
 
 
 
 
599
  raise HTTPException(status_code=400, detail="Input sentences are required")
600
+ batch = ip.preprocess_batch(request.sentences, src_lang=request.src_lang, tgt_lang=request.tgt_lang)
601
+ inputs = translate_manager.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(translate_manager.device_type)
602
+ with torch.no_grad(), autocast():
603
+ generated_tokens = translate_manager.model.generate(**inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  with translate_manager.tokenizer.as_target_tokenizer():
605
+ generated_tokens = translate_manager.tokenizer.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
606
+ translations = ip.postprocess_batch(generated_tokens, lang=request.tgt_lang)
 
 
 
 
 
607
  return TranslationResponse(translations=translations)
608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  @app.get("/v1/health")
610
  async def health_check():
611
+ memory_usage = torch.cuda.memory_allocated() / (24 * 1024**3) if cuda_available else 0
612
+ if memory_usage > 0.9:
613
+ logger.warning("GPU memory usage exceeds 90%; consider unloading models")
614
+ llm_status = "unhealthy"
615
+ llm_latency = None
616
+ if llm_manager.is_loaded:
617
+ start = time()
618
+ try:
619
+ llm_test = await llm_manager.generate("What is the capital of Karnataka?", max_tokens=10)
620
+ llm_latency = time() - start
621
+ llm_status = "healthy" if llm_test else "unhealthy"
622
+ except Exception as e:
623
+ logger.error(f"LLM health check failed: {str(e)}")
624
+ tts_status = "unhealthy"
625
+ tts_latency = None
626
+ if tts_manager.model:
627
+ start = time()
628
+ try:
629
+ audio_buffer = await synthesize_speech(tts_manager, "Test", "KAN_F (Happy)", EXAMPLES[0]["ref_text"])
630
+ tts_latency = time() - start
631
+ tts_status = "healthy" if audio_buffer else "unhealthy"
632
+ except Exception as e:
633
+ logger.error(f"TTS health check failed: {str(e)}")
634
+ asr_status = "unhealthy"
635
+ asr_latency = None
636
+ if asr_manager.model:
637
+ start = time()
638
+ try:
639
+ dummy_audio = np.zeros(16000, dtype=np.float32)
640
+ wav = torch.tensor(dummy_audio).unsqueeze(0).to(device)
641
+ with autocast(), torch.no_grad():
642
+ asr_test = asr_manager.model(wav, asr_manager.model_language["kannada"], "rnnt")
643
+ asr_latency = time() - start
644
+ asr_status = "healthy" if asr_test else "unhealthy"
645
+ except Exception as e:
646
+ logger.error(f"ASR health check failed: {str(e)}")
647
+ status = {
648
+ "status": "healthy" if llm_status == "healthy" and tts_status == "healthy" and asr_status == "healthy" else "degraded",
649
+ "model": settings.llm_model_name,
650
+ "llm_status": llm_status,
651
+ "llm_latency": f"{llm_latency:.3f}s" if llm_latency else "N/A",
652
+ "tts_status": tts_status,
653
+ "tts_latency": f"{tts_latency:.3f}s" if tts_latency else "N/A",
654
+ "asr_status": asr_status,
655
+ "asr_latency": f"{asr_latency:.3f}s" if asr_latency else "N/A",
656
+ "translation_models": list(model_manager.models.keys()),
657
+ "gpu_memory_usage": f"{memory_usage:.2%}"
658
+ }
659
+ logger.info("Health check completed")
660
+ return status
661
 
662
  @app.get("/")
663
  async def home():
 
668
  try:
669
  logger.info("Starting to unload all models...")
670
  llm_manager.unload()
671
+ tts_manager.unload()
672
+ asr_manager.unload()
673
+ for model in model_manager.models.values():
674
+ model.unload()
675
  logger.info("All models unloaded successfully")
676
  return {"status": "success", "message": "All models unloaded"}
677
  except Exception as e:
 
683
  try:
684
  logger.info("Starting to load all models...")
685
  llm_manager.load()
686
+ tts_manager.load()
687
+ asr_manager.load()
688
+ for src_lang, tgt_lang, key in [
689
+ ('eng_Latn', 'kan_Knda', 'eng_indic'),
690
+ ('kan_Knda', 'eng_Latn', 'indic_eng'),
691
+ ('kan_Knda', 'hin_Deva', 'indic_indic'),
692
+ ]:
693
+ if key not in model_manager.models:
694
+ model_manager.load_model(src_lang, tgt_lang, key)
695
  logger.info("All models loaded successfully")
696
  return {"status": "success", "message": "All models loaded"}
697
  except Exception as e:
 
702
  async def translate_endpoint(request: TranslationRequest):
703
  logger.info(f"Received translation request: {request.dict()}")
704
  try:
705
+ translations = await perform_internal_translation(request.sentences, request.src_lang, request.tgt_lang)
 
 
 
 
706
  logger.info(f"Translation successful: {translations}")
707
  return TranslationResponse(translations=translations)
708
  except Exception as e:
 
712
  @app.post("/v1/chat", response_model=ChatResponse)
713
  @limiter.limit(settings.chat_rate_limit)
714
  async def chat(request: Request, chat_request: ChatRequest):
715
+ async with request_queue:
716
+ if not chat_request.prompt:
717
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
718
+ logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
719
+ EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
720
+ try:
721
+ if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
722
+ translated_prompt = await perform_internal_translation([chat_request.prompt], chat_request.src_lang, "eng_Latn")
723
+ prompt_to_process = translated_prompt[0]
724
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
725
+ else:
726
+ prompt_to_process = chat_request.prompt
727
+ logger.info("Prompt in English or European language, no translation needed")
728
+ response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
729
+ logger.info(f"Generated English response: {response}")
730
+ if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
731
+ translated_response = await perform_internal_translation([response], "eng_Latn", chat_request.tgt_lang)
732
+ final_response = translated_response[0]
733
+ logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
734
+ else:
735
+ final_response = response
736
+ logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
737
+ return ChatResponse(response=final_response)
738
+ except Exception as e:
739
+ logger.error(f"Error processing request: {str(e)}")
740
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
741
 
742
  @app.post("/v1/visual_query/")
743
  async def visual_query(
 
746
  src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
747
  tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
748
  ):
749
+ async with request_queue:
750
+ try:
751
+ image = Image.open(file.file)
752
+ if image.size == (0, 0):
753
+ raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
754
+ if src_lang != "eng_Latn":
755
+ translated_query = await perform_internal_translation([query], src_lang, "eng_Latn")
756
+ query_to_process = translated_query[0]
757
+ logger.info(f"Translated query to English: {query_to_process}")
758
+ else:
759
+ query_to_process = query
760
+ logger.info("Query already in English, no translation needed")
761
+ answer = await llm_manager.vision_query(image, query_to_process)
762
+ logger.info(f"Generated English answer: {answer}")
763
+ if tgt_lang != "eng_Latn":
764
+ translated_answer = await perform_internal_translation([answer], "eng_Latn", tgt_lang)
765
+ final_answer = translated_answer[0]
766
+ logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
767
+ else:
768
+ final_answer = answer
769
+ logger.info("Answer kept in English, no translation needed")
770
+ return {"answer": final_answer}
771
+ except Exception as e:
772
+ logger.error(f"Error processing request: {str(e)}")
773
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
774
 
775
  @app.post("/v1/chat_v2", response_model=ChatResponse)
776
  @limiter.limit(settings.chat_rate_limit)
 
781
  src_lang: str = Form("kan_Knda"),
782
  tgt_lang: str = Form("kan_Knda"),
783
  ):
784
+ async with request_queue:
785
+ if not prompt:
786
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
787
+ if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
788
+ raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
789
+ logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
790
+ try:
791
+ if image:
792
+ image_data = await image.read()
793
+ if not image_data:
794
+ raise HTTPException(status_code=400, detail="Uploaded image is empty")
795
+ img = Image.open(io.BytesIO(image_data))
796
+ if src_lang != "eng_Latn":
797
+ translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
798
+ prompt_to_process = translated_prompt[0]
799
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
800
+ else:
801
+ prompt_to_process = prompt
802
+ decoded = await llm_manager.chat_v2(img, prompt_to_process)
803
+ logger.info(f"Generated English response: {decoded}")
804
+ if tgt_lang != "eng_Latn":
805
+ translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
806
+ final_response = translated_response[0]
807
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
808
+ else:
809
+ final_response = decoded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
  else:
811
+ if src_lang != "eng_Latn":
812
+ translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
813
+ prompt_to_process = translated_prompt[0]
814
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
815
+ else:
816
+ prompt_to_process = prompt
817
+ decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
818
+ logger.info(f"Generated English response: {decoded}")
819
+ if tgt_lang != "eng_Latn":
820
+ translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
821
+ final_response = translated_response[0]
822
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
823
+ else:
824
+ final_response = decoded
825
+ return ChatResponse(response=final_response)
826
+ except Exception as e:
827
+ logger.error(f"Error processing request: {str(e)}")
828
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
829
 
830
  @app.post("/transcribe/", response_model=TranscriptionResponse)
831
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
832
+ async with request_queue:
833
+ if not asr_manager.model:
834
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
835
+ try:
836
+ wav, sr = torchaudio.load(file.file, backend="cuda" if cuda_available else "cpu")
837
+ wav = torch.mean(wav, dim=0, keepdim=True).to(device)
838
+ target_sample_rate = 16000
839
+ if sr != target_sample_rate:
840
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device)
841
+ wav = resampler(wav)
842
+ with autocast(), torch.no_grad():
843
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
844
+ return TranscriptionResponse(text=transcription_rnnt)
845
+ except Exception as e:
846
+ logger.error(f"Error in transcription: {str(e)}")
847
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
848
 
849
  @app.post("/v1/speech_to_speech")
850
  async def speech_to_speech(
 
852
  file: UploadFile = File(...),
853
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
854
  ) -> StreamingResponse:
855
+ async with request_queue:
856
+ if not tts_manager.model:
857
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
858
+ transcription = await transcribe_audio(file, language)
859
+ logger.info(f"Transcribed text: {transcription.text}")
860
+ chat_request = ChatRequest(prompt=transcription.text, src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"), tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"))
861
+ processed_text = await chat(request, chat_request)
862
+ logger.info(f"Processed text: {processed_text.response}")
863
+ voice_request = KannadaSynthesizeRequest(text=processed_text.response)
864
+ audio_response = await synthesize_kannada(voice_request)
865
+ return audio_response
866
+
867
+ LANGUAGE_TO_SCRIPT = {"kannada": "kan_Knda"}
868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869
  if __name__ == "__main__":
870
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
871
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
 
892
  settings.speech_rate_limit = global_settings["speech_rate_limit"]
893
 
894
  llm_manager = LLMManager(settings.llm_model_name)
 
895
  if selected_config["components"]["ASR"]:
 
896
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
 
897
  if selected_config["components"]["Translation"]:
898
  translation_configs.extend(selected_config["components"]["Translation"])
899
 
900
  host = args.host if args.host != settings.host else settings.host
901
  port = args.port if args.port != settings.port else settings.port
902
 
903
+ # Run Uvicorn with import string to support workers
904
+ uvicorn.run("main:app", host=host, port=port, workers=2)