ajsbsd commited on
Commit
6f168c5
Β·
1 Parent(s): 452b964
Files changed (2) hide show
  1. app.py +287 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ from diffusers import StableDiffusionXLImg2ImgPipeline
5
+ from diffusers.utils import load_image
6
+ from PIL import Image
7
+ from PIL.PngImagePlugin import PngInfo
8
+ import json
9
+ import gradio as gr
10
+ import tempfile
11
+
12
+ # Set environment variable to reduce memory fragmentation
13
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
14
+
15
+ # Check if CUDA is available, fallback to CPU
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
18
+
19
+ # Load pipeline with error handling for HF Spaces
20
+ try:
21
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
22
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
23
+ torch_dtype=torch_dtype,
24
+ variant="fp16" if device == "cuda" else None,
25
+ use_safetensors=True
26
+ )
27
+
28
+ # Move to device
29
+ pipe = pipe.to(device)
30
+
31
+ # Enable optimizations based on available hardware
32
+ if device == "cuda":
33
+ # Use CPU offloading to reduce VRAM usage on GPU
34
+ pipe.enable_model_cpu_offload()
35
+
36
+ # Try to enable memory efficient attention
37
+ try:
38
+ pipe.enable_xformers_memory_efficient_attention()
39
+ except (ModuleNotFoundError, ImportError):
40
+ print("xformers not available, using attention slicing")
41
+ pipe.enable_attention_slicing()
42
+ else:
43
+ # For CPU inference, enable attention slicing
44
+ pipe.enable_attention_slicing()
45
+
46
+ except Exception as e:
47
+ print(f"Error loading pipeline: {e}")
48
+ pipe = None
49
+
50
+
51
+ def img2img(
52
+ uploaded_image,
53
+ image_url: str,
54
+ prompt: str,
55
+ negative_prompt: str = "",
56
+ strength: float = 0.7,
57
+ guidance_scale: float = 3.5,
58
+ num_inference_steps: int = 50,
59
+ seed: int = -1,
60
+ ):
61
+ if pipe is None:
62
+ return None, "❌ Model failed to load. Please try again later.", None
63
+
64
+ try:
65
+ # Choose image source
66
+ if uploaded_image is not None:
67
+ init_image = Image.open(uploaded_image).convert("RGB")
68
+ elif image_url.strip() != "":
69
+ try:
70
+ init_image = load_image(image_url).convert("RGB")
71
+ except Exception as e:
72
+ return None, f"❌ Failed to load image from URL: {str(e)}", None
73
+ else:
74
+ return None, "❌ Please upload an image or enter a valid URL", None
75
+
76
+ # Resize image (keeping aspect ratio consideration for better results)
77
+ init_image.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
78
+
79
+ # Ensure dimensions are multiples of 8 for SDXL
80
+ width, height = init_image.size
81
+ width = (width // 8) * 8
82
+ height = (height // 8) * 8
83
+ init_image = init_image.resize((width, height))
84
+
85
+ # Set seed and generator
86
+ if seed == -1:
87
+ generator = torch.Generator(device=device)
88
+ else:
89
+ generator = torch.Generator(device=device).manual_seed(seed)
90
+
91
+ # Validate inputs
92
+ if not prompt.strip():
93
+ return None, "❌ Please enter a prompt", None
94
+
95
+ # Run inference with progress tracking
96
+ with torch.inference_mode():
97
+ result = pipe(
98
+ prompt=prompt,
99
+ negative_prompt=negative_prompt if negative_prompt.strip() else None,
100
+ image=init_image,
101
+ strength=max(0.1, min(1.0, strength)), # Clamp strength
102
+ guidance_scale=max(1.0, min(20.0, guidance_scale)), # Clamp guidance
103
+ num_inference_steps=max(10, min(100, num_inference_steps)), # Clamp steps
104
+ generator=generator
105
+ ).images[0]
106
+
107
+ used_seed = generator.initial_seed()
108
+
109
+ # Create metadata dictionary
110
+ metadata = {
111
+ "prompt": prompt,
112
+ "negative_prompt": negative_prompt,
113
+ "seed": used_seed,
114
+ "model": "stabilityai/stable-diffusion-xl-refiner-1.0",
115
+ "pipeline": "StableDiffusionXLImg2ImgPipeline",
116
+ "guidance_scale": guidance_scale,
117
+ "strength": strength,
118
+ "steps": num_inference_steps,
119
+ "width": result.width,
120
+ "height": result.height,
121
+ "device": device
122
+ }
123
+
124
+ # Save metadata into PNG
125
+ png_info = PngInfo()
126
+ png_info.add_text("parameters", json.dumps(metadata))
127
+
128
+ # Use temporary file for HF Spaces
129
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
130
+ output_path = tmp_file.name
131
+ result.save(output_path, format="PNG", pnginfo=png_info)
132
+
133
+ # Build markdown preview of metadata
134
+ metadata_str = (
135
+ f"**Prompt:** {metadata['prompt']}\n\n"
136
+ f"**Negative Prompt:** {metadata['negative_prompt']}\n\n"
137
+ f"**Seed:** {metadata['seed']}\n\n"
138
+ f"**Model:** {metadata['model']}\n\n"
139
+ f"**Guidance Scale:** {metadata['guidance_scale']}\n\n"
140
+ f"**Strength:** {metadata['strength']}\n\n"
141
+ f"**Steps:** {metadata['steps']}\n\n"
142
+ f"**Dimensions:** {metadata['width']}x{metadata['height']}\n\n"
143
+ f"**Device:** {metadata['device']}"
144
+ )
145
+
146
+ return output_path, f"βœ… **Generation Complete!**\n\n{metadata_str}", output_path
147
+
148
+ except torch.cuda.OutOfMemoryError:
149
+ return None, "❌ GPU out of memory. Try reducing image size or inference steps.", None
150
+ except Exception as e:
151
+ return None, f"❌ Error during generation: {str(e)}", None
152
+
153
+
154
+ # Define UI components with better styling
155
+ title = "🎨 SDXL Image-to-Image Editor"
156
+ description = """
157
+ Transform your images with AI! Upload an image and describe the changes you want to make.
158
+
159
+ **Tips:**
160
+ - Use detailed prompts for better results
161
+ - Lower strength values preserve more of the original image
162
+ - Higher guidance scale follows your prompt more closely
163
+ """
164
+
165
+ # Custom CSS for better appearance
166
+ css = """
167
+ .gradio-container {
168
+ font-family: 'IBM Plex Sans', sans-serif;
169
+ }
170
+ .gr-button {
171
+ color: white;
172
+ background: linear-gradient(90deg, #4f46e5, #7c3aed);
173
+ border: none;
174
+ }
175
+ .gr-button:hover {
176
+ background: linear-gradient(90deg, #4338ca, #6d28d9);
177
+ }
178
+ """
179
+
180
+ with gr.Blocks(title=title, css=css, theme=gr.themes.Soft()) as demo:
181
+ gr.Markdown(f"# {title}")
182
+ gr.Markdown(description)
183
+
184
+ with gr.Row():
185
+ with gr.Column(scale=1):
186
+ gr.Markdown("### πŸ“Έ Input Image")
187
+ uploaded_image = gr.Image(
188
+ label="Upload Image",
189
+ type="filepath",
190
+ height=300
191
+ )
192
+
193
+ gr.Markdown("**Or**")
194
+ image_url = gr.Textbox(
195
+ label="Image URL",
196
+ placeholder="https://example.com/image.jpg",
197
+ info="Paste a direct link to an image"
198
+ )
199
+
200
+ gr.Markdown("### ✍️ Prompts")
201
+ prompt = gr.Textbox(
202
+ label="Prompt",
203
+ placeholder="a beautiful sunset over mountains, photorealistic, detailed",
204
+ lines=3,
205
+ info="Describe what you want to see"
206
+ )
207
+ negative_prompt = gr.Textbox(
208
+ label="Negative Prompt",
209
+ placeholder="blurry, low quality, distorted",
210
+ lines=2,
211
+ info="What to avoid in the image"
212
+ )
213
+
214
+ gr.Markdown("### βš™οΈ Settings")
215
+ with gr.Row():
216
+ strength = gr.Slider(
217
+ minimum=0.1, maximum=1.0, value=0.7, step=0.05,
218
+ label="Transformation Strength",
219
+ info="0.1 = subtle changes, 1.0 = major changes"
220
+ )
221
+ guidance_scale = gr.Slider(
222
+ minimum=1.0, maximum=20.0, value=7.5, step=0.5,
223
+ label="Guidance Scale",
224
+ info="How closely to follow the prompt"
225
+ )
226
+
227
+ with gr.Row():
228
+ num_inference_steps = gr.Slider(
229
+ minimum=10, maximum=50, step=5, value=30,
230
+ label="Quality Steps",
231
+ info="More steps = higher quality but slower"
232
+ )
233
+ seed = gr.Slider(
234
+ minimum=-1, maximum=999999, step=1, value=-1,
235
+ label="Seed",
236
+ info="-1 for random"
237
+ )
238
+
239
+ submit_btn = gr.Button("πŸš€ Generate Image", variant="primary", size="lg")
240
+
241
+ with gr.Column(scale=1):
242
+ gr.Markdown("### πŸ–ΌοΈ Result")
243
+ image_output = gr.Image(label="Generated Image", height=400)
244
+ download_button = gr.File(label="πŸ“₯ Download Full Resolution", visible=False)
245
+
246
+ gr.Markdown("### πŸ“Š Generation Details")
247
+ metadata_output = gr.Markdown()
248
+
249
+ # Event handlers
250
+ submit_btn.click(
251
+ fn=img2img,
252
+ inputs=[
253
+ uploaded_image,
254
+ image_url,
255
+ prompt,
256
+ negative_prompt,
257
+ strength,
258
+ guidance_scale,
259
+ num_inference_steps,
260
+ seed
261
+ ],
262
+ outputs=[image_output, metadata_output, download_button]
263
+ ).then(
264
+ lambda x: gr.update(visible=x is not None),
265
+ inputs=[image_output],
266
+ outputs=[download_button]
267
+ )
268
+
269
+ # Examples
270
+ gr.Markdown("### 🎯 Examples")
271
+ gr.Examples(
272
+ examples=[
273
+ ["", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png", "make it a van gogh painting", "blurry, low quality", 0.8, 7.5, 30, 42],
274
+ ["", "", "turn into a cyberpunk cityscape", "blurry, distorted", 0.9, 8.0, 30, 123],
275
+ ],
276
+ inputs=[uploaded_image, image_url, prompt, negative_prompt, strength, guidance_scale, num_inference_steps, seed],
277
+ )
278
+
279
+ # Launch configuration for HF Spaces
280
+ if __name__ == "__main__":
281
+ demo.queue(max_size=20) # Enable queuing for better performance
282
+ demo.launch(
283
+ show_error=True,
284
+ share=False, # Don't create gradio.live links in HF Spaces
285
+ inbrowser=False, # Don't try to open browser in cloud environment
286
+ quiet=False
287
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ diffusers>=0.24.0
3
+ transformers>=4.25.0
4
+ accelerate>=0.20.0
5
+ gradio>=4.0.0
6
+ Pillow>=9.0.0
7
+ xformers>=0.0.20