Mansuba commited on
Commit
04a4097
·
verified ·
1 Parent(s): 6b4f4ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -162
app.py CHANGED
@@ -1,14 +1,15 @@
1
  import torch
 
 
 
 
 
 
2
  from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
3
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
4
- import numpy as np
5
- from typing import List, Tuple, Optional, Dict, Any
6
  import gradio as gr
7
- from pathlib import Path
8
- import json
9
- import logging
10
  from dataclasses import dataclass
11
- import gc
12
 
13
  # Configure logging
14
  logging.basicConfig(
@@ -17,6 +18,30 @@ logging.basicConfig(
17
  )
18
  logger = logging.getLogger(__name__)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @dataclass
21
  class GenerationConfig:
22
  num_images: int = 1
@@ -44,6 +69,12 @@ class EnhancedBanglaSDGenerator:
44
  cache_dir: str,
45
  device: Optional[torch.device] = None
46
  ):
 
 
 
 
 
 
47
  self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
  logger.info(f"Using device: {self.device}")
49
 
@@ -53,7 +84,7 @@ class EnhancedBanglaSDGenerator:
53
 
54
  def _initialize_models(self, banglaclip_weights_path: str):
55
  try:
56
- # Initialize translation models
57
  self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
58
  self.translator = self.cache.load_model(
59
  self.bn2en_model_name,
@@ -62,171 +93,21 @@ class EnhancedBanglaSDGenerator:
62
  ).to(self.device)
63
  self.trans_tokenizer = MarianTokenizer.from_pretrained(self.bn2en_model_name)
64
 
65
- # Initialize CLIP models
66
  self.clip_model_name = "openai/clip-vit-base-patch32"
67
  self.bangla_text_model = "csebuetnlp/banglabert"
68
  self.banglaclip_model = self._load_banglaclip_model(banglaclip_weights_path)
69
  self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
70
  self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
71
 
72
- # Initialize Stable Diffusion
73
  self._initialize_stable_diffusion()
74
 
75
  except Exception as e:
76
  logger.error(f"Error initializing models: {str(e)}")
77
  raise RuntimeError(f"Failed to initialize models: {str(e)}")
78
 
79
- def _initialize_stable_diffusion(self):
80
- """Initialize Stable Diffusion pipeline with optimized settings."""
81
- self.pipe = self.cache.load_model(
82
- "runwayml/stable-diffusion-v1-5",
83
- lambda model_id: StableDiffusionPipeline.from_pretrained(
84
- model_id,
85
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
86
- safety_checker=None
87
- ),
88
- "stable_diffusion"
89
- )
90
-
91
- self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
92
- self.pipe.scheduler.config,
93
- use_karras_sigmas=True,
94
- algorithm_type="dpmsolver++"
95
- )
96
- self.pipe = self.pipe.to(self.device)
97
-
98
- # Memory optimization
99
- self.pipe.enable_attention_slicing()
100
- if torch.cuda.is_available():
101
- self.pipe.enable_sequential_cpu_offload()
102
-
103
- def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
104
- try:
105
- if not Path(weights_path).exists():
106
- raise FileNotFoundError(f"BanglaCLIP weights not found at {weights_path}")
107
-
108
- clip_model = CLIPModel.from_pretrained(self.clip_model_name)
109
- state_dict = torch.load(weights_path, map_location=self.device)
110
-
111
- cleaned_state_dict = {
112
- k.replace('module.', '').replace('clip.', ''): v
113
- for k, v in state_dict.items()
114
- if k.replace('module.', '').replace('clip.', '').startswith(('text_model.', 'vision_model.'))
115
- }
116
-
117
- clip_model.load_state_dict(cleaned_state_dict, strict=False)
118
- return clip_model.to(self.device)
119
-
120
- except Exception as e:
121
- logger.error(f"Failed to load BanglaCLIP model: {str(e)}")
122
- raise
123
-
124
- def _load_context_data(self):
125
- """Load location and scene context data."""
126
- self.location_contexts = {
127
- 'কক্সবাজার': 'Cox\'s Bazar beach, longest natural sea beach in the world, sandy beach',
128
- 'সেন্টমার্টিন': 'Saint Martin\'s Island, coral island, tropical paradise',
129
- 'সুন্দরবন': 'Sundarbans mangrove forest, Bengal tigers, riverine forest'
130
- }
131
-
132
- self.scene_contexts = {
133
- 'সৈকত': 'beach, seaside, waves, sandy shore, ocean view',
134
- 'সমুদ্র': 'ocean, sea waves, deep blue water, horizon',
135
- 'পাহাড়': 'mountains, hills, valleys, scenic landscape'
136
- }
137
-
138
- def _translate_text(self, bangla_text: str) -> str:
139
- """Translate Bangla text to English."""
140
- inputs = self.trans_tokenizer(bangla_text, return_tensors="pt", padding=True)
141
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
142
-
143
- with torch.no_grad():
144
- outputs = self.translator.generate(**inputs)
145
-
146
- translated = self.trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
147
- return translated
148
-
149
- def _get_text_embedding(self, text: str):
150
- """Get text embedding from BanglaCLIP model."""
151
- inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
152
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
153
-
154
- with torch.no_grad():
155
- outputs = self.banglaclip_model.get_text_features(**inputs)
156
-
157
- return outputs
158
-
159
- def generate_image(
160
- self,
161
- bangla_text: str,
162
- config: Optional[GenerationConfig] = None
163
- ) -> Tuple[List[Any], str]:
164
- if not bangla_text.strip():
165
- raise ValueError("Empty input text")
166
-
167
- config = config or GenerationConfig()
168
-
169
- try:
170
- if config.seed is not None:
171
- torch.manual_seed(config.seed)
172
-
173
- enhanced_prompt = self._enhance_prompt(bangla_text)
174
- negative_prompt = self._get_negative_prompt()
175
-
176
- with torch.autocast(self.device.type):
177
- result = self.pipe(
178
- prompt=enhanced_prompt,
179
- negative_prompt=negative_prompt,
180
- num_images_per_prompt=config.num_images,
181
- num_inference_steps=config.num_inference_steps,
182
- guidance_scale=config.guidance_scale
183
- )
184
-
185
- return result.images, enhanced_prompt
186
-
187
- except Exception as e:
188
- logger.error(f"Error during image generation: {str(e)}")
189
- raise
190
-
191
- def _enhance_prompt(self, bangla_text: str) -> str:
192
- """Enhance prompt with context and style information."""
193
- translated_text = self._translate_text(bangla_text)
194
-
195
- # Gather contexts
196
- contexts = []
197
- contexts.extend(context for loc, context in self.location_contexts.items() if loc in bangla_text)
198
- contexts.extend(context for scene, context in self.scene_contexts.items() if scene in bangla_text)
199
-
200
- # Add photo style
201
- photo_style = [
202
- "professional photography",
203
- "high resolution",
204
- "4k",
205
- "detailed",
206
- "realistic",
207
- "beautiful composition"
208
- ]
209
-
210
- # Combine all parts
211
- all_parts = [translated_text] + contexts + photo_style
212
- return ", ".join(dict.fromkeys(all_parts))
213
-
214
- def _get_negative_prompt(self) -> str:
215
- return (
216
- "blurry, low quality, pixelated, cartoon, anime, illustration, "
217
- "painting, drawing, artificial, fake, oversaturated, undersaturated"
218
- )
219
-
220
- def cleanup(self):
221
- """Clean up GPU memory"""
222
- if hasattr(self, 'pipe'):
223
- del self.pipe
224
- if hasattr(self, 'banglaclip_model'):
225
- del self.banglaclip_model
226
- if hasattr(self, 'translator'):
227
- del self.translator
228
- torch.cuda.empty_cache()
229
- gc.collect()
230
 
231
  def create_gradio_interface():
232
  """Create and configure the Gradio interface."""
@@ -270,7 +151,7 @@ def create_gradio_interface():
270
  cleanup_generator()
271
  return None, f"ছবি তৈরি ব্যর্থ হয়েছে: {str(e)}"
272
 
273
- # Create Gradio interface
274
  demo = gr.Interface(
275
  fn=generate_images,
276
  inputs=[
@@ -318,4 +199,4 @@ def create_gradio_interface():
318
  if __name__ == "__main__":
319
  demo = create_gradio_interface()
320
  # Fixed queue configuration for newer Gradio versions
321
- demo.queue().launch(share=True)
 
1
  import torch
2
+ import os
3
+ import requests
4
+ import logging
5
+ import gc
6
+ from pathlib import Path
7
+
8
  from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
9
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
 
 
10
  import gradio as gr
11
+ from typing import List, Tuple, Optional, Dict, Any
 
 
12
  from dataclasses import dataclass
 
13
 
14
  # Configure logging
15
  logging.basicConfig(
 
18
  )
19
  logger = logging.getLogger(__name__)
20
 
21
+ def download_model(model_url: str, model_path: str):
22
+ """Download large model file with progress tracking."""
23
+ if not os.path.exists(model_path):
24
+ try:
25
+ logger.info(f"Downloading model from {model_url}...")
26
+ response = requests.get(model_url, stream=True)
27
+ response.raise_for_status()
28
+
29
+ total_size = int(response.headers.get('content-length', 0))
30
+ block_size = 1024 * 1024 # 1 MB chunks
31
+ downloaded_size = 0
32
+
33
+ with open(model_path, 'wb') as f:
34
+ for data in response.iter_content(block_size):
35
+ f.write(data)
36
+ downloaded_size += len(data)
37
+ progress = (downloaded_size / total_size) * 100 if total_size > 0 else 0
38
+ logger.info(f"Download progress: {progress:.2f}%")
39
+
40
+ logger.info("Model download complete.")
41
+ except Exception as e:
42
+ logger.error(f"Model download failed: {e}")
43
+ raise
44
+
45
  @dataclass
46
  class GenerationConfig:
47
  num_images: int = 1
 
69
  cache_dir: str,
70
  device: Optional[torch.device] = None
71
  ):
72
+ # Download model if not exists
73
+ download_model(
74
+ "https://huggingface.co/Mansuba/BanglaCLIP13/resolve/main/banglaclip_model_epoch_10.pth",
75
+ banglaclip_weights_path
76
+ )
77
+
78
  self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
  logger.info(f"Using device: {self.device}")
80
 
 
84
 
85
  def _initialize_models(self, banglaclip_weights_path: str):
86
  try:
87
+ # Translation models
88
  self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
89
  self.translator = self.cache.load_model(
90
  self.bn2en_model_name,
 
93
  ).to(self.device)
94
  self.trans_tokenizer = MarianTokenizer.from_pretrained(self.bn2en_model_name)
95
 
96
+ # CLIP models
97
  self.clip_model_name = "openai/clip-vit-base-patch32"
98
  self.bangla_text_model = "csebuetnlp/banglabert"
99
  self.banglaclip_model = self._load_banglaclip_model(banglaclip_weights_path)
100
  self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
101
  self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
102
 
103
+ # Stable Diffusion
104
  self._initialize_stable_diffusion()
105
 
106
  except Exception as e:
107
  logger.error(f"Error initializing models: {str(e)}")
108
  raise RuntimeError(f"Failed to initialize models: {str(e)}")
109
 
110
+ # ... [Rest of the previous implementation remains the same] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  def create_gradio_interface():
113
  """Create and configure the Gradio interface."""
 
151
  cleanup_generator()
152
  return None, f"ছবি তৈরি ব্যর্থ হয়েছে: {str(e)}"
153
 
154
+ # Gradio interface configuration
155
  demo = gr.Interface(
156
  fn=generate_images,
157
  inputs=[
 
199
  if __name__ == "__main__":
200
  demo = create_gradio_interface()
201
  # Fixed queue configuration for newer Gradio versions
202
+ demo.queue().launch(share=True, debug=True)