Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
|
2 |
|
3 |
-
|
4 |
import torch
|
5 |
from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
|
6 |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
@@ -105,11 +104,16 @@ class EnhancedBanglaSDGenerator:
|
|
105 |
|
106 |
def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
|
107 |
try:
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
clip_model = CLIPModel.from_pretrained(self.clip_model_name)
|
112 |
-
state_dict = torch.load(
|
113 |
|
114 |
cleaned_state_dict = {
|
115 |
k.replace('module.', '').replace('clip.', ''): v
|
@@ -240,7 +244,7 @@ def create_gradio_interface():
|
|
240 |
nonlocal generator
|
241 |
if generator is None:
|
242 |
generator = EnhancedBanglaSDGenerator(
|
243 |
-
banglaclip_weights_path="banglaclip_model_epoch_10.pth",
|
244 |
cache_dir=str(cache_dir)
|
245 |
)
|
246 |
return generator
|
@@ -320,6 +324,4 @@ def create_gradio_interface():
|
|
320 |
|
321 |
if __name__ == "__main__":
|
322 |
demo = create_gradio_interface()
|
323 |
-
# Fixed queue configuration for newer Gradio versions
|
324 |
demo.queue().launch(share=True)
|
325 |
-
|
|
|
1 |
|
2 |
|
|
|
3 |
import torch
|
4 |
from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
|
5 |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
|
|
104 |
|
105 |
def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
|
106 |
try:
|
107 |
+
model_file_path = Path("model_cache/banglaclip_model_epoch_10.pth")
|
108 |
+
if not model_file_path.exists():
|
109 |
+
logger.info("Downloading BanglaCLIP weights...")
|
110 |
+
torch.hub.download_url_to_file(
|
111 |
+
'https://huggingface.co/Mansuba/Bangla_text_to_image_app/resolve/main/banglaclip_model_epoch_10.pth',
|
112 |
+
model_file_path
|
113 |
+
)
|
114 |
|
115 |
clip_model = CLIPModel.from_pretrained(self.clip_model_name)
|
116 |
+
state_dict = torch.load(model_file_path, map_location=self.device)
|
117 |
|
118 |
cleaned_state_dict = {
|
119 |
k.replace('module.', '').replace('clip.', ''): v
|
|
|
244 |
nonlocal generator
|
245 |
if generator is None:
|
246 |
generator = EnhancedBanglaSDGenerator(
|
247 |
+
banglaclip_weights_path="model_cache/banglaclip_model_epoch_10.pth",
|
248 |
cache_dir=str(cache_dir)
|
249 |
)
|
250 |
return generator
|
|
|
324 |
|
325 |
if __name__ == "__main__":
|
326 |
demo = create_gradio_interface()
|
|
|
327 |
demo.queue().launch(share=True)
|
|