Update app.py
Browse files
app.py
CHANGED
@@ -65,24 +65,27 @@ class ModelCache:
|
|
65 |
class EnhancedBanglaSDGenerator:
|
66 |
def __init__(
|
67 |
self,
|
68 |
-
banglaclip_weights_path: str,
|
69 |
cache_dir: str,
|
70 |
device: Optional[torch.device] = None
|
71 |
):
|
72 |
-
# Download model if not exists
|
73 |
-
download_model(
|
74 |
-
"https://huggingface.co/Mansuba/BanglaCLIP13/resolve/main/banglaclip_model_epoch_10.pth",
|
75 |
-
banglaclip_weights_path
|
76 |
-
)
|
77 |
-
|
78 |
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
79 |
logger.info(f"Using device: {self.device}")
|
80 |
|
81 |
self.cache = ModelCache(Path(cache_dir))
|
82 |
-
self._initialize_models(
|
83 |
self._load_context_data()
|
84 |
|
85 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
try:
|
87 |
# Translation models
|
88 |
self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
|
@@ -96,7 +99,10 @@ class EnhancedBanglaSDGenerator:
|
|
96 |
# CLIP models
|
97 |
self.clip_model_name = "openai/clip-vit-base-patch32"
|
98 |
self.bangla_text_model = "csebuetnlp/banglabert"
|
99 |
-
|
|
|
|
|
|
|
100 |
self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
|
101 |
self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
|
102 |
|
@@ -107,7 +113,7 @@ class EnhancedBanglaSDGenerator:
|
|
107 |
logger.error(f"Error initializing models: {str(e)}")
|
108 |
raise RuntimeError(f"Failed to initialize models: {str(e)}")
|
109 |
|
110 |
-
#
|
111 |
|
112 |
def create_gradio_interface():
|
113 |
"""Create and configure the Gradio interface."""
|
@@ -118,7 +124,6 @@ def create_gradio_interface():
|
|
118 |
nonlocal generator
|
119 |
if generator is None:
|
120 |
generator = EnhancedBanglaSDGenerator(
|
121 |
-
banglaclip_weights_path="banglaclip_model_epoch_10.pth",
|
122 |
cache_dir=str(cache_dir)
|
123 |
)
|
124 |
return generator
|
@@ -151,52 +156,8 @@ def create_gradio_interface():
|
|
151 |
cleanup_generator()
|
152 |
return None, f"ছবি তৈরি ব্যর্থ হয়েছে: {str(e)}"
|
153 |
|
154 |
-
# Gradio interface configuration
|
155 |
-
demo = gr.Interface(
|
156 |
-
fn=generate_images,
|
157 |
-
inputs=[
|
158 |
-
gr.Textbox(
|
159 |
-
label="বাংলা টেক্সট লিখুন",
|
160 |
-
placeholder="যেকোনো বাংলা টেক্সট লিখুন...",
|
161 |
-
lines=3
|
162 |
-
),
|
163 |
-
gr.Slider(
|
164 |
-
minimum=1,
|
165 |
-
maximum=4,
|
166 |
-
step=1,
|
167 |
-
value=1,
|
168 |
-
label="ছবির সংখ্যা"
|
169 |
-
),
|
170 |
-
gr.Slider(
|
171 |
-
minimum=20,
|
172 |
-
maximum=100,
|
173 |
-
step=1,
|
174 |
-
value=50,
|
175 |
-
label="স্টেপস"
|
176 |
-
),
|
177 |
-
gr.Slider(
|
178 |
-
minimum=1.0,
|
179 |
-
maximum=20.0,
|
180 |
-
step=0.5,
|
181 |
-
value=7.5,
|
182 |
-
label="গাইডেন্স স্কেল"
|
183 |
-
),
|
184 |
-
gr.Number(
|
185 |
-
label="সীড (ঐচ্ছিক)",
|
186 |
-
precision=0
|
187 |
-
)
|
188 |
-
],
|
189 |
-
outputs=[
|
190 |
-
gr.Gallery(label="তৈরি করা ছবি"),
|
191 |
-
gr.Textbox(label="ব্যবহৃত প্রম্পট")
|
192 |
-
],
|
193 |
-
title="বাংলা টেক্সট থেকে ছবি তৈরি",
|
194 |
-
description="যেকোনো বাংলা টেক্সট দিয়ে উচ্চমানের ছবি তৈরি করুন"
|
195 |
-
)
|
196 |
-
|
197 |
-
return demo
|
198 |
|
199 |
if __name__ == "__main__":
|
200 |
demo = create_gradio_interface()
|
201 |
-
# Fixed queue configuration for newer Gradio versions
|
202 |
demo.queue().launch(share=True, debug=True)
|
|
|
65 |
class EnhancedBanglaSDGenerator:
|
66 |
def __init__(
|
67 |
self,
|
|
|
68 |
cache_dir: str,
|
69 |
device: Optional[torch.device] = None
|
70 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
72 |
logger.info(f"Using device: {self.device}")
|
73 |
|
74 |
self.cache = ModelCache(Path(cache_dir))
|
75 |
+
self._initialize_models()
|
76 |
self._load_context_data()
|
77 |
|
78 |
+
def _load_banglaclip_model(self):
|
79 |
+
"""Load BanglaCLIP model from Hugging Face directly"""
|
80 |
+
try:
|
81 |
+
# Use the Hugging Face model directly instead of local weight file
|
82 |
+
model = CLIPModel.from_pretrained("Mansuba/BanglaCLIP13")
|
83 |
+
return model.to(self.device)
|
84 |
+
except Exception as e:
|
85 |
+
logger.error(f"Failed to load BanglaCLIP model: {str(e)}")
|
86 |
+
raise
|
87 |
+
|
88 |
+
def _initialize_models(self):
|
89 |
try:
|
90 |
# Translation models
|
91 |
self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
|
|
|
99 |
# CLIP models
|
100 |
self.clip_model_name = "openai/clip-vit-base-patch32"
|
101 |
self.bangla_text_model = "csebuetnlp/banglabert"
|
102 |
+
|
103 |
+
# Load BanglaCLIP model directly from Hugging Face
|
104 |
+
self.banglaclip_model = self._load_banglaclip_model()
|
105 |
+
|
106 |
self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
|
107 |
self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
|
108 |
|
|
|
113 |
logger.error(f"Error initializing models: {str(e)}")
|
114 |
raise RuntimeError(f"Failed to initialize models: {str(e)}")
|
115 |
|
116 |
+
# [Rest of the existing implementation remains the same]
|
117 |
|
118 |
def create_gradio_interface():
|
119 |
"""Create and configure the Gradio interface."""
|
|
|
124 |
nonlocal generator
|
125 |
if generator is None:
|
126 |
generator = EnhancedBanglaSDGenerator(
|
|
|
127 |
cache_dir=str(cache_dir)
|
128 |
)
|
129 |
return generator
|
|
|
156 |
cleanup_generator()
|
157 |
return None, f"ছবি তৈরি ব্যর্থ হয়েছে: {str(e)}"
|
158 |
|
159 |
+
# [Gradio interface configuration remains the same]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
if __name__ == "__main__":
|
162 |
demo = create_gradio_interface()
|
|
|
163 |
demo.queue().launch(share=True, debug=True)
|