Mansuba commited on
Commit
d47dd8d
·
verified ·
1 Parent(s): e8c275d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -16
app.py CHANGED
@@ -71,7 +71,7 @@ class EnhancedBanglaSDGenerator:
71
  self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
72
  self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
73
 
74
- # Initialize Stable Diffusion
75
  self._initialize_stable_diffusion()
76
 
77
  except Exception as e:
@@ -79,28 +79,38 @@ class EnhancedBanglaSDGenerator:
79
  raise RuntimeError(f"Failed to initialize models: {str(e)}")
80
 
81
  def _initialize_stable_diffusion(self):
82
- """Initialize Stable Diffusion pipeline with optimized settings."""
83
  self.pipe = self.cache.load_model(
84
  "runwayml/stable-diffusion-v1-5",
85
  lambda model_id: StableDiffusionPipeline.from_pretrained(
86
  model_id,
87
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
88
- safety_checker=None
 
 
 
89
  ),
90
  "stable_diffusion"
91
  )
92
 
 
93
  self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
94
  self.pipe.scheduler.config,
95
  use_karras_sigmas=True,
96
  algorithm_type="dpmsolver++"
97
  )
98
- self.pipe = self.pipe.to(self.device)
99
 
100
- # Memory optimization
101
- self.pipe.enable_attention_slicing()
102
- if torch.cuda.is_available():
103
- self.pipe.enable_sequential_cpu_offload()
 
 
 
 
 
 
 
104
 
105
  def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
106
  try:
@@ -175,15 +185,27 @@ class EnhancedBanglaSDGenerator:
175
  enhanced_prompt = self._enhance_prompt(bangla_text)
176
  negative_prompt = self._get_negative_prompt()
177
 
178
- with torch.autocast(self.device.type):
 
 
 
 
 
 
179
  result = self.pipe(
180
  prompt=enhanced_prompt,
181
  negative_prompt=negative_prompt,
182
  num_images_per_prompt=config.num_images,
183
  num_inference_steps=config.num_inference_steps,
184
- guidance_scale=config.guidance_scale
 
 
185
  )
186
 
 
 
 
 
187
  return result.images, enhanced_prompt
188
 
189
  except Exception as e:
@@ -194,12 +216,10 @@ class EnhancedBanglaSDGenerator:
194
  """Enhance prompt with context and style information."""
195
  translated_text = self._translate_text(bangla_text)
196
 
197
- # Gather contexts
198
  contexts = []
199
  contexts.extend(context for loc, context in self.location_contexts.items() if loc in bangla_text)
200
  contexts.extend(context for scene, context in self.scene_contexts.items() if scene in bangla_text)
201
 
202
- # Add photo style
203
  photo_style = [
204
  "professional photography",
205
  "high resolution",
@@ -209,7 +229,6 @@ class EnhancedBanglaSDGenerator:
209
  "beautiful composition"
210
  ]
211
 
212
- # Combine all parts
213
  all_parts = [translated_text] + contexts + photo_style
214
  return ", ".join(dict.fromkeys(all_parts))
215
 
@@ -319,5 +338,4 @@ def create_gradio_interface():
319
 
320
  if __name__ == "__main__":
321
  demo = create_gradio_interface()
322
- # Fixed queue configuration for newer Gradio versions
323
- demo.queue().launch(share=True)
 
71
  self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
72
  self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
73
 
74
+ # Initialize Stable Diffusion with optimizations
75
  self._initialize_stable_diffusion()
76
 
77
  except Exception as e:
 
79
  raise RuntimeError(f"Failed to initialize models: {str(e)}")
80
 
81
  def _initialize_stable_diffusion(self):
82
+ """Initialize Stable Diffusion pipeline with CPU performance optimizations."""
83
  self.pipe = self.cache.load_model(
84
  "runwayml/stable-diffusion-v1-5",
85
  lambda model_id: StableDiffusionPipeline.from_pretrained(
86
  model_id,
87
+ torch_dtype=torch.float32,
88
+ safety_checker=None,
89
+ use_safetensors=True,
90
+ use_memory_efficient_attention=True,
91
+ local_files_only=True
92
  ),
93
  "stable_diffusion"
94
  )
95
 
96
+ # Optimize scheduler
97
  self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
98
  self.pipe.scheduler.config,
99
  use_karras_sigmas=True,
100
  algorithm_type="dpmsolver++"
101
  )
 
102
 
103
+ # CPU optimizations
104
+ self.pipe.enable_attention_slicing(slice_size=1)
105
+ self.pipe.enable_vae_slicing()
106
+ self.pipe.enable_sequential_cpu_offload()
107
+
108
+ # Component-level optimizations
109
+ for component in [self.pipe.text_encoder, self.pipe.vae, self.pipe.unet]:
110
+ if hasattr(component, 'enable_model_cpu_offload'):
111
+ component.enable_model_cpu_offload()
112
+
113
+ self.pipe = self.pipe.to(self.device)
114
 
115
  def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
116
  try:
 
185
  enhanced_prompt = self._enhance_prompt(bangla_text)
186
  negative_prompt = self._get_negative_prompt()
187
 
188
+ # Pre-generation optimization
189
+ torch.set_num_threads(max(4, torch.get_num_threads()))
190
+ gc.collect()
191
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
192
+
193
+ # Memory-optimized generation
194
+ with torch.inference_mode():
195
  result = self.pipe(
196
  prompt=enhanced_prompt,
197
  negative_prompt=negative_prompt,
198
  num_images_per_prompt=config.num_images,
199
  num_inference_steps=config.num_inference_steps,
200
+ guidance_scale=config.guidance_scale,
201
+ use_memory_efficient_attention=True,
202
+ use_memory_efficient_cross_attention=True
203
  )
204
 
205
+ # Post-generation cleanup
206
+ gc.collect()
207
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
208
+
209
  return result.images, enhanced_prompt
210
 
211
  except Exception as e:
 
216
  """Enhance prompt with context and style information."""
217
  translated_text = self._translate_text(bangla_text)
218
 
 
219
  contexts = []
220
  contexts.extend(context for loc, context in self.location_contexts.items() if loc in bangla_text)
221
  contexts.extend(context for scene, context in self.scene_contexts.items() if scene in bangla_text)
222
 
 
223
  photo_style = [
224
  "professional photography",
225
  "high resolution",
 
229
  "beautiful composition"
230
  ]
231
 
 
232
  all_parts = [translated_text] + contexts + photo_style
233
  return ", ".join(dict.fromkeys(all_parts))
234
 
 
338
 
339
  if __name__ == "__main__":
340
  demo = create_gradio_interface()
341
+ demo.queue().launch(share=True)