sachin commited on
Commit
66eb611
·
1 Parent(s): 8308b9e

download-modesl

Browse files
Files changed (3) hide show
  1. Dockerfile +14 -0
  2. download_models.py +1 -1
  3. src/server/main_hfy.py +910 -0
Dockerfile CHANGED
@@ -4,6 +4,20 @@ WORKDIR /app
4
  COPY dhwani_config.json .
5
 
6
  # Create a directory for pre-downloaded models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  COPY . .
8
 
9
  # Set up user
 
4
  COPY dhwani_config.json .
5
 
6
  # Create a directory for pre-downloaded models
7
+
8
+
9
+ RUN mkdir -p /app/models
10
+
11
+ # Define build argument for HF_TOKEN
12
+ ARG HF_TOKEN_DOCKER
13
+
14
+ # Set environment variable for the build process
15
+ ENV HF_TOKEN=$HF_TOKEN_DOCKER
16
+
17
+ # Copy and run the model download script
18
+ COPY download_models.py .
19
+ RUN python download_models.py
20
+
21
  COPY . .
22
 
23
  # Set up user
download_models.py CHANGED
@@ -10,7 +10,7 @@ if not hf_token:
10
 
11
  # Define the models to download
12
  models = {
13
- 'llm_model': ('google/gemma-3-4b-it', Gemma3ForConditionalGeneration, AutoProcessor),
14
  'tts_model': ('ai4bharat/IndicF5', AutoModel, None),
15
  'asr_model': ('ai4bharat/indic-conformer-600m-multilingual', AutoModel, None),
16
  'trans_en_indic': ('ai4bharat/indictrans2-en-indic-dist-200M', AutoModelForSeq2SeqLM, AutoTokenizer),
 
10
 
11
  # Define the models to download
12
  models = {
13
+ #'llm_model': ('google/gemma-3-4b-it', Gemma3ForConditionalGeneration, AutoProcessor),
14
  'tts_model': ('ai4bharat/IndicF5', AutoModel, None),
15
  'asr_model': ('ai4bharat/indic-conformer-600m-multilingual', AutoModel, None),
16
  'trans_en_indic': ('ai4bharat/indictrans2-en-indic-dist-200M', AutoModelForSeq2SeqLM, AutoTokenizer),
src/server/main_hfy.py ADDED
@@ -0,0 +1,910 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import os
4
+ from time import time
5
+ from typing import List
6
+ import tempfile
7
+ import uvicorn
8
+ from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
+ from PIL import Image
12
+ from pydantic import BaseModel, field_validator
13
+ from pydantic_settings import BaseSettings
14
+ from slowapi import Limiter
15
+ from slowapi.util import get_remote_address
16
+ import torch
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig, AutoModel, Gemma3ForConditionalGeneration
18
+ from IndicTransToolkit import IndicProcessor
19
+ import json
20
+ import asyncio
21
+ from contextlib import asynccontextmanager
22
+ import soundfile as sf
23
+ import numpy as np
24
+ import requests
25
+ from starlette.responses import StreamingResponse
26
+ from logging_config import logger
27
+ from tts_config import SPEED, ResponseFormat, config as tts_config
28
+ import torchaudio
29
+
30
+ # Device setup
31
+ if torch.cuda.is_available():
32
+ device = "cuda:0"
33
+ logger.info("GPU will be used for inference")
34
+ else:
35
+ device = "cpu"
36
+ logger.info("CPU will be used for inference")
37
+ torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
38
+
39
+ # Check CUDA availability and version
40
+ cuda_available = torch.cuda.is_available()
41
+ cuda_version = torch.version.cuda if cuda_available else None
42
+
43
+ if torch.cuda.is_available():
44
+ device_idx = torch.cuda.current_device()
45
+ capability = torch.cuda.get_device_capability(device_idx)
46
+ compute_capability_float = float(f"{capability[0]}.{capability[1]}")
47
+ print(f"CUDA version: {cuda_version}")
48
+ print(f"CUDA Compute Capability: {compute_capability_float}")
49
+ else:
50
+ print("CUDA is not available on this system.")
51
+
52
+ # Settings
53
+ class Settings(BaseSettings):
54
+ llm_model_name: str = "google/gemma-3-4b-it"
55
+ max_tokens: int = 512
56
+ host: str = "0.0.0.0"
57
+ port: int = 7860
58
+ chat_rate_limit: str = "100/minute"
59
+ speech_rate_limit: str = "5/minute"
60
+
61
+ @field_validator("chat_rate_limit", "speech_rate_limit")
62
+ def validate_rate_limit(cls, v):
63
+ if not v.count("/") == 1 or not v.split("/")[0].isdigit():
64
+ raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
65
+ return v
66
+
67
+ class Config:
68
+ env_file = ".env"
69
+
70
+ settings = Settings()
71
+
72
+ # Quantization config for LLM
73
+ quantization_config = BitsAndBytesConfig(
74
+ load_in_4bit=True,
75
+ bnb_4bit_quant_type="nf4",
76
+ bnb_4bit_use_double_quant=True,
77
+ bnb_4bit_compute_dtype=torch.bfloat16
78
+ )
79
+
80
+ # LLM Manager
81
+ class LLMManager:
82
+ def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
83
+ self.model_name = model_name
84
+ self.device = torch.device(device)
85
+ self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
+ self.model = None
87
+ self.processor = None
88
+ self.is_loaded = False
89
+ logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
+
91
+ async def load(self):
92
+ if not self.is_loaded:
93
+ try:
94
+ self.model = await asyncio.to_thread(
95
+ Gemma3ForConditionalGeneration.from_pretrained,
96
+ self.model_name,
97
+ device_map="auto",
98
+ quantization_config=quantization_config,
99
+ torch_dtype=self.torch_dtype
100
+ )
101
+ self.model.eval()
102
+ self.processor = await asyncio.to_thread(
103
+ AutoProcessor.from_pretrained,
104
+ self.model_name
105
+ )
106
+ self.is_loaded = True
107
+ logger.info(f"LLM {self.model_name} loaded asynchronously on {self.device}")
108
+ except Exception as e:
109
+ logger.error(f"Failed to load LLM: {str(e)}")
110
+ raise
111
+
112
+ def unload(self):
113
+ if self.is_loaded:
114
+ del self.model
115
+ del self.processor
116
+ if self.device.type == "cuda":
117
+ torch.cuda.empty_cache()
118
+ logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
119
+ self.is_loaded = False
120
+ logger.info(f"LLM {self.model_name} unloaded from {self.device}")
121
+
122
+ async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
123
+ if not self.is_loaded:
124
+ await self.load()
125
+
126
+ messages_vlm = [
127
+ {
128
+ "role": "system",
129
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
130
+ },
131
+ {
132
+ "role": "user",
133
+ "content": [{"type": "text", "text": prompt}]
134
+ }
135
+ ]
136
+
137
+ try:
138
+ inputs_vlm = self.processor.apply_chat_template(
139
+ messages_vlm,
140
+ add_generation_prompt=True,
141
+ tokenize=True,
142
+ return_dict=True,
143
+ return_tensors="pt"
144
+ ).to(self.device, dtype=torch.bfloat16)
145
+ except Exception as e:
146
+ logger.error(f"Error in tokenization: {str(e)}")
147
+ raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
148
+
149
+ input_len = inputs_vlm["input_ids"].shape[-1]
150
+
151
+ with torch.inference_mode():
152
+ generation = self.model.generate(
153
+ **inputs_vlm,
154
+ max_new_tokens=max_tokens,
155
+ do_sample=True,
156
+ temperature=temperature
157
+ )
158
+ generation = generation[0][input_len:]
159
+
160
+ response = self.processor.decode(generation, skip_special_tokens=True)
161
+ logger.info(f"Generated response: {response}")
162
+ return response
163
+
164
+ async def vision_query(self, image: Image.Image, query: str) -> str:
165
+ if not self.is_loaded:
166
+ await self.load()
167
+
168
+ messages_vlm = [
169
+ {
170
+ "role": "system",
171
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
172
+ },
173
+ {
174
+ "role": "user",
175
+ "content": []
176
+ }
177
+ ]
178
+
179
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
180
+ if image and image.size[0] > 0 and image.size[1] > 0:
181
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
182
+ logger.info(f"Received valid image for processing")
183
+ else:
184
+ logger.info("No valid image provided, processing text only")
185
+
186
+ try:
187
+ inputs_vlm = self.processor.apply_chat_template(
188
+ messages_vlm,
189
+ add_generation_prompt=True,
190
+ tokenize=True,
191
+ return_dict=True,
192
+ return_tensors="pt"
193
+ ).to(self.device, dtype=torch.bfloat16)
194
+ except Exception as e:
195
+ logger.error(f"Error in apply_chat_template: {str(e)}")
196
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
197
+
198
+ input_len = inputs_vlm["input_ids"].shape[-1]
199
+
200
+ with torch.inference_mode():
201
+ generation = self.model.generate(
202
+ **inputs_vlm,
203
+ max_new_tokens=512,
204
+ do_sample=True,
205
+ temperature=0.7
206
+ )
207
+ generation = generation[0][input_len:]
208
+
209
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
210
+ logger.info(f"Vision query response: {decoded}")
211
+ return decoded
212
+
213
+ async def chat_v2(self, image: Image.Image, query: str) -> str:
214
+ if not self.is_loaded:
215
+ await self.load()
216
+
217
+ messages_vlm = [
218
+ {
219
+ "role": "system",
220
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
221
+ },
222
+ {
223
+ "role": "user",
224
+ "content": []
225
+ }
226
+ ]
227
+
228
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
229
+ if image and image.size[0] > 0 and image.size[1] > 0:
230
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
231
+ logger.info(f"Received valid image for processing")
232
+ else:
233
+ logger.info("No valid image provided, processing text only")
234
+
235
+ try:
236
+ inputs_vlm = self.processor.apply_chat_template(
237
+ messages_vlm,
238
+ add_generation_prompt=True,
239
+ tokenize=True,
240
+ return_dict=True,
241
+ return_tensors="pt"
242
+ ).to(self.device, dtype=torch.bfloat16)
243
+ except Exception as e:
244
+ logger.error(f"Error in apply_chat_template: {str(e)}")
245
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
246
+
247
+ input_len = inputs_vlm["input_ids"].shape[-1]
248
+
249
+ with torch.inference_mode():
250
+ generation = self.model.generate(
251
+ **inputs_vlm,
252
+ max_new_tokens=512,
253
+ do_sample=True,
254
+ temperature=0.7
255
+ )
256
+ generation = generation[0][input_len:]
257
+
258
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
259
+ logger.info(f"Chat_v2 response: {decoded}")
260
+ return decoded
261
+
262
+ # TTS Manager
263
+ class TTSManager:
264
+ def __init__(self, device_type=device):
265
+ self.device_type = device_type
266
+ self.model = None
267
+ self.repo_id = "ai4bharat/IndicF5"
268
+
269
+ async def load(self):
270
+ if not self.model:
271
+ logger.info("Loading TTS model IndicF5 asynchronously...")
272
+ self.model = await asyncio.to_thread(
273
+ AutoModel.from_pretrained,
274
+ self.repo_id,
275
+ trust_remote_code=True
276
+ )
277
+ self.model = self.model.to(self.device_type)
278
+ logger.info("TTS model IndicF5 loaded asynchronously")
279
+
280
+ def synthesize(self, text, ref_audio_path, ref_text):
281
+ if not self.model:
282
+ raise ValueError("TTS model not loaded")
283
+ return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
284
+
285
+ # TTS Constants
286
+ EXAMPLES = [
287
+ {
288
+ "audio_name": "KAN_F (Happy)",
289
+ "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
290
+ "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
291
+ "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
292
+ },
293
+ ]
294
+
295
+ # Pydantic models for TTS
296
+ class SynthesizeRequest(BaseModel):
297
+ text: str
298
+ ref_audio_name: str
299
+ ref_text: str = None
300
+
301
+ class KannadaSynthesizeRequest(BaseModel):
302
+ text: str
303
+
304
+ # TTS Functions
305
+ def load_audio_from_url(url: str):
306
+ response = requests.get(url)
307
+ if response.status_code == 200:
308
+ audio_data, sample_rate = sf.read(io.BytesIO(response.content))
309
+ return sample_rate, audio_data
310
+ raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
311
+
312
+ def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
313
+ ref_audio_url = None
314
+ for example in EXAMPLES:
315
+ if example["audio_name"] == ref_audio_name:
316
+ ref_audio_url = example["audio_url"]
317
+ if not ref_text:
318
+ ref_text = example["ref_text"]
319
+ break
320
+
321
+ if not ref_audio_url:
322
+ raise HTTPException(status_code=400, detail="Invalid reference audio name.")
323
+ if not text.strip():
324
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
325
+ if not ref_text or not ref_text.strip():
326
+ raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
327
+
328
+ sample_rate, audio_data = load_audio_from_url(ref_audio_url)
329
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
330
+ sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
331
+ temp_audio.flush()
332
+ audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
333
+
334
+ if audio.dtype == np.int16:
335
+ audio = audio.astype(np.float32) / 32768.0
336
+ buffer = io.BytesIO()
337
+ sf.write(buffer, audio, 24000, format='WAV')
338
+ buffer.seek(0)
339
+ return buffer
340
+
341
+ # Supported languages
342
+ SUPPORTED_LANGUAGES = {
343
+ "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
344
+ "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
345
+ "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
346
+ "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
347
+ "kan_Knda", "ory_Orya",
348
+ "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
349
+ "por_Latn", "rus_Cyrl", "pol_Latn"
350
+ }
351
+
352
+ # Translation Manager
353
+ class TranslateManager:
354
+ def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
355
+ self.device_type = device_type
356
+ self.tokenizer = None
357
+ self.model = None
358
+ self.src_lang = src_lang
359
+ self.tgt_lang = tgt_lang
360
+ self.use_distilled = use_distilled
361
+
362
+ async def load(self):
363
+ if not self.tokenizer or not self.model:
364
+ if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
365
+ model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
366
+ elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
367
+ model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
368
+ elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
369
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
370
+ else:
371
+ raise ValueError("Invalid language combination")
372
+
373
+ self.tokenizer = await asyncio.to_thread(
374
+ AutoTokenizer.from_pretrained,
375
+ model_name,
376
+ trust_remote_code=True
377
+ )
378
+ self.model = await asyncio.to_thread(
379
+ AutoModelForSeq2SeqLM.from_pretrained,
380
+ model_name,
381
+ trust_remote_code=True,
382
+ torch_dtype=torch.float16,
383
+ attn_implementation="flash_attention_2"
384
+ )
385
+ self.model = self.model.to(self.device_type)
386
+ self.model = torch.compile(self.model, mode="reduce-overhead")
387
+ logger.info(f"Translation model {model_name} loaded asynchronously")
388
+
389
+ class ModelManager:
390
+ def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
391
+ self.models = {}
392
+ self.device_type = device_type
393
+ self.use_distilled = use_distilled
394
+ self.is_lazy_loading = is_lazy_loading
395
+
396
+ async def load_model(self, src_lang, tgt_lang, key):
397
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang} asynchronously")
398
+ translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
399
+ await translate_manager.load()
400
+ self.models[key] = translate_manager
401
+ logger.info(f"Loaded translation model for {key} asynchronously")
402
+
403
+ def get_model(self, src_lang, tgt_lang):
404
+ key = self._get_model_key(src_lang, tgt_lang)
405
+ if key not in self.models:
406
+ if self.is_lazy_loading:
407
+ asyncio.create_task(self.load_model(src_lang, tgt_lang, key))
408
+ else:
409
+ raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
410
+ return self.models.get(key)
411
+
412
+ def _get_model_key(self, src_lang, tgt_lang):
413
+ if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
414
+ return 'eng_indic'
415
+ elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
416
+ return 'indic_eng'
417
+ elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
418
+ return 'indic_indic'
419
+ raise ValueError("Invalid language combination")
420
+
421
+ # ASR Manager
422
+ class ASRModelManager:
423
+ def __init__(self, device_type="cuda"):
424
+ self.device_type = device_type
425
+ self.model = None
426
+ self.model_language = {"kannada": "kn"}
427
+
428
+ async def load(self):
429
+ if not self.model:
430
+ logger.info("Loading ASR model asynchronously...")
431
+ self.model = await asyncio.to_thread(
432
+ AutoModel.from_pretrained,
433
+ "ai4bharat/indic-conformer-600m-multilingual",
434
+ trust_remote_code=True
435
+ )
436
+ self.model = self.model.to(self.device_type)
437
+ logger.info("ASR model loaded asynchronously")
438
+
439
+ # Global Managers
440
+ llm_manager = LLMManager(settings.llm_model_name)
441
+ model_manager = ModelManager()
442
+ asr_manager = ASRModelManager()
443
+ tts_manager = TTSManager()
444
+ ip = IndicProcessor(inference=True)
445
+
446
+ # Pydantic Models
447
+ class ChatRequest(BaseModel):
448
+ prompt: str
449
+ src_lang: str = "kan_Knda"
450
+ tgt_lang: str = "kan_Knda"
451
+
452
+ @field_validator("prompt")
453
+ def prompt_must_be_valid(cls, v):
454
+ if len(v) > 1000:
455
+ raise ValueError("Prompt cannot exceed 1000 characters")
456
+ return v.strip()
457
+
458
+ @field_validator("src_lang", "tgt_lang")
459
+ def validate_language(cls, v):
460
+ if v not in SUPPORTED_LANGUAGES:
461
+ raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
462
+ return v
463
+
464
+ class ChatResponse(BaseModel):
465
+ response: str
466
+
467
+ class TranslationRequest(BaseModel):
468
+ sentences: List[str]
469
+ src_lang: str
470
+ tgt_lang: str
471
+
472
+ class TranscriptionResponse(BaseModel):
473
+ text: str
474
+
475
+ class TranslationResponse(BaseModel):
476
+ translations: List[str]
477
+
478
+ # Dependency
479
+ def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
480
+ return model_manager.get_model(src_lang, tgt_lang)
481
+
482
+ # Lifespan Event Handler
483
+ translation_configs = []
484
+
485
+ @asynccontextmanager
486
+ async def lifespan(app: FastAPI):
487
+ async def load_all_models():
488
+ try:
489
+ tasks = [
490
+ llm_manager.load(),
491
+ tts_manager.load(),
492
+ asr_manager.load(),
493
+ ]
494
+
495
+ translation_tasks = [
496
+ model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic'),
497
+ model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng'),
498
+ model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic'),
499
+ ]
500
+
501
+ for config in translation_configs:
502
+ src_lang = config["src_lang"]
503
+ tgt_lang = config["tgt_lang"]
504
+ key = model_manager._get_model_key(src_lang, tgt_lang)
505
+ translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
506
+
507
+ await asyncio.gather(*tasks, *translation_tasks)
508
+ logger.info("All models loaded successfully asynchronously")
509
+ except Exception as e:
510
+ logger.error(f"Error loading models: {str(e)}")
511
+ raise
512
+
513
+ logger.info("Starting asynchronous model loading...")
514
+ await load_all_models()
515
+ yield
516
+ llm_manager.unload()
517
+ logger.info("Server shutdown complete")
518
+
519
+ # FastAPI App
520
+ app = FastAPI(
521
+ title="Dhwani API",
522
+ description="AI Chat API supporting Indian languages",
523
+ version="1.0.0",
524
+ redirect_slashes=False,
525
+ lifespan=lifespan
526
+ )
527
+
528
+ app.add_middleware(
529
+ CORSMiddleware,
530
+ allow_origins=["*"],
531
+ allow_credentials=False,
532
+ allow_methods=["*"],
533
+ allow_headers=["*"],
534
+ )
535
+
536
+ limiter = Limiter(key_func=get_remote_address)
537
+ app.state.limiter = limiter
538
+
539
+ # API Endpoints
540
+ @app.post("/audio/speech", response_class=StreamingResponse)
541
+ async def synthesize_kannada(request: KannadaSynthesizeRequest):
542
+ if not tts_manager.model:
543
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
544
+ kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
545
+ if not request.text.strip():
546
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
547
+
548
+ audio_buffer = synthesize_speech(
549
+ tts_manager,
550
+ text=request.text,
551
+ ref_audio_name="KAN_F (Happy)",
552
+ ref_text=kannada_example["ref_text"]
553
+ )
554
+
555
+ return StreamingResponse(
556
+ audio_buffer,
557
+ media_type="audio/wav",
558
+ headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
559
+ )
560
+
561
+ @app.post("/translate", response_model=TranslationResponse)
562
+ async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
563
+ input_sentences = request.sentences
564
+ src_lang = request.src_lang
565
+ tgt_lang = request.tgt_lang
566
+
567
+ if not input_sentences:
568
+ raise HTTPException(status_code=400, detail="Input sentences are required")
569
+
570
+ batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
571
+ inputs = translate_manager.tokenizer(
572
+ batch,
573
+ truncation=True,
574
+ padding="longest",
575
+ return_tensors="pt",
576
+ return_attention_mask=True,
577
+ ).to(translate_manager.device_type)
578
+
579
+ with torch.no_grad():
580
+ generated_tokens = translate_manager.model.generate(
581
+ **inputs,
582
+ use_cache=True,
583
+ min_length=0,
584
+ max_length=256,
585
+ num_beams=5,
586
+ num_return_sequences=1,
587
+ )
588
+
589
+ with translate_manager.tokenizer.as_target_tokenizer():
590
+ generated_tokens = translate_manager.tokenizer.batch_decode(
591
+ generated_tokens.detach().cpu().tolist(),
592
+ skip_special_tokens=True,
593
+ clean_up_tokenization_spaces=True,
594
+ )
595
+
596
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
597
+ return TranslationResponse(translations=translations)
598
+
599
+ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
600
+ try:
601
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
602
+ except ValueError as e:
603
+ logger.info(f"Model not preloaded: {str(e)}, loading now...")
604
+ key = model_manager._get_model_key(src_lang, tgt_lang)
605
+ await model_manager.load_model(src_lang, tgt_lang, key)
606
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
607
+
608
+ if not translate_manager.model:
609
+ await translate_manager.load()
610
+
611
+ request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
612
+ response = await translate(request, translate_manager)
613
+ return response.translations
614
+
615
+ @app.get("/v1/health")
616
+ async def health_check():
617
+ return {"status": "healthy", "model": settings.llm_model_name}
618
+
619
+ @app.get("/")
620
+ async def home():
621
+ return RedirectResponse(url="/docs")
622
+
623
+ @app.post("/v1/unload_all_models")
624
+ async def unload_all_models():
625
+ try:
626
+ logger.info("Starting to unload all models...")
627
+ llm_manager.unload()
628
+ logger.info("All models unloaded successfully")
629
+ return {"status": "success", "message": "All models unloaded"}
630
+ except Exception as e:
631
+ logger.error(f"Error unloading models: {str(e)}")
632
+ raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")
633
+
634
+ @app.post("/v1/load_all_models")
635
+ async def load_all_models():
636
+ try:
637
+ logger.info("Starting to load all models...")
638
+ await llm_manager.load()
639
+ logger.info("All models loaded successfully")
640
+ return {"status": "success", "message": "All models loaded"}
641
+ except Exception as e:
642
+ logger.error(f"Error loading models: {str(e)}")
643
+ raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}")
644
+
645
+ @app.post("/v1/translate", response_model=TranslationResponse)
646
+ async def translate_endpoint(request: TranslationRequest):
647
+ logger.info(f"Received translation request: {request.dict()}")
648
+ try:
649
+ translations = await perform_internal_translation(
650
+ sentences=request.sentences,
651
+ src_lang=request.src_lang,
652
+ tgt_lang=request.tgt_lang
653
+ )
654
+ logger.info(f"Translation successful: {translations}")
655
+ return TranslationResponse(translations=translations)
656
+ except Exception as e:
657
+ logger.error(f"Unexpected error during translation: {str(e)}")
658
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
659
+
660
+ @app.post("/v1/chat", response_model=ChatResponse)
661
+ @limiter.limit(settings.chat_rate_limit)
662
+ async def chat(request: Request, chat_request: ChatRequest):
663
+ if not chat_request.prompt:
664
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
665
+ logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
666
+
667
+ EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
668
+
669
+ try:
670
+ if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
671
+ translated_prompt = await perform_internal_translation(
672
+ sentences=[chat_request.prompt],
673
+ src_lang=chat_request.src_lang,
674
+ tgt_lang="eng_Latn"
675
+ )
676
+ prompt_to_process = translated_prompt[0]
677
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
678
+ else:
679
+ prompt_to_process = chat_request.prompt
680
+ logger.info("Prompt in English or European language, no translation needed")
681
+
682
+ response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
683
+ logger.info(f"Generated response: {response}")
684
+
685
+ if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
686
+ translated_response = await perform_internal_translation(
687
+ sentences=[response],
688
+ src_lang="eng_Latn",
689
+ tgt_lang=chat_request.tgt_lang
690
+ )
691
+ final_response = translated_response[0]
692
+ logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
693
+ else:
694
+ final_response = response
695
+ logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
696
+
697
+ return ChatResponse(response=final_response)
698
+ except Exception as e:
699
+ logger.error(f"Error processing request: {str(e)}")
700
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
701
+
702
+ @app.post("/v1/visual_query/")
703
+ async def visual_query(
704
+ file: UploadFile = File(...),
705
+ query: str = Body(...),
706
+ src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
707
+ tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
708
+ ):
709
+ try:
710
+ image = Image.open(file.file)
711
+ if image.size == (0, 0):
712
+ raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
713
+
714
+ if src_lang != "eng_Latn":
715
+ translated_query = await perform_internal_translation(
716
+ sentences=[query],
717
+ src_lang=src_lang,
718
+ tgt_lang="eng_Latn"
719
+ )
720
+ query_to_process = translated_query[0]
721
+ logger.info(f"Translated query to English: {query_to_process}")
722
+ else:
723
+ query_to_process = query
724
+ logger.info("Query already in English, no translation needed")
725
+
726
+ answer = await llm_manager.vision_query(image, query_to_process)
727
+ logger.info(f"Generated English answer: {answer}")
728
+
729
+ if tgt_lang != "eng_Latn":
730
+ translated_answer = await perform_internal_translation(
731
+ sentences=[answer],
732
+ src_lang="eng_Latn",
733
+ tgt_lang=tgt_lang
734
+ )
735
+ final_answer = translated_answer[0]
736
+ logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
737
+ else:
738
+ final_answer = answer
739
+ logger.info("Answer kept in English, no translation needed")
740
+
741
+ return {"answer": final_answer}
742
+ except Exception as e:
743
+ logger.error(f"Error processing request: {str(e)}")
744
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
745
+
746
+ @app.post("/v1/chat_v2", response_model=ChatResponse)
747
+ @limiter.limit(settings.chat_rate_limit)
748
+ async def chat_v2(
749
+ request: Request,
750
+ prompt: str = Form(...),
751
+ image: UploadFile = File(default=None),
752
+ src_lang: str = Form("kan_Knda"),
753
+ tgt_lang: str = Form("kan_Knda"),
754
+ ):
755
+ if not prompt:
756
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
757
+ if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
758
+ raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
759
+
760
+ logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
761
+
762
+ try:
763
+ if image:
764
+ image_data = await image.read()
765
+ if not image_data:
766
+ raise HTTPException(status_code=400, detail="Uploaded image is empty")
767
+ img = Image.open(io.BytesIO(image_data))
768
+
769
+ if src_lang != "eng_Latn":
770
+ translated_prompt = await perform_internal_translation(
771
+ sentences=[prompt],
772
+ src_lang=src_lang,
773
+ tgt_lang="eng_Latn"
774
+ )
775
+ prompt_to_process = translated_prompt[0]
776
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
777
+ else:
778
+ prompt_to_process = prompt
779
+ logger.info("Prompt already in English, no translation needed")
780
+
781
+ decoded = await llm_manager.chat_v2(img, prompt_to_process)
782
+ logger.info(f"Generated English response: {decoded}")
783
+
784
+ if tgt_lang != "eng_Latn":
785
+ translated_response = await perform_internal_translation(
786
+ sentences=[decoded],
787
+ src_lang="eng_Latn",
788
+ tgt_lang=tgt_lang
789
+ )
790
+ final_response = translated_response[0]
791
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
792
+ else:
793
+ final_response = decoded
794
+ logger.info("Response kept in English, no translation needed")
795
+ else:
796
+ if src_lang != "eng_Latn":
797
+ translated_prompt = await perform_internal_translation(
798
+ sentences=[prompt],
799
+ src_lang=src_lang,
800
+ tgt_lang="eng_Latn"
801
+ )
802
+ prompt_to_process = translated_prompt[0]
803
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
804
+ else:
805
+ prompt_to_process = prompt
806
+ logger.info("Prompt already in English, no translation needed")
807
+
808
+ decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
809
+ logger.info(f"Generated English response: {decoded}")
810
+
811
+ if tgt_lang != "eng_Latn":
812
+ translated_response = await perform_internal_translation(
813
+ sentences=[decoded],
814
+ src_lang="eng_Latn",
815
+ tgt_lang=tgt_lang
816
+ )
817
+ final_response = translated_response[0]
818
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
819
+ else:
820
+ final_response = decoded
821
+ logger.info("Response kept in English, no translation needed")
822
+
823
+ return ChatResponse(response=final_response)
824
+ except Exception as e:
825
+ logger.error(f"Error processing request: {str(e)}")
826
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
827
+
828
+ @app.post("/transcribe/", response_model=TranscriptionResponse)
829
+ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
830
+ if not asr_manager.model:
831
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
832
+ try:
833
+ wav, sr = torchaudio.load(file.file)
834
+ wav = torch.mean(wav, dim=0, keepdim=True)
835
+ target_sample_rate = 16000
836
+ if sr != target_sample_rate:
837
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
838
+ wav = resampler(wav)
839
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
840
+ return TranscriptionResponse(text=transcription_rnnt)
841
+ except Exception as e:
842
+ logger.error(f"Error in transcription: {str(e)}")
843
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
844
+
845
+ @app.post("/v1/speech_to_speech")
846
+ async def speech_to_speech(
847
+ request: Request,
848
+ file: UploadFile = File(...),
849
+ language: str = Query(..., enum=list(asr_manager.model_language.keys())),
850
+ ) -> StreamingResponse:
851
+ if not tts_manager.model:
852
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
853
+ transcription = await transcribe_audio(file, language)
854
+ logger.info(f"Transcribed text: {transcription.text}")
855
+
856
+ chat_request = ChatRequest(
857
+ prompt=transcription.text,
858
+ src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
859
+ tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
860
+ )
861
+ processed_text = await chat(request, chat_request)
862
+ logger.info(f"Processed text: {processed_text.response}")
863
+
864
+ voice_request = KannadaSynthesizeRequest(text=processed_text.response)
865
+ audio_response = await synthesize_kannada(voice_request)
866
+ return audio_response
867
+
868
+ LANGUAGE_TO_SCRIPT = {
869
+ "kannada": "kan_Knda"
870
+ }
871
+
872
+ # Main Execution
873
+ if __name__ == "__main__":
874
+ parser = argparse.ArgumentParser(description="Run the FastAPI server.")
875
+ parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
876
+ parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
877
+ parser.add_argument("--config", type=str, default="config_one", help="Configuration to use")
878
+ args = parser.parse_args()
879
+
880
+ def load_config(config_path="dhwani_config.json"):
881
+ with open(config_path, "r") as f:
882
+ return json.load(f)
883
+
884
+ config_data = load_config()
885
+ if args.config not in config_data["configs"]:
886
+ raise ValueError(f"Invalid config: {args.config}. Available: {list(config_data['configs'].keys())}")
887
+
888
+ selected_config = config_data["configs"][args.config]
889
+ global_settings = config_data["global_settings"]
890
+
891
+ settings.llm_model_name = selected_config["components"]["LLM"]["model"]
892
+ settings.max_tokens = selected_config["components"]["LLM"]["max_tokens"]
893
+ settings.host = global_settings["host"]
894
+ settings.port = global_settings["port"]
895
+ settings.chat_rate_limit = global_settings["chat_rate_limit"]
896
+ settings.speech_rate_limit = global_settings["speech_rate_limit"]
897
+
898
+ llm_manager = LLMManager(settings.llm_model_name)
899
+
900
+ if selected_config["components"]["ASR"]:
901
+ asr_model_name = selected_config["components"]["ASR"]["model"]
902
+ asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
903
+
904
+ if selected_config["components"]["Translation"]:
905
+ translation_configs.extend(selected_config["components"]["Translation"])
906
+
907
+ host = args.host if args.host != settings.host else settings.host
908
+ port = args.port if args.port != settings.port else settings.port
909
+
910
+ uvicorn.run(app, host=host, port=port)