sachin commited on
Commit
0fb44d7
·
1 Parent(s): 7fbf9f0
Files changed (1) hide show
  1. src/server/main.py +23 -20
src/server/main.py CHANGED
@@ -14,7 +14,7 @@ from pydantic_settings import BaseSettings
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig, AutoModel, Gemma3ForConditionalGeneration
18
  from IndicTransToolkit import IndicProcessor
19
  import json
20
  import asyncio
@@ -91,9 +91,10 @@ class LLMManager:
91
  async def load(self):
92
  if not self.is_loaded:
93
  try:
 
94
  self.model = await asyncio.to_thread(
95
  Gemma3ForConditionalGeneration.from_pretrained,
96
- self.model_name,
97
  device_map="auto",
98
  quantization_config=quantization_config,
99
  torch_dtype=self.torch_dtype
@@ -101,10 +102,10 @@ class LLMManager:
101
  self.model.eval()
102
  self.processor = await asyncio.to_thread(
103
  AutoProcessor.from_pretrained,
104
- self.model_name
105
  )
106
  self.is_loaded = True
107
- logger.info(f"LLM {self.model_name} loaded asynchronously on {self.device}")
108
  except Exception as e:
109
  logger.error(f"Failed to load LLM: {str(e)}")
110
  raise
@@ -268,14 +269,15 @@ class TTSManager:
268
 
269
  async def load(self):
270
  if not self.model:
271
- logger.info("Loading TTS model IndicF5 asynchronously...")
 
272
  self.model = await asyncio.to_thread(
273
  AutoModel.from_pretrained,
274
- self.repo_id,
275
  trust_remote_code=True
276
  )
277
  self.model = self.model.to(self.device_type)
278
- logger.info("TTS model IndicF5 loaded asynchronously")
279
 
280
  def synthesize(self, text, ref_audio_path, ref_text):
281
  if not self.model:
@@ -362,29 +364,29 @@ class TranslateManager:
362
  async def load(self):
363
  if not self.tokenizer or not self.model:
364
  if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
365
- model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
366
  elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
367
- model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
368
  elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
369
- model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
370
  else:
371
  raise ValueError("Invalid language combination")
372
 
373
  self.tokenizer = await asyncio.to_thread(
374
  AutoTokenizer.from_pretrained,
375
- model_name,
376
  trust_remote_code=True
377
  )
378
  self.model = await asyncio.to_thread(
379
  AutoModelForSeq2SeqLM.from_pretrained,
380
- model_name,
381
  trust_remote_code=True,
382
  torch_dtype=torch.float16,
383
  attn_implementation="flash_attention_2"
384
  )
385
  self.model = self.model.to(self.device_type)
386
  self.model = torch.compile(self.model, mode="reduce-overhead")
387
- logger.info(f"Translation model {model_name} loaded asynchronously")
388
 
389
  class ModelManager:
390
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
@@ -394,11 +396,11 @@ class ModelManager:
394
  self.is_lazy_loading = is_lazy_loading
395
 
396
  async def load_model(self, src_lang, tgt_lang, key):
397
- logger.info(f"Loading translation model for {src_lang} -> {tgt_lang} asynchronously")
398
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
399
  await translate_manager.load()
400
  self.models[key] = translate_manager
401
- logger.info(f"Loaded translation model for {key} asynchronously")
402
 
403
  def get_model(self, src_lang, tgt_lang):
404
  key = self._get_model_key(src_lang, tgt_lang)
@@ -427,14 +429,15 @@ class ASRModelManager:
427
 
428
  async def load(self):
429
  if not self.model:
430
- logger.info("Loading ASR model asynchronously...")
 
431
  self.model = await asyncio.to_thread(
432
  AutoModel.from_pretrained,
433
- "ai4bharat/indic-conformer-600m-multilingual",
434
  trust_remote_code=True
435
  )
436
  self.model = self.model.to(self.device_type)
437
- logger.info("ASR model loaded asynchronously")
438
 
439
  # Global Managers
440
  llm_manager = LLMManager(settings.llm_model_name)
@@ -505,12 +508,12 @@ async def lifespan(app: FastAPI):
505
  translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
506
 
507
  await asyncio.gather(*tasks, *translation_tasks)
508
- logger.info("All models loaded successfully asynchronously")
509
  except Exception as e:
510
  logger.error(f"Error loading models: {str(e)}")
511
  raise
512
 
513
- logger.info("Starting asynchronous model loading...")
514
  await load_all_models()
515
  yield
516
  llm_manager.unload()
 
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, AutoModel, BitsAndBytesConfig, Gemma3ForConditionalGeneration
18
  from IndicTransToolkit import IndicProcessor
19
  import json
20
  import asyncio
 
91
  async def load(self):
92
  if not self.is_loaded:
93
  try:
94
+ local_path = "/app/models/llm_model"
95
  self.model = await asyncio.to_thread(
96
  Gemma3ForConditionalGeneration.from_pretrained,
97
+ local_path,
98
  device_map="auto",
99
  quantization_config=quantization_config,
100
  torch_dtype=self.torch_dtype
 
102
  self.model.eval()
103
  self.processor = await asyncio.to_thread(
104
  AutoProcessor.from_pretrained,
105
+ local_path
106
  )
107
  self.is_loaded = True
108
+ logger.info(f"LLM loaded from {local_path} on {self.device}")
109
  except Exception as e:
110
  logger.error(f"Failed to load LLM: {str(e)}")
111
  raise
 
269
 
270
  async def load(self):
271
  if not self.model:
272
+ logger.info("Loading TTS model from local path asynchronously...")
273
+ local_path = "/app/models/tts_model"
274
  self.model = await asyncio.to_thread(
275
  AutoModel.from_pretrained,
276
+ local_path,
277
  trust_remote_code=True
278
  )
279
  self.model = self.model.to(self.device_type)
280
+ logger.info("TTS model loaded from local path asynchronously")
281
 
282
  def synthesize(self, text, ref_audio_path, ref_text):
283
  if not self.model:
 
364
  async def load(self):
365
  if not self.tokenizer or not self.model:
366
  if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
367
+ local_path = "/app/models/trans_en_indic"
368
  elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
369
+ local_path = "/app/models/trans_indic_en"
370
  elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
371
+ local_path = "/app/models/trans_indic_indic"
372
  else:
373
  raise ValueError("Invalid language combination")
374
 
375
  self.tokenizer = await asyncio.to_thread(
376
  AutoTokenizer.from_pretrained,
377
+ local_path,
378
  trust_remote_code=True
379
  )
380
  self.model = await asyncio.to_thread(
381
  AutoModelForSeq2SeqLM.from_pretrained,
382
+ local_path,
383
  trust_remote_code=True,
384
  torch_dtype=torch.float16,
385
  attn_implementation="flash_attention_2"
386
  )
387
  self.model = self.model.to(self.device_type)
388
  self.model = torch.compile(self.model, mode="reduce-overhead")
389
+ logger.info(f"Translation model loaded from {local_path} asynchronously")
390
 
391
  class ModelManager:
392
  def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
 
396
  self.is_lazy_loading = is_lazy_loading
397
 
398
  async def load_model(self, src_lang, tgt_lang, key):
399
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang} from local path")
400
  translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
401
  await translate_manager.load()
402
  self.models[key] = translate_manager
403
+ logger.info(f"Loaded translation model for {key} from local path")
404
 
405
  def get_model(self, src_lang, tgt_lang):
406
  key = self._get_model_key(src_lang, tgt_lang)
 
429
 
430
  async def load(self):
431
  if not self.model:
432
+ logger.info("Loading ASR model from local path asynchronously...")
433
+ local_path = "/app/models/asr_model"
434
  self.model = await asyncio.to_thread(
435
  AutoModel.from_pretrained,
436
+ local_path,
437
  trust_remote_code=True
438
  )
439
  self.model = self.model.to(self.device_type)
440
+ logger.info("ASR model loaded from local path asynchronously")
441
 
442
  # Global Managers
443
  llm_manager = LLMManager(settings.llm_model_name)
 
508
  translation_tasks.append(model_manager.load_model(src_lang, tgt_lang, key))
509
 
510
  await asyncio.gather(*tasks, *translation_tasks)
511
+ logger.info("All models loaded successfully from local paths")
512
  except Exception as e:
513
  logger.error(f"Error loading models: {str(e)}")
514
  raise
515
 
516
+ logger.info("Starting asynchronous model loading from local paths...")
517
  await load_all_models()
518
  yield
519
  llm_manager.unload()