m1k3wn commited on
Commit
d551427
·
verified ·
1 Parent(s): 02f1f50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -72,12 +72,24 @@ class ModelManager:
72
  )
73
 
74
  logger.info(f"Loading model {model_name}")
75
- model = T5ForConditionalGeneration.from_pretrained(
76
- model_path,
77
- token=HF_TOKEN,
78
- local_files_only=False,
79
- device_map="auto" # This will handle GPU if available
80
- )
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  cls._instances[model_name] = (model, tokenizer)
83
  logger.info(f"Successfully loaded {model_name}")
 
72
  )
73
 
74
  logger.info(f"Loading model {model_name}")
75
+ # Check if accelerate is available
76
+ try:
77
+ import accelerate
78
+ logger.info("Accelerate package found, using device_map='auto'")
79
+ model = T5ForConditionalGeneration.from_pretrained(
80
+ model_path,
81
+ token=HF_TOKEN,
82
+ local_files_only=False,
83
+ device_map="auto"
84
+ )
85
+ except ImportError:
86
+ logger.warning("Accelerate package not found, falling back to CPU")
87
+ model = T5ForConditionalGeneration.from_pretrained(
88
+ model_path,
89
+ token=HF_TOKEN,
90
+ local_files_only=False
91
+ )
92
+ model = model.cpu()
93
 
94
  cls._instances[model_name] = (model, tokenizer)
95
  logger.info(f"Successfully loaded {model_name}")