Sergidev commited on
Commit
05e7c02
·
verified ·
1 Parent(s): ab3f785
Files changed (1) hide show
  1. app.py +72 -290
app.py CHANGED
@@ -38,34 +38,34 @@ torch.backends.cudnn.benchmark = False
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
 
41
 
42
- def load_pipeline(model_name):
43
- vae = AutoencoderKL.from_pretrained(
44
- "madebyollin/sdxl-vae-fp16-fix",
45
- torch_dtype=torch.float16,
46
- )
47
- pipeline = (
48
- StableDiffusionXLPipeline.from_single_file
49
- if MODEL.endswith(".safetensors")
50
- else StableDiffusionXLPipeline.from_pretrained
51
- )
52
-
53
- pipe = pipeline(
54
- model_name,
55
- vae=vae,
56
- torch_dtype=torch.float16,
57
- custom_pipeline="lpw_stable_diffusion_xl",
58
- use_safetensors=True,
59
- add_watermarker=False,
60
- use_auth_token=HF_TOKEN,
61
- variant="fp16",
62
- )
63
-
64
- pipe.to(device)
65
- return pipe
66
-
67
 
68
- @spaces.GPU
69
  def generate(
70
  prompt: str,
71
  negative_prompt: str = "",
@@ -81,283 +81,65 @@ def generate(
81
  upscale_by: float = 1.5,
82
  progress=gr.Progress(track_tqdm=True),
83
  ) -> Image:
84
- generator = utils.seed_everything(seed)
85
-
86
- width, height = utils.aspect_ratio_handler(
87
- aspect_ratio_selector,
88
- custom_width,
89
- custom_height,
90
- )
91
 
92
- width, height = utils.preprocess_image_dimensions(width, height)
 
 
 
93
 
94
- backup_scheduler = pipe.scheduler
95
- pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
96
 
97
- if use_upscaler:
98
- upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
99
- metadata = {
100
- "prompt": prompt,
101
- "negative_prompt": negative_prompt,
102
- "resolution": f"{width} x {height}",
103
- "guidance_scale": guidance_scale,
104
- "num_inference_steps": num_inference_steps,
105
- "seed": seed,
106
- "sampler": sampler,
107
- }
108
-
109
- if use_upscaler:
110
- new_width = int(width * upscale_by)
111
- new_height = int(height * upscale_by)
112
- metadata["use_upscaler"] = {
113
- "upscale_method": "nearest-exact",
114
- "upscaler_strength": upscaler_strength,
115
- "upscale_by": upscale_by,
116
- "new_resolution": f"{new_width} x {new_height}",
117
- }
118
- else:
119
- metadata["use_upscaler"] = None
120
- logger.info(json.dumps(metadata, indent=4))
121
-
122
- try:
123
- if use_upscaler:
124
- latents = pipe(
125
- prompt=prompt,
126
- negative_prompt=negative_prompt,
127
- width=width,
128
- height=height,
129
- guidance_scale=guidance_scale,
130
- num_inference_steps=num_inference_steps,
131
- generator=generator,
132
- output_type="latent",
133
- ).images
134
- upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
135
- images = upscaler_pipe(
136
- prompt=prompt,
137
- negative_prompt=negative_prompt,
138
- image=upscaled_latents,
139
- guidance_scale=guidance_scale,
140
- num_inference_steps=num_inference_steps,
141
- strength=upscaler_strength,
142
- generator=generator,
143
- output_type="pil",
144
- ).images
145
- else:
146
- images = pipe(
147
- prompt=prompt,
148
- negative_prompt=negative_prompt,
149
- width=width,
150
- height=height,
151
- guidance_scale=guidance_scale,
152
- num_inference_steps=num_inference_steps,
153
- generator=generator,
154
- output_type="pil",
155
- ).images
156
 
157
- if images and IS_COLAB:
158
- for image in images:
159
- filepath = utils.save_image(image, metadata, OUTPUT_DIR)
160
- logger.info(f"Image saved as {filepath} with metadata")
161
 
162
- return images, metadata
163
- except Exception as e:
164
- logger.exception(f"An error occurred: {e}")
165
- raise
166
- finally:
167
- if use_upscaler:
168
- del upscaler_pipe
169
- pipe.scheduler = backup_scheduler
170
- utils.free_memory()
171
 
 
 
 
172
 
173
- if torch.cuda.is_available():
174
- pipe = load_pipeline(MODEL)
175
- logger.info("Loaded on Device!")
176
- else:
177
- pipe = None
178
 
179
- with gr.Blocks(css="style.css") as demo:
180
- title = gr.HTML(
181
- f"""<h1><span>{DESCRIPTION}</span></h1>""",
182
- elem_id="title",
183
- )
184
- gr.Markdown(
185
- f"""Gradio demo for [Pony Diffusion V6](https://civitai.com/models/257749/pony-diffusion-v6-xl/)""",
186
- elem_id="subtitle",
187
- )
188
- gr.DuplicateButton(
189
- value="Duplicate Space for private use",
190
- elem_id="duplicate-button",
191
- visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
192
  )
193
- with gr.Group():
194
- with gr.Row():
195
- prompt = gr.Text(
196
- label="Prompt",
197
- show_label=False,
198
- max_lines=5,
199
- placeholder="Enter your prompt",
200
- container=False,
201
- )
202
- run_button = gr.Button(
203
- "Generate",
204
- variant="primary",
205
- scale=0
206
- )
207
- result = gr.Gallery(
208
- label="Result",
209
- columns=1,
210
- preview=True,
211
- show_label=False
212
- )
213
- with gr.Accordion(label="Advanced Settings", open=False):
214
- negative_prompt = gr.Text(
215
- label="Negative Prompt",
216
- max_lines=5,
217
- placeholder="Enter a negative prompt",
218
- value=""
219
- )
220
- aspect_ratio_selector = gr.Radio(
221
- label="Aspect Ratio",
222
- choices=config.aspect_ratios,
223
- value="1024 x 1024",
224
- container=True,
225
- )
226
- with gr.Group(visible=False) as custom_resolution:
227
- with gr.Row():
228
- custom_width = gr.Slider(
229
- label="Width",
230
- minimum=MIN_IMAGE_SIZE,
231
- maximum=MAX_IMAGE_SIZE,
232
- step=8,
233
- value=1024,
234
- )
235
- custom_height = gr.Slider(
236
- label="Height",
237
- minimum=MIN_IMAGE_SIZE,
238
- maximum=MAX_IMAGE_SIZE,
239
- step=8,
240
- value=1024,
241
- )
242
- use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
243
- with gr.Row() as upscaler_row:
244
- upscaler_strength = gr.Slider(
245
- label="Strength",
246
- minimum=0,
247
- maximum=1,
248
- step=0.05,
249
- value=0.55,
250
- visible=False,
251
- )
252
- upscale_by = gr.Slider(
253
- label="Upscale by",
254
- minimum=1,
255
- maximum=1.5,
256
- step=0.1,
257
- value=1.5,
258
- visible=False,
259
- )
260
 
261
- sampler = gr.Dropdown(
262
- label="Sampler",
263
- choices=config.sampler_list,
264
- interactive=True,
265
- value="DPM++ 2M SDE Karras",
266
- )
267
- with gr.Row():
268
- seed = gr.Slider(
269
- label="Seed", minimum=0, maximum=utils.MAX_SEED, step=1, value=0
270
- )
271
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
272
- with gr.Group():
273
- with gr.Row():
274
- guidance_scale = gr.Slider(
275
- label="Guidance scale",
276
- minimum=1,
277
- maximum=12,
278
- step=0.1,
279
- value=7.0,
280
- )
281
- num_inference_steps = gr.Slider(
282
- label="Number of inference steps",
283
- minimum=1,
284
- maximum=50,
285
- step=1,
286
- value=28,
287
- )
288
- with gr.Accordion(label="Generation Parameters", open=False):
289
- gr_metadata = gr.JSON(label="Metadata", show_label=False)
290
- gr.Examples(
291
- examples=config.examples,
292
- inputs=prompt,
293
- outputs=[result, gr_metadata],
294
- fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
295
- cache_examples=CACHE_EXAMPLES,
296
  )
297
- use_upscaler.change(
298
- fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
299
- inputs=use_upscaler,
300
- outputs=[upscaler_strength, upscale_by],
301
- queue=False,
302
- api_name=False,
303
  )
304
- aspect_ratio_selector.change(
305
- fn=lambda x: gr.update(visible=x == "Custom"),
306
- inputs=aspect_ratio_selector,
307
- outputs=custom_resolution,
308
- queue=False,
309
- api_name=False,
310
  )
311
 
312
- inputs = [
313
- prompt,
314
- negative_prompt,
315
- seed,
316
- custom_width,
317
- custom_height,
318
- guidance_scale,
319
- num_inference_steps,
320
- sampler,
321
- aspect_ratio_selector,
322
- use_upscaler,
323
- upscaler_strength,
324
- upscale_by,
325
- ]
326
 
327
- prompt.submit(
328
- fn=utils.randomize_seed_fn,
329
- inputs=[seed, randomize_seed],
330
- outputs=seed,
331
- queue=False,
332
- api_name=False,
333
- ).then(
334
- fn=generate,
335
- inputs=inputs,
336
- outputs=result,
337
- api_name="run",
338
- )
339
- negative_prompt.submit(
340
- fn=utils.randomize_seed_fn,
341
- inputs=[seed, randomize_seed],
342
- outputs=seed,
343
- queue=False,
344
- api_name=False,
345
- ).then(
346
- fn=generate,
347
- inputs=inputs,
348
- outputs=result,
349
- api_name=False,
350
- )
351
- run_button.click(
352
- fn=utils.randomize_seed_fn,
353
- inputs=[seed, randomize_seed],
354
- outputs=seed,
355
- queue=False,
356
- api_name=False,
357
- ).then(
358
- fn=generate,
359
- inputs=inputs,
360
- outputs=[result, gr_metadata],
361
- api_name=False,
362
- )
363
  demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)
 
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
41
+ # Load pipeline function remains unchanged
42
 
43
+ def parse_json_parameters(json_str):
44
+ try:
45
+ params = json.loads(json_str)
46
+ return params
47
+ except json.JSONDecodeError:
48
+ return None
49
+
50
+ def apply_json_parameters(json_str):
51
+ params = parse_json_parameters(json_str)
52
+ if params:
53
+ return (
54
+ params.get("prompt", ""),
55
+ params.get("negative_prompt", ""),
56
+ params.get("seed", 0),
57
+ params.get("width", 1024),
58
+ params.get("height", 1024),
59
+ params.get("guidance_scale", 7.0),
60
+ params.get("num_inference_steps", 30),
61
+ params.get("sampler", "DPM++ 2M SDE Karras"),
62
+ params.get("aspect_ratio", "1024 x 1024"),
63
+ params.get("use_upscaler", False),
64
+ params.get("upscaler_strength", 0.55),
65
+ params.get("upscale_by", 1.5),
66
+ )
67
+ return [gr.update()] * 12
68
 
 
69
  def generate(
70
  prompt: str,
71
  negative_prompt: str = "",
 
81
  upscale_by: float = 1.5,
82
  progress=gr.Progress(track_tqdm=True),
83
  ) -> Image:
84
+ # Existing generate function code...
 
 
 
 
 
 
85
 
86
+ # Update history after generation
87
+ history = gr.get_state("history") or []
88
+ history.insert(0, {"prompt": prompt, "image": images[0], "metadata": metadata})
89
+ gr.set_state("history", history[:10]) # Keep only the last 10 entries
90
 
91
+ return images, metadata, gr.update(choices=[h["prompt"] for h in history])
 
92
 
93
+ def get_random_prompt():
94
+ return random.choice(config.examples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ with gr.Blocks(css="style.css") as demo:
97
+ # Existing UI elements...
 
 
98
 
99
+ with gr.Accordion(label="JSON Parameters", open=False):
100
+ json_input = gr.TextArea(label="Input JSON parameters")
101
+ apply_json_button = gr.Button("Apply JSON Parameters")
 
 
 
 
 
 
102
 
103
+ with gr.Row():
104
+ clear_button = gr.Button("Clear All")
105
+ random_prompt_button = gr.Button("Random Prompt")
106
 
107
+ history_dropdown = gr.Dropdown(label="Generation History", choices=[], interactive=True)
 
 
 
 
108
 
109
+ # Connect components
110
+ apply_json_button.click(
111
+ fn=apply_json_parameters,
112
+ inputs=json_input,
113
+ outputs=[prompt, negative_prompt, seed, custom_width, custom_height,
114
+ guidance_scale, num_inference_steps, sampler,
115
+ aspect_ratio_selector, use_upscaler, upscaler_strength, upscale_by]
 
 
 
 
 
 
116
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ clear_button.click(
119
+ fn=lambda: (gr.update(value=""), gr.update(value=""), gr.update(value=0),
120
+ gr.update(value=1024), gr.update(value=1024),
121
+ gr.update(value=7.0), gr.update(value=30),
122
+ gr.update(value="DPM++ 2M SDE Karras"),
123
+ gr.update(value="1024 x 1024"), gr.update(value=False),
124
+ gr.update(value=0.55), gr.update(value=1.5)),
125
+ inputs=[],
126
+ outputs=[prompt, negative_prompt, seed, custom_width, custom_height,
127
+ guidance_scale, num_inference_steps, sampler,
128
+ aspect_ratio_selector, use_upscaler, upscaler_strength, upscale_by]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  )
130
+
131
+ random_prompt_button.click(
132
+ fn=get_random_prompt,
133
+ inputs=[],
134
+ outputs=prompt
 
135
  )
136
+
137
+ history_dropdown.change(
138
+ fn=lambda x: gr.update(value=x),
139
+ inputs=history_dropdown,
140
+ outputs=prompt
 
141
  )
142
 
143
+ # Existing event handlers...
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)