Mansuba commited on
Commit
14d0307
·
verified ·
1 Parent(s): 31723cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -57
app.py CHANGED
@@ -65,24 +65,27 @@ class ModelCache:
65
  class EnhancedBanglaSDGenerator:
66
  def __init__(
67
  self,
68
- banglaclip_weights_path: str,
69
  cache_dir: str,
70
  device: Optional[torch.device] = None
71
  ):
72
- # Download model if not exists
73
- download_model(
74
- "https://huggingface.co/Mansuba/BanglaCLIP13/resolve/main/banglaclip_model_epoch_10.pth",
75
- banglaclip_weights_path
76
- )
77
-
78
  self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
  logger.info(f"Using device: {self.device}")
80
 
81
  self.cache = ModelCache(Path(cache_dir))
82
- self._initialize_models(banglaclip_weights_path)
83
  self._load_context_data()
84
 
85
- def _initialize_models(self, banglaclip_weights_path: str):
 
 
 
 
 
 
 
 
 
 
86
  try:
87
  # Translation models
88
  self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
@@ -96,7 +99,10 @@ class EnhancedBanglaSDGenerator:
96
  # CLIP models
97
  self.clip_model_name = "openai/clip-vit-base-patch32"
98
  self.bangla_text_model = "csebuetnlp/banglabert"
99
- self.banglaclip_model = self._load_banglaclip_model(banglaclip_weights_path)
 
 
 
100
  self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
101
  self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
102
 
@@ -107,7 +113,7 @@ class EnhancedBanglaSDGenerator:
107
  logger.error(f"Error initializing models: {str(e)}")
108
  raise RuntimeError(f"Failed to initialize models: {str(e)}")
109
 
110
- # ... [Rest of the previous implementation remains the same] ...
111
 
112
  def create_gradio_interface():
113
  """Create and configure the Gradio interface."""
@@ -118,7 +124,6 @@ def create_gradio_interface():
118
  nonlocal generator
119
  if generator is None:
120
  generator = EnhancedBanglaSDGenerator(
121
- banglaclip_weights_path="banglaclip_model_epoch_10.pth",
122
  cache_dir=str(cache_dir)
123
  )
124
  return generator
@@ -151,52 +156,8 @@ def create_gradio_interface():
151
  cleanup_generator()
152
  return None, f"ছবি তৈরি ব্যর্থ হয়েছে: {str(e)}"
153
 
154
- # Gradio interface configuration
155
- demo = gr.Interface(
156
- fn=generate_images,
157
- inputs=[
158
- gr.Textbox(
159
- label="বাংলা টেক্সট লিখুন",
160
- placeholder="যেকোনো বাংলা টেক্সট লিখুন...",
161
- lines=3
162
- ),
163
- gr.Slider(
164
- minimum=1,
165
- maximum=4,
166
- step=1,
167
- value=1,
168
- label="ছবির সংখ্যা"
169
- ),
170
- gr.Slider(
171
- minimum=20,
172
- maximum=100,
173
- step=1,
174
- value=50,
175
- label="স্টেপস"
176
- ),
177
- gr.Slider(
178
- minimum=1.0,
179
- maximum=20.0,
180
- step=0.5,
181
- value=7.5,
182
- label="গাইডেন্স স্কেল"
183
- ),
184
- gr.Number(
185
- label="সীড (ঐচ্ছিক)",
186
- precision=0
187
- )
188
- ],
189
- outputs=[
190
- gr.Gallery(label="তৈরি করা ছবি"),
191
- gr.Textbox(label="ব্যবহৃত প্রম্পট")
192
- ],
193
- title="বাংলা টেক্সট থেকে ছবি তৈরি",
194
- description="যেকোনো বাংলা টেক্সট দিয়ে উচ্চমানের ছবি তৈরি করুন"
195
- )
196
-
197
- return demo
198
 
199
  if __name__ == "__main__":
200
  demo = create_gradio_interface()
201
- # Fixed queue configuration for newer Gradio versions
202
  demo.queue().launch(share=True, debug=True)
 
65
  class EnhancedBanglaSDGenerator:
66
  def __init__(
67
  self,
 
68
  cache_dir: str,
69
  device: Optional[torch.device] = None
70
  ):
 
 
 
 
 
 
71
  self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
  logger.info(f"Using device: {self.device}")
73
 
74
  self.cache = ModelCache(Path(cache_dir))
75
+ self._initialize_models()
76
  self._load_context_data()
77
 
78
+ def _load_banglaclip_model(self):
79
+ """Load BanglaCLIP model from Hugging Face directly"""
80
+ try:
81
+ # Use the Hugging Face model directly instead of local weight file
82
+ model = CLIPModel.from_pretrained("Mansuba/BanglaCLIP13")
83
+ return model.to(self.device)
84
+ except Exception as e:
85
+ logger.error(f"Failed to load BanglaCLIP model: {str(e)}")
86
+ raise
87
+
88
+ def _initialize_models(self):
89
  try:
90
  # Translation models
91
  self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
 
99
  # CLIP models
100
  self.clip_model_name = "openai/clip-vit-base-patch32"
101
  self.bangla_text_model = "csebuetnlp/banglabert"
102
+
103
+ # Load BanglaCLIP model directly from Hugging Face
104
+ self.banglaclip_model = self._load_banglaclip_model()
105
+
106
  self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
107
  self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
108
 
 
113
  logger.error(f"Error initializing models: {str(e)}")
114
  raise RuntimeError(f"Failed to initialize models: {str(e)}")
115
 
116
+ # [Rest of the existing implementation remains the same]
117
 
118
  def create_gradio_interface():
119
  """Create and configure the Gradio interface."""
 
124
  nonlocal generator
125
  if generator is None:
126
  generator = EnhancedBanglaSDGenerator(
 
127
  cache_dir=str(cache_dir)
128
  )
129
  return generator
 
156
  cleanup_generator()
157
  return None, f"ছবি তৈরি ব্যর্থ হয়েছে: {str(e)}"
158
 
159
+ # [Gradio interface configuration remains the same]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  if __name__ == "__main__":
162
  demo = create_gradio_interface()
 
163
  demo.queue().launch(share=True, debug=True)