sachin
commited on
Commit
·
0fb44d7
1
Parent(s):
7fbf9f0
test
Browse files- src/server/main.py +23 -20
src/server/main.py
CHANGED
@@ -14,7 +14,7 @@ from pydantic_settings import BaseSettings
|
|
14 |
from slowapi import Limiter
|
15 |
from slowapi.util import get_remote_address
|
16 |
import torch
|
17 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor,
|
18 |
from IndicTransToolkit import IndicProcessor
|
19 |
import json
|
20 |
import asyncio
|
@@ -91,9 +91,10 @@ class LLMManager:
|
|
91 |
async def load(self):
|
92 |
if not self.is_loaded:
|
93 |
try:
|
|
|
94 |
self.model = await asyncio.to_thread(
|
95 |
Gemma3ForConditionalGeneration.from_pretrained,
|
96 |
-
|
97 |
device_map="auto",
|
98 |
quantization_config=quantization_config,
|
99 |
torch_dtype=self.torch_dtype
|
@@ -101,10 +102,10 @@ class LLMManager:
|
|
101 |
self.model.eval()
|
102 |
self.processor = await asyncio.to_thread(
|
103 |
AutoProcessor.from_pretrained,
|
104 |
-
|
105 |
)
|
106 |
self.is_loaded = True
|
107 |
-
logger.info(f"LLM {
|
108 |
except Exception as e:
|
109 |
logger.error(f"Failed to load LLM: {str(e)}")
|
110 |
raise
|
@@ -268,14 +269,15 @@ class TTSManager:
|
|
268 |
|
269 |
async def load(self):
|
270 |
if not self.model:
|
271 |
-
logger.info("Loading TTS model
|
|
|
272 |
self.model = await asyncio.to_thread(
|
273 |
AutoModel.from_pretrained,
|
274 |
-
|
275 |
trust_remote_code=True
|
276 |
)
|
277 |
self.model = self.model.to(self.device_type)
|
278 |
-
logger.info("TTS model
|
279 |
|
280 |
def synthesize(self, text, ref_audio_path, ref_text):
|
281 |
if not self.model:
|
@@ -362,29 +364,29 @@ class TranslateManager:
|
|
362 |
async def load(self):
|
363 |
if not self.tokenizer or not self.model:
|
364 |
if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
|
365 |
-
|
366 |
elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
|
367 |
-
|
368 |
elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
|
369 |
-
|
370 |
else:
|
371 |
raise ValueError("Invalid language combination")
|
372 |
|
373 |
self.tokenizer = await asyncio.to_thread(
|
374 |
AutoTokenizer.from_pretrained,
|
375 |
-
|
376 |
trust_remote_code=True
|
377 |
)
|
378 |
self.model = await asyncio.to_thread(
|
379 |
AutoModelForSeq2SeqLM.from_pretrained,
|
380 |
-
|
381 |
trust_remote_code=True,
|
382 |
torch_dtype=torch.float16,
|
383 |
attn_implementation="flash_attention_2"
|
384 |
)
|
385 |
self.model = self.model.to(self.device_type)
|
386 |
self.model = torch.compile(self.model, mode="reduce-overhead")
|
387 |
-
logger.info(f"Translation model {
|
388 |
|
389 |
class ModelManager:
|
390 |
def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
|
@@ -394,11 +396,11 @@ class ModelManager:
|
|
394 |
self.is_lazy_loading = is_lazy_loading
|
395 |
|
396 |
async def load_model(self, src_lang, tgt_lang, key):
|
397 |
-
logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}
|
398 |
translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
|
399 |
await translate_manager.load()
|
400 |
self.models[key] = translate_manager
|
401 |
-
logger.info(f"Loaded translation model for {key}
|
402 |
|
403 |
def get_model(self, src_lang, tgt_lang):
|
404 |
key = self._get_model_key(src_lang, tgt_lang)
|
@@ -427,14 +429,15 @@ class ASRModelManager:
|
|
427 |
|
428 |
async def load(self):
|
429 |
if not self.model:
|
430 |
-
logger.info("Loading ASR model asynchronously...")
|
|
|
431 |
self.model = await asyncio.to_thread(
|
432 |
AutoModel.from_pretrained,
|
433 |
-
|
434 |
trust_remote_code=True
|
435 |
)
|
436 |
self.model = self.model.to(self.device_type)
|
437 |
-
logger.info("ASR model loaded asynchronously")
|
438 |
|
439 |
# Global Managers
|
440 |
llm_manager = LLMManager(settings.llm_model_name)
|
@@ -505,12 +508,12 @@ async def lifespan(app: FastAPI):
|
|
505 |
translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
|
506 |
|
507 |
await asyncio.gather(*tasks, *translation_tasks)
|
508 |
-
logger.info("All models loaded successfully
|
509 |
except Exception as e:
|
510 |
logger.error(f"Error loading models: {str(e)}")
|
511 |
raise
|
512 |
|
513 |
-
logger.info("Starting asynchronous model loading...")
|
514 |
await load_all_models()
|
515 |
yield
|
516 |
llm_manager.unload()
|
|
|
14 |
from slowapi import Limiter
|
15 |
from slowapi.util import get_remote_address
|
16 |
import torch
|
17 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, AutoModel, BitsAndBytesConfig, Gemma3ForConditionalGeneration
|
18 |
from IndicTransToolkit import IndicProcessor
|
19 |
import json
|
20 |
import asyncio
|
|
|
91 |
async def load(self):
|
92 |
if not self.is_loaded:
|
93 |
try:
|
94 |
+
local_path = "/app/models/llm_model"
|
95 |
self.model = await asyncio.to_thread(
|
96 |
Gemma3ForConditionalGeneration.from_pretrained,
|
97 |
+
local_path,
|
98 |
device_map="auto",
|
99 |
quantization_config=quantization_config,
|
100 |
torch_dtype=self.torch_dtype
|
|
|
102 |
self.model.eval()
|
103 |
self.processor = await asyncio.to_thread(
|
104 |
AutoProcessor.from_pretrained,
|
105 |
+
local_path
|
106 |
)
|
107 |
self.is_loaded = True
|
108 |
+
logger.info(f"LLM loaded from {local_path} on {self.device}")
|
109 |
except Exception as e:
|
110 |
logger.error(f"Failed to load LLM: {str(e)}")
|
111 |
raise
|
|
|
269 |
|
270 |
async def load(self):
|
271 |
if not self.model:
|
272 |
+
logger.info("Loading TTS model from local path asynchronously...")
|
273 |
+
local_path = "/app/models/tts_model"
|
274 |
self.model = await asyncio.to_thread(
|
275 |
AutoModel.from_pretrained,
|
276 |
+
local_path,
|
277 |
trust_remote_code=True
|
278 |
)
|
279 |
self.model = self.model.to(self.device_type)
|
280 |
+
logger.info("TTS model loaded from local path asynchronously")
|
281 |
|
282 |
def synthesize(self, text, ref_audio_path, ref_text):
|
283 |
if not self.model:
|
|
|
364 |
async def load(self):
|
365 |
if not self.tokenizer or not self.model:
|
366 |
if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
|
367 |
+
local_path = "/app/models/trans_en_indic"
|
368 |
elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
|
369 |
+
local_path = "/app/models/trans_indic_en"
|
370 |
elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
|
371 |
+
local_path = "/app/models/trans_indic_indic"
|
372 |
else:
|
373 |
raise ValueError("Invalid language combination")
|
374 |
|
375 |
self.tokenizer = await asyncio.to_thread(
|
376 |
AutoTokenizer.from_pretrained,
|
377 |
+
local_path,
|
378 |
trust_remote_code=True
|
379 |
)
|
380 |
self.model = await asyncio.to_thread(
|
381 |
AutoModelForSeq2SeqLM.from_pretrained,
|
382 |
+
local_path,
|
383 |
trust_remote_code=True,
|
384 |
torch_dtype=torch.float16,
|
385 |
attn_implementation="flash_attention_2"
|
386 |
)
|
387 |
self.model = self.model.to(self.device_type)
|
388 |
self.model = torch.compile(self.model, mode="reduce-overhead")
|
389 |
+
logger.info(f"Translation model loaded from {local_path} asynchronously")
|
390 |
|
391 |
class ModelManager:
|
392 |
def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
|
|
|
396 |
self.is_lazy_loading = is_lazy_loading
|
397 |
|
398 |
async def load_model(self, src_lang, tgt_lang, key):
|
399 |
+
logger.info(f"Loading translation model for {src_lang} -> {tgt_lang} from local path")
|
400 |
translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
|
401 |
await translate_manager.load()
|
402 |
self.models[key] = translate_manager
|
403 |
+
logger.info(f"Loaded translation model for {key} from local path")
|
404 |
|
405 |
def get_model(self, src_lang, tgt_lang):
|
406 |
key = self._get_model_key(src_lang, tgt_lang)
|
|
|
429 |
|
430 |
async def load(self):
|
431 |
if not self.model:
|
432 |
+
logger.info("Loading ASR model from local path asynchronously...")
|
433 |
+
local_path = "/app/models/asr_model"
|
434 |
self.model = await asyncio.to_thread(
|
435 |
AutoModel.from_pretrained,
|
436 |
+
local_path,
|
437 |
trust_remote_code=True
|
438 |
)
|
439 |
self.model = self.model.to(self.device_type)
|
440 |
+
logger.info("ASR model loaded from local path asynchronously")
|
441 |
|
442 |
# Global Managers
|
443 |
llm_manager = LLMManager(settings.llm_model_name)
|
|
|
508 |
translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
|
509 |
|
510 |
await asyncio.gather(*tasks, *translation_tasks)
|
511 |
+
logger.info("All models loaded successfully from local paths")
|
512 |
except Exception as e:
|
513 |
logger.error(f"Error loading models: {str(e)}")
|
514 |
raise
|
515 |
|
516 |
+
logger.info("Starting asynchronous model loading from local paths...")
|
517 |
await load_all_models()
|
518 |
yield
|
519 |
llm_manager.unload()
|