Mansuba commited on
Commit
8a74117
·
verified ·
1 Parent(s): 113e1ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -1,6 +1,5 @@
1
 
2
 
3
-
4
  import torch
5
  from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
6
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
@@ -105,11 +104,16 @@ class EnhancedBanglaSDGenerator:
105
 
106
  def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
107
  try:
108
- if not Path(weights_path).exists():
109
- raise FileNotFoundError(f"BanglaCLIP weights not found at {weights_path}")
 
 
 
 
 
110
 
111
  clip_model = CLIPModel.from_pretrained(self.clip_model_name)
112
- state_dict = torch.load(weights_path, map_location=self.device)
113
 
114
  cleaned_state_dict = {
115
  k.replace('module.', '').replace('clip.', ''): v
@@ -240,7 +244,7 @@ def create_gradio_interface():
240
  nonlocal generator
241
  if generator is None:
242
  generator = EnhancedBanglaSDGenerator(
243
- banglaclip_weights_path="banglaclip_model_epoch_10.pth",
244
  cache_dir=str(cache_dir)
245
  )
246
  return generator
@@ -320,6 +324,4 @@ def create_gradio_interface():
320
 
321
  if __name__ == "__main__":
322
  demo = create_gradio_interface()
323
- # Fixed queue configuration for newer Gradio versions
324
  demo.queue().launch(share=True)
325
-
 
1
 
2
 
 
3
  import torch
4
  from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
5
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
 
104
 
105
  def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
106
  try:
107
+ model_file_path = Path("model_cache/banglaclip_model_epoch_10.pth")
108
+ if not model_file_path.exists():
109
+ logger.info("Downloading BanglaCLIP weights...")
110
+ torch.hub.download_url_to_file(
111
+ 'https://huggingface.co/Mansuba/Bangla_text_to_image_app/resolve/main/banglaclip_model_epoch_10.pth',
112
+ model_file_path
113
+ )
114
 
115
  clip_model = CLIPModel.from_pretrained(self.clip_model_name)
116
+ state_dict = torch.load(model_file_path, map_location=self.device)
117
 
118
  cleaned_state_dict = {
119
  k.replace('module.', '').replace('clip.', ''): v
 
244
  nonlocal generator
245
  if generator is None:
246
  generator = EnhancedBanglaSDGenerator(
247
+ banglaclip_weights_path="model_cache/banglaclip_model_epoch_10.pth",
248
  cache_dir=str(cache_dir)
249
  )
250
  return generator
 
324
 
325
  if __name__ == "__main__":
326
  demo = create_gradio_interface()
 
327
  demo.queue().launch(share=True)