Mansuba commited on
Commit
b180e8e
·
verified ·
1 Parent(s): 184cc4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -61
app.py CHANGED
@@ -1,13 +1,17 @@
 
 
 
1
  import torch
2
- import os
3
- import requests
4
- import logging
5
- from pathlib import Path
6
  from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
7
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
 
 
8
  import gradio as gr
9
- from typing import List, Tuple, Optional, Any
 
 
10
  from dataclasses import dataclass
 
11
 
12
  # Configure logging
13
  logging.basicConfig(
@@ -16,30 +20,6 @@ logging.basicConfig(
16
  )
17
  logger = logging.getLogger(__name__)
18
 
19
- def download_model(model_url: str, model_path: str):
20
- """Download large model file with progress tracking."""
21
- if not os.path.exists(model_path):
22
- try:
23
- logger.info(f"Downloading model from {model_url}...")
24
- response = requests.get(model_url, stream=True)
25
- response.raise_for_status()
26
-
27
- total_size = int(response.headers.get('content-length', 0))
28
- block_size = 1024 * 1024 # 1 MB chunks
29
- downloaded_size = 0
30
-
31
- with open(model_path, 'wb') as f:
32
- for data in response.iter_content(block_size):
33
- f.write(data)
34
- downloaded_size += len(data)
35
- progress = (downloaded_size / total_size) * 100 if total_size > 0 else 0
36
- logger.info(f"Download progress: {progress:.2f}%")
37
-
38
- logger.info("Model download complete.")
39
- except Exception as e:
40
- logger.error(f"Model download failed: {e}")
41
- raise
42
-
43
  @dataclass
44
  class GenerationConfig:
45
  num_images: int = 1
@@ -63,6 +43,7 @@ class ModelCache:
63
  class EnhancedBanglaSDGenerator:
64
  def __init__(
65
  self,
 
66
  cache_dir: str,
67
  device: Optional[torch.device] = None
68
  ):
@@ -70,21 +51,12 @@ class EnhancedBanglaSDGenerator:
70
  logger.info(f"Using device: {self.device}")
71
 
72
  self.cache = ModelCache(Path(cache_dir))
73
- self._initialize_models()
74
  self._load_context_data()
75
 
76
- def _load_banglaclip_model(self):
77
- """Load BanglaCLIP model from Hugging Face directly"""
78
- try:
79
- model = CLIPModel.from_pretrained("Mansuba/BanglaCLIP13")
80
- return model.to(self.device)
81
- except Exception as e:
82
- logger.error(f"Failed to load BanglaCLIP model: {str(e)}")
83
- raise
84
-
85
- def _initialize_models(self):
86
  try:
87
- # Translation models
88
  self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
89
  self.translator = self.cache.load_model(
90
  self.bn2en_model_name,
@@ -93,17 +65,14 @@ class EnhancedBanglaSDGenerator:
93
  ).to(self.device)
94
  self.trans_tokenizer = MarianTokenizer.from_pretrained(self.bn2en_model_name)
95
 
96
- # CLIP models
97
  self.clip_model_name = "openai/clip-vit-base-patch32"
98
  self.bangla_text_model = "csebuetnlp/banglabert"
99
-
100
- # Load BanglaCLIP model directly from Hugging Face
101
- self.banglaclip_model = self._load_banglaclip_model()
102
-
103
  self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
104
  self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
105
 
106
- # Stable Diffusion
107
  self._initialize_stable_diffusion()
108
 
109
  except Exception as e:
@@ -111,8 +80,156 @@ class EnhancedBanglaSDGenerator:
111
  raise RuntimeError(f"Failed to initialize models: {str(e)}")
112
 
113
  def _initialize_stable_diffusion(self):
114
- """Load and initialize Stable Diffusion pipeline"""
115
- pass # Your existing code for initializing Stable Diffusion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  def create_gradio_interface():
118
  """Create and configure the Gradio interface."""
@@ -123,6 +240,7 @@ def create_gradio_interface():
123
  nonlocal generator
124
  if generator is None:
125
  generator = EnhancedBanglaSDGenerator(
 
126
  cache_dir=str(cache_dir)
127
  )
128
  return generator
@@ -155,20 +273,53 @@ def create_gradio_interface():
155
  cleanup_generator()
156
  return None, f"ছবি তৈরি ব্যর্থ হয়েছে: {str(e)}"
157
 
158
- with gr.Blocks() as demo:
159
- text_input = gr.Textbox(label="Text", placeholder="Enter your prompt here...")
160
- num_images_input = gr.Slider(minimum=1, maximum=5, value=1, label="Number of Images")
161
- steps_input = gr.Slider(minimum=1, maximum=100, value=50, label="Steps")
162
- guidance_scale_input = gr.Slider(minimum=1, maximum=20, value=7.5, label="Guidance Scale")
163
- seed_input = gr.Number(label="Seed", optional=True)
164
-
165
- output_images = gr.Gallery(label="Generated Images")
166
-
167
- generate_button = gr.Button("Generate Images")
168
- generate_button.click(generate_images, inputs=[text_input, num_images_input, steps_input, guidance_scale_input, seed_input], outputs=[output_images])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  return demo
171
 
172
  if __name__ == "__main__":
173
  demo = create_gradio_interface()
174
- demo.queue().launch(share=True, debug=True)
 
 
 
1
+ !wget -O banglaclip_model_epoch_10.pth https://huggingface.co/Mansuba/BanglaCLIP13/resolve/main/banglaclip_model_epoch_10.pth
2
+
3
+
4
  import torch
 
 
 
 
5
  from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
6
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
7
+ import numpy as np
8
+ from typing import List, Tuple, Optional, Dict, Any
9
  import gradio as gr
10
+ from pathlib import Path
11
+ import json
12
+ import logging
13
  from dataclasses import dataclass
14
+ import gc
15
 
16
  # Configure logging
17
  logging.basicConfig(
 
20
  )
21
  logger = logging.getLogger(__name__)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  @dataclass
24
  class GenerationConfig:
25
  num_images: int = 1
 
43
  class EnhancedBanglaSDGenerator:
44
  def __init__(
45
  self,
46
+ banglaclip_weights_path: str,
47
  cache_dir: str,
48
  device: Optional[torch.device] = None
49
  ):
 
51
  logger.info(f"Using device: {self.device}")
52
 
53
  self.cache = ModelCache(Path(cache_dir))
54
+ self._initialize_models(banglaclip_weights_path)
55
  self._load_context_data()
56
 
57
+ def _initialize_models(self, banglaclip_weights_path: str):
 
 
 
 
 
 
 
 
 
58
  try:
59
+ # Initialize translation models
60
  self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
61
  self.translator = self.cache.load_model(
62
  self.bn2en_model_name,
 
65
  ).to(self.device)
66
  self.trans_tokenizer = MarianTokenizer.from_pretrained(self.bn2en_model_name)
67
 
68
+ # Initialize CLIP models
69
  self.clip_model_name = "openai/clip-vit-base-patch32"
70
  self.bangla_text_model = "csebuetnlp/banglabert"
71
+ self.banglaclip_model = self._load_banglaclip_model(banglaclip_weights_path)
 
 
 
72
  self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
73
  self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
74
 
75
+ # Initialize Stable Diffusion
76
  self._initialize_stable_diffusion()
77
 
78
  except Exception as e:
 
80
  raise RuntimeError(f"Failed to initialize models: {str(e)}")
81
 
82
  def _initialize_stable_diffusion(self):
83
+ """Initialize Stable Diffusion pipeline with optimized settings."""
84
+ self.pipe = self.cache.load_model(
85
+ "runwayml/stable-diffusion-v1-5",
86
+ lambda model_id: StableDiffusionPipeline.from_pretrained(
87
+ model_id,
88
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
89
+ safety_checker=None
90
+ ),
91
+ "stable_diffusion"
92
+ )
93
+
94
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
95
+ self.pipe.scheduler.config,
96
+ use_karras_sigmas=True,
97
+ algorithm_type="dpmsolver++"
98
+ )
99
+ self.pipe = self.pipe.to(self.device)
100
+
101
+ # Memory optimization
102
+ self.pipe.enable_attention_slicing()
103
+ if torch.cuda.is_available():
104
+ self.pipe.enable_sequential_cpu_offload()
105
+
106
+ def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
107
+ try:
108
+ if not Path(weights_path).exists():
109
+ raise FileNotFoundError(f"BanglaCLIP weights not found at {weights_path}")
110
+
111
+ clip_model = CLIPModel.from_pretrained(self.clip_model_name)
112
+ state_dict = torch.load(weights_path, map_location=self.device)
113
+
114
+ cleaned_state_dict = {
115
+ k.replace('module.', '').replace('clip.', ''): v
116
+ for k, v in state_dict.items()
117
+ if k.replace('module.', '').replace('clip.', '').startswith(('text_model.', 'vision_model.'))
118
+ }
119
+
120
+ clip_model.load_state_dict(cleaned_state_dict, strict=False)
121
+ return clip_model.to(self.device)
122
+
123
+ except Exception as e:
124
+ logger.error(f"Failed to load BanglaCLIP model: {str(e)}")
125
+ raise
126
+
127
+ def _load_context_data(self):
128
+ """Load location and scene context data."""
129
+ self.location_contexts = {
130
+ 'কক্সবাজার': 'Cox\'s Bazar beach, longest natural sea beach in the world, sandy beach',
131
+ 'সেন্টমার্টিন': 'Saint Martin\'s Island, coral island, tropical paradise',
132
+ 'সুন্দরবন': 'Sundarbans mangrove forest, Bengal tigers, riverine forest'
133
+ }
134
+
135
+ self.scene_contexts = {
136
+ 'সৈকত': 'beach, seaside, waves, sandy shore, ocean view',
137
+ 'সমুদ্র': 'ocean, sea waves, deep blue water, horizon',
138
+ 'পাহাড়': 'mountains, hills, valleys, scenic landscape'
139
+ }
140
+
141
+ def _translate_text(self, bangla_text: str) -> str:
142
+ """Translate Bangla text to English."""
143
+ inputs = self.trans_tokenizer(bangla_text, return_tensors="pt", padding=True)
144
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
145
+
146
+ with torch.no_grad():
147
+ outputs = self.translator.generate(**inputs)
148
+
149
+ translated = self.trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
150
+ return translated
151
+
152
+ def _get_text_embedding(self, text: str):
153
+ """Get text embedding from BanglaCLIP model."""
154
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
155
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
156
+
157
+ with torch.no_grad():
158
+ outputs = self.banglaclip_model.get_text_features(**inputs)
159
+
160
+ return outputs
161
+
162
+ def generate_image(
163
+ self,
164
+ bangla_text: str,
165
+ config: Optional[GenerationConfig] = None
166
+ ) -> Tuple[List[Any], str]:
167
+ if not bangla_text.strip():
168
+ raise ValueError("Empty input text")
169
+
170
+ config = config or GenerationConfig()
171
+
172
+ try:
173
+ if config.seed is not None:
174
+ torch.manual_seed(config.seed)
175
+
176
+ enhanced_prompt = self._enhance_prompt(bangla_text)
177
+ negative_prompt = self._get_negative_prompt()
178
+
179
+ with torch.autocast(self.device.type):
180
+ result = self.pipe(
181
+ prompt=enhanced_prompt,
182
+ negative_prompt=negative_prompt,
183
+ num_images_per_prompt=config.num_images,
184
+ num_inference_steps=config.num_inference_steps,
185
+ guidance_scale=config.guidance_scale
186
+ )
187
+
188
+ return result.images, enhanced_prompt
189
+
190
+ except Exception as e:
191
+ logger.error(f"Error during image generation: {str(e)}")
192
+ raise
193
+
194
+ def _enhance_prompt(self, bangla_text: str) -> str:
195
+ """Enhance prompt with context and style information."""
196
+ translated_text = self._translate_text(bangla_text)
197
+
198
+ # Gather contexts
199
+ contexts = []
200
+ contexts.extend(context for loc, context in self.location_contexts.items() if loc in bangla_text)
201
+ contexts.extend(context for scene, context in self.scene_contexts.items() if scene in bangla_text)
202
+
203
+ # Add photo style
204
+ photo_style = [
205
+ "professional photography",
206
+ "high resolution",
207
+ "4k",
208
+ "detailed",
209
+ "realistic",
210
+ "beautiful composition"
211
+ ]
212
+
213
+ # Combine all parts
214
+ all_parts = [translated_text] + contexts + photo_style
215
+ return ", ".join(dict.fromkeys(all_parts))
216
+
217
+ def _get_negative_prompt(self) -> str:
218
+ return (
219
+ "blurry, low quality, pixelated, cartoon, anime, illustration, "
220
+ "painting, drawing, artificial, fake, oversaturated, undersaturated"
221
+ )
222
+
223
+ def cleanup(self):
224
+ """Clean up GPU memory"""
225
+ if hasattr(self, 'pipe'):
226
+ del self.pipe
227
+ if hasattr(self, 'banglaclip_model'):
228
+ del self.banglaclip_model
229
+ if hasattr(self, 'translator'):
230
+ del self.translator
231
+ torch.cuda.empty_cache()
232
+ gc.collect()
233
 
234
  def create_gradio_interface():
235
  """Create and configure the Gradio interface."""
 
240
  nonlocal generator
241
  if generator is None:
242
  generator = EnhancedBanglaSDGenerator(
243
+ banglaclip_weights_path="banglaclip_model_epoch_10.pth",
244
  cache_dir=str(cache_dir)
245
  )
246
  return generator
 
273
  cleanup_generator()
274
  return None, f"ছবি তৈরি ব্যর্থ হয়েছে: {str(e)}"
275
 
276
+ # Create Gradio interface
277
+ demo = gr.Interface(
278
+ fn=generate_images,
279
+ inputs=[
280
+ gr.Textbox(
281
+ label="বাংলা টেক্সট লিখুন",
282
+ placeholder="যেকোনো বাংলা টেক্সট লিখুন...",
283
+ lines=3
284
+ ),
285
+ gr.Slider(
286
+ minimum=1,
287
+ maximum=4,
288
+ step=1,
289
+ value=1,
290
+ label="ছবির সংখ্যা"
291
+ ),
292
+ gr.Slider(
293
+ minimum=20,
294
+ maximum=100,
295
+ step=1,
296
+ value=50,
297
+ label="স্টেপস"
298
+ ),
299
+ gr.Slider(
300
+ minimum=1.0,
301
+ maximum=20.0,
302
+ step=0.5,
303
+ value=7.5,
304
+ label="গাইডেন্স স্কেল"
305
+ ),
306
+ gr.Number(
307
+ label="সীড (ঐচ্ছিক)",
308
+ precision=0
309
+ )
310
+ ],
311
+ outputs=[
312
+ gr.Gallery(label="তৈরি করা ছবি"),
313
+ gr.Textbox(label="ব্যবহৃত প্রম্পট")
314
+ ],
315
+ title="বাংলা টেক্সট থেকে ছবি তৈরি",
316
+ description="যেকোনো বাংলা টেক্সট দিয়ে উচ্চমানের ছবি তৈরি করুন"
317
+ )
318
 
319
  return demo
320
 
321
  if __name__ == "__main__":
322
  demo = create_gradio_interface()
323
+ # Fixed queue configuration for newer Gradio versions
324
+ demo.queue().launch(share=True)
325
+