Mansuba commited on
Commit
6b4f4ab
·
verified ·
1 Parent(s): fd0bc72

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +321 -0
app.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
3
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
4
+ import numpy as np
5
+ from typing import List, Tuple, Optional, Dict, Any
6
+ import gradio as gr
7
+ from pathlib import Path
8
+ import json
9
+ import logging
10
+ from dataclasses import dataclass
11
+ import gc
12
+
13
+ # Configure logging
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+
20
+ @dataclass
21
+ class GenerationConfig:
22
+ num_images: int = 1
23
+ num_inference_steps: int = 50
24
+ guidance_scale: float = 7.5
25
+ seed: Optional[int] = None
26
+
27
+ class ModelCache:
28
+ def __init__(self, cache_dir: Path):
29
+ self.cache_dir = cache_dir
30
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
31
+
32
+ def load_model(self, model_id: str, load_func: callable, cache_name: str) -> Any:
33
+ try:
34
+ logger.info(f"Loading {cache_name}")
35
+ return load_func(model_id)
36
+ except Exception as e:
37
+ logger.error(f"Error loading model {cache_name}: {str(e)}")
38
+ raise
39
+
40
+ class EnhancedBanglaSDGenerator:
41
+ def __init__(
42
+ self,
43
+ banglaclip_weights_path: str,
44
+ cache_dir: str,
45
+ device: Optional[torch.device] = None
46
+ ):
47
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ logger.info(f"Using device: {self.device}")
49
+
50
+ self.cache = ModelCache(Path(cache_dir))
51
+ self._initialize_models(banglaclip_weights_path)
52
+ self._load_context_data()
53
+
54
+ def _initialize_models(self, banglaclip_weights_path: str):
55
+ try:
56
+ # Initialize translation models
57
+ self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
58
+ self.translator = self.cache.load_model(
59
+ self.bn2en_model_name,
60
+ MarianMTModel.from_pretrained,
61
+ "translator"
62
+ ).to(self.device)
63
+ self.trans_tokenizer = MarianTokenizer.from_pretrained(self.bn2en_model_name)
64
+
65
+ # Initialize CLIP models
66
+ self.clip_model_name = "openai/clip-vit-base-patch32"
67
+ self.bangla_text_model = "csebuetnlp/banglabert"
68
+ self.banglaclip_model = self._load_banglaclip_model(banglaclip_weights_path)
69
+ self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
70
+ self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
71
+
72
+ # Initialize Stable Diffusion
73
+ self._initialize_stable_diffusion()
74
+
75
+ except Exception as e:
76
+ logger.error(f"Error initializing models: {str(e)}")
77
+ raise RuntimeError(f"Failed to initialize models: {str(e)}")
78
+
79
+ def _initialize_stable_diffusion(self):
80
+ """Initialize Stable Diffusion pipeline with optimized settings."""
81
+ self.pipe = self.cache.load_model(
82
+ "runwayml/stable-diffusion-v1-5",
83
+ lambda model_id: StableDiffusionPipeline.from_pretrained(
84
+ model_id,
85
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
86
+ safety_checker=None
87
+ ),
88
+ "stable_diffusion"
89
+ )
90
+
91
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
92
+ self.pipe.scheduler.config,
93
+ use_karras_sigmas=True,
94
+ algorithm_type="dpmsolver++"
95
+ )
96
+ self.pipe = self.pipe.to(self.device)
97
+
98
+ # Memory optimization
99
+ self.pipe.enable_attention_slicing()
100
+ if torch.cuda.is_available():
101
+ self.pipe.enable_sequential_cpu_offload()
102
+
103
+ def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
104
+ try:
105
+ if not Path(weights_path).exists():
106
+ raise FileNotFoundError(f"BanglaCLIP weights not found at {weights_path}")
107
+
108
+ clip_model = CLIPModel.from_pretrained(self.clip_model_name)
109
+ state_dict = torch.load(weights_path, map_location=self.device)
110
+
111
+ cleaned_state_dict = {
112
+ k.replace('module.', '').replace('clip.', ''): v
113
+ for k, v in state_dict.items()
114
+ if k.replace('module.', '').replace('clip.', '').startswith(('text_model.', 'vision_model.'))
115
+ }
116
+
117
+ clip_model.load_state_dict(cleaned_state_dict, strict=False)
118
+ return clip_model.to(self.device)
119
+
120
+ except Exception as e:
121
+ logger.error(f"Failed to load BanglaCLIP model: {str(e)}")
122
+ raise
123
+
124
+ def _load_context_data(self):
125
+ """Load location and scene context data."""
126
+ self.location_contexts = {
127
+ 'কক্সবাজার': 'Cox\'s Bazar beach, longest natural sea beach in the world, sandy beach',
128
+ 'সেন্টমার্টিন': 'Saint Martin\'s Island, coral island, tropical paradise',
129
+ 'সুন্দরবন': 'Sundarbans mangrove forest, Bengal tigers, riverine forest'
130
+ }
131
+
132
+ self.scene_contexts = {
133
+ 'সৈকত': 'beach, seaside, waves, sandy shore, ocean view',
134
+ 'সমুদ্র': 'ocean, sea waves, deep blue water, horizon',
135
+ 'পাহাড়': 'mountains, hills, valleys, scenic landscape'
136
+ }
137
+
138
+ def _translate_text(self, bangla_text: str) -> str:
139
+ """Translate Bangla text to English."""
140
+ inputs = self.trans_tokenizer(bangla_text, return_tensors="pt", padding=True)
141
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
142
+
143
+ with torch.no_grad():
144
+ outputs = self.translator.generate(**inputs)
145
+
146
+ translated = self.trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
147
+ return translated
148
+
149
+ def _get_text_embedding(self, text: str):
150
+ """Get text embedding from BanglaCLIP model."""
151
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
152
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
153
+
154
+ with torch.no_grad():
155
+ outputs = self.banglaclip_model.get_text_features(**inputs)
156
+
157
+ return outputs
158
+
159
+ def generate_image(
160
+ self,
161
+ bangla_text: str,
162
+ config: Optional[GenerationConfig] = None
163
+ ) -> Tuple[List[Any], str]:
164
+ if not bangla_text.strip():
165
+ raise ValueError("Empty input text")
166
+
167
+ config = config or GenerationConfig()
168
+
169
+ try:
170
+ if config.seed is not None:
171
+ torch.manual_seed(config.seed)
172
+
173
+ enhanced_prompt = self._enhance_prompt(bangla_text)
174
+ negative_prompt = self._get_negative_prompt()
175
+
176
+ with torch.autocast(self.device.type):
177
+ result = self.pipe(
178
+ prompt=enhanced_prompt,
179
+ negative_prompt=negative_prompt,
180
+ num_images_per_prompt=config.num_images,
181
+ num_inference_steps=config.num_inference_steps,
182
+ guidance_scale=config.guidance_scale
183
+ )
184
+
185
+ return result.images, enhanced_prompt
186
+
187
+ except Exception as e:
188
+ logger.error(f"Error during image generation: {str(e)}")
189
+ raise
190
+
191
+ def _enhance_prompt(self, bangla_text: str) -> str:
192
+ """Enhance prompt with context and style information."""
193
+ translated_text = self._translate_text(bangla_text)
194
+
195
+ # Gather contexts
196
+ contexts = []
197
+ contexts.extend(context for loc, context in self.location_contexts.items() if loc in bangla_text)
198
+ contexts.extend(context for scene, context in self.scene_contexts.items() if scene in bangla_text)
199
+
200
+ # Add photo style
201
+ photo_style = [
202
+ "professional photography",
203
+ "high resolution",
204
+ "4k",
205
+ "detailed",
206
+ "realistic",
207
+ "beautiful composition"
208
+ ]
209
+
210
+ # Combine all parts
211
+ all_parts = [translated_text] + contexts + photo_style
212
+ return ", ".join(dict.fromkeys(all_parts))
213
+
214
+ def _get_negative_prompt(self) -> str:
215
+ return (
216
+ "blurry, low quality, pixelated, cartoon, anime, illustration, "
217
+ "painting, drawing, artificial, fake, oversaturated, undersaturated"
218
+ )
219
+
220
+ def cleanup(self):
221
+ """Clean up GPU memory"""
222
+ if hasattr(self, 'pipe'):
223
+ del self.pipe
224
+ if hasattr(self, 'banglaclip_model'):
225
+ del self.banglaclip_model
226
+ if hasattr(self, 'translator'):
227
+ del self.translator
228
+ torch.cuda.empty_cache()
229
+ gc.collect()
230
+
231
+ def create_gradio_interface():
232
+ """Create and configure the Gradio interface."""
233
+ cache_dir = Path("model_cache")
234
+ generator = None
235
+
236
+ def initialize_generator():
237
+ nonlocal generator
238
+ if generator is None:
239
+ generator = EnhancedBanglaSDGenerator(
240
+ banglaclip_weights_path="banglaclip_model_epoch_10.pth",
241
+ cache_dir=str(cache_dir)
242
+ )
243
+ return generator
244
+
245
+ def cleanup_generator():
246
+ nonlocal generator
247
+ if generator is not None:
248
+ generator.cleanup()
249
+ generator = None
250
+
251
+ def generate_images(text: str, num_images: int, steps: int, guidance_scale: float, seed: Optional[int]) -> Tuple[List[Any], str]:
252
+ if not text.strip():
253
+ return None, "দয়া করে কিছু টেক্সট লিখুন"
254
+
255
+ try:
256
+ gen = initialize_generator()
257
+ config = GenerationConfig(
258
+ num_images=int(num_images),
259
+ num_inference_steps=int(steps),
260
+ guidance_scale=float(guidance_scale),
261
+ seed=int(seed) if seed else None
262
+ )
263
+
264
+ images, prompt = gen.generate_image(text, config)
265
+ cleanup_generator()
266
+ return images, prompt
267
+
268
+ except Exception as e:
269
+ logger.error(f"Error in Gradio interface: {str(e)}")
270
+ cleanup_generator()
271
+ return None, f"ছবি তৈরি ব্যর্থ হয়েছে: {str(e)}"
272
+
273
+ # Create Gradio interface
274
+ demo = gr.Interface(
275
+ fn=generate_images,
276
+ inputs=[
277
+ gr.Textbox(
278
+ label="বাংলা টেক্সট লিখুন",
279
+ placeholder="যেকোনো বাংলা টেক্সট লিখুন...",
280
+ lines=3
281
+ ),
282
+ gr.Slider(
283
+ minimum=1,
284
+ maximum=4,
285
+ step=1,
286
+ value=1,
287
+ label="ছবির সংখ্যা"
288
+ ),
289
+ gr.Slider(
290
+ minimum=20,
291
+ maximum=100,
292
+ step=1,
293
+ value=50,
294
+ label="স্টেপস"
295
+ ),
296
+ gr.Slider(
297
+ minimum=1.0,
298
+ maximum=20.0,
299
+ step=0.5,
300
+ value=7.5,
301
+ label="গাইডেন্স স্কেল"
302
+ ),
303
+ gr.Number(
304
+ label="সীড (ঐচ্ছিক)",
305
+ precision=0
306
+ )
307
+ ],
308
+ outputs=[
309
+ gr.Gallery(label="তৈরি করা ছবি"),
310
+ gr.Textbox(label="ব্যবহৃত প্রম্পট")
311
+ ],
312
+ title="বাংলা টেক্সট থেকে ছবি তৈরি",
313
+ description="যেকোনো বাংলা টেক্সট দিয়ে উচ্চমানের ছবি তৈরি করুন"
314
+ )
315
+
316
+ return demo
317
+
318
+ if __name__ == "__main__":
319
+ demo = create_gradio_interface()
320
+ # Fixed queue configuration for newer Gradio versions
321
+ demo.queue().launch(share=True)