tejani commited on
Commit
5b16c73
·
verified ·
1 Parent(s): fa5c881

Upload 75 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __init__.py +0 -0
  2. app.py +554 -0
  3. app_settings.py +124 -0
  4. backend/__init__.py +0 -0
  5. backend/annotators/canny_control.py +15 -0
  6. backend/annotators/control_interface.py +12 -0
  7. backend/annotators/depth_control.py +15 -0
  8. backend/annotators/image_control_factory.py +31 -0
  9. backend/annotators/lineart_control.py +11 -0
  10. backend/annotators/mlsd_control.py +10 -0
  11. backend/annotators/normal_control.py +10 -0
  12. backend/annotators/pose_control.py +10 -0
  13. backend/annotators/shuffle_control.py +10 -0
  14. backend/annotators/softedge_control.py +10 -0
  15. backend/api/mcp_server.py +97 -0
  16. backend/api/models/response.py +16 -0
  17. backend/api/web.py +112 -0
  18. backend/base64_image.py +21 -0
  19. backend/controlnet.py +90 -0
  20. backend/device.py +23 -0
  21. backend/gguf/gguf_diffusion.py +319 -0
  22. backend/gguf/sdcpp_types.py +104 -0
  23. backend/image_saver.py +75 -0
  24. backend/lcm_text_to_image.py +577 -0
  25. backend/lora.py +136 -0
  26. backend/models/device.py +9 -0
  27. backend/models/gen_images.py +17 -0
  28. backend/models/lcmdiffusion_setting.py +76 -0
  29. backend/models/upscale.py +9 -0
  30. backend/openvino/custom_ov_model_vae_decoder.py +21 -0
  31. backend/openvino/flux_pipeline.py +36 -0
  32. backend/openvino/ov_hc_stablediffusion_pipeline.py +93 -0
  33. backend/openvino/ovflux.py +675 -0
  34. backend/openvino/pipelines.py +75 -0
  35. backend/openvino/stable_diffusion_engine.py +1817 -0
  36. backend/pipelines/lcm.py +122 -0
  37. backend/pipelines/lcm_lora.py +81 -0
  38. backend/tiny_decoder.py +32 -0
  39. backend/upscale/aura_sr.py +1004 -0
  40. backend/upscale/aura_sr_upscale.py +9 -0
  41. backend/upscale/edsr_upscale_onnx.py +37 -0
  42. backend/upscale/tiled_upscale.py +237 -0
  43. backend/upscale/upscaler.py +52 -0
  44. configs/lcm-lora-models.txt +4 -0
  45. configs/lcm-models.txt +8 -0
  46. configs/openvino-lcm-models.txt +9 -0
  47. configs/stable-diffusion-models.txt +7 -0
  48. constants.py +25 -0
  49. context.py +85 -0
  50. frontend/cli_interactive.py +661 -0
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from argparse import ArgumentParser
3
+
4
+ from PIL import Image
5
+
6
+ import constants
7
+ from backend.controlnet import controlnet_settings_from_dict
8
+ from backend.device import get_device_name
9
+ from backend.models.gen_images import ImageFormat
10
+ from backend.models.lcmdiffusion_setting import DiffusionTask
11
+ from backend.upscale.tiled_upscale import generate_upscaled_image
12
+ from constants import APP_VERSION, DEVICE
13
+ from frontend.webui.image_variations_ui import generate_image_variations
14
+ from models.interface_types import InterfaceType
15
+ from paths import FastStableDiffusionPaths, ensure_path
16
+ from state import get_context, get_settings
17
+ from utils import show_system_info
18
+
19
+ parser = ArgumentParser(description=f"FAST SD CPU {constants.APP_VERSION}")
20
+ parser.add_argument(
21
+ "-s",
22
+ "--share",
23
+ action="store_true",
24
+ help="Create sharable link(Web UI)",
25
+ required=False,
26
+ )
27
+ group = parser.add_mutually_exclusive_group(required=False)
28
+ group.add_argument(
29
+ "-g",
30
+ "--gui",
31
+ action="store_true",
32
+ help="Start desktop GUI",
33
+ )
34
+ group.add_argument(
35
+ "-w",
36
+ "--webui",
37
+ action="store_true",
38
+ help="Start Web UI",
39
+ )
40
+ group.add_argument(
41
+ "-a",
42
+ "--api",
43
+ action="store_true",
44
+ help="Start Web API server",
45
+ )
46
+ group.add_argument(
47
+ "-m",
48
+ "--mcp",
49
+ action="store_true",
50
+ help="Start MCP(Model Context Protocol) server",
51
+ )
52
+ group.add_argument(
53
+ "-r",
54
+ "--realtime",
55
+ action="store_true",
56
+ help="Start realtime inference UI(experimental)",
57
+ )
58
+ group.add_argument(
59
+ "-v",
60
+ "--version",
61
+ action="store_true",
62
+ help="Version",
63
+ )
64
+
65
+ parser.add_argument(
66
+ "-b",
67
+ "--benchmark",
68
+ action="store_true",
69
+ help="Run inference benchmark on the selected device",
70
+ )
71
+ parser.add_argument(
72
+ "--lcm_model_id",
73
+ type=str,
74
+ help="Model ID or path,Default stabilityai/sd-turbo",
75
+ default="stabilityai/sd-turbo",
76
+ )
77
+ parser.add_argument(
78
+ "--openvino_lcm_model_id",
79
+ type=str,
80
+ help="OpenVINO Model ID or path,Default rupeshs/sd-turbo-openvino",
81
+ default="rupeshs/sd-turbo-openvino",
82
+ )
83
+ parser.add_argument(
84
+ "--prompt",
85
+ type=str,
86
+ help="Describe the image you want to generate",
87
+ default="",
88
+ )
89
+ parser.add_argument(
90
+ "--negative_prompt",
91
+ type=str,
92
+ help="Describe what you want to exclude from the generation",
93
+ default="",
94
+ )
95
+ parser.add_argument(
96
+ "--image_height",
97
+ type=int,
98
+ help="Height of the image",
99
+ default=512,
100
+ )
101
+ parser.add_argument(
102
+ "--image_width",
103
+ type=int,
104
+ help="Width of the image",
105
+ default=512,
106
+ )
107
+ parser.add_argument(
108
+ "--inference_steps",
109
+ type=int,
110
+ help="Number of steps,default : 1",
111
+ default=1,
112
+ )
113
+ parser.add_argument(
114
+ "--guidance_scale",
115
+ type=float,
116
+ help="Guidance scale,default : 1.0",
117
+ default=1.0,
118
+ )
119
+
120
+ parser.add_argument(
121
+ "--number_of_images",
122
+ type=int,
123
+ help="Number of images to generate ,default : 1",
124
+ default=1,
125
+ )
126
+ parser.add_argument(
127
+ "--seed",
128
+ type=int,
129
+ help="Seed,default : -1 (disabled) ",
130
+ default=-1,
131
+ )
132
+ parser.add_argument(
133
+ "--use_openvino",
134
+ action="store_true",
135
+ help="Use OpenVINO model",
136
+ )
137
+
138
+ parser.add_argument(
139
+ "--use_offline_model",
140
+ action="store_true",
141
+ help="Use offline model",
142
+ )
143
+ parser.add_argument(
144
+ "--clip_skip",
145
+ type=int,
146
+ help="CLIP Skip (1-12), default : 1 (disabled) ",
147
+ default=1,
148
+ )
149
+ parser.add_argument(
150
+ "--token_merging",
151
+ type=float,
152
+ help="Token merging scale, 0.0 - 1.0, default : 0.0",
153
+ default=0.0,
154
+ )
155
+
156
+ parser.add_argument(
157
+ "--use_safety_checker",
158
+ action="store_true",
159
+ help="Use safety checker",
160
+ )
161
+ parser.add_argument(
162
+ "--use_lcm_lora",
163
+ action="store_true",
164
+ help="Use LCM-LoRA",
165
+ )
166
+ parser.add_argument(
167
+ "--base_model_id",
168
+ type=str,
169
+ help="LCM LoRA base model ID,Default Lykon/dreamshaper-8",
170
+ default="Lykon/dreamshaper-8",
171
+ )
172
+ parser.add_argument(
173
+ "--lcm_lora_id",
174
+ type=str,
175
+ help="LCM LoRA model ID,Default latent-consistency/lcm-lora-sdv1-5",
176
+ default="latent-consistency/lcm-lora-sdv1-5",
177
+ )
178
+ parser.add_argument(
179
+ "-i",
180
+ "--interactive",
181
+ action="store_true",
182
+ help="Interactive CLI mode",
183
+ )
184
+ parser.add_argument(
185
+ "-t",
186
+ "--use_tiny_auto_encoder",
187
+ action="store_true",
188
+ help="Use tiny auto encoder for SD (TAESD)",
189
+ )
190
+ parser.add_argument(
191
+ "-f",
192
+ "--file",
193
+ type=str,
194
+ help="Input image for img2img mode",
195
+ default="",
196
+ )
197
+ parser.add_argument(
198
+ "--img2img",
199
+ action="store_true",
200
+ help="img2img mode; requires input file via -f argument",
201
+ )
202
+ parser.add_argument(
203
+ "--batch_count",
204
+ type=int,
205
+ help="Number of sequential generations",
206
+ default=1,
207
+ )
208
+ parser.add_argument(
209
+ "--strength",
210
+ type=float,
211
+ help="Denoising strength for img2img and Image variations",
212
+ default=0.3,
213
+ )
214
+ parser.add_argument(
215
+ "--sdupscale",
216
+ action="store_true",
217
+ help="Tiled SD upscale,works only for the resolution 512x512,(2x upscale)",
218
+ )
219
+ parser.add_argument(
220
+ "--upscale",
221
+ action="store_true",
222
+ help="EDSR SD upscale ",
223
+ )
224
+ parser.add_argument(
225
+ "--custom_settings",
226
+ type=str,
227
+ help="JSON file containing custom generation settings",
228
+ default=None,
229
+ )
230
+ parser.add_argument(
231
+ "--usejpeg",
232
+ action="store_true",
233
+ help="Images will be saved as JPEG format",
234
+ )
235
+ parser.add_argument(
236
+ "--noimagesave",
237
+ action="store_true",
238
+ help="Disable image saving",
239
+ )
240
+ parser.add_argument(
241
+ "--imagequality", type=int, help="Output image quality [0 to 100]", default=90
242
+ )
243
+ parser.add_argument(
244
+ "--lora",
245
+ type=str,
246
+ help="LoRA model full path e.g D:\lora_models\CuteCartoon15V-LiberteRedmodModel-Cartoon-CuteCartoonAF.safetensors",
247
+ default=None,
248
+ )
249
+ parser.add_argument(
250
+ "--lora_weight",
251
+ type=float,
252
+ help="LoRA adapter weight [0 to 1.0]",
253
+ default=0.5,
254
+ )
255
+ parser.add_argument(
256
+ "--port",
257
+ type=int,
258
+ help="Web server port",
259
+ default=8000,
260
+ )
261
+
262
+ args = parser.parse_args()
263
+
264
+ if args.version:
265
+ print(APP_VERSION)
266
+ exit()
267
+
268
+ # parser.print_help()
269
+ print("FastSD CPU - ", APP_VERSION)
270
+ show_system_info()
271
+ print(f"Using device : {constants.DEVICE}")
272
+
273
+
274
+ if args.webui:
275
+ app_settings = get_settings()
276
+ else:
277
+ app_settings = get_settings()
278
+
279
+ print(f"Output path : {app_settings.settings.generated_images.path}")
280
+ ensure_path(app_settings.settings.generated_images.path)
281
+
282
+ print(f"Found {len(app_settings.lcm_models)} LCM models in config/lcm-models.txt")
283
+ print(
284
+ f"Found {len(app_settings.stable_diffsuion_models)} stable diffusion models in config/stable-diffusion-models.txt"
285
+ )
286
+ print(
287
+ f"Found {len(app_settings.lcm_lora_models)} LCM-LoRA models in config/lcm-lora-models.txt"
288
+ )
289
+ print(
290
+ f"Found {len(app_settings.openvino_lcm_models)} OpenVINO LCM models in config/openvino-lcm-models.txt"
291
+ )
292
+
293
+ if args.noimagesave:
294
+ app_settings.settings.generated_images.save_image = False
295
+ else:
296
+ app_settings.settings.generated_images.save_image = True
297
+
298
+ app_settings.settings.generated_images.save_image_quality = args.imagequality
299
+
300
+ if not args.realtime:
301
+ # To minimize realtime mode dependencies
302
+ from backend.upscale.upscaler import upscale_image
303
+ from frontend.cli_interactive import interactive_mode
304
+
305
+ if args.gui:
306
+ from frontend.gui.ui import start_gui
307
+
308
+ print("Starting desktop GUI mode(Qt)")
309
+ start_gui(
310
+ [],
311
+ app_settings,
312
+ )
313
+ elif args.webui:
314
+ from frontend.webui.ui import start_webui
315
+
316
+ print("Starting web UI mode")
317
+ start_webui(
318
+ args.share,
319
+ )
320
+ elif args.realtime:
321
+ from frontend.webui.realtime_ui import start_realtime_text_to_image
322
+
323
+ print("Starting realtime text to image(EXPERIMENTAL)")
324
+ start_realtime_text_to_image(args.share)
325
+ elif args.api:
326
+ from backend.api.web import start_web_server
327
+
328
+ start_web_server(args.port)
329
+ elif args.mcp:
330
+ from backend.api.mcp_server import start_mcp_server
331
+
332
+ start_mcp_server(args.port)
333
+ else:
334
+ context = get_context(InterfaceType.CLI)
335
+ config = app_settings.settings
336
+
337
+ if args.use_openvino:
338
+ config.lcm_diffusion_setting.openvino_lcm_model_id = args.openvino_lcm_model_id
339
+ else:
340
+ config.lcm_diffusion_setting.lcm_model_id = args.lcm_model_id
341
+
342
+ config.lcm_diffusion_setting.prompt = args.prompt
343
+ config.lcm_diffusion_setting.negative_prompt = args.negative_prompt
344
+ config.lcm_diffusion_setting.image_height = args.image_height
345
+ config.lcm_diffusion_setting.image_width = args.image_width
346
+ config.lcm_diffusion_setting.guidance_scale = args.guidance_scale
347
+ config.lcm_diffusion_setting.number_of_images = args.number_of_images
348
+ config.lcm_diffusion_setting.inference_steps = args.inference_steps
349
+ config.lcm_diffusion_setting.strength = args.strength
350
+ config.lcm_diffusion_setting.seed = args.seed
351
+ config.lcm_diffusion_setting.use_openvino = args.use_openvino
352
+ config.lcm_diffusion_setting.use_tiny_auto_encoder = args.use_tiny_auto_encoder
353
+ config.lcm_diffusion_setting.use_lcm_lora = args.use_lcm_lora
354
+ config.lcm_diffusion_setting.lcm_lora.base_model_id = args.base_model_id
355
+ config.lcm_diffusion_setting.lcm_lora.lcm_lora_id = args.lcm_lora_id
356
+ config.lcm_diffusion_setting.diffusion_task = DiffusionTask.text_to_image.value
357
+ config.lcm_diffusion_setting.lora.enabled = False
358
+ config.lcm_diffusion_setting.lora.path = args.lora
359
+ config.lcm_diffusion_setting.lora.weight = args.lora_weight
360
+ config.lcm_diffusion_setting.lora.fuse = True
361
+ if config.lcm_diffusion_setting.lora.path:
362
+ config.lcm_diffusion_setting.lora.enabled = True
363
+ if args.usejpeg:
364
+ config.generated_images.format = ImageFormat.JPEG.value.upper()
365
+ if args.seed > -1:
366
+ config.lcm_diffusion_setting.use_seed = True
367
+ else:
368
+ config.lcm_diffusion_setting.use_seed = False
369
+ config.lcm_diffusion_setting.use_offline_model = args.use_offline_model
370
+ config.lcm_diffusion_setting.clip_skip = args.clip_skip
371
+ config.lcm_diffusion_setting.token_merging = args.token_merging
372
+ config.lcm_diffusion_setting.use_safety_checker = args.use_safety_checker
373
+
374
+ # Read custom settings from JSON file
375
+ custom_settings = {}
376
+ if args.custom_settings:
377
+ with open(args.custom_settings) as f:
378
+ custom_settings = json.load(f)
379
+
380
+ # Basic ControlNet settings; if ControlNet is enabled, an image is
381
+ # required even in txt2img mode
382
+ config.lcm_diffusion_setting.controlnet = None
383
+ controlnet_settings_from_dict(
384
+ config.lcm_diffusion_setting,
385
+ custom_settings,
386
+ )
387
+
388
+ # Interactive mode
389
+ if args.interactive:
390
+ # wrapper(interactive_mode, config, context)
391
+ config.lcm_diffusion_setting.lora.fuse = False
392
+ interactive_mode(config, context)
393
+
394
+ # Start of non-interactive CLI image generation
395
+ if args.img2img and args.file != "":
396
+ config.lcm_diffusion_setting.init_image = Image.open(args.file)
397
+ config.lcm_diffusion_setting.diffusion_task = DiffusionTask.image_to_image.value
398
+ elif args.img2img and args.file == "":
399
+ print("Error : You need to specify a file in img2img mode")
400
+ exit()
401
+ elif args.upscale and args.file == "" and args.custom_settings == None:
402
+ print("Error : You need to specify a file in SD upscale mode")
403
+ exit()
404
+ elif (
405
+ args.prompt == ""
406
+ and args.file == ""
407
+ and args.custom_settings == None
408
+ and not args.benchmark
409
+ ):
410
+ print("Error : You need to provide a prompt")
411
+ exit()
412
+
413
+ if args.upscale:
414
+ # image = Image.open(args.file)
415
+ output_path = FastStableDiffusionPaths.get_upscale_filepath(
416
+ args.file,
417
+ 2,
418
+ config.generated_images.format,
419
+ )
420
+ result = upscale_image(
421
+ context,
422
+ args.file,
423
+ output_path,
424
+ 2,
425
+ )
426
+ # Perform Tiled SD upscale (EXPERIMENTAL)
427
+ elif args.sdupscale:
428
+ if args.use_openvino:
429
+ config.lcm_diffusion_setting.strength = 0.3
430
+ upscale_settings = None
431
+ if custom_settings != {}:
432
+ upscale_settings = custom_settings
433
+ filepath = args.file
434
+ output_format = config.generated_images.format
435
+ if upscale_settings:
436
+ filepath = upscale_settings["source_file"]
437
+ output_format = upscale_settings["output_format"].upper()
438
+ output_path = FastStableDiffusionPaths.get_upscale_filepath(
439
+ filepath,
440
+ 2,
441
+ output_format,
442
+ )
443
+
444
+ generate_upscaled_image(
445
+ config,
446
+ filepath,
447
+ config.lcm_diffusion_setting.strength,
448
+ upscale_settings=upscale_settings,
449
+ context=context,
450
+ tile_overlap=32 if config.lcm_diffusion_setting.use_openvino else 16,
451
+ output_path=output_path,
452
+ image_format=output_format,
453
+ )
454
+ exit()
455
+ # If img2img argument is set and prompt is empty, use image variations mode
456
+ elif args.img2img and args.prompt == "":
457
+ for i in range(0, args.batch_count):
458
+ generate_image_variations(
459
+ config.lcm_diffusion_setting.init_image, args.strength
460
+ )
461
+ else:
462
+ if args.benchmark:
463
+ print("Initializing benchmark...")
464
+ bench_lcm_setting = config.lcm_diffusion_setting
465
+ bench_lcm_setting.prompt = "a cat"
466
+ bench_lcm_setting.use_tiny_auto_encoder = False
467
+ context.generate_text_to_image(
468
+ settings=config,
469
+ device=DEVICE,
470
+ )
471
+
472
+ latencies = []
473
+
474
+ print("Starting benchmark please wait...")
475
+ for _ in range(3):
476
+ context.generate_text_to_image(
477
+ settings=config,
478
+ device=DEVICE,
479
+ )
480
+ latencies.append(context.latency)
481
+
482
+ avg_latency = sum(latencies) / 3
483
+
484
+ bench_lcm_setting.use_tiny_auto_encoder = True
485
+
486
+ context.generate_text_to_image(
487
+ settings=config,
488
+ device=DEVICE,
489
+ )
490
+ latencies = []
491
+ for _ in range(3):
492
+ context.generate_text_to_image(
493
+ settings=config,
494
+ device=DEVICE,
495
+ )
496
+ latencies.append(context.latency)
497
+
498
+ avg_latency_taesd = sum(latencies) / 3
499
+
500
+ benchmark_name = ""
501
+
502
+ if config.lcm_diffusion_setting.use_openvino:
503
+ benchmark_name = "OpenVINO"
504
+ else:
505
+ benchmark_name = "PyTorch"
506
+
507
+ bench_model_id = ""
508
+ if bench_lcm_setting.use_openvino:
509
+ bench_model_id = bench_lcm_setting.openvino_lcm_model_id
510
+ elif bench_lcm_setting.use_lcm_lora:
511
+ bench_model_id = bench_lcm_setting.lcm_lora.base_model_id
512
+ else:
513
+ bench_model_id = bench_lcm_setting.lcm_model_id
514
+
515
+ benchmark_result = [
516
+ ["Device", f"{DEVICE.upper()},{get_device_name()}"],
517
+ ["Stable Diffusion Model", bench_model_id],
518
+ [
519
+ "Image Size ",
520
+ f"{bench_lcm_setting.image_width}x{bench_lcm_setting.image_height}",
521
+ ],
522
+ [
523
+ "Inference Steps",
524
+ f"{bench_lcm_setting.inference_steps}",
525
+ ],
526
+ [
527
+ "Benchmark Passes",
528
+ 3,
529
+ ],
530
+ [
531
+ "Average Latency",
532
+ f"{round(avg_latency, 3)} sec",
533
+ ],
534
+ [
535
+ "Average Latency(TAESD* enabled)",
536
+ f"{round(avg_latency_taesd, 3)} sec",
537
+ ],
538
+ ]
539
+ print()
540
+ print(
541
+ f" FastSD Benchmark - {benchmark_name:8} "
542
+ )
543
+ print(f"-" * 80)
544
+ for benchmark in benchmark_result:
545
+ print(f"{benchmark[0]:35} - {benchmark[1]}")
546
+ print(f"-" * 80)
547
+ print("*TAESD - Tiny AutoEncoder for Stable Diffusion")
548
+
549
+ else:
550
+ for i in range(0, args.batch_count):
551
+ context.generate_text_to_image(
552
+ settings=config,
553
+ device=DEVICE,
554
+ )
app_settings.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from os import makedirs, path
3
+
4
+ import yaml
5
+ from constants import (
6
+ LCM_LORA_MODELS_FILE,
7
+ LCM_MODELS_FILE,
8
+ OPENVINO_LCM_MODELS_FILE,
9
+ SD_MODELS_FILE,
10
+ )
11
+ from paths import FastStableDiffusionPaths, join_paths
12
+ from utils import get_files_in_dir, get_models_from_text_file
13
+
14
+ from models.settings import Settings
15
+
16
+
17
+ class AppSettings:
18
+ def __init__(self):
19
+ self.config_path = FastStableDiffusionPaths().get_app_settings_path()
20
+ self._stable_diffsuion_models = get_models_from_text_file(
21
+ FastStableDiffusionPaths().get_models_config_path(SD_MODELS_FILE)
22
+ )
23
+ self._lcm_lora_models = get_models_from_text_file(
24
+ FastStableDiffusionPaths().get_models_config_path(LCM_LORA_MODELS_FILE)
25
+ )
26
+ self._openvino_lcm_models = get_models_from_text_file(
27
+ FastStableDiffusionPaths().get_models_config_path(OPENVINO_LCM_MODELS_FILE)
28
+ )
29
+ self._lcm_models = get_models_from_text_file(
30
+ FastStableDiffusionPaths().get_models_config_path(LCM_MODELS_FILE)
31
+ )
32
+ self._gguf_diffusion_models = get_files_in_dir(
33
+ join_paths(FastStableDiffusionPaths().get_gguf_models_path(), "diffusion")
34
+ )
35
+ self._gguf_clip_models = get_files_in_dir(
36
+ join_paths(FastStableDiffusionPaths().get_gguf_models_path(), "clip")
37
+ )
38
+ self._gguf_vae_models = get_files_in_dir(
39
+ join_paths(FastStableDiffusionPaths().get_gguf_models_path(), "vae")
40
+ )
41
+ self._gguf_t5xxl_models = get_files_in_dir(
42
+ join_paths(FastStableDiffusionPaths().get_gguf_models_path(), "t5xxl")
43
+ )
44
+ self._config = None
45
+
46
+ @property
47
+ def settings(self):
48
+ return self._config
49
+
50
+ @property
51
+ def stable_diffsuion_models(self):
52
+ return self._stable_diffsuion_models
53
+
54
+ @property
55
+ def openvino_lcm_models(self):
56
+ return self._openvino_lcm_models
57
+
58
+ @property
59
+ def lcm_models(self):
60
+ return self._lcm_models
61
+
62
+ @property
63
+ def lcm_lora_models(self):
64
+ return self._lcm_lora_models
65
+
66
+ @property
67
+ def gguf_diffusion_models(self):
68
+ return self._gguf_diffusion_models
69
+
70
+ @property
71
+ def gguf_clip_models(self):
72
+ return self._gguf_clip_models
73
+
74
+ @property
75
+ def gguf_vae_models(self):
76
+ return self._gguf_vae_models
77
+
78
+ @property
79
+ def gguf_t5xxl_models(self):
80
+ return self._gguf_t5xxl_models
81
+
82
+ def load(self, skip_file=False):
83
+ if skip_file:
84
+ print("Skipping config file")
85
+ settings_dict = self._load_default()
86
+ self._config = Settings.model_validate(settings_dict)
87
+ else:
88
+ if not path.exists(self.config_path):
89
+ base_dir = path.dirname(self.config_path)
90
+ if not path.exists(base_dir):
91
+ makedirs(base_dir)
92
+ try:
93
+ print("Settings not found creating default settings")
94
+ with open(self.config_path, "w") as file:
95
+ yaml.dump(
96
+ self._load_default(),
97
+ file,
98
+ )
99
+ except Exception as ex:
100
+ print(f"Error in creating settings : {ex}")
101
+ exit()
102
+ try:
103
+ with open(self.config_path) as file:
104
+ settings_dict = yaml.safe_load(file)
105
+ self._config = Settings.model_validate(settings_dict)
106
+ except Exception as ex:
107
+ print(f"Error in loading settings : {ex}")
108
+
109
+ def save(self):
110
+ try:
111
+ with open(self.config_path, "w") as file:
112
+ tmp_cfg = deepcopy(self._config)
113
+ tmp_cfg.lcm_diffusion_setting.init_image = None
114
+ configurations = tmp_cfg.model_dump(
115
+ exclude=["init_image"],
116
+ )
117
+ if configurations:
118
+ yaml.dump(configurations, file)
119
+ except Exception as ex:
120
+ print(f"Error in saving settings : {ex}")
121
+
122
+ def _load_default(self) -> dict:
123
+ default_config = Settings()
124
+ return default_config.model_dump()
backend/__init__.py ADDED
File without changes
backend/annotators/canny_control.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from backend.annotators.control_interface import ControlInterface
3
+ from cv2 import Canny
4
+ from PIL import Image
5
+
6
+
7
+ class CannyControl(ControlInterface):
8
+ def get_control_image(self, image: Image) -> Image:
9
+ low_threshold = 100
10
+ high_threshold = 200
11
+ image = np.array(image)
12
+ image = Canny(image, low_threshold, high_threshold)
13
+ image = image[:, :, None]
14
+ image = np.concatenate([image, image, image], axis=2)
15
+ return Image.fromarray(image)
backend/annotators/control_interface.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ from PIL import Image
4
+
5
+
6
+ class ControlInterface(ABC):
7
+ @abstractmethod
8
+ def get_control_image(
9
+ self,
10
+ image: Image,
11
+ ) -> Image:
12
+ pass
backend/annotators/depth_control.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from backend.annotators.control_interface import ControlInterface
3
+ from PIL import Image
4
+ from transformers import pipeline
5
+
6
+
7
+ class DepthControl(ControlInterface):
8
+ def get_control_image(self, image: Image) -> Image:
9
+ depth_estimator = pipeline("depth-estimation")
10
+ image = depth_estimator(image)["depth"]
11
+ image = np.array(image)
12
+ image = image[:, :, None]
13
+ image = np.concatenate([image, image, image], axis=2)
14
+ image = Image.fromarray(image)
15
+ return image
backend/annotators/image_control_factory.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.annotators.canny_control import CannyControl
2
+ from backend.annotators.depth_control import DepthControl
3
+ from backend.annotators.lineart_control import LineArtControl
4
+ from backend.annotators.mlsd_control import MlsdControl
5
+ from backend.annotators.normal_control import NormalControl
6
+ from backend.annotators.pose_control import PoseControl
7
+ from backend.annotators.shuffle_control import ShuffleControl
8
+ from backend.annotators.softedge_control import SoftEdgeControl
9
+
10
+
11
+ class ImageControlFactory:
12
+ def create_control(self, controlnet_type: str):
13
+ if controlnet_type == "Canny":
14
+ return CannyControl()
15
+ elif controlnet_type == "Pose":
16
+ return PoseControl()
17
+ elif controlnet_type == "MLSD":
18
+ return MlsdControl()
19
+ elif controlnet_type == "Depth":
20
+ return DepthControl()
21
+ elif controlnet_type == "LineArt":
22
+ return LineArtControl()
23
+ elif controlnet_type == "Shuffle":
24
+ return ShuffleControl()
25
+ elif controlnet_type == "NormalBAE":
26
+ return NormalControl()
27
+ elif controlnet_type == "SoftEdge":
28
+ return SoftEdgeControl()
29
+ else:
30
+ print("Error: Control type not implemented!")
31
+ raise Exception("Error: Control type not implemented!")
backend/annotators/lineart_control.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from backend.annotators.control_interface import ControlInterface
3
+ from controlnet_aux import LineartDetector
4
+ from PIL import Image
5
+
6
+
7
+ class LineArtControl(ControlInterface):
8
+ def get_control_image(self, image: Image) -> Image:
9
+ processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
10
+ control_image = processor(image)
11
+ return control_image
backend/annotators/mlsd_control.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.annotators.control_interface import ControlInterface
2
+ from controlnet_aux import MLSDdetector
3
+ from PIL import Image
4
+
5
+
6
+ class MlsdControl(ControlInterface):
7
+ def get_control_image(self, image: Image) -> Image:
8
+ mlsd = MLSDdetector.from_pretrained("lllyasviel/ControlNet")
9
+ image = mlsd(image)
10
+ return image
backend/annotators/normal_control.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.annotators.control_interface import ControlInterface
2
+ from controlnet_aux import NormalBaeDetector
3
+ from PIL import Image
4
+
5
+
6
+ class NormalControl(ControlInterface):
7
+ def get_control_image(self, image: Image) -> Image:
8
+ processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
9
+ control_image = processor(image)
10
+ return control_image
backend/annotators/pose_control.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.annotators.control_interface import ControlInterface
2
+ from controlnet_aux import OpenposeDetector
3
+ from PIL import Image
4
+
5
+
6
+ class PoseControl(ControlInterface):
7
+ def get_control_image(self, image: Image) -> Image:
8
+ openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
9
+ image = openpose(image)
10
+ return image
backend/annotators/shuffle_control.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.annotators.control_interface import ControlInterface
2
+ from controlnet_aux import ContentShuffleDetector
3
+ from PIL import Image
4
+
5
+
6
+ class ShuffleControl(ControlInterface):
7
+ def get_control_image(self, image: Image) -> Image:
8
+ shuffle_processor = ContentShuffleDetector()
9
+ image = shuffle_processor(image)
10
+ return image
backend/annotators/softedge_control.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.annotators.control_interface import ControlInterface
2
+ from controlnet_aux import PidiNetDetector
3
+ from PIL import Image
4
+
5
+
6
+ class SoftEdgeControl(ControlInterface):
7
+ def get_control_image(self, image: Image) -> Image:
8
+ processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
9
+ control_image = processor(image)
10
+ return control_image
backend/api/mcp_server.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+
3
+ import uvicorn
4
+ from backend.device import get_device_name
5
+ from backend.models.device import DeviceInfo
6
+ from constants import APP_VERSION, DEVICE
7
+ from context import Context
8
+ from fastapi import FastAPI, Request
9
+ from fastapi_mcp import FastApiMCP
10
+ from state import get_settings
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from models.interface_types import InterfaceType
13
+ from fastapi.staticfiles import StaticFiles
14
+
15
+ app_settings = get_settings()
16
+ app = FastAPI(
17
+ title="FastSD CPU",
18
+ description="Fast stable diffusion on CPU",
19
+ version=APP_VERSION,
20
+ license_info={
21
+ "name": "MIT",
22
+ "identifier": "MIT",
23
+ },
24
+ describe_all_responses=True,
25
+ describe_full_response_schema=True,
26
+ )
27
+ origins = ["*"]
28
+
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=origins,
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+ print(app_settings.settings.lcm_diffusion_setting)
37
+
38
+ context = Context(InterfaceType.API_SERVER)
39
+ app.mount("/results", StaticFiles(directory="results"), name="results")
40
+
41
+
42
+ @app.get(
43
+ "/info",
44
+ description="Get system information",
45
+ summary="Get system information",
46
+ operation_id="get_system_info",
47
+ )
48
+ async def info() -> dict:
49
+ device_info = DeviceInfo(
50
+ device_type=DEVICE,
51
+ device_name=get_device_name(),
52
+ os=platform.system(),
53
+ platform=platform.platform(),
54
+ processor=platform.processor(),
55
+ )
56
+ return device_info.model_dump()
57
+
58
+
59
+ @app.post(
60
+ "/generate",
61
+ description="Generate image from text prompt",
62
+ summary="Text to image generation",
63
+ operation_id="generate",
64
+ )
65
+ async def generate(
66
+ prompt: str,
67
+ request: Request,
68
+ ) -> str:
69
+ """
70
+ Returns URL of the generated image for text prompt
71
+ """
72
+
73
+ app_settings.settings.lcm_diffusion_setting.prompt = prompt
74
+ images = context.generate_text_to_image(app_settings.settings)
75
+ image_names = context.save_images(
76
+ images,
77
+ app_settings.settings,
78
+ )
79
+ url = request.url_for("results", path=image_names[0])
80
+ image_url = f"The generated image available at the URL {url}"
81
+ return image_url
82
+
83
+
84
+ def start_mcp_server(port: int = 8000):
85
+ mcp = FastApiMCP(
86
+ app,
87
+ name="FastSDCPU MCP",
88
+ description="MCP server for FastSD CPU API",
89
+ base_url=f"http://localhost:{port}",
90
+ )
91
+
92
+ mcp.mount()
93
+ uvicorn.run(
94
+ app,
95
+ host="0.0.0.0",
96
+ port=port,
97
+ )
backend/api/models/response.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class StableDiffusionResponse(BaseModel):
7
+ """
8
+ Stable diffusion response model
9
+
10
+ Attributes:
11
+ images (List[str]): List of JPEG image as base64 encoded
12
+ latency (float): Latency in seconds
13
+ """
14
+
15
+ images: List[str]
16
+ latency: float
backend/api/web.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+
3
+ import uvicorn
4
+ from fastapi import FastAPI
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+
7
+ from backend.api.models.response import StableDiffusionResponse
8
+ from backend.base64_image import base64_image_to_pil, pil_image_to_base64_str
9
+ from backend.device import get_device_name
10
+ from backend.models.device import DeviceInfo
11
+ from backend.models.lcmdiffusion_setting import DiffusionTask, LCMDiffusionSetting
12
+ from constants import APP_VERSION, DEVICE
13
+ from context import Context
14
+ from models.interface_types import InterfaceType
15
+ from state import get_settings
16
+
17
+ app_settings = get_settings()
18
+ app = FastAPI(
19
+ title="FastSD CPU",
20
+ description="Fast stable diffusion on CPU",
21
+ version=APP_VERSION,
22
+ license_info={
23
+ "name": "MIT",
24
+ "identifier": "MIT",
25
+ },
26
+ docs_url="/api/docs",
27
+ redoc_url="/api/redoc",
28
+ openapi_url="/api/openapi.json",
29
+ )
30
+ print(app_settings.settings.lcm_diffusion_setting)
31
+ origins = ["*"]
32
+ app.add_middleware(
33
+ CORSMiddleware,
34
+ allow_origins=origins,
35
+ allow_credentials=True,
36
+ allow_methods=["*"],
37
+ allow_headers=["*"],
38
+ )
39
+ context = Context(InterfaceType.API_SERVER)
40
+
41
+
42
+ @app.get("/api/")
43
+ async def root():
44
+ return {"message": "Welcome to FastSD CPU API"}
45
+
46
+
47
+ @app.get(
48
+ "/api/info",
49
+ description="Get system information",
50
+ summary="Get system information",
51
+ )
52
+ async def info():
53
+ device_info = DeviceInfo(
54
+ device_type=DEVICE,
55
+ device_name=get_device_name(),
56
+ os=platform.system(),
57
+ platform=platform.platform(),
58
+ processor=platform.processor(),
59
+ )
60
+ return device_info.model_dump()
61
+
62
+
63
+ @app.get(
64
+ "/api/config",
65
+ description="Get current configuration",
66
+ summary="Get configurations",
67
+ )
68
+ async def config():
69
+ return app_settings.settings
70
+
71
+
72
+ @app.get(
73
+ "/api/models",
74
+ description="Get available models",
75
+ summary="Get available models",
76
+ )
77
+ async def models():
78
+ return {
79
+ "lcm_lora_models": app_settings.lcm_lora_models,
80
+ "stable_diffusion": app_settings.stable_diffsuion_models,
81
+ "openvino_models": app_settings.openvino_lcm_models,
82
+ "lcm_models": app_settings.lcm_models,
83
+ }
84
+
85
+
86
+ @app.post(
87
+ "/api/generate",
88
+ description="Generate image(Text to image,Image to Image)",
89
+ summary="Generate image(Text to image,Image to Image)",
90
+ )
91
+ async def generate(diffusion_config: LCMDiffusionSetting) -> StableDiffusionResponse:
92
+ app_settings.settings.lcm_diffusion_setting = diffusion_config
93
+ if diffusion_config.diffusion_task == DiffusionTask.image_to_image:
94
+ app_settings.settings.lcm_diffusion_setting.init_image = base64_image_to_pil(
95
+ diffusion_config.init_image
96
+ )
97
+
98
+ images = context.generate_text_to_image(app_settings.settings)
99
+
100
+ images_base64 = [pil_image_to_base64_str(img) for img in images]
101
+ return StableDiffusionResponse(
102
+ latency=round(context.latency, 2),
103
+ images=images_base64,
104
+ )
105
+
106
+
107
+ def start_web_server(port: int = 8000):
108
+ uvicorn.run(
109
+ app,
110
+ host="0.0.0.0",
111
+ port=port,
112
+ )
backend/base64_image.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from base64 import b64encode, b64decode
3
+ from PIL import Image
4
+
5
+
6
+ def pil_image_to_base64_str(
7
+ image: Image,
8
+ format: str = "JPEG",
9
+ ) -> str:
10
+ buffer = BytesIO()
11
+ image.save(buffer, format=format)
12
+ buffer.seek(0)
13
+ img_base64 = b64encode(buffer.getvalue()).decode("utf-8")
14
+ return img_base64
15
+
16
+
17
+ def base64_image_to_pil(base64_str) -> Image:
18
+ image_data = b64decode(base64_str)
19
+ image_buffer = BytesIO(image_data)
20
+ image = Image.open(image_buffer)
21
+ return image
backend/controlnet.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from PIL import Image
3
+ from diffusers import ControlNetModel
4
+ from backend.models.lcmdiffusion_setting import (
5
+ DiffusionTask,
6
+ ControlNetSetting,
7
+ )
8
+
9
+
10
+ # Prepares ControlNet adapters for use with FastSD CPU
11
+ #
12
+ # This function loads the ControlNet adapters defined by the
13
+ # _lcm_diffusion_setting.controlnet_ object and returns a dictionary
14
+ # with the pipeline arguments required to use the loaded adapters
15
+ def load_controlnet_adapters(lcm_diffusion_setting) -> dict:
16
+ controlnet_args = {}
17
+ if (
18
+ lcm_diffusion_setting.controlnet is None
19
+ or not lcm_diffusion_setting.controlnet.enabled
20
+ ):
21
+ return controlnet_args
22
+
23
+ logging.info("Loading ControlNet adapter")
24
+ controlnet_adapter = ControlNetModel.from_single_file(
25
+ lcm_diffusion_setting.controlnet.adapter_path,
26
+ # local_files_only=True,
27
+ use_safetensors=True,
28
+ )
29
+ controlnet_args["controlnet"] = controlnet_adapter
30
+ return controlnet_args
31
+
32
+
33
+ # Updates the ControlNet pipeline arguments to use for image generation
34
+ #
35
+ # This function uses the contents of the _lcm_diffusion_setting.controlnet_
36
+ # object to generate a dictionary with the corresponding pipeline arguments
37
+ # to be used for image generation; in particular, it sets the ControlNet control
38
+ # image and conditioning scale
39
+ def update_controlnet_arguments(lcm_diffusion_setting) -> dict:
40
+ controlnet_args = {}
41
+ if (
42
+ lcm_diffusion_setting.controlnet is None
43
+ or not lcm_diffusion_setting.controlnet.enabled
44
+ ):
45
+ return controlnet_args
46
+
47
+ controlnet_args["controlnet_conditioning_scale"] = (
48
+ lcm_diffusion_setting.controlnet.conditioning_scale
49
+ )
50
+ if lcm_diffusion_setting.diffusion_task == DiffusionTask.text_to_image.value:
51
+ controlnet_args["image"] = lcm_diffusion_setting.controlnet._control_image
52
+ elif lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value:
53
+ controlnet_args["control_image"] = (
54
+ lcm_diffusion_setting.controlnet._control_image
55
+ )
56
+ return controlnet_args
57
+
58
+
59
+ # Helper function to adjust ControlNet settings from a dictionary
60
+ def controlnet_settings_from_dict(
61
+ lcm_diffusion_setting,
62
+ dictionary,
63
+ ) -> None:
64
+ if lcm_diffusion_setting is None or dictionary is None:
65
+ logging.error("Invalid arguments!")
66
+ return
67
+ if (
68
+ "controlnet" not in dictionary
69
+ or dictionary["controlnet"] is None
70
+ or len(dictionary["controlnet"]) == 0
71
+ ):
72
+ logging.warning("ControlNet settings not found, ControlNet will be disabled")
73
+ lcm_diffusion_setting.controlnet = None
74
+ return
75
+
76
+ controlnet = ControlNetSetting()
77
+ controlnet.enabled = dictionary["controlnet"][0]["enabled"]
78
+ controlnet.conditioning_scale = dictionary["controlnet"][0]["conditioning_scale"]
79
+ controlnet.adapter_path = dictionary["controlnet"][0]["adapter_path"]
80
+ controlnet._control_image = None
81
+ image_path = dictionary["controlnet"][0]["control_image"]
82
+ if controlnet.enabled:
83
+ try:
84
+ controlnet._control_image = Image.open(image_path)
85
+ except (AttributeError, FileNotFoundError) as err:
86
+ print(err)
87
+ if controlnet._control_image is None:
88
+ logging.error("Wrong ControlNet control image! Disabling ControlNet")
89
+ controlnet.enabled = False
90
+ lcm_diffusion_setting.controlnet = controlnet
backend/device.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ from constants import DEVICE
3
+ import torch
4
+ import openvino as ov
5
+
6
+ core = ov.Core()
7
+
8
+
9
+ def is_openvino_device() -> bool:
10
+ if DEVICE.lower() == "cpu" or DEVICE.lower()[0] == "g" or DEVICE.lower()[0] == "n":
11
+ return True
12
+ else:
13
+ return False
14
+
15
+
16
+ def get_device_name() -> str:
17
+ if DEVICE == "cuda" or DEVICE == "mps":
18
+ default_gpu_index = torch.cuda.current_device()
19
+ return torch.cuda.get_device_name(default_gpu_index)
20
+ elif platform.system().lower() == "darwin":
21
+ return platform.processor()
22
+ elif is_openvino_device():
23
+ return core.get_property(DEVICE.upper(), "FULL_DEVICE_NAME")
backend/gguf/gguf_diffusion.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrapper class to call the stablediffusion.cpp shared library for GGUF support
3
+ """
4
+
5
+ import ctypes
6
+ import platform
7
+ from ctypes import (
8
+ POINTER,
9
+ c_bool,
10
+ c_char_p,
11
+ c_float,
12
+ c_int,
13
+ c_int64,
14
+ c_void_p,
15
+ )
16
+ from dataclasses import dataclass
17
+ from os import path
18
+ from typing import List, Any
19
+
20
+ import numpy as np
21
+ from PIL import Image
22
+
23
+ from backend.gguf.sdcpp_types import (
24
+ RngType,
25
+ SampleMethod,
26
+ Schedule,
27
+ SDCPPLogLevel,
28
+ SDImage,
29
+ SdType,
30
+ )
31
+
32
+
33
+ @dataclass
34
+ class ModelConfig:
35
+ model_path: str = ""
36
+ clip_l_path: str = ""
37
+ t5xxl_path: str = ""
38
+ diffusion_model_path: str = ""
39
+ vae_path: str = ""
40
+ taesd_path: str = ""
41
+ control_net_path: str = ""
42
+ lora_model_dir: str = ""
43
+ embed_dir: str = ""
44
+ stacked_id_embed_dir: str = ""
45
+ vae_decode_only: bool = True
46
+ vae_tiling: bool = False
47
+ free_params_immediately: bool = False
48
+ n_threads: int = 4
49
+ wtype: SdType = SdType.SD_TYPE_Q4_0
50
+ rng_type: RngType = RngType.CUDA_RNG
51
+ schedule: Schedule = Schedule.DEFAULT
52
+ keep_clip_on_cpu: bool = False
53
+ keep_control_net_cpu: bool = False
54
+ keep_vae_on_cpu: bool = False
55
+
56
+
57
+ @dataclass
58
+ class Txt2ImgConfig:
59
+ prompt: str = "a man wearing sun glasses, highly detailed"
60
+ negative_prompt: str = ""
61
+ clip_skip: int = -1
62
+ cfg_scale: float = 2.0
63
+ guidance: float = 3.5
64
+ width: int = 512
65
+ height: int = 512
66
+ sample_method: SampleMethod = SampleMethod.EULER_A
67
+ sample_steps: int = 1
68
+ seed: int = -1
69
+ batch_count: int = 2
70
+ control_cond: Image = None
71
+ control_strength: float = 0.90
72
+ style_strength: float = 0.5
73
+ normalize_input: bool = False
74
+ input_id_images_path: bytes = b""
75
+
76
+
77
+ class GGUFDiffusion:
78
+ """GGUF Diffusion
79
+ To support GGUF diffusion model based on stablediffusion.cpp
80
+ https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
81
+ Implmented based on stablediffusion.h
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ libpath: str,
87
+ config: ModelConfig,
88
+ logging_enabled: bool = False,
89
+ ):
90
+ sdcpp_shared_lib_path = self._get_sdcpp_shared_lib_path(libpath)
91
+ try:
92
+ self.libsdcpp = ctypes.CDLL(sdcpp_shared_lib_path)
93
+ except OSError as e:
94
+ print(f"Failed to load library {sdcpp_shared_lib_path}")
95
+ raise ValueError(f"Error: {e}")
96
+
97
+ if not config.clip_l_path or not path.exists(config.clip_l_path):
98
+ raise ValueError(
99
+ "CLIP model file not found,please check readme.md for GGUF model usage"
100
+ )
101
+
102
+ if not config.t5xxl_path or not path.exists(config.t5xxl_path):
103
+ raise ValueError(
104
+ "T5XXL model file not found,please check readme.md for GGUF model usage"
105
+ )
106
+
107
+ if not config.diffusion_model_path or not path.exists(
108
+ config.diffusion_model_path
109
+ ):
110
+ raise ValueError(
111
+ "Diffusion model file not found,please check readme.md for GGUF model usage"
112
+ )
113
+
114
+ if not config.vae_path or not path.exists(config.vae_path):
115
+ raise ValueError(
116
+ "VAE model file not found,please check readme.md for GGUF model usage"
117
+ )
118
+
119
+ self.model_config = config
120
+
121
+ self.libsdcpp.new_sd_ctx.argtypes = [
122
+ c_char_p, # const char* model_path
123
+ c_char_p, # const char* clip_l_path
124
+ c_char_p, # const char* t5xxl_path
125
+ c_char_p, # const char* diffusion_model_path
126
+ c_char_p, # const char* vae_path
127
+ c_char_p, # const char* taesd_path
128
+ c_char_p, # const char* control_net_path_c_str
129
+ c_char_p, # const char* lora_model_dir
130
+ c_char_p, # const char* embed_dir_c_str
131
+ c_char_p, # const char* stacked_id_embed_dir_c_str
132
+ c_bool, # bool vae_decode_only
133
+ c_bool, # bool vae_tiling
134
+ c_bool, # bool free_params_immediately
135
+ c_int, # int n_threads
136
+ SdType, # enum sd_type_t wtype
137
+ RngType, # enum rng_type_t rng_type
138
+ Schedule, # enum schedule_t s
139
+ c_bool, # bool keep_clip_on_cpu
140
+ c_bool, # bool keep_control_net_cpu
141
+ c_bool, # bool keep_vae_on_cpu
142
+ ]
143
+
144
+ self.libsdcpp.new_sd_ctx.restype = POINTER(c_void_p)
145
+
146
+ self.sd_ctx = self.libsdcpp.new_sd_ctx(
147
+ self._str_to_bytes(self.model_config.model_path),
148
+ self._str_to_bytes(self.model_config.clip_l_path),
149
+ self._str_to_bytes(self.model_config.t5xxl_path),
150
+ self._str_to_bytes(self.model_config.diffusion_model_path),
151
+ self._str_to_bytes(self.model_config.vae_path),
152
+ self._str_to_bytes(self.model_config.taesd_path),
153
+ self._str_to_bytes(self.model_config.control_net_path),
154
+ self._str_to_bytes(self.model_config.lora_model_dir),
155
+ self._str_to_bytes(self.model_config.embed_dir),
156
+ self._str_to_bytes(self.model_config.stacked_id_embed_dir),
157
+ self.model_config.vae_decode_only,
158
+ self.model_config.vae_tiling,
159
+ self.model_config.free_params_immediately,
160
+ self.model_config.n_threads,
161
+ self.model_config.wtype,
162
+ self.model_config.rng_type,
163
+ self.model_config.schedule,
164
+ self.model_config.keep_clip_on_cpu,
165
+ self.model_config.keep_control_net_cpu,
166
+ self.model_config.keep_vae_on_cpu,
167
+ )
168
+
169
+ if logging_enabled:
170
+ self._set_logcallback()
171
+
172
+ def _set_logcallback(self):
173
+ print("Setting logging callback")
174
+ # Define function callback
175
+ SdLogCallbackType = ctypes.CFUNCTYPE(
176
+ None,
177
+ SDCPPLogLevel,
178
+ ctypes.c_char_p,
179
+ ctypes.c_void_p,
180
+ )
181
+
182
+ self.libsdcpp.sd_set_log_callback.argtypes = [
183
+ SdLogCallbackType,
184
+ ctypes.c_void_p,
185
+ ]
186
+ self.libsdcpp.sd_set_log_callback.restype = None
187
+ # Convert the Python callback to a C func pointer
188
+ self.c_log_callback = SdLogCallbackType(
189
+ self.log_callback
190
+ ) # prevent GC,keep callback as member variable
191
+ self.libsdcpp.sd_set_log_callback(self.c_log_callback, None)
192
+
193
+ def _get_sdcpp_shared_lib_path(
194
+ self,
195
+ root_path: str,
196
+ ) -> str:
197
+ system_name = platform.system()
198
+ print(f"GGUF Diffusion on {system_name}")
199
+ lib_name = "stable-diffusion.dll"
200
+ sdcpp_lib_path = ""
201
+
202
+ if system_name == "Windows":
203
+ sdcpp_lib_path = path.join(root_path, lib_name)
204
+ elif system_name == "Linux":
205
+ lib_name = "libstable-diffusion.so"
206
+ sdcpp_lib_path = path.join(root_path, lib_name)
207
+ elif system_name == "Darwin":
208
+ lib_name = "libstable-diffusion.dylib"
209
+ sdcpp_lib_path = path.join(root_path, lib_name)
210
+ else:
211
+ print("Unknown platform.")
212
+
213
+ return sdcpp_lib_path
214
+
215
+ @staticmethod
216
+ def log_callback(
217
+ level,
218
+ text,
219
+ data,
220
+ ):
221
+ print(f"{text.decode('utf-8')}", end="")
222
+
223
+ def _str_to_bytes(self, in_str: str, encoding: str = "utf-8") -> bytes:
224
+ if in_str:
225
+ return in_str.encode(encoding)
226
+ else:
227
+ return b""
228
+
229
+ def generate_text2mg(self, txt2img_cfg: Txt2ImgConfig) -> List[Any]:
230
+ self.libsdcpp.txt2img.restype = POINTER(SDImage)
231
+ self.libsdcpp.txt2img.argtypes = [
232
+ c_void_p, # sd_ctx_t* sd_ctx (pointer to context object)
233
+ c_char_p, # const char* prompt
234
+ c_char_p, # const char* negative_prompt
235
+ c_int, # int clip_skip
236
+ c_float, # float cfg_scale
237
+ c_float, # float guidance
238
+ c_int, # int width
239
+ c_int, # int height
240
+ SampleMethod, # enum sample_method_t sample_method
241
+ c_int, # int sample_steps
242
+ c_int64, # int64_t seed
243
+ c_int, # int batch_count
244
+ POINTER(SDImage), # const sd_image_t* control_cond (pointer to SDImage)
245
+ c_float, # float control_strength
246
+ c_float, # float style_strength
247
+ c_bool, # bool normalize_input
248
+ c_char_p, # const char* input_id_images_path
249
+ ]
250
+
251
+ image_buffer = self.libsdcpp.txt2img(
252
+ self.sd_ctx,
253
+ self._str_to_bytes(txt2img_cfg.prompt),
254
+ self._str_to_bytes(txt2img_cfg.negative_prompt),
255
+ txt2img_cfg.clip_skip,
256
+ txt2img_cfg.cfg_scale,
257
+ txt2img_cfg.guidance,
258
+ txt2img_cfg.width,
259
+ txt2img_cfg.height,
260
+ txt2img_cfg.sample_method,
261
+ txt2img_cfg.sample_steps,
262
+ txt2img_cfg.seed,
263
+ txt2img_cfg.batch_count,
264
+ txt2img_cfg.control_cond,
265
+ txt2img_cfg.control_strength,
266
+ txt2img_cfg.style_strength,
267
+ txt2img_cfg.normalize_input,
268
+ txt2img_cfg.input_id_images_path,
269
+ )
270
+
271
+ images = self._get_sd_images_from_buffer(
272
+ image_buffer,
273
+ txt2img_cfg.batch_count,
274
+ )
275
+
276
+ return images
277
+
278
+ def _get_sd_images_from_buffer(
279
+ self,
280
+ image_buffer: Any,
281
+ batch_count: int,
282
+ ) -> List[Any]:
283
+ images = []
284
+ if image_buffer:
285
+ for i in range(batch_count):
286
+ image = image_buffer[i]
287
+ print(
288
+ f"Generated image: {image.width}x{image.height} with {image.channel} channels"
289
+ )
290
+
291
+ width = image.width
292
+ height = image.height
293
+ channels = image.channel
294
+ pixel_data = np.ctypeslib.as_array(
295
+ image.data, shape=(height, width, channels)
296
+ )
297
+
298
+ if channels == 1:
299
+ pil_image = Image.fromarray(pixel_data.squeeze(), mode="L")
300
+ elif channels == 3:
301
+ pil_image = Image.fromarray(pixel_data, mode="RGB")
302
+ elif channels == 4:
303
+ pil_image = Image.fromarray(pixel_data, mode="RGBA")
304
+ else:
305
+ raise ValueError(f"Unsupported number of channels: {channels}")
306
+
307
+ images.append(pil_image)
308
+ return images
309
+
310
+ def terminate(self):
311
+ if self.libsdcpp:
312
+ if self.sd_ctx:
313
+ self.libsdcpp.free_sd_ctx.argtypes = [c_void_p]
314
+ self.libsdcpp.free_sd_ctx.restype = None
315
+ self.libsdcpp.free_sd_ctx(self.sd_ctx)
316
+ del self.sd_ctx
317
+ self.sd_ctx = None
318
+ del self.libsdcpp
319
+ self.libsdcpp = None
backend/gguf/sdcpp_types.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ctypes for stablediffusion.cpp shared library
3
+ This is as per the stablediffusion.h file
4
+ """
5
+
6
+ from enum import IntEnum
7
+ from ctypes import (
8
+ c_int,
9
+ c_uint32,
10
+ c_uint8,
11
+ POINTER,
12
+ Structure,
13
+ )
14
+
15
+
16
+ class CtypesEnum(IntEnum):
17
+ """A ctypes-compatible IntEnum superclass."""
18
+
19
+ @classmethod
20
+ def from_param(cls, obj):
21
+ return int(obj)
22
+
23
+
24
+ class RngType(CtypesEnum):
25
+ STD_DEFAULT_RNG = 0
26
+ CUDA_RNG = 1
27
+
28
+
29
+ class SampleMethod(CtypesEnum):
30
+ EULER_A = 0
31
+ EULER = 1
32
+ HEUN = 2
33
+ DPM2 = 3
34
+ DPMPP2S_A = 4
35
+ DPMPP2M = 5
36
+ DPMPP2Mv2 = 6
37
+ IPNDM = 7
38
+ IPNDM_V = 7
39
+ LCM = 8
40
+ N_SAMPLE_METHODS = 9
41
+
42
+
43
+ class Schedule(CtypesEnum):
44
+ DEFAULT = 0
45
+ DISCRETE = 1
46
+ KARRAS = 2
47
+ EXPONENTIAL = 3
48
+ AYS = 4
49
+ GITS = 5
50
+ N_SCHEDULES = 5
51
+
52
+
53
+ class SdType(CtypesEnum):
54
+ SD_TYPE_F32 = 0
55
+ SD_TYPE_F16 = 1
56
+ SD_TYPE_Q4_0 = 2
57
+ SD_TYPE_Q4_1 = 3
58
+ # SD_TYPE_Q4_2 = 4, support has been removed
59
+ # SD_TYPE_Q4_3 = 5, support has been removed
60
+ SD_TYPE_Q5_0 = 6
61
+ SD_TYPE_Q5_1 = 7
62
+ SD_TYPE_Q8_0 = 8
63
+ SD_TYPE_Q8_1 = 9
64
+ SD_TYPE_Q2_K = 10
65
+ SD_TYPE_Q3_K = 11
66
+ SD_TYPE_Q4_K = 12
67
+ SD_TYPE_Q5_K = 13
68
+ SD_TYPE_Q6_K = 14
69
+ SD_TYPE_Q8_K = 15
70
+ SD_TYPE_IQ2_XXS = 16
71
+ SD_TYPE_IQ2_XS = 17
72
+ SD_TYPE_IQ3_XXS = 18
73
+ SD_TYPE_IQ1_S = 19
74
+ SD_TYPE_IQ4_NL = 20
75
+ SD_TYPE_IQ3_S = 21
76
+ SD_TYPE_IQ2_S = 22
77
+ SD_TYPE_IQ4_XS = 23
78
+ SD_TYPE_I8 = 24
79
+ SD_TYPE_I16 = 25
80
+ SD_TYPE_I32 = 26
81
+ SD_TYPE_I64 = 27
82
+ SD_TYPE_F64 = 28
83
+ SD_TYPE_IQ1_M = 29
84
+ SD_TYPE_BF16 = 30
85
+ SD_TYPE_Q4_0_4_4 = 31
86
+ SD_TYPE_Q4_0_4_8 = 32
87
+ SD_TYPE_Q4_0_8_8 = 33
88
+ SD_TYPE_COUNT = 34
89
+
90
+
91
+ class SDImage(Structure):
92
+ _fields_ = [
93
+ ("width", c_uint32),
94
+ ("height", c_uint32),
95
+ ("channel", c_uint32),
96
+ ("data", POINTER(c_uint8)),
97
+ ]
98
+
99
+
100
+ class SDCPPLogLevel(c_int):
101
+ SD_LOG_LEVEL_DEBUG = 0
102
+ SD_LOG_LEVEL_INFO = 1
103
+ SD_LOG_LEVEL_WARNING = 2
104
+ SD_LOG_LEVEL_ERROR = 3
backend/image_saver.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from os import path, mkdir
3
+ from typing import Any
4
+ from uuid import uuid4
5
+ from backend.models.lcmdiffusion_setting import LCMDiffusionSetting
6
+ from utils import get_image_file_extension
7
+
8
+
9
+ def get_exclude_keys():
10
+ exclude_keys = {
11
+ "init_image": True,
12
+ "generated_images": True,
13
+ "lora": {
14
+ "models_dir": True,
15
+ "path": True,
16
+ },
17
+ "dirs": True,
18
+ "controlnet": {
19
+ "adapter_path": True,
20
+ },
21
+ }
22
+ return exclude_keys
23
+
24
+
25
+ class ImageSaver:
26
+ @staticmethod
27
+ def save_images(
28
+ output_path: str,
29
+ images: Any,
30
+ folder_name: str = "",
31
+ format: str = "PNG",
32
+ jpeg_quality: int = 90,
33
+ lcm_diffusion_setting: LCMDiffusionSetting = None,
34
+ ) -> list[str]:
35
+ gen_id = uuid4()
36
+ image_ids = []
37
+
38
+ if images:
39
+ image_seeds = []
40
+
41
+ for index, image in enumerate(images):
42
+
43
+ image_seed = image.info.get('image_seed')
44
+ if image_seed is not None:
45
+ image_seeds.append(image_seed)
46
+
47
+ if not path.exists(output_path):
48
+ mkdir(output_path)
49
+
50
+ if folder_name:
51
+ out_path = path.join(
52
+ output_path,
53
+ folder_name,
54
+ )
55
+ else:
56
+ out_path = output_path
57
+
58
+ if not path.exists(out_path):
59
+ mkdir(out_path)
60
+ image_extension = get_image_file_extension(format)
61
+ image_file_name = f"{gen_id}-{index+1}{image_extension}"
62
+ image_ids.append(image_file_name)
63
+ image.save(path.join(out_path, image_file_name), quality = jpeg_quality)
64
+ if lcm_diffusion_setting:
65
+ data = lcm_diffusion_setting.model_dump(exclude=get_exclude_keys())
66
+ if image_seeds:
67
+ data['image_seeds'] = image_seeds
68
+ with open(path.join(out_path, f"{gen_id}.json"), "w") as json_file:
69
+ json.dump(
70
+ data,
71
+ json_file,
72
+ indent=4,
73
+ )
74
+ return image_ids
75
+
backend/lcm_text_to_image.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from math import ceil
3
+ from typing import Any, List
4
+ import random
5
+
6
+ import numpy as np
7
+ import torch
8
+ from backend.device import is_openvino_device
9
+ from backend.controlnet import (
10
+ load_controlnet_adapters,
11
+ update_controlnet_arguments,
12
+ )
13
+ from backend.models.lcmdiffusion_setting import (
14
+ DiffusionTask,
15
+ LCMDiffusionSetting,
16
+ LCMLora,
17
+ )
18
+ from backend.openvino.pipelines import (
19
+ get_ov_image_to_image_pipeline,
20
+ get_ov_text_to_image_pipeline,
21
+ ov_load_taesd,
22
+ )
23
+ from backend.pipelines.lcm import (
24
+ get_image_to_image_pipeline,
25
+ get_lcm_model_pipeline,
26
+ load_taesd,
27
+ )
28
+ from backend.pipelines.lcm_lora import get_lcm_lora_pipeline
29
+ from constants import DEVICE, GGUF_THREADS
30
+ from diffusers import LCMScheduler
31
+ from image_ops import resize_pil_image
32
+ from backend.openvino.flux_pipeline import get_flux_pipeline
33
+ from backend.openvino.ov_hc_stablediffusion_pipeline import OvHcLatentConsistency
34
+ from backend.gguf.gguf_diffusion import (
35
+ GGUFDiffusion,
36
+ ModelConfig,
37
+ Txt2ImgConfig,
38
+ SampleMethod,
39
+ )
40
+ from paths import get_app_path
41
+ from pprint import pprint
42
+
43
+ try:
44
+ # support for token merging; keeping it optional for now
45
+ import tomesd
46
+ except ImportError:
47
+ print("tomesd library unavailable; disabling token merging support")
48
+ tomesd = None
49
+
50
+
51
+ class LCMTextToImage:
52
+ def __init__(
53
+ self,
54
+ device: str = "cpu",
55
+ ) -> None:
56
+ self.pipeline = None
57
+ self.use_openvino = False
58
+ self.device = ""
59
+ self.previous_model_id = None
60
+ self.previous_use_tae_sd = False
61
+ self.previous_use_lcm_lora = False
62
+ self.previous_ov_model_id = ""
63
+ self.previous_token_merging = 0.0
64
+ self.previous_safety_checker = False
65
+ self.previous_use_openvino = False
66
+ self.img_to_img_pipeline = None
67
+ self.is_openvino_init = False
68
+ self.previous_lora = None
69
+ self.task_type = DiffusionTask.text_to_image
70
+ self.previous_use_gguf_model = False
71
+ self.previous_gguf_model = None
72
+ self.torch_data_type = (
73
+ torch.float32 if is_openvino_device() or DEVICE == "mps" else torch.float16
74
+ )
75
+ self.ov_model_id = None
76
+ print(f"Torch datatype : {self.torch_data_type}")
77
+
78
+ def _pipeline_to_device(self):
79
+ print(f"Pipeline device : {DEVICE}")
80
+ print(f"Pipeline dtype : {self.torch_data_type}")
81
+ self.pipeline.to(
82
+ torch_device=DEVICE,
83
+ torch_dtype=self.torch_data_type,
84
+ )
85
+
86
+ def _add_freeu(self):
87
+ pipeline_class = self.pipeline.__class__.__name__
88
+ if isinstance(self.pipeline.scheduler, LCMScheduler):
89
+ if pipeline_class == "StableDiffusionPipeline":
90
+ print("Add FreeU - SD")
91
+ self.pipeline.enable_freeu(
92
+ s1=0.9,
93
+ s2=0.2,
94
+ b1=1.2,
95
+ b2=1.4,
96
+ )
97
+ elif pipeline_class == "StableDiffusionXLPipeline":
98
+ print("Add FreeU - SDXL")
99
+ self.pipeline.enable_freeu(
100
+ s1=0.6,
101
+ s2=0.4,
102
+ b1=1.1,
103
+ b2=1.2,
104
+ )
105
+
106
+ def _enable_vae_tiling(self):
107
+ self.pipeline.vae.enable_tiling()
108
+
109
+ def _update_lcm_scheduler_params(self):
110
+ if isinstance(self.pipeline.scheduler, LCMScheduler):
111
+ self.pipeline.scheduler = LCMScheduler.from_config(
112
+ self.pipeline.scheduler.config,
113
+ beta_start=0.001,
114
+ beta_end=0.01,
115
+ )
116
+
117
+ def _is_hetero_pipeline(self) -> bool:
118
+ return "square" in self.ov_model_id.lower()
119
+
120
+ def _load_ov_hetero_pipeline(self):
121
+ print("Loading Heterogeneous Compute pipeline")
122
+ if DEVICE.upper() == "NPU":
123
+ device = ["NPU", "NPU", "NPU"]
124
+ self.pipeline = OvHcLatentConsistency(self.ov_model_id, device)
125
+ else:
126
+ self.pipeline = OvHcLatentConsistency(self.ov_model_id)
127
+
128
+ def _generate_images_hetero_compute(
129
+ self,
130
+ lcm_diffusion_setting: LCMDiffusionSetting,
131
+ ):
132
+ print("Using OpenVINO ")
133
+ if lcm_diffusion_setting.diffusion_task == DiffusionTask.text_to_image.value:
134
+ return [
135
+ self.pipeline.generate(
136
+ prompt=lcm_diffusion_setting.prompt,
137
+ neg_prompt=lcm_diffusion_setting.negative_prompt,
138
+ init_image=None,
139
+ strength=1.0,
140
+ num_inference_steps=lcm_diffusion_setting.inference_steps,
141
+ )
142
+ ]
143
+ else:
144
+ return [
145
+ self.pipeline.generate(
146
+ prompt=lcm_diffusion_setting.prompt,
147
+ neg_prompt=lcm_diffusion_setting.negative_prompt,
148
+ init_image=lcm_diffusion_setting.init_image,
149
+ strength=lcm_diffusion_setting.strength,
150
+ num_inference_steps=lcm_diffusion_setting.inference_steps,
151
+ )
152
+ ]
153
+
154
+ def _is_valid_mode(
155
+ self,
156
+ modes: List,
157
+ ) -> bool:
158
+ return modes.count(True) == 1 or modes.count(False) == 3
159
+
160
+ def _validate_mode(
161
+ self,
162
+ modes: List,
163
+ ) -> None:
164
+ if not self._is_valid_mode(modes):
165
+ raise ValueError("Invalid mode,delete configs/settings.yaml and retry!")
166
+
167
+ def init(
168
+ self,
169
+ device: str = "cpu",
170
+ lcm_diffusion_setting: LCMDiffusionSetting = LCMDiffusionSetting(),
171
+ ) -> None:
172
+ # Mode validation either LCM LoRA or OpenVINO or GGUF
173
+
174
+ modes = [
175
+ lcm_diffusion_setting.use_gguf_model,
176
+ lcm_diffusion_setting.use_openvino,
177
+ lcm_diffusion_setting.use_lcm_lora,
178
+ ]
179
+ self._validate_mode(modes)
180
+ self.device = device
181
+ self.use_openvino = lcm_diffusion_setting.use_openvino
182
+ model_id = lcm_diffusion_setting.lcm_model_id
183
+ use_local_model = lcm_diffusion_setting.use_offline_model
184
+ use_tiny_auto_encoder = lcm_diffusion_setting.use_tiny_auto_encoder
185
+ use_lora = lcm_diffusion_setting.use_lcm_lora
186
+ lcm_lora: LCMLora = lcm_diffusion_setting.lcm_lora
187
+ token_merging = lcm_diffusion_setting.token_merging
188
+ self.ov_model_id = lcm_diffusion_setting.openvino_lcm_model_id
189
+
190
+ if lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value:
191
+ lcm_diffusion_setting.init_image = resize_pil_image(
192
+ lcm_diffusion_setting.init_image,
193
+ lcm_diffusion_setting.image_width,
194
+ lcm_diffusion_setting.image_height,
195
+ )
196
+
197
+ if (
198
+ self.pipeline is None
199
+ or self.previous_model_id != model_id
200
+ or self.previous_use_tae_sd != use_tiny_auto_encoder
201
+ or self.previous_lcm_lora_base_id != lcm_lora.base_model_id
202
+ or self.previous_lcm_lora_id != lcm_lora.lcm_lora_id
203
+ or self.previous_use_lcm_lora != use_lora
204
+ or self.previous_ov_model_id != self.ov_model_id
205
+ or self.previous_token_merging != token_merging
206
+ or self.previous_safety_checker != lcm_diffusion_setting.use_safety_checker
207
+ or self.previous_use_openvino != lcm_diffusion_setting.use_openvino
208
+ or self.previous_use_gguf_model != lcm_diffusion_setting.use_gguf_model
209
+ or self.previous_gguf_model != lcm_diffusion_setting.gguf_model
210
+ or (
211
+ self.use_openvino
212
+ and (
213
+ self.previous_task_type != lcm_diffusion_setting.diffusion_task
214
+ or self.previous_lora != lcm_diffusion_setting.lora
215
+ )
216
+ )
217
+ or lcm_diffusion_setting.rebuild_pipeline
218
+ ):
219
+ if self.use_openvino and is_openvino_device():
220
+ if self.pipeline:
221
+ del self.pipeline
222
+ self.pipeline = None
223
+ gc.collect()
224
+ self.is_openvino_init = True
225
+ if (
226
+ lcm_diffusion_setting.diffusion_task
227
+ == DiffusionTask.text_to_image.value
228
+ ):
229
+ print(
230
+ f"***** Init Text to image (OpenVINO) - {self.ov_model_id} *****"
231
+ )
232
+ if "flux" in self.ov_model_id.lower():
233
+ print("Loading OpenVINO Flux pipeline")
234
+ self.pipeline = get_flux_pipeline(
235
+ self.ov_model_id,
236
+ lcm_diffusion_setting.use_tiny_auto_encoder,
237
+ )
238
+ elif self._is_hetero_pipeline():
239
+ self._load_ov_hetero_pipeline()
240
+ else:
241
+ self.pipeline = get_ov_text_to_image_pipeline(
242
+ self.ov_model_id,
243
+ use_local_model,
244
+ )
245
+ elif (
246
+ lcm_diffusion_setting.diffusion_task
247
+ == DiffusionTask.image_to_image.value
248
+ ):
249
+ if not self.pipeline and self._is_hetero_pipeline():
250
+ self._load_ov_hetero_pipeline()
251
+ else:
252
+ print(
253
+ f"***** Image to image (OpenVINO) - {self.ov_model_id} *****"
254
+ )
255
+ self.pipeline = get_ov_image_to_image_pipeline(
256
+ self.ov_model_id,
257
+ use_local_model,
258
+ )
259
+ elif lcm_diffusion_setting.use_gguf_model:
260
+ model = lcm_diffusion_setting.gguf_model.diffusion_path
261
+ print(f"***** Init Text to image (GGUF) - {model} *****")
262
+ # if self.pipeline:
263
+ # self.pipeline.terminate()
264
+ # del self.pipeline
265
+ # self.pipeline = None
266
+ self._init_gguf_diffusion(lcm_diffusion_setting)
267
+ else:
268
+ if self.pipeline or self.img_to_img_pipeline:
269
+ self.pipeline = None
270
+ self.img_to_img_pipeline = None
271
+ gc.collect()
272
+
273
+ controlnet_args = load_controlnet_adapters(lcm_diffusion_setting)
274
+ if use_lora:
275
+ print(
276
+ f"***** Init LCM-LoRA pipeline - {lcm_lora.base_model_id} *****"
277
+ )
278
+ self.pipeline = get_lcm_lora_pipeline(
279
+ lcm_lora.base_model_id,
280
+ lcm_lora.lcm_lora_id,
281
+ use_local_model,
282
+ torch_data_type=self.torch_data_type,
283
+ pipeline_args=controlnet_args,
284
+ )
285
+
286
+ else:
287
+ print(f"***** Init LCM Model pipeline - {model_id} *****")
288
+ self.pipeline = get_lcm_model_pipeline(
289
+ model_id,
290
+ use_local_model,
291
+ controlnet_args,
292
+ )
293
+
294
+ self.img_to_img_pipeline = get_image_to_image_pipeline(self.pipeline)
295
+
296
+ if tomesd and token_merging > 0.001:
297
+ print(f"***** Token Merging: {token_merging} *****")
298
+ tomesd.apply_patch(self.pipeline, ratio=token_merging)
299
+ tomesd.apply_patch(self.img_to_img_pipeline, ratio=token_merging)
300
+
301
+ if use_tiny_auto_encoder:
302
+ if self.use_openvino and is_openvino_device():
303
+ if self.pipeline.__class__.__name__ != "OVFluxPipeline":
304
+ print("Using Tiny Auto Encoder (OpenVINO)")
305
+ ov_load_taesd(
306
+ self.pipeline,
307
+ use_local_model,
308
+ )
309
+ else:
310
+ print("Using Tiny Auto Encoder")
311
+ load_taesd(
312
+ self.pipeline,
313
+ use_local_model,
314
+ self.torch_data_type,
315
+ )
316
+ load_taesd(
317
+ self.img_to_img_pipeline,
318
+ use_local_model,
319
+ self.torch_data_type,
320
+ )
321
+
322
+ if not self.use_openvino and not is_openvino_device():
323
+ self._pipeline_to_device()
324
+
325
+ if not self._is_hetero_pipeline():
326
+ if (
327
+ lcm_diffusion_setting.diffusion_task
328
+ == DiffusionTask.image_to_image.value
329
+ and lcm_diffusion_setting.use_openvino
330
+ ):
331
+ self.pipeline.scheduler = LCMScheduler.from_config(
332
+ self.pipeline.scheduler.config,
333
+ )
334
+ else:
335
+ if not lcm_diffusion_setting.use_gguf_model:
336
+ self._update_lcm_scheduler_params()
337
+
338
+ if use_lora:
339
+ self._add_freeu()
340
+
341
+ self.previous_model_id = model_id
342
+ self.previous_ov_model_id = self.ov_model_id
343
+ self.previous_use_tae_sd = use_tiny_auto_encoder
344
+ self.previous_lcm_lora_base_id = lcm_lora.base_model_id
345
+ self.previous_lcm_lora_id = lcm_lora.lcm_lora_id
346
+ self.previous_use_lcm_lora = use_lora
347
+ self.previous_token_merging = lcm_diffusion_setting.token_merging
348
+ self.previous_safety_checker = lcm_diffusion_setting.use_safety_checker
349
+ self.previous_use_openvino = lcm_diffusion_setting.use_openvino
350
+ self.previous_task_type = lcm_diffusion_setting.diffusion_task
351
+ self.previous_lora = lcm_diffusion_setting.lora.model_copy(deep=True)
352
+ self.previous_use_gguf_model = lcm_diffusion_setting.use_gguf_model
353
+ self.previous_gguf_model = lcm_diffusion_setting.gguf_model.model_copy(
354
+ deep=True
355
+ )
356
+ lcm_diffusion_setting.rebuild_pipeline = False
357
+ if (
358
+ lcm_diffusion_setting.diffusion_task
359
+ == DiffusionTask.text_to_image.value
360
+ ):
361
+ print(f"Pipeline : {self.pipeline}")
362
+ elif (
363
+ lcm_diffusion_setting.diffusion_task
364
+ == DiffusionTask.image_to_image.value
365
+ ):
366
+ if self.use_openvino and is_openvino_device():
367
+ print(f"Pipeline : {self.pipeline}")
368
+ else:
369
+ print(f"Pipeline : {self.img_to_img_pipeline}")
370
+ if self.use_openvino:
371
+ if lcm_diffusion_setting.lora.enabled:
372
+ print("Warning: Lora models not supported on OpenVINO mode")
373
+ elif not lcm_diffusion_setting.use_gguf_model:
374
+ adapters = self.pipeline.get_active_adapters()
375
+ print(f"Active adapters : {adapters}")
376
+
377
+ def _get_timesteps(self):
378
+ time_steps = self.pipeline.scheduler.config.get("timesteps")
379
+ time_steps_value = [int(time_steps)] if time_steps else None
380
+ return time_steps_value
381
+
382
+ def generate(
383
+ self,
384
+ lcm_diffusion_setting: LCMDiffusionSetting,
385
+ reshape: bool = False,
386
+ ) -> Any:
387
+ guidance_scale = lcm_diffusion_setting.guidance_scale
388
+ img_to_img_inference_steps = lcm_diffusion_setting.inference_steps
389
+ check_step_value = int(
390
+ lcm_diffusion_setting.inference_steps * lcm_diffusion_setting.strength
391
+ )
392
+ if (
393
+ lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value
394
+ and check_step_value < 1
395
+ ):
396
+ img_to_img_inference_steps = ceil(1 / lcm_diffusion_setting.strength)
397
+ print(
398
+ f"Strength: {lcm_diffusion_setting.strength},{img_to_img_inference_steps}"
399
+ )
400
+
401
+ pipeline_extra_args = {}
402
+
403
+ if lcm_diffusion_setting.use_seed:
404
+ cur_seed = lcm_diffusion_setting.seed
405
+ # for multiple images with a fixed seed, use sequential seeds
406
+ seeds = [
407
+ (cur_seed + i) for i in range(lcm_diffusion_setting.number_of_images)
408
+ ]
409
+ else:
410
+ seeds = [
411
+ random.randint(0, 999999999)
412
+ for i in range(lcm_diffusion_setting.number_of_images)
413
+ ]
414
+
415
+ if self.use_openvino:
416
+ # no support for generators; try at least to ensure reproducible results for single images
417
+ np.random.seed(seeds[0])
418
+ if self._is_hetero_pipeline():
419
+ torch.manual_seed(seeds[0])
420
+ lcm_diffusion_setting.seed = seeds[0]
421
+ else:
422
+ pipeline_extra_args["generator"] = [
423
+ torch.Generator(device=self.device).manual_seed(s) for s in seeds
424
+ ]
425
+
426
+ is_openvino_pipe = lcm_diffusion_setting.use_openvino and is_openvino_device()
427
+ if is_openvino_pipe and not self._is_hetero_pipeline():
428
+ print("Using OpenVINO")
429
+ if reshape and not self.is_openvino_init:
430
+ print("Reshape and compile")
431
+ self.pipeline.reshape(
432
+ batch_size=-1,
433
+ height=lcm_diffusion_setting.image_height,
434
+ width=lcm_diffusion_setting.image_width,
435
+ num_images_per_prompt=lcm_diffusion_setting.number_of_images,
436
+ )
437
+ self.pipeline.compile()
438
+
439
+ if self.is_openvino_init:
440
+ self.is_openvino_init = False
441
+
442
+ if is_openvino_pipe and self._is_hetero_pipeline():
443
+ return self._generate_images_hetero_compute(lcm_diffusion_setting)
444
+ elif lcm_diffusion_setting.use_gguf_model:
445
+ return self._generate_images_gguf(lcm_diffusion_setting)
446
+
447
+ if lcm_diffusion_setting.clip_skip > 1:
448
+ # We follow the convention that "CLIP Skip == 2" means "skip
449
+ # the last layer", so "CLIP Skip == 1" means "no skipping"
450
+ pipeline_extra_args["clip_skip"] = lcm_diffusion_setting.clip_skip - 1
451
+
452
+ if not lcm_diffusion_setting.use_safety_checker:
453
+ self.pipeline.safety_checker = None
454
+ if (
455
+ lcm_diffusion_setting.diffusion_task
456
+ == DiffusionTask.image_to_image.value
457
+ and not is_openvino_pipe
458
+ ):
459
+ self.img_to_img_pipeline.safety_checker = None
460
+
461
+ if (
462
+ not lcm_diffusion_setting.use_lcm_lora
463
+ and not lcm_diffusion_setting.use_openvino
464
+ and lcm_diffusion_setting.guidance_scale != 1.0
465
+ ):
466
+ print("Not using LCM-LoRA so setting guidance_scale 1.0")
467
+ guidance_scale = 1.0
468
+
469
+ controlnet_args = update_controlnet_arguments(lcm_diffusion_setting)
470
+ if lcm_diffusion_setting.use_openvino:
471
+ if (
472
+ lcm_diffusion_setting.diffusion_task
473
+ == DiffusionTask.text_to_image.value
474
+ ):
475
+ result_images = self.pipeline(
476
+ prompt=lcm_diffusion_setting.prompt,
477
+ negative_prompt=lcm_diffusion_setting.negative_prompt,
478
+ num_inference_steps=lcm_diffusion_setting.inference_steps,
479
+ guidance_scale=guidance_scale,
480
+ width=lcm_diffusion_setting.image_width,
481
+ height=lcm_diffusion_setting.image_height,
482
+ num_images_per_prompt=lcm_diffusion_setting.number_of_images,
483
+ ).images
484
+ elif (
485
+ lcm_diffusion_setting.diffusion_task
486
+ == DiffusionTask.image_to_image.value
487
+ ):
488
+ result_images = self.pipeline(
489
+ image=lcm_diffusion_setting.init_image,
490
+ strength=lcm_diffusion_setting.strength,
491
+ prompt=lcm_diffusion_setting.prompt,
492
+ negative_prompt=lcm_diffusion_setting.negative_prompt,
493
+ num_inference_steps=img_to_img_inference_steps * 3,
494
+ guidance_scale=guidance_scale,
495
+ num_images_per_prompt=lcm_diffusion_setting.number_of_images,
496
+ ).images
497
+
498
+ else:
499
+ if (
500
+ lcm_diffusion_setting.diffusion_task
501
+ == DiffusionTask.text_to_image.value
502
+ ):
503
+ result_images = self.pipeline(
504
+ prompt=lcm_diffusion_setting.prompt,
505
+ negative_prompt=lcm_diffusion_setting.negative_prompt,
506
+ num_inference_steps=lcm_diffusion_setting.inference_steps,
507
+ guidance_scale=guidance_scale,
508
+ width=lcm_diffusion_setting.image_width,
509
+ height=lcm_diffusion_setting.image_height,
510
+ num_images_per_prompt=lcm_diffusion_setting.number_of_images,
511
+ timesteps=self._get_timesteps(),
512
+ **pipeline_extra_args,
513
+ **controlnet_args,
514
+ ).images
515
+
516
+ elif (
517
+ lcm_diffusion_setting.diffusion_task
518
+ == DiffusionTask.image_to_image.value
519
+ ):
520
+ result_images = self.img_to_img_pipeline(
521
+ image=lcm_diffusion_setting.init_image,
522
+ strength=lcm_diffusion_setting.strength,
523
+ prompt=lcm_diffusion_setting.prompt,
524
+ negative_prompt=lcm_diffusion_setting.negative_prompt,
525
+ num_inference_steps=img_to_img_inference_steps,
526
+ guidance_scale=guidance_scale,
527
+ width=lcm_diffusion_setting.image_width,
528
+ height=lcm_diffusion_setting.image_height,
529
+ num_images_per_prompt=lcm_diffusion_setting.number_of_images,
530
+ **pipeline_extra_args,
531
+ **controlnet_args,
532
+ ).images
533
+
534
+ for i, seed in enumerate(seeds):
535
+ result_images[i].info["image_seed"] = seed
536
+
537
+ return result_images
538
+
539
+ def _init_gguf_diffusion(
540
+ self,
541
+ lcm_diffusion_setting: LCMDiffusionSetting,
542
+ ):
543
+ config = ModelConfig()
544
+ config.model_path = lcm_diffusion_setting.gguf_model.diffusion_path
545
+ config.diffusion_model_path = lcm_diffusion_setting.gguf_model.diffusion_path
546
+ config.clip_l_path = lcm_diffusion_setting.gguf_model.clip_path
547
+ config.t5xxl_path = lcm_diffusion_setting.gguf_model.t5xxl_path
548
+ config.vae_path = lcm_diffusion_setting.gguf_model.vae_path
549
+ config.n_threads = GGUF_THREADS
550
+ print(f"GGUF Threads : {GGUF_THREADS} ")
551
+ print("GGUF - Model config")
552
+ pprint(lcm_diffusion_setting.gguf_model.model_dump())
553
+ self.pipeline = GGUFDiffusion(
554
+ get_app_path(), # Place DLL in fastsdcpu folder
555
+ config,
556
+ True,
557
+ )
558
+
559
+ def _generate_images_gguf(
560
+ self,
561
+ lcm_diffusion_setting: LCMDiffusionSetting,
562
+ ):
563
+ if lcm_diffusion_setting.diffusion_task == DiffusionTask.text_to_image.value:
564
+ t2iconfig = Txt2ImgConfig()
565
+ t2iconfig.prompt = lcm_diffusion_setting.prompt
566
+ t2iconfig.batch_count = lcm_diffusion_setting.number_of_images
567
+ t2iconfig.cfg_scale = lcm_diffusion_setting.guidance_scale
568
+ t2iconfig.height = lcm_diffusion_setting.image_height
569
+ t2iconfig.width = lcm_diffusion_setting.image_width
570
+ t2iconfig.sample_steps = lcm_diffusion_setting.inference_steps
571
+ t2iconfig.sample_method = SampleMethod.EULER
572
+ if lcm_diffusion_setting.use_seed:
573
+ t2iconfig.seed = lcm_diffusion_setting.seed
574
+ else:
575
+ t2iconfig.seed = -1
576
+
577
+ return self.pipeline.generate_text2mg(t2iconfig)
backend/lora.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ from os import path
3
+ from paths import get_file_name, FastStableDiffusionPaths
4
+ from pathlib import Path
5
+
6
+
7
+ # A basic class to keep track of the currently loaded LoRAs and
8
+ # their weights; the diffusers function \c get_active_adapters()
9
+ # returns a list of adapter names but not their weights so we need
10
+ # a way to keep track of the current LoRA weights to set whenever
11
+ # a new LoRA is loaded
12
+ class _lora_info:
13
+ def __init__(
14
+ self,
15
+ path: str,
16
+ weight: float,
17
+ ):
18
+ self.path = path
19
+ self.adapter_name = get_file_name(path)
20
+ self.weight = weight
21
+
22
+ def __del__(self):
23
+ self.path = None
24
+ self.adapter_name = None
25
+
26
+
27
+ _loaded_loras = []
28
+ _current_pipeline = None
29
+
30
+
31
+ # This function loads a LoRA from the LoRA path setting, so it's
32
+ # possible to load multiple LoRAs by calling this function more than
33
+ # once with a different LoRA path setting; note that if you plan to
34
+ # load multiple LoRAs and dynamically change their weights, you
35
+ # might want to set the LoRA fuse option to False
36
+ def load_lora_weight(
37
+ pipeline,
38
+ lcm_diffusion_setting,
39
+ ):
40
+ if not lcm_diffusion_setting.lora.path:
41
+ raise Exception("Empty lora model path")
42
+
43
+ if not path.exists(lcm_diffusion_setting.lora.path):
44
+ raise Exception("Lora model path is invalid")
45
+
46
+ # If the pipeline has been rebuilt since the last call, remove all
47
+ # references to previously loaded LoRAs and store the new pipeline
48
+ global _loaded_loras
49
+ global _current_pipeline
50
+ if pipeline != _current_pipeline:
51
+ for lora in _loaded_loras:
52
+ del lora
53
+ del _loaded_loras
54
+ _loaded_loras = []
55
+ _current_pipeline = pipeline
56
+
57
+ current_lora = _lora_info(
58
+ lcm_diffusion_setting.lora.path,
59
+ lcm_diffusion_setting.lora.weight,
60
+ )
61
+ _loaded_loras.append(current_lora)
62
+
63
+ if lcm_diffusion_setting.lora.enabled:
64
+ print(f"LoRA adapter name : {current_lora.adapter_name}")
65
+ pipeline.load_lora_weights(
66
+ FastStableDiffusionPaths.get_lora_models_path(),
67
+ weight_name=Path(lcm_diffusion_setting.lora.path).name,
68
+ local_files_only=True,
69
+ adapter_name=current_lora.adapter_name,
70
+ )
71
+ update_lora_weights(
72
+ pipeline,
73
+ lcm_diffusion_setting,
74
+ )
75
+
76
+ if lcm_diffusion_setting.lora.fuse:
77
+ pipeline.fuse_lora()
78
+
79
+
80
+ def get_lora_models(root_dir: str):
81
+ lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True)
82
+ lora_models_map = {}
83
+ for file_path in lora_models:
84
+ lora_name = get_file_name(file_path)
85
+ if lora_name is not None:
86
+ lora_models_map[lora_name] = file_path
87
+ return lora_models_map
88
+
89
+
90
+ # This function returns a list of (adapter_name, weight) tuples for the
91
+ # currently loaded LoRAs
92
+ def get_active_lora_weights():
93
+ active_loras = []
94
+ for lora_info in _loaded_loras:
95
+ active_loras.append(
96
+ (
97
+ lora_info.adapter_name,
98
+ lora_info.weight,
99
+ )
100
+ )
101
+ return active_loras
102
+
103
+
104
+ # This function receives a pipeline, an lcm_diffusion_setting object and
105
+ # an optional list of updated (adapter_name, weight) tuples
106
+ def update_lora_weights(
107
+ pipeline,
108
+ lcm_diffusion_setting,
109
+ lora_weights=None,
110
+ ):
111
+ global _loaded_loras
112
+ global _current_pipeline
113
+ if pipeline != _current_pipeline:
114
+ print("Wrong pipeline when trying to update LoRA weights")
115
+ return
116
+ if lora_weights:
117
+ for idx, lora in enumerate(lora_weights):
118
+ if _loaded_loras[idx].adapter_name != lora[0]:
119
+ print("Wrong adapter name in LoRA enumeration!")
120
+ continue
121
+ _loaded_loras[idx].weight = lora[1]
122
+
123
+ adapter_names = []
124
+ adapter_weights = []
125
+ if lcm_diffusion_setting.use_lcm_lora:
126
+ adapter_names.append("lcm")
127
+ adapter_weights.append(1.0)
128
+ for lora in _loaded_loras:
129
+ adapter_names.append(lora.adapter_name)
130
+ adapter_weights.append(lora.weight)
131
+ pipeline.set_adapters(
132
+ adapter_names,
133
+ adapter_weights=adapter_weights,
134
+ )
135
+ adapter_weights = zip(adapter_names, adapter_weights)
136
+ print(f"Adapters: {list(adapter_weights)}")
backend/models/device.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class DeviceInfo(BaseModel):
5
+ device_type: str
6
+ device_name: str
7
+ os: str
8
+ platform: str
9
+ processor: str
backend/models/gen_images.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from enum import Enum
3
+ from paths import FastStableDiffusionPaths
4
+
5
+
6
+ class ImageFormat(str, Enum):
7
+ """Image format"""
8
+
9
+ JPEG = "jpeg"
10
+ PNG = "png"
11
+
12
+
13
+ class GeneratedImages(BaseModel):
14
+ path: str = FastStableDiffusionPaths.get_results_path()
15
+ format: str = ImageFormat.PNG.value.upper()
16
+ save_image: bool = True
17
+ save_image_quality: int = 90
backend/models/lcmdiffusion_setting.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from PIL import Image
3
+ from typing import Any, Optional, Union
4
+
5
+ from constants import LCM_DEFAULT_MODEL, LCM_DEFAULT_MODEL_OPENVINO
6
+ from paths import FastStableDiffusionPaths
7
+ from pydantic import BaseModel
8
+
9
+
10
+ class LCMLora(BaseModel):
11
+ base_model_id: str = "Lykon/dreamshaper-8"
12
+ lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
13
+
14
+
15
+ class DiffusionTask(str, Enum):
16
+ """Diffusion task types"""
17
+
18
+ text_to_image = "text_to_image"
19
+ image_to_image = "image_to_image"
20
+
21
+
22
+ class Lora(BaseModel):
23
+ models_dir: str = FastStableDiffusionPaths.get_lora_models_path()
24
+ path: Optional[Any] = None
25
+ weight: Optional[float] = 0.5
26
+ fuse: bool = True
27
+ enabled: bool = False
28
+
29
+
30
+ class ControlNetSetting(BaseModel):
31
+ adapter_path: Optional[str] = None # ControlNet adapter path
32
+ conditioning_scale: float = 0.5
33
+ enabled: bool = False
34
+ _control_image: Image = None # Control image, PIL image
35
+
36
+
37
+ class GGUFModel(BaseModel):
38
+ gguf_models: str = FastStableDiffusionPaths.get_gguf_models_path()
39
+ diffusion_path: Optional[str] = None
40
+ clip_path: Optional[str] = None
41
+ t5xxl_path: Optional[str] = None
42
+ vae_path: Optional[str] = None
43
+
44
+
45
+ class LCMDiffusionSetting(BaseModel):
46
+ lcm_model_id: str = LCM_DEFAULT_MODEL
47
+ openvino_lcm_model_id: str = LCM_DEFAULT_MODEL_OPENVINO
48
+ use_offline_model: bool = False
49
+ use_lcm_lora: bool = False
50
+ lcm_lora: Optional[LCMLora] = LCMLora()
51
+ use_tiny_auto_encoder: bool = False
52
+ use_openvino: bool = False
53
+ prompt: str = ""
54
+ negative_prompt: str = ""
55
+ init_image: Any = None
56
+ strength: Optional[float] = 0.6
57
+ image_height: Optional[int] = 512
58
+ image_width: Optional[int] = 512
59
+ inference_steps: Optional[int] = 1
60
+ guidance_scale: Optional[float] = 1
61
+ clip_skip: Optional[int] = 1
62
+ token_merging: Optional[float] = 0
63
+ number_of_images: Optional[int] = 1
64
+ seed: Optional[int] = 123123
65
+ use_seed: bool = False
66
+ use_safety_checker: bool = False
67
+ diffusion_task: str = DiffusionTask.text_to_image.value
68
+ lora: Optional[Lora] = Lora()
69
+ controlnet: Optional[Union[ControlNetSetting, list[ControlNetSetting]]] = None
70
+ dirs: dict = {
71
+ "controlnet": FastStableDiffusionPaths.get_controlnet_models_path(),
72
+ "lora": FastStableDiffusionPaths.get_lora_models_path(),
73
+ }
74
+ rebuild_pipeline: bool = False
75
+ use_gguf_model: bool = False
76
+ gguf_model: Optional[GGUFModel] = GGUFModel()
backend/models/upscale.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class UpscaleMode(str, Enum):
5
+ """Diffusion task types"""
6
+
7
+ normal = "normal"
8
+ sd_upscale = "sd_upscale"
9
+ aura_sr = "aura_sr"
backend/openvino/custom_ov_model_vae_decoder.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.device import is_openvino_device
2
+
3
+ if is_openvino_device():
4
+ from optimum.intel.openvino.modeling_diffusion import OVModelVaeDecoder
5
+
6
+
7
+ class CustomOVModelVaeDecoder(OVModelVaeDecoder):
8
+ def __init__(
9
+ self,
10
+ model,
11
+ parent_model,
12
+ ov_config=None,
13
+ model_dir=None,
14
+ ):
15
+ super(OVModelVaeDecoder, self).__init__(
16
+ model,
17
+ parent_model,
18
+ ov_config,
19
+ "vae_decoder",
20
+ model_dir,
21
+ )
backend/openvino/flux_pipeline.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from constants import DEVICE, LCM_DEFAULT_MODEL_OPENVINO, TAEF1_MODEL_OPENVINO
4
+ from huggingface_hub import snapshot_download
5
+
6
+ from backend.openvino.ovflux import (
7
+ TEXT_ENCODER_2_PATH,
8
+ TEXT_ENCODER_PATH,
9
+ TRANSFORMER_PATH,
10
+ VAE_DECODER_PATH,
11
+ init_pipeline,
12
+ )
13
+
14
+
15
+ def get_flux_pipeline(
16
+ model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
17
+ use_taef1: bool = False,
18
+ taef1_path: str = TAEF1_MODEL_OPENVINO,
19
+ ):
20
+ model_dir = Path(snapshot_download(model_id))
21
+ vae_dir = Path(snapshot_download(taef1_path)) if use_taef1 else model_dir
22
+
23
+ model_dict = {
24
+ "transformer": model_dir / TRANSFORMER_PATH,
25
+ "text_encoder": model_dir / TEXT_ENCODER_PATH,
26
+ "text_encoder_2": model_dir / TEXT_ENCODER_2_PATH,
27
+ "vae": vae_dir / VAE_DECODER_PATH,
28
+ }
29
+ ov_pipe = init_pipeline(
30
+ model_dir,
31
+ model_dict,
32
+ device=DEVICE.upper(),
33
+ use_taef1=use_taef1,
34
+ )
35
+
36
+ return ov_pipe
backend/openvino/ov_hc_stablediffusion_pipeline.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is an experimental pipeline used to test AI PC NPU and GPU"""
2
+
3
+ from pathlib import Path
4
+
5
+ from diffusers import EulerDiscreteScheduler,LCMScheduler
6
+ from huggingface_hub import snapshot_download
7
+ from PIL import Image
8
+ from backend.openvino.stable_diffusion_engine import (
9
+ StableDiffusionEngineAdvanced,
10
+ LatentConsistencyEngineAdvanced
11
+ )
12
+
13
+
14
+ class OvHcStableDiffusion:
15
+ "OpenVINO Heterogeneous compute Stablediffusion"
16
+
17
+ def __init__(
18
+ self,
19
+ model_path,
20
+ device: list = ["GPU", "NPU", "GPU", "GPU"],
21
+ ):
22
+ model_dir = Path(snapshot_download(model_path))
23
+ self.scheduler = EulerDiscreteScheduler(
24
+ beta_start=0.00085,
25
+ beta_end=0.012,
26
+ beta_schedule="scaled_linear",
27
+ )
28
+ self.ov_sd_pipleline = StableDiffusionEngineAdvanced(
29
+ model=model_dir,
30
+ device=device,
31
+ )
32
+
33
+ def generate(
34
+ self,
35
+ prompt: str,
36
+ neg_prompt: str,
37
+ init_image: Image = None,
38
+ strength: float = 1.0,
39
+ ):
40
+ image = self.ov_sd_pipleline(
41
+ prompt=prompt,
42
+ negative_prompt=neg_prompt,
43
+ init_image=init_image,
44
+ strength=strength,
45
+ num_inference_steps=25,
46
+ scheduler=self.scheduler,
47
+ )
48
+ image_rgb = image[..., ::-1]
49
+ return Image.fromarray(image_rgb)
50
+
51
+
52
+ class OvHcLatentConsistency:
53
+ """
54
+ OpenVINO Heterogeneous compute Latent consistency models
55
+ For the current Intel Cor Ultra, the Text Encoder and Unet can run on NPU
56
+ Supports following - Text to image , Image to image and image variations
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ model_path,
62
+ device: list = ["NPU", "NPU", "GPU"],
63
+ ):
64
+
65
+ model_dir = Path(snapshot_download(model_path))
66
+
67
+ self.scheduler = LCMScheduler(
68
+ beta_start=0.001,
69
+ beta_end=0.01,
70
+ )
71
+ self.ov_sd_pipleline = LatentConsistencyEngineAdvanced(
72
+ model=model_dir,
73
+ device=device,
74
+ )
75
+
76
+ def generate(
77
+ self,
78
+ prompt: str,
79
+ neg_prompt: str,
80
+ init_image: Image = None,
81
+ num_inference_steps=4,
82
+ strength: float = 0.5,
83
+ ):
84
+ image = self.ov_sd_pipleline(
85
+ prompt=prompt,
86
+ init_image = init_image,
87
+ strength = strength,
88
+ num_inference_steps=num_inference_steps,
89
+ scheduler=self.scheduler,
90
+ seed=None,
91
+ )
92
+
93
+ return image
backend/openvino/ovflux.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Based on https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/flux.1-image-generation/flux_helper.py"""
2
+
3
+ import inspect
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import openvino as ov
10
+ import torch
11
+ from diffusers.image_processor import VaeImageProcessor
12
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
13
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
14
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
15
+ from diffusers.utils.torch_utils import randn_tensor
16
+ from transformers import AutoTokenizer
17
+
18
+ TRANSFORMER_PATH = Path("transformer/transformer.xml")
19
+ VAE_DECODER_PATH = Path("vae/vae_decoder.xml")
20
+ TEXT_ENCODER_PATH = Path("text_encoder/text_encoder.xml")
21
+ TEXT_ENCODER_2_PATH = Path("text_encoder_2/text_encoder_2.xml")
22
+
23
+
24
+ def cleanup_torchscript_cache():
25
+ """
26
+ Helper for removing cached model representation
27
+ """
28
+ torch._C._jit_clear_class_registry()
29
+ torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
30
+ torch.jit._state._clear_class_state()
31
+
32
+
33
+ def _prepare_latent_image_ids(
34
+ batch_size, height, width, device=torch.device("cpu"), dtype=torch.float32
35
+ ):
36
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
37
+ latent_image_ids[..., 1] = (
38
+ latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
39
+ )
40
+ latent_image_ids[..., 2] = (
41
+ latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
42
+ )
43
+
44
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
45
+ latent_image_ids.shape
46
+ )
47
+
48
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
49
+ latent_image_ids = latent_image_ids.reshape(
50
+ batch_size,
51
+ latent_image_id_height * latent_image_id_width,
52
+ latent_image_id_channels,
53
+ )
54
+
55
+ return latent_image_ids.to(device=device, dtype=dtype)
56
+
57
+
58
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
59
+ assert dim % 2 == 0, "The dimension must be even."
60
+
61
+ scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
62
+ omega = 1.0 / (theta**scale)
63
+
64
+ batch_size, seq_length = pos.shape
65
+ out = pos.unsqueeze(-1) * omega.unsqueeze(0).unsqueeze(0)
66
+ cos_out = torch.cos(out)
67
+ sin_out = torch.sin(out)
68
+
69
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
70
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
71
+ return out.float()
72
+
73
+
74
+ def calculate_shift(
75
+ image_seq_len,
76
+ base_seq_len: int = 256,
77
+ max_seq_len: int = 4096,
78
+ base_shift: float = 0.5,
79
+ max_shift: float = 1.16,
80
+ ):
81
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
82
+ b = base_shift - m * base_seq_len
83
+ mu = image_seq_len * m + b
84
+ return mu
85
+
86
+
87
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
88
+ def retrieve_timesteps(
89
+ scheduler,
90
+ num_inference_steps: Optional[int] = None,
91
+ timesteps: Optional[List[int]] = None,
92
+ sigmas: Optional[List[float]] = None,
93
+ **kwargs,
94
+ ):
95
+ """
96
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
97
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
98
+
99
+ Args:
100
+ scheduler (`SchedulerMixin`):
101
+ The scheduler to get timesteps from.
102
+ num_inference_steps (`int`):
103
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
104
+ must be `None`.
105
+ device (`str` or `torch.device`, *optional*):
106
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
107
+ timesteps (`List[int]`, *optional*):
108
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
109
+ `num_inference_steps` and `sigmas` must be `None`.
110
+ sigmas (`List[float]`, *optional*):
111
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
112
+ `num_inference_steps` and `timesteps` must be `None`.
113
+
114
+ Returns:
115
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
116
+ second element is the number of inference steps.
117
+ """
118
+ if timesteps is not None and sigmas is not None:
119
+ raise ValueError(
120
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
121
+ )
122
+ if timesteps is not None:
123
+ accepts_timesteps = "timesteps" in set(
124
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
125
+ )
126
+ if not accepts_timesteps:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" timestep schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(timesteps=timesteps, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ elif sigmas is not None:
135
+ accept_sigmas = "sigmas" in set(
136
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
137
+ )
138
+ if not accept_sigmas:
139
+ raise ValueError(
140
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
141
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
142
+ )
143
+ scheduler.set_timesteps(sigmas=sigmas, **kwargs)
144
+ timesteps = scheduler.timesteps
145
+ num_inference_steps = len(timesteps)
146
+ else:
147
+ scheduler.set_timesteps(num_inference_steps, **kwargs)
148
+ timesteps = scheduler.timesteps
149
+ return timesteps, num_inference_steps
150
+
151
+
152
+ class OVFluxPipeline(DiffusionPipeline):
153
+ def __init__(
154
+ self,
155
+ scheduler,
156
+ transformer,
157
+ vae,
158
+ text_encoder,
159
+ text_encoder_2,
160
+ tokenizer,
161
+ tokenizer_2,
162
+ transformer_config,
163
+ vae_config,
164
+ ):
165
+ super().__init__()
166
+
167
+ self.register_modules(
168
+ vae=vae,
169
+ text_encoder=text_encoder,
170
+ text_encoder_2=text_encoder_2,
171
+ tokenizer=tokenizer,
172
+ tokenizer_2=tokenizer_2,
173
+ transformer=transformer,
174
+ scheduler=scheduler,
175
+ )
176
+ self.vae_config = vae_config
177
+ self.transformer_config = transformer_config
178
+ self.vae_scale_factor = 2 ** (
179
+ len(self.vae_config.get("block_out_channels", [0] * 16))
180
+ if hasattr(self, "vae") and self.vae is not None
181
+ else 16
182
+ )
183
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
184
+ self.tokenizer_max_length = (
185
+ self.tokenizer.model_max_length
186
+ if hasattr(self, "tokenizer") and self.tokenizer is not None
187
+ else 77
188
+ )
189
+ self.default_sample_size = 64
190
+
191
+ def _get_t5_prompt_embeds(
192
+ self,
193
+ prompt: Union[str, List[str]] = None,
194
+ num_images_per_prompt: int = 1,
195
+ max_sequence_length: int = 512,
196
+ ):
197
+ prompt = [prompt] if isinstance(prompt, str) else prompt
198
+ batch_size = len(prompt)
199
+
200
+ text_inputs = self.tokenizer_2(
201
+ prompt,
202
+ padding="max_length",
203
+ max_length=max_sequence_length,
204
+ truncation=True,
205
+ return_length=False,
206
+ return_overflowing_tokens=False,
207
+ return_tensors="pt",
208
+ )
209
+ text_input_ids = text_inputs.input_ids
210
+ prompt_embeds = torch.from_numpy(self.text_encoder_2(text_input_ids)[0])
211
+
212
+ _, seq_len, _ = prompt_embeds.shape
213
+
214
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
215
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
216
+ prompt_embeds = prompt_embeds.view(
217
+ batch_size * num_images_per_prompt, seq_len, -1
218
+ )
219
+
220
+ return prompt_embeds
221
+
222
+ def _get_clip_prompt_embeds(
223
+ self,
224
+ prompt: Union[str, List[str]],
225
+ num_images_per_prompt: int = 1,
226
+ ):
227
+
228
+ prompt = [prompt] if isinstance(prompt, str) else prompt
229
+ batch_size = len(prompt)
230
+
231
+ text_inputs = self.tokenizer(
232
+ prompt,
233
+ padding="max_length",
234
+ max_length=self.tokenizer_max_length,
235
+ truncation=True,
236
+ return_overflowing_tokens=False,
237
+ return_length=False,
238
+ return_tensors="pt",
239
+ )
240
+
241
+ text_input_ids = text_inputs.input_ids
242
+ prompt_embeds = torch.from_numpy(self.text_encoder(text_input_ids)[1])
243
+
244
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
245
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
246
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
247
+
248
+ return prompt_embeds
249
+
250
+ def encode_prompt(
251
+ self,
252
+ prompt: Union[str, List[str]],
253
+ prompt_2: Union[str, List[str]],
254
+ num_images_per_prompt: int = 1,
255
+ prompt_embeds: Optional[torch.FloatTensor] = None,
256
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
257
+ max_sequence_length: int = 512,
258
+ ):
259
+ r"""
260
+
261
+ Args:
262
+ prompt (`str` or `List[str]`, *optional*):
263
+ prompt to be encoded
264
+ prompt_2 (`str` or `List[str]`, *optional*):
265
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
266
+ used in all text-encoders
267
+ num_images_per_prompt (`int`):
268
+ number of images that should be generated per prompt
269
+ prompt_embeds (`torch.FloatTensor`, *optional*):
270
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
271
+ provided, text embeddings will be generated from `prompt` input argument.
272
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
273
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
274
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
275
+ lora_scale (`float`, *optional*):
276
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
277
+ """
278
+
279
+ prompt = [prompt] if isinstance(prompt, str) else prompt
280
+ if prompt is not None:
281
+ batch_size = len(prompt)
282
+ else:
283
+ batch_size = prompt_embeds.shape[0]
284
+
285
+ if prompt_embeds is None:
286
+ prompt_2 = prompt_2 or prompt
287
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
288
+
289
+ # We only use the pooled prompt output from the CLIPTextModel
290
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
291
+ prompt=prompt,
292
+ num_images_per_prompt=num_images_per_prompt,
293
+ )
294
+ prompt_embeds = self._get_t5_prompt_embeds(
295
+ prompt=prompt_2,
296
+ num_images_per_prompt=num_images_per_prompt,
297
+ max_sequence_length=max_sequence_length,
298
+ )
299
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3)
300
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
301
+
302
+ return prompt_embeds, pooled_prompt_embeds, text_ids
303
+
304
+ def check_inputs(
305
+ self,
306
+ prompt,
307
+ prompt_2,
308
+ height,
309
+ width,
310
+ prompt_embeds=None,
311
+ pooled_prompt_embeds=None,
312
+ max_sequence_length=None,
313
+ ):
314
+ if height % 8 != 0 or width % 8 != 0:
315
+ raise ValueError(
316
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
317
+ )
318
+
319
+ if prompt is not None and prompt_embeds is not None:
320
+ raise ValueError(
321
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
322
+ " only forward one of the two."
323
+ )
324
+ elif prompt_2 is not None and prompt_embeds is not None:
325
+ raise ValueError(
326
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
327
+ " only forward one of the two."
328
+ )
329
+ elif prompt is None and prompt_embeds is None:
330
+ raise ValueError(
331
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
332
+ )
333
+ elif prompt is not None and (
334
+ not isinstance(prompt, str) and not isinstance(prompt, list)
335
+ ):
336
+ raise ValueError(
337
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
338
+ )
339
+ elif prompt_2 is not None and (
340
+ not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
341
+ ):
342
+ raise ValueError(
343
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
344
+ )
345
+
346
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
347
+ raise ValueError(
348
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
349
+ )
350
+
351
+ if max_sequence_length is not None and max_sequence_length > 512:
352
+ raise ValueError(
353
+ f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
354
+ )
355
+
356
+ @staticmethod
357
+ def _prepare_latent_image_ids(batch_size, height, width):
358
+ return _prepare_latent_image_ids(batch_size, height, width)
359
+
360
+ @staticmethod
361
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
362
+ latents = latents.view(
363
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
364
+ )
365
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
366
+ latents = latents.reshape(
367
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
368
+ )
369
+
370
+ return latents
371
+
372
+ @staticmethod
373
+ def _unpack_latents(latents, height, width, vae_scale_factor):
374
+ batch_size, num_patches, channels = latents.shape
375
+
376
+ height = height // vae_scale_factor
377
+ width = width // vae_scale_factor
378
+
379
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
380
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
381
+
382
+ latents = latents.reshape(
383
+ batch_size, channels // (2 * 2), height * 2, width * 2
384
+ )
385
+
386
+ return latents
387
+
388
+ def prepare_latents(
389
+ self,
390
+ batch_size,
391
+ num_channels_latents,
392
+ height,
393
+ width,
394
+ generator,
395
+ latents=None,
396
+ ):
397
+ height = 2 * (int(height) // self.vae_scale_factor)
398
+ width = 2 * (int(width) // self.vae_scale_factor)
399
+
400
+ shape = (batch_size, num_channels_latents, height, width)
401
+
402
+ if latents is not None:
403
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width)
404
+ return latents, latent_image_ids
405
+
406
+ if isinstance(generator, list) and len(generator) != batch_size:
407
+ raise ValueError(
408
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
409
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
410
+ )
411
+
412
+ latents = randn_tensor(shape, generator=generator)
413
+ latents = self._pack_latents(
414
+ latents, batch_size, num_channels_latents, height, width
415
+ )
416
+
417
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width)
418
+
419
+ return latents, latent_image_ids
420
+
421
+ @property
422
+ def guidance_scale(self):
423
+ return self._guidance_scale
424
+
425
+ @property
426
+ def num_timesteps(self):
427
+ return self._num_timesteps
428
+
429
+ @property
430
+ def interrupt(self):
431
+ return self._interrupt
432
+
433
+ def __call__(
434
+ self,
435
+ prompt: Union[str, List[str]] = None,
436
+ prompt_2: Optional[Union[str, List[str]]] = None,
437
+ height: Optional[int] = None,
438
+ width: Optional[int] = None,
439
+ negative_prompt: str = None,
440
+ num_inference_steps: int = 28,
441
+ timesteps: List[int] = None,
442
+ guidance_scale: float = 7.0,
443
+ num_images_per_prompt: Optional[int] = 1,
444
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
445
+ latents: Optional[torch.FloatTensor] = None,
446
+ prompt_embeds: Optional[torch.FloatTensor] = None,
447
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
448
+ output_type: Optional[str] = "pil",
449
+ return_dict: bool = True,
450
+ max_sequence_length: int = 512,
451
+ ):
452
+ r"""
453
+ Function invoked when calling the pipeline for generation.
454
+
455
+ Args:
456
+ prompt (`str` or `List[str]`, *optional*):
457
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
458
+ instead.
459
+ prompt_2 (`str` or `List[str]`, *optional*):
460
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
461
+ will be used instead
462
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
463
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
464
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
465
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
466
+ num_inference_steps (`int`, *optional*, defaults to 50):
467
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
468
+ expense of slower inference.
469
+ timesteps (`List[int]`, *optional*):
470
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
471
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
472
+ passed will be used. Must be in descending order.
473
+ guidance_scale (`float`, *optional*, defaults to 7.0):
474
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
475
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
476
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
477
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
478
+ usually at the expense of lower image quality.
479
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
480
+ The number of images to generate per prompt.
481
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
482
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
483
+ to make generation deterministic.
484
+ latents (`torch.FloatTensor`, *optional*):
485
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
486
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
487
+ tensor will ge generated by sampling using the supplied random `generator`.
488
+ prompt_embeds (`torch.FloatTensor`, *optional*):
489
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
490
+ provided, text embeddings will be generated from `prompt` input argument.
491
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
492
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
493
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
494
+ output_type (`str`, *optional*, defaults to `"pil"`):
495
+ The output format of the generate image. Choose between
496
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
497
+ return_dict (`bool`, *optional*, defaults to `True`):
498
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
499
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
500
+ Returns:
501
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
502
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
503
+ images.
504
+ """
505
+
506
+ height = height or self.default_sample_size * self.vae_scale_factor
507
+ width = width or self.default_sample_size * self.vae_scale_factor
508
+
509
+ # 1. Check inputs. Raise error if not correct
510
+ self.check_inputs(
511
+ prompt,
512
+ prompt_2,
513
+ height,
514
+ width,
515
+ prompt_embeds=prompt_embeds,
516
+ pooled_prompt_embeds=pooled_prompt_embeds,
517
+ max_sequence_length=max_sequence_length,
518
+ )
519
+
520
+ self._guidance_scale = guidance_scale
521
+ self._interrupt = False
522
+
523
+ # 2. Define call parameters
524
+ if prompt is not None and isinstance(prompt, str):
525
+ batch_size = 1
526
+ elif prompt is not None and isinstance(prompt, list):
527
+ batch_size = len(prompt)
528
+ else:
529
+ batch_size = prompt_embeds.shape[0]
530
+
531
+ (
532
+ prompt_embeds,
533
+ pooled_prompt_embeds,
534
+ text_ids,
535
+ ) = self.encode_prompt(
536
+ prompt=prompt,
537
+ prompt_2=prompt_2,
538
+ prompt_embeds=prompt_embeds,
539
+ pooled_prompt_embeds=pooled_prompt_embeds,
540
+ num_images_per_prompt=num_images_per_prompt,
541
+ max_sequence_length=max_sequence_length,
542
+ )
543
+
544
+ # 4. Prepare latent variables
545
+ num_channels_latents = self.transformer_config.get("in_channels", 64) // 4
546
+ latents, latent_image_ids = self.prepare_latents(
547
+ batch_size * num_images_per_prompt,
548
+ num_channels_latents,
549
+ height,
550
+ width,
551
+ generator,
552
+ latents,
553
+ )
554
+
555
+ # 5. Prepare timesteps
556
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
557
+ image_seq_len = latents.shape[1]
558
+ mu = calculate_shift(
559
+ image_seq_len,
560
+ self.scheduler.config.base_image_seq_len,
561
+ self.scheduler.config.max_image_seq_len,
562
+ self.scheduler.config.base_shift,
563
+ self.scheduler.config.max_shift,
564
+ )
565
+ timesteps, num_inference_steps = retrieve_timesteps(
566
+ scheduler=self.scheduler,
567
+ num_inference_steps=num_inference_steps,
568
+ timesteps=timesteps,
569
+ sigmas=sigmas,
570
+ mu=mu,
571
+ )
572
+ num_warmup_steps = max(
573
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
574
+ )
575
+ self._num_timesteps = len(timesteps)
576
+
577
+ # 6. Denoising loop
578
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
579
+ for i, t in enumerate(timesteps):
580
+ if self.interrupt:
581
+ continue
582
+
583
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
584
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
585
+
586
+ # handle guidance
587
+ if self.transformer_config.get("guidance_embeds"):
588
+ guidance = torch.tensor([guidance_scale])
589
+ guidance = guidance.expand(latents.shape[0])
590
+ else:
591
+ guidance = None
592
+
593
+ transformer_input = {
594
+ "hidden_states": latents,
595
+ "timestep": timestep / 1000,
596
+ "pooled_projections": pooled_prompt_embeds,
597
+ "encoder_hidden_states": prompt_embeds,
598
+ "txt_ids": text_ids,
599
+ "img_ids": latent_image_ids,
600
+ }
601
+ if guidance is not None:
602
+ transformer_input["guidance"] = guidance
603
+
604
+ noise_pred = torch.from_numpy(self.transformer(transformer_input)[0])
605
+
606
+ latents = self.scheduler.step(
607
+ noise_pred, t, latents, return_dict=False
608
+ )[0]
609
+
610
+ # call the callback, if provided
611
+ if i == len(timesteps) - 1 or (
612
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
613
+ ):
614
+ progress_bar.update()
615
+
616
+ if output_type == "latent":
617
+ image = latents
618
+
619
+ else:
620
+ latents = self._unpack_latents(
621
+ latents, height, width, self.vae_scale_factor
622
+ )
623
+ latents = latents / self.vae_config.get(
624
+ "scaling_factor"
625
+ ) + self.vae_config.get("shift_factor")
626
+ image = self.vae(latents)[0]
627
+ image = self.image_processor.postprocess(
628
+ torch.from_numpy(image), output_type=output_type
629
+ )
630
+
631
+ if not return_dict:
632
+ return (image,)
633
+
634
+ return FluxPipelineOutput(images=image)
635
+
636
+
637
+ def init_pipeline(
638
+ model_dir,
639
+ models_dict: Dict[str, Any],
640
+ device: str,
641
+ use_taef1: bool = False,
642
+ ):
643
+ pipeline_args = {}
644
+
645
+ print("OpenVINO FLUX Model compilation")
646
+ core = ov.Core()
647
+ for model_name, model_path in models_dict.items():
648
+ pipeline_args[model_name] = core.compile_model(model_path, device)
649
+ if model_name == "vae" and use_taef1:
650
+ print(f"✅ VAE(TAEF1) - Done!")
651
+ else:
652
+ print(f"✅ {model_name} - Done!")
653
+
654
+ transformer_path = models_dict["transformer"]
655
+ transformer_config_path = transformer_path.parent / "config.json"
656
+ with transformer_config_path.open("r") as f:
657
+ transformer_config = json.load(f)
658
+ vae_path = models_dict["vae"]
659
+ vae_config_path = vae_path.parent / "config.json"
660
+ with vae_config_path.open("r") as f:
661
+ vae_config = json.load(f)
662
+
663
+ pipeline_args["vae_config"] = vae_config
664
+ pipeline_args["transformer_config"] = transformer_config
665
+
666
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_dir / "scheduler")
667
+
668
+ tokenizer = AutoTokenizer.from_pretrained(model_dir / "tokenizer")
669
+ tokenizer_2 = AutoTokenizer.from_pretrained(model_dir / "tokenizer_2")
670
+
671
+ pipeline_args["scheduler"] = scheduler
672
+ pipeline_args["tokenizer"] = tokenizer
673
+ pipeline_args["tokenizer_2"] = tokenizer_2
674
+ ov_pipe = OVFluxPipeline(**pipeline_args)
675
+ return ov_pipe
backend/openvino/pipelines.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import DEVICE, LCM_DEFAULT_MODEL_OPENVINO
2
+ from backend.tiny_decoder import get_tiny_decoder_vae_model
3
+ from typing import Any
4
+ from backend.device import is_openvino_device
5
+ from paths import get_base_folder_name
6
+
7
+ if is_openvino_device():
8
+ from huggingface_hub import snapshot_download
9
+ from optimum.intel.openvino.modeling_diffusion import OVBaseModel
10
+
11
+ from optimum.intel.openvino.modeling_diffusion import (
12
+ OVStableDiffusionPipeline,
13
+ OVStableDiffusionImg2ImgPipeline,
14
+ OVStableDiffusionXLPipeline,
15
+ OVStableDiffusionXLImg2ImgPipeline,
16
+ )
17
+ from backend.openvino.custom_ov_model_vae_decoder import CustomOVModelVaeDecoder
18
+
19
+
20
+ def ov_load_taesd(
21
+ pipeline: Any,
22
+ use_local_model: bool = False,
23
+ ):
24
+ taesd_dir = snapshot_download(
25
+ repo_id=get_tiny_decoder_vae_model(pipeline.__class__.__name__),
26
+ local_files_only=use_local_model,
27
+ )
28
+ pipeline.vae_decoder = CustomOVModelVaeDecoder(
29
+ model=OVBaseModel.load_model(f"{taesd_dir}/vae_decoder/openvino_model.xml"),
30
+ parent_model=pipeline,
31
+ model_dir=taesd_dir,
32
+ )
33
+
34
+
35
+ def get_ov_text_to_image_pipeline(
36
+ model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
37
+ use_local_model: bool = False,
38
+ ) -> Any:
39
+ if "xl" in get_base_folder_name(model_id).lower():
40
+ pipeline = OVStableDiffusionXLPipeline.from_pretrained(
41
+ model_id,
42
+ local_files_only=use_local_model,
43
+ ov_config={"CACHE_DIR": ""},
44
+ device=DEVICE.upper(),
45
+ )
46
+ else:
47
+ pipeline = OVStableDiffusionPipeline.from_pretrained(
48
+ model_id,
49
+ local_files_only=use_local_model,
50
+ ov_config={"CACHE_DIR": ""},
51
+ device=DEVICE.upper(),
52
+ )
53
+
54
+ return pipeline
55
+
56
+
57
+ def get_ov_image_to_image_pipeline(
58
+ model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
59
+ use_local_model: bool = False,
60
+ ) -> Any:
61
+ if "xl" in get_base_folder_name(model_id).lower():
62
+ pipeline = OVStableDiffusionXLImg2ImgPipeline.from_pretrained(
63
+ model_id,
64
+ local_files_only=use_local_model,
65
+ ov_config={"CACHE_DIR": ""},
66
+ device=DEVICE.upper(),
67
+ )
68
+ else:
69
+ pipeline = OVStableDiffusionImg2ImgPipeline.from_pretrained(
70
+ model_id,
71
+ local_files_only=use_local_model,
72
+ ov_config={"CACHE_DIR": ""},
73
+ device=DEVICE.upper(),
74
+ )
75
+ return pipeline
backend/openvino/stable_diffusion_engine.py ADDED
@@ -0,0 +1,1817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright(C) 2022-2023 Intel Corporation
3
+ SPDX - License - Identifier: Apache - 2.0
4
+
5
+ """
6
+ import inspect
7
+ from typing import Union, Optional, Any, List, Dict
8
+ import numpy as np
9
+ # openvino
10
+ from openvino.runtime import Core
11
+ # tokenizer
12
+ from transformers import CLIPTokenizer
13
+ import torch
14
+ import random
15
+
16
+ from diffusers import DiffusionPipeline
17
+ from diffusers.schedulers import (DDIMScheduler,
18
+ LMSDiscreteScheduler,
19
+ PNDMScheduler,
20
+ EulerDiscreteScheduler,
21
+ EulerAncestralDiscreteScheduler)
22
+
23
+
24
+ from diffusers.image_processor import VaeImageProcessor
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from diffusers.utils import PIL_INTERPOLATION
27
+
28
+ import cv2
29
+ import os
30
+ import sys
31
+
32
+ # for multithreading
33
+ import concurrent.futures
34
+
35
+ #For GIF
36
+ import PIL
37
+ from PIL import Image
38
+ import glob
39
+ import json
40
+ import time
41
+
42
+ def scale_fit_to_window(dst_width:int, dst_height:int, image_width:int, image_height:int):
43
+ """
44
+ Preprocessing helper function for calculating image size for resize with peserving original aspect ratio
45
+ and fitting image to specific window size
46
+
47
+ Parameters:
48
+ dst_width (int): destination window width
49
+ dst_height (int): destination window height
50
+ image_width (int): source image width
51
+ image_height (int): source image height
52
+ Returns:
53
+ result_width (int): calculated width for resize
54
+ result_height (int): calculated height for resize
55
+ """
56
+ im_scale = min(dst_height / image_height, dst_width / image_width)
57
+ return int(im_scale * image_width), int(im_scale * image_height)
58
+
59
+ def preprocess(image: PIL.Image.Image, ht=512, wt=512):
60
+ """
61
+ Image preprocessing function. Takes image in PIL.Image format, resizes it to keep aspect ration and fits to model input window 512x512,
62
+ then converts it to np.ndarray and adds padding with zeros on right or bottom side of image (depends from aspect ratio), after that
63
+ converts data to float32 data type and change range of values from [0, 255] to [-1, 1], finally, converts data layout from planar NHWC to NCHW.
64
+ The function returns preprocessed input tensor and padding size, which can be used in postprocessing.
65
+
66
+ Parameters:
67
+ image (PIL.Image.Image): input image
68
+ Returns:
69
+ image (np.ndarray): preprocessed image tensor
70
+ meta (Dict): dictionary with preprocessing metadata info
71
+ """
72
+
73
+ src_width, src_height = image.size
74
+ image = image.convert('RGB')
75
+ dst_width, dst_height = scale_fit_to_window(
76
+ wt, ht, src_width, src_height)
77
+ image = np.array(image.resize((dst_width, dst_height),
78
+ resample=PIL.Image.Resampling.LANCZOS))[None, :]
79
+
80
+ pad_width = wt - dst_width
81
+ pad_height = ht - dst_height
82
+ pad = ((0, 0), (0, pad_height), (0, pad_width), (0, 0))
83
+ image = np.pad(image, pad, mode="constant")
84
+ image = image.astype(np.float32) / 255.0
85
+ image = 2.0 * image - 1.0
86
+ image = image.transpose(0, 3, 1, 2)
87
+
88
+ return image, {"padding": pad, "src_width": src_width, "src_height": src_height}
89
+
90
+ def try_enable_npu_turbo(device, core):
91
+ import platform
92
+ if "windows" in platform.system().lower():
93
+ if "NPU" in device and "3720" not in core.get_property('NPU', 'DEVICE_ARCHITECTURE'):
94
+ try:
95
+ core.set_property(properties={'NPU_TURBO': 'YES'},device_name='NPU')
96
+ except:
97
+ print(f"Failed loading NPU_TURBO for device {device}. Skipping... ")
98
+ else:
99
+ print_npu_turbo_art()
100
+ else:
101
+ print(f"Skipping NPU_TURBO for device {device}")
102
+ elif "linux" in platform.system().lower():
103
+ if os.path.isfile('/sys/module/intel_vpu/parameters/test_mode'):
104
+ with open('/sys/module/intel_vpu/version', 'r') as f:
105
+ version = f.readline().split()[0]
106
+ if tuple(map(int, version.split('.'))) < tuple(map(int, '1.9.0'.split('.'))):
107
+ print(f"The driver intel_vpu-1.9.0 (or later) needs to be loaded for NPU Turbo (currently {version}). Skipping...")
108
+ else:
109
+ with open('/sys/module/intel_vpu/parameters/test_mode', 'r') as tm_file:
110
+ test_mode = int(tm_file.readline().split()[0])
111
+ if test_mode == 512:
112
+ print_npu_turbo_art()
113
+ else:
114
+ print("The driver >=intel_vpu-1.9.0 was must be loaded with "
115
+ "\"modprobe intel_vpu test_mode=512\" to enable NPU_TURBO "
116
+ f"(currently test_mode={test_mode}). Skipping...")
117
+ else:
118
+ print(f"The driver >=intel_vpu-1.9.0 must be loaded with \"modprobe intel_vpu test_mode=512\" to enable NPU_TURBO. Skipping...")
119
+ else:
120
+ print(f"This platform ({platform.system()}) does not support NPU Turbo")
121
+
122
+ def result(var):
123
+ return next(iter(var.values()))
124
+
125
+ class StableDiffusionEngineAdvanced(DiffusionPipeline):
126
+ def __init__(self, model="runwayml/stable-diffusion-v1-5",
127
+ tokenizer="openai/clip-vit-large-patch14",
128
+ device=["CPU", "CPU", "CPU", "CPU"]):
129
+ try:
130
+ self.tokenizer = CLIPTokenizer.from_pretrained(model, local_files_only=True)
131
+ except:
132
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
133
+ self.tokenizer.save_pretrained(model)
134
+
135
+ self.core = Core()
136
+ self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')})
137
+ try_enable_npu_turbo(device, self.core)
138
+
139
+ print("Loading models... ")
140
+
141
+
142
+
143
+ with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
144
+ futures = {
145
+ "unet_time_proj": executor.submit(self.core.compile_model, os.path.join(model, "unet_time_proj.xml"), device[0]),
146
+ "text": executor.submit(self.load_model, model, "text_encoder", device[0]),
147
+ "unet": executor.submit(self.load_model, model, "unet_int8", device[1]),
148
+ "unet_neg": executor.submit(self.load_model, model, "unet_int8", device[2]) if device[1] != device[2] else None,
149
+ "vae_decoder": executor.submit(self.load_model, model, "vae_decoder", device[3]),
150
+ "vae_encoder": executor.submit(self.load_model, model, "vae_encoder", device[3])
151
+ }
152
+
153
+ self.unet_time_proj = futures["unet_time_proj"].result()
154
+ self.text_encoder = futures["text"].result()
155
+ self.unet = futures["unet"].result()
156
+ self.unet_neg = futures["unet_neg"].result() if futures["unet_neg"] else self.unet
157
+ self.vae_decoder = futures["vae_decoder"].result()
158
+ self.vae_encoder = futures["vae_encoder"].result()
159
+ print("Text Device:", device[0])
160
+ print("unet Device:", device[1])
161
+ print("unet-neg Device:", device[2])
162
+ print("VAE Device:", device[3])
163
+
164
+ self._text_encoder_output = self.text_encoder.output(0)
165
+ self._vae_d_output = self.vae_decoder.output(0)
166
+ self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder else None
167
+
168
+ self.set_dimensions()
169
+ self.infer_request_neg = self.unet_neg.create_infer_request()
170
+ self.infer_request = self.unet.create_infer_request()
171
+ self.infer_request_time_proj = self.unet_time_proj.create_infer_request()
172
+ self.time_proj_constants = np.load(os.path.join(model, "time_proj_constants.npy"))
173
+
174
+ def load_model(self, model, model_name, device):
175
+ if "NPU" in device:
176
+ with open(os.path.join(model, f"{model_name}.blob"), "rb") as f:
177
+ return self.core.import_model(f.read(), device)
178
+ return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
179
+
180
+ def set_dimensions(self):
181
+ latent_shape = self.unet.input("latent_model_input").shape
182
+ if latent_shape[1] == 4:
183
+ self.height = latent_shape[2] * 8
184
+ self.width = latent_shape[3] * 8
185
+ else:
186
+ self.height = latent_shape[1] * 8
187
+ self.width = latent_shape[2] * 8
188
+
189
+ def __call__(
190
+ self,
191
+ prompt,
192
+ init_image = None,
193
+ negative_prompt=None,
194
+ scheduler=None,
195
+ strength = 0.5,
196
+ num_inference_steps = 32,
197
+ guidance_scale = 7.5,
198
+ eta = 0.0,
199
+ create_gif = False,
200
+ model = None,
201
+ callback = None,
202
+ callback_userdata = None
203
+ ):
204
+
205
+ # extract condition
206
+ text_input = self.tokenizer(
207
+ prompt,
208
+ padding="max_length",
209
+ max_length=self.tokenizer.model_max_length,
210
+ truncation=True,
211
+ return_tensors="np",
212
+ )
213
+ text_embeddings = self.text_encoder(text_input.input_ids)[self._text_encoder_output]
214
+
215
+ # do classifier free guidance
216
+ do_classifier_free_guidance = guidance_scale > 1.0
217
+ if do_classifier_free_guidance:
218
+
219
+ if negative_prompt is None:
220
+ uncond_tokens = [""]
221
+ elif isinstance(negative_prompt, str):
222
+ uncond_tokens = [negative_prompt]
223
+ else:
224
+ uncond_tokens = negative_prompt
225
+
226
+ tokens_uncond = self.tokenizer(
227
+ uncond_tokens,
228
+ padding="max_length",
229
+ max_length=self.tokenizer.model_max_length, #truncation=True,
230
+ return_tensors="np"
231
+ )
232
+ uncond_embeddings = self.text_encoder(tokens_uncond.input_ids)[self._text_encoder_output]
233
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
234
+
235
+ # set timesteps
236
+ accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
237
+ extra_set_kwargs = {}
238
+
239
+ if accepts_offset:
240
+ extra_set_kwargs["offset"] = 1
241
+
242
+ scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
243
+
244
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
245
+ latent_timestep = timesteps[:1]
246
+
247
+ # get the initial random noise unless the user supplied it
248
+ latents, meta = self.prepare_latents(init_image, latent_timestep, scheduler)
249
+
250
+
251
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
252
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
253
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
254
+ # and should be between [0, 1]
255
+ accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
256
+ extra_step_kwargs = {}
257
+ if accepts_eta:
258
+ extra_step_kwargs["eta"] = eta
259
+ if create_gif:
260
+ frames = []
261
+
262
+ for i, t in enumerate(self.progress_bar(timesteps)):
263
+ if callback:
264
+ callback(i, callback_userdata)
265
+
266
+ # expand the latents if we are doing classifier free guidance
267
+ noise_pred = []
268
+ latent_model_input = latents
269
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
270
+
271
+ latent_model_input_neg = latent_model_input
272
+ if self.unet.input("latent_model_input").shape[1] != 4:
273
+ #print("In transpose")
274
+ try:
275
+ latent_model_input = latent_model_input.permute(0,2,3,1)
276
+ except:
277
+ latent_model_input = latent_model_input.transpose(0,2,3,1)
278
+
279
+ if self.unet_neg.input("latent_model_input").shape[1] != 4:
280
+ #print("In transpose")
281
+ try:
282
+ latent_model_input_neg = latent_model_input_neg.permute(0,2,3,1)
283
+ except:
284
+ latent_model_input_neg = latent_model_input_neg.transpose(0,2,3,1)
285
+
286
+
287
+ time_proj_constants_fp16 = np.float16(self.time_proj_constants)
288
+ t_scaled_fp16 = time_proj_constants_fp16 * np.float16(t)
289
+ cosine_t_fp16 = np.cos(t_scaled_fp16)
290
+ sine_t_fp16 = np.sin(t_scaled_fp16)
291
+
292
+ t_scaled = self.time_proj_constants * np.float32(t)
293
+
294
+ cosine_t = np.cos(t_scaled)
295
+ sine_t = np.sin(t_scaled)
296
+
297
+ time_proj_dict = {"sine_t" : np.float32(sine_t), "cosine_t" : np.float32(cosine_t)}
298
+ self.infer_request_time_proj.start_async(time_proj_dict)
299
+ self.infer_request_time_proj.wait()
300
+ time_proj = self.infer_request_time_proj.get_output_tensor(0).data.astype(np.float32)
301
+
302
+ input_tens_neg_dict = {"time_proj": np.float32(time_proj), "latent_model_input":latent_model_input_neg, "encoder_hidden_states": np.expand_dims(text_embeddings[0], axis=0)}
303
+ input_tens_dict = {"time_proj": np.float32(time_proj), "latent_model_input":latent_model_input, "encoder_hidden_states": np.expand_dims(text_embeddings[1], axis=0)}
304
+
305
+ self.infer_request_neg.start_async(input_tens_neg_dict)
306
+ self.infer_request.start_async(input_tens_dict)
307
+ self.infer_request_neg.wait()
308
+ self.infer_request.wait()
309
+
310
+ noise_pred_neg = self.infer_request_neg.get_output_tensor(0)
311
+ noise_pred_pos = self.infer_request.get_output_tensor(0)
312
+
313
+ noise_pred.append(noise_pred_neg.data.astype(np.float32))
314
+ noise_pred.append(noise_pred_pos.data.astype(np.float32))
315
+
316
+ # perform guidance
317
+ if do_classifier_free_guidance:
318
+ noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
319
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
320
+
321
+ # compute the previous noisy sample x_t -> x_t-1
322
+ latents = scheduler.step(torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs)["prev_sample"].numpy()
323
+
324
+ if create_gif:
325
+ frames.append(latents)
326
+
327
+ if callback:
328
+ callback(num_inference_steps, callback_userdata)
329
+
330
+ # scale and decode the image latents with vae
331
+ latents = 1 / 0.18215 * latents
332
+
333
+ start = time.time()
334
+ image = self.vae_decoder(latents)[self._vae_d_output]
335
+ print("Decoder ended:",time.time() - start)
336
+
337
+ image = self.postprocess_image(image, meta)
338
+
339
+ if create_gif:
340
+ gif_folder=os.path.join(model,"../../../gif")
341
+ print("gif_folder:",gif_folder)
342
+ if not os.path.exists(gif_folder):
343
+ os.makedirs(gif_folder)
344
+ for i in range(0,len(frames)):
345
+ image = self.vae_decoder(frames[i]*(1/0.18215))[self._vae_d_output]
346
+ image = self.postprocess_image(image, meta)
347
+ output = gif_folder + "/" + str(i).zfill(3) +".png"
348
+ cv2.imwrite(output, image)
349
+ with open(os.path.join(gif_folder, "prompt.json"), "w") as file:
350
+ json.dump({"prompt": prompt}, file)
351
+ frames_image = [Image.open(image) for image in glob.glob(f"{gif_folder}/*.png")]
352
+ frame_one = frames_image[0]
353
+ gif_file=os.path.join(gif_folder,"stable_diffusion.gif")
354
+ frame_one.save(gif_file, format="GIF", append_images=frames_image, save_all=True, duration=100, loop=0)
355
+
356
+ return image
357
+
358
+ def prepare_latents(self, image:PIL.Image.Image = None, latent_timestep:torch.Tensor = None, scheduler = LMSDiscreteScheduler):
359
+ """
360
+ Function for getting initial latents for starting generation
361
+
362
+ Parameters:
363
+ image (PIL.Image.Image, *optional*, None):
364
+ Input image for generation, if not provided randon noise will be used as starting point
365
+ latent_timestep (torch.Tensor, *optional*, None):
366
+ Predicted by scheduler initial step for image generation, required for latent image mixing with nosie
367
+ Returns:
368
+ latents (np.ndarray):
369
+ Image encoded in latent space
370
+ """
371
+ latents_shape = (1, 4, self.height // 8, self.width // 8)
372
+
373
+ noise = np.random.randn(*latents_shape).astype(np.float32)
374
+ if image is None:
375
+ ##print("Image is NONE")
376
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
377
+ if isinstance(scheduler, LMSDiscreteScheduler):
378
+
379
+ noise = noise * scheduler.sigmas[0].numpy()
380
+ return noise, {}
381
+ elif isinstance(scheduler, EulerDiscreteScheduler) or isinstance(scheduler,EulerAncestralDiscreteScheduler):
382
+
383
+ noise = noise * scheduler.sigmas.max().numpy()
384
+ return noise, {}
385
+ else:
386
+ return noise, {}
387
+ input_image, meta = preprocess(image,self.height,self.width)
388
+
389
+ moments = self.vae_encoder(input_image)[self._vae_e_output]
390
+
391
+ mean, logvar = np.split(moments, 2, axis=1)
392
+
393
+ std = np.exp(logvar * 0.5)
394
+ latents = (mean + std * np.random.randn(*mean.shape)) * 0.18215
395
+
396
+
397
+ latents = scheduler.add_noise(torch.from_numpy(latents), torch.from_numpy(noise), latent_timestep).numpy()
398
+ return latents, meta
399
+
400
+ def postprocess_image(self, image:np.ndarray, meta:Dict):
401
+ """
402
+ Postprocessing for decoded image. Takes generated image decoded by VAE decoder, unpad it to initial image size (if required),
403
+ normalize and convert to [0, 255] pixels range. Optionally, convertes it from np.ndarray to PIL.Image format
404
+
405
+ Parameters:
406
+ image (np.ndarray):
407
+ Generated image
408
+ meta (Dict):
409
+ Metadata obtained on latents preparing step, can be empty
410
+ output_type (str, *optional*, pil):
411
+ Output format for result, can be pil or numpy
412
+ Returns:
413
+ image (List of np.ndarray or PIL.Image.Image):
414
+ Postprocessed images
415
+
416
+ if "src_height" in meta:
417
+ orig_height, orig_width = meta["src_height"], meta["src_width"]
418
+ image = [cv2.resize(img, (orig_width, orig_height))
419
+ for img in image]
420
+
421
+ return image
422
+ """
423
+ if "padding" in meta:
424
+ pad = meta["padding"]
425
+ (_, end_h), (_, end_w) = pad[1:3]
426
+ h, w = image.shape[2:]
427
+ #print("image shape",image.shape[2:])
428
+ unpad_h = h - end_h
429
+ unpad_w = w - end_w
430
+ image = image[:, :, :unpad_h, :unpad_w]
431
+ image = np.clip(image / 2 + 0.5, 0, 1)
432
+ image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
433
+
434
+
435
+
436
+ if "src_height" in meta:
437
+ orig_height, orig_width = meta["src_height"], meta["src_width"]
438
+ image = cv2.resize(image, (orig_width, orig_height))
439
+
440
+ return image
441
+
442
+
443
+
444
+
445
+ def get_timesteps(self, num_inference_steps:int, strength:float, scheduler):
446
+ """
447
+ Helper function for getting scheduler timesteps for generation
448
+ In case of image-to-image generation, it updates number of steps according to strength
449
+
450
+ Parameters:
451
+ num_inference_steps (int):
452
+ number of inference steps for generation
453
+ strength (float):
454
+ value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
455
+ Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
456
+ """
457
+ # get the original timestep using init_timestep
458
+
459
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
460
+
461
+ t_start = max(num_inference_steps - init_timestep, 0)
462
+ timesteps = scheduler.timesteps[t_start:]
463
+
464
+ return timesteps, num_inference_steps - t_start
465
+
466
+ class StableDiffusionEngine(DiffusionPipeline):
467
+ def __init__(
468
+ self,
469
+ model="bes-dev/stable-diffusion-v1-4-openvino",
470
+ tokenizer="openai/clip-vit-large-patch14",
471
+ device=["CPU","CPU","CPU","CPU"]):
472
+
473
+ self.core = Core()
474
+ self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')})
475
+
476
+ self.batch_size = 2 if device[1] == device[2] and device[1] == "GPU" else 1
477
+ try_enable_npu_turbo(device, self.core)
478
+
479
+ try:
480
+ self.tokenizer = CLIPTokenizer.from_pretrained(model, local_files_only=True)
481
+ except Exception as e:
482
+ print("Local tokenizer not found. Attempting to download...")
483
+ self.tokenizer = self.download_tokenizer(tokenizer, model)
484
+
485
+ print("Loading models... ")
486
+
487
+ with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
488
+ text_future = executor.submit(self.load_model, model, "text_encoder", device[0])
489
+ vae_de_future = executor.submit(self.load_model, model, "vae_decoder", device[3])
490
+ vae_en_future = executor.submit(self.load_model, model, "vae_encoder", device[3])
491
+
492
+ if self.batch_size == 1:
493
+ if "int8" not in model:
494
+ unet_future = executor.submit(self.load_model, model, "unet_bs1", device[1])
495
+ unet_neg_future = executor.submit(self.load_model, model, "unet_bs1", device[2]) if device[1] != device[2] else None
496
+ else:
497
+ unet_future = executor.submit(self.load_model, model, "unet_int8a16", device[1])
498
+ unet_neg_future = executor.submit(self.load_model, model, "unet_int8a16", device[2]) if device[1] != device[2] else None
499
+ else:
500
+ unet_future = executor.submit(self.load_model, model, "unet", device[1])
501
+ unet_neg_future = None
502
+
503
+ self.unet = unet_future.result()
504
+ self.unet_neg = unet_neg_future.result() if unet_neg_future else self.unet
505
+ self.text_encoder = text_future.result()
506
+ self.vae_decoder = vae_de_future.result()
507
+ self.vae_encoder = vae_en_future.result()
508
+ print("Text Device:", device[0])
509
+ print("unet Device:", device[1])
510
+ print("unet-neg Device:", device[2])
511
+ print("VAE Device:", device[3])
512
+
513
+ self._text_encoder_output = self.text_encoder.output(0)
514
+ self._unet_output = self.unet.output(0)
515
+ self._vae_d_output = self.vae_decoder.output(0)
516
+ self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder else None
517
+
518
+ self.unet_input_tensor_name = "sample" if 'sample' in self.unet.input(0).names else "latent_model_input"
519
+
520
+ if self.batch_size == 1:
521
+ self.infer_request = self.unet.create_infer_request()
522
+ self.infer_request_neg = self.unet_neg.create_infer_request()
523
+ self._unet_neg_output = self.unet_neg.output(0)
524
+ else:
525
+ self.infer_request = None
526
+ self.infer_request_neg = None
527
+ self._unet_neg_output = None
528
+
529
+ self.set_dimensions()
530
+
531
+
532
+
533
+ def load_model(self, model, model_name, device):
534
+ if "NPU" in device:
535
+ with open(os.path.join(model, f"{model_name}.blob"), "rb") as f:
536
+ return self.core.import_model(f.read(), device)
537
+ return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
538
+
539
+ def set_dimensions(self):
540
+ latent_shape = self.unet.input(self.unet_input_tensor_name).shape
541
+ if latent_shape[1] == 4:
542
+ self.height = latent_shape[2] * 8
543
+ self.width = latent_shape[3] * 8
544
+ else:
545
+ self.height = latent_shape[1] * 8
546
+ self.width = latent_shape[2] * 8
547
+
548
+ def __call__(
549
+ self,
550
+ prompt,
551
+ init_image=None,
552
+ negative_prompt=None,
553
+ scheduler=None,
554
+ strength=0.5,
555
+ num_inference_steps=32,
556
+ guidance_scale=7.5,
557
+ eta=0.0,
558
+ create_gif=False,
559
+ model=None,
560
+ callback=None,
561
+ callback_userdata=None
562
+ ):
563
+ # extract condition
564
+ text_input = self.tokenizer(
565
+ prompt,
566
+ padding="max_length",
567
+ max_length=self.tokenizer.model_max_length,
568
+ truncation=True,
569
+ return_tensors="np",
570
+ )
571
+ text_embeddings = self.text_encoder(text_input.input_ids)[self._text_encoder_output]
572
+
573
+
574
+ # do classifier free guidance
575
+ do_classifier_free_guidance = guidance_scale > 1.0
576
+ if do_classifier_free_guidance:
577
+ if negative_prompt is None:
578
+ uncond_tokens = [""]
579
+ elif isinstance(negative_prompt, str):
580
+ uncond_tokens = [negative_prompt]
581
+ else:
582
+ uncond_tokens = negative_prompt
583
+
584
+ tokens_uncond = self.tokenizer(
585
+ uncond_tokens,
586
+ padding="max_length",
587
+ max_length=self.tokenizer.model_max_length, # truncation=True,
588
+ return_tensors="np"
589
+ )
590
+ uncond_embeddings = self.text_encoder(tokens_uncond.input_ids)[self._text_encoder_output]
591
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
592
+
593
+ # set timesteps
594
+ accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
595
+ extra_set_kwargs = {}
596
+
597
+ if accepts_offset:
598
+ extra_set_kwargs["offset"] = 1
599
+
600
+ scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
601
+
602
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
603
+ latent_timestep = timesteps[:1]
604
+
605
+ # get the initial random noise unless the user supplied it
606
+ latents, meta = self.prepare_latents(init_image, latent_timestep, scheduler,model)
607
+
608
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
609
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
610
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
611
+ # and should be between [0, 1]
612
+ accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
613
+ extra_step_kwargs = {}
614
+ if accepts_eta:
615
+ extra_step_kwargs["eta"] = eta
616
+ if create_gif:
617
+ frames = []
618
+
619
+ for i, t in enumerate(self.progress_bar(timesteps)):
620
+ if callback:
621
+ callback(i, callback_userdata)
622
+
623
+ if self.batch_size == 1:
624
+ # expand the latents if we are doing classifier free guidance
625
+ noise_pred = []
626
+ latent_model_input = latents
627
+
628
+ #Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
629
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
630
+ latent_model_input_pos = latent_model_input
631
+ latent_model_input_neg = latent_model_input
632
+
633
+ if self.unet.input(self.unet_input_tensor_name).shape[1] != 4:
634
+ try:
635
+ latent_model_input_pos = latent_model_input_pos.permute(0,2,3,1)
636
+ except:
637
+ latent_model_input_pos = latent_model_input_pos.transpose(0,2,3,1)
638
+
639
+ if self.unet_neg.input(self.unet_input_tensor_name).shape[1] != 4:
640
+ try:
641
+ latent_model_input_neg = latent_model_input_neg.permute(0,2,3,1)
642
+ except:
643
+ latent_model_input_neg = latent_model_input_neg.transpose(0,2,3,1)
644
+
645
+ if "sample" in self.unet_input_tensor_name:
646
+ input_tens_neg_dict = {"sample" : latent_model_input_neg, "encoder_hidden_states": np.expand_dims(text_embeddings[0], axis=0), "timestep": np.expand_dims(np.float32(t), axis=0)}
647
+ input_tens_pos_dict = {"sample" : latent_model_input_pos, "encoder_hidden_states": np.expand_dims(text_embeddings[1], axis=0), "timestep": np.expand_dims(np.float32(t), axis=0)}
648
+ else:
649
+ input_tens_neg_dict = {"latent_model_input" : latent_model_input_neg, "encoder_hidden_states": np.expand_dims(text_embeddings[0], axis=0), "t": np.expand_dims(np.float32(t), axis=0)}
650
+ input_tens_pos_dict = {"latent_model_input" : latent_model_input_pos, "encoder_hidden_states": np.expand_dims(text_embeddings[1], axis=0), "t": np.expand_dims(np.float32(t), axis=0)}
651
+
652
+ self.infer_request_neg.start_async(input_tens_neg_dict)
653
+ self.infer_request.start_async(input_tens_pos_dict)
654
+
655
+ self.infer_request_neg.wait()
656
+ self.infer_request.wait()
657
+
658
+ noise_pred_neg = self.infer_request_neg.get_output_tensor(0)
659
+ noise_pred_pos = self.infer_request.get_output_tensor(0)
660
+
661
+ noise_pred.append(noise_pred_neg.data.astype(np.float32))
662
+ noise_pred.append(noise_pred_pos.data.astype(np.float32))
663
+ else:
664
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
665
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
666
+ noise_pred = self.unet([latent_model_input, np.array(t, dtype=np.float32), text_embeddings])[self._unet_output]
667
+
668
+ if do_classifier_free_guidance:
669
+ noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
670
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
671
+
672
+ # compute the previous noisy sample x_t -> x_t-1
673
+ latents = scheduler.step(torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs)["prev_sample"].numpy()
674
+
675
+ if create_gif:
676
+ frames.append(latents)
677
+
678
+ if callback:
679
+ callback(num_inference_steps, callback_userdata)
680
+
681
+ # scale and decode the image latents with vae
682
+ #if self.height == 512 and self.width == 512:
683
+ latents = 1 / 0.18215 * latents
684
+ image = self.vae_decoder(latents)[self._vae_d_output]
685
+ image = self.postprocess_image(image, meta)
686
+
687
+ return image
688
+
689
+ def prepare_latents(self, image: PIL.Image.Image = None, latent_timestep: torch.Tensor = None,
690
+ scheduler=LMSDiscreteScheduler,model=None):
691
+ """
692
+ Function for getting initial latents for starting generation
693
+
694
+ Parameters:
695
+ image (PIL.Image.Image, *optional*, None):
696
+ Input image for generation, if not provided randon noise will be used as starting point
697
+ latent_timestep (torch.Tensor, *optional*, None):
698
+ Predicted by scheduler initial step for image generation, required for latent image mixing with nosie
699
+ Returns:
700
+ latents (np.ndarray):
701
+ Image encoded in latent space
702
+ """
703
+ latents_shape = (1, 4, self.height // 8, self.width // 8)
704
+
705
+ noise = np.random.randn(*latents_shape).astype(np.float32)
706
+ if image is None:
707
+ #print("Image is NONE")
708
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
709
+ if isinstance(scheduler, LMSDiscreteScheduler):
710
+
711
+ noise = noise * scheduler.sigmas[0].numpy()
712
+ return noise, {}
713
+ elif isinstance(scheduler, EulerDiscreteScheduler):
714
+
715
+ noise = noise * scheduler.sigmas.max().numpy()
716
+ return noise, {}
717
+ else:
718
+ return noise, {}
719
+ input_image, meta = preprocess(image, self.height, self.width)
720
+
721
+ moments = self.vae_encoder(input_image)[self._vae_e_output]
722
+
723
+ if "sd_2.1" in model:
724
+ latents = moments * 0.18215
725
+
726
+ else:
727
+
728
+ mean, logvar = np.split(moments, 2, axis=1)
729
+
730
+ std = np.exp(logvar * 0.5)
731
+ latents = (mean + std * np.random.randn(*mean.shape)) * 0.18215
732
+
733
+ latents = scheduler.add_noise(torch.from_numpy(latents), torch.from_numpy(noise), latent_timestep).numpy()
734
+ return latents, meta
735
+
736
+
737
+ def postprocess_image(self, image: np.ndarray, meta: Dict):
738
+ """
739
+ Postprocessing for decoded image. Takes generated image decoded by VAE decoder, unpad it to initila image size (if required),
740
+ normalize and convert to [0, 255] pixels range. Optionally, convertes it from np.ndarray to PIL.Image format
741
+
742
+ Parameters:
743
+ image (np.ndarray):
744
+ Generated image
745
+ meta (Dict):
746
+ Metadata obtained on latents preparing step, can be empty
747
+ output_type (str, *optional*, pil):
748
+ Output format for result, can be pil or numpy
749
+ Returns:
750
+ image (List of np.ndarray or PIL.Image.Image):
751
+ Postprocessed images
752
+
753
+ if "src_height" in meta:
754
+ orig_height, orig_width = meta["src_height"], meta["src_width"]
755
+ image = [cv2.resize(img, (orig_width, orig_height))
756
+ for img in image]
757
+
758
+ return image
759
+ """
760
+ if "padding" in meta:
761
+ pad = meta["padding"]
762
+ (_, end_h), (_, end_w) = pad[1:3]
763
+ h, w = image.shape[2:]
764
+ # print("image shape",image.shape[2:])
765
+ unpad_h = h - end_h
766
+ unpad_w = w - end_w
767
+ image = image[:, :, :unpad_h, :unpad_w]
768
+ image = np.clip(image / 2 + 0.5, 0, 1)
769
+ image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
770
+
771
+ if "src_height" in meta:
772
+ orig_height, orig_width = meta["src_height"], meta["src_width"]
773
+ image = cv2.resize(image, (orig_width, orig_height))
774
+
775
+ return image
776
+
777
+ # image = (image / 2 + 0.5).clip(0, 1)
778
+ # image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
779
+
780
+ def get_timesteps(self, num_inference_steps: int, strength: float, scheduler):
781
+ """
782
+ Helper function for getting scheduler timesteps for generation
783
+ In case of image-to-image generation, it updates number of steps according to strength
784
+
785
+ Parameters:
786
+ num_inference_steps (int):
787
+ number of inference steps for generation
788
+ strength (float):
789
+ value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
790
+ Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
791
+ """
792
+ # get the original timestep using init_timestep
793
+
794
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
795
+
796
+ t_start = max(num_inference_steps - init_timestep, 0)
797
+ timesteps = scheduler.timesteps[t_start:]
798
+
799
+ return timesteps, num_inference_steps - t_start
800
+
801
+ class LatentConsistencyEngine(DiffusionPipeline):
802
+ def __init__(
803
+ self,
804
+ model="SimianLuo/LCM_Dreamshaper_v7",
805
+ tokenizer="openai/clip-vit-large-patch14",
806
+ device=["CPU", "CPU", "CPU"],
807
+ ):
808
+ super().__init__()
809
+ try:
810
+ self.tokenizer = CLIPTokenizer.from_pretrained(model, local_files_only=True)
811
+ except:
812
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
813
+ self.tokenizer.save_pretrained(model)
814
+
815
+ self.core = Core()
816
+ self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')}) # adding caching to reduce init time
817
+ try_enable_npu_turbo(device, self.core)
818
+
819
+
820
+ with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
821
+ text_future = executor.submit(self.load_model, model, "text_encoder", device[0])
822
+ unet_future = executor.submit(self.load_model, model, "unet", device[1])
823
+ vae_de_future = executor.submit(self.load_model, model, "vae_decoder", device[2])
824
+
825
+ print("Text Device:", device[0])
826
+ self.text_encoder = text_future.result()
827
+ self._text_encoder_output = self.text_encoder.output(0)
828
+
829
+ print("Unet Device:", device[1])
830
+ self.unet = unet_future.result()
831
+ self._unet_output = self.unet.output(0)
832
+ self.infer_request = self.unet.create_infer_request()
833
+
834
+ print(f"VAE Device: {device[2]}")
835
+ self.vae_decoder = vae_de_future.result()
836
+ self.infer_request_vae = self.vae_decoder.create_infer_request()
837
+ self.safety_checker = None #pipe.safety_checker
838
+ self.feature_extractor = None #pipe.feature_extractor
839
+ self.vae_scale_factor = 2 ** 3
840
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
841
+
842
+ def load_model(self, model, model_name, device):
843
+ if "NPU" in device:
844
+ with open(os.path.join(model, f"{model_name}.blob"), "rb") as f:
845
+ return self.core.import_model(f.read(), device)
846
+ return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
847
+
848
+ def _encode_prompt(
849
+ self,
850
+ prompt,
851
+ num_images_per_prompt,
852
+ prompt_embeds: None,
853
+ ):
854
+ r"""
855
+ Encodes the prompt into text encoder hidden states.
856
+ Args:
857
+ prompt (`str` or `List[str]`, *optional*):
858
+ prompt to be encoded
859
+ num_images_per_prompt (`int`):
860
+ number of images that should be generated per prompt
861
+ prompt_embeds (`torch.FloatTensor`, *optional*):
862
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
863
+ provided, text embeddings will be generated from `prompt` input argument.
864
+ """
865
+
866
+ if prompt_embeds is None:
867
+
868
+ text_inputs = self.tokenizer(
869
+ prompt,
870
+ padding="max_length",
871
+ max_length=self.tokenizer.model_max_length,
872
+ truncation=True,
873
+ return_tensors="pt",
874
+ )
875
+ text_input_ids = text_inputs.input_ids
876
+ untruncated_ids = self.tokenizer(
877
+ prompt, padding="longest", return_tensors="pt"
878
+ ).input_ids
879
+
880
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
881
+ -1
882
+ ] and not torch.equal(text_input_ids, untruncated_ids):
883
+ removed_text = self.tokenizer.batch_decode(
884
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
885
+ )
886
+ logger.warning(
887
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
888
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
889
+ )
890
+
891
+ prompt_embeds = self.text_encoder(text_input_ids, share_inputs=True, share_outputs=True)
892
+ prompt_embeds = torch.from_numpy(prompt_embeds[0])
893
+
894
+ bs_embed, seq_len, _ = prompt_embeds.shape
895
+ # duplicate text embeddings for each generation per prompt
896
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
897
+ prompt_embeds = prompt_embeds.view(
898
+ bs_embed * num_images_per_prompt, seq_len, -1
899
+ )
900
+
901
+ # Don't need to get uncond prompt embedding because of LCM Guided Distillation
902
+ return prompt_embeds
903
+
904
+ def run_safety_checker(self, image, dtype):
905
+ if self.safety_checker is None:
906
+ has_nsfw_concept = None
907
+ else:
908
+ if torch.is_tensor(image):
909
+ feature_extractor_input = self.image_processor.postprocess(
910
+ image, output_type="pil"
911
+ )
912
+ else:
913
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
914
+ safety_checker_input = self.feature_extractor(
915
+ feature_extractor_input, return_tensors="pt"
916
+ )
917
+ image, has_nsfw_concept = self.safety_checker(
918
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
919
+ )
920
+ return image, has_nsfw_concept
921
+
922
+ def prepare_latents(
923
+ self, batch_size, num_channels_latents, height, width, dtype, latents=None
924
+ ):
925
+ shape = (
926
+ batch_size,
927
+ num_channels_latents,
928
+ height // self.vae_scale_factor,
929
+ width // self.vae_scale_factor,
930
+ )
931
+ if latents is None:
932
+ latents = torch.randn(shape, dtype=dtype)
933
+ # scale the initial noise by the standard deviation required by the scheduler
934
+ return latents
935
+
936
+ def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
937
+ """
938
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
939
+ Args:
940
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
941
+ embedding_dim: int: dimension of the embeddings to generate
942
+ dtype: data type of the generated embeddings
943
+ Returns:
944
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
945
+ """
946
+ assert len(w.shape) == 1
947
+ w = w * 1000.0
948
+
949
+ half_dim = embedding_dim // 2
950
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
951
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
952
+ emb = w.to(dtype)[:, None] * emb[None, :]
953
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
954
+ if embedding_dim % 2 == 1: # zero pad
955
+ emb = torch.nn.functional.pad(emb, (0, 1))
956
+ assert emb.shape == (w.shape[0], embedding_dim)
957
+ return emb
958
+
959
+ @torch.no_grad()
960
+ def __call__(
961
+ self,
962
+ prompt: Union[str, List[str]] = None,
963
+ height: Optional[int] = 512,
964
+ width: Optional[int] = 512,
965
+ guidance_scale: float = 7.5,
966
+ scheduler = None,
967
+ num_images_per_prompt: Optional[int] = 1,
968
+ latents: Optional[torch.FloatTensor] = None,
969
+ num_inference_steps: int = 4,
970
+ lcm_origin_steps: int = 50,
971
+ prompt_embeds: Optional[torch.FloatTensor] = None,
972
+ output_type: Optional[str] = "pil",
973
+ return_dict: bool = True,
974
+ model: Optional[Dict[str, any]] = None,
975
+ seed: Optional[int] = 1234567,
976
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
977
+ callback = None,
978
+ callback_userdata = None
979
+ ):
980
+
981
+ # 1. Define call parameters
982
+ if prompt is not None and isinstance(prompt, str):
983
+ batch_size = 1
984
+ elif prompt is not None and isinstance(prompt, list):
985
+ batch_size = len(prompt)
986
+ else:
987
+ batch_size = prompt_embeds.shape[0]
988
+
989
+ if seed is not None:
990
+ torch.manual_seed(seed)
991
+
992
+ #print("After Step 1: batch size is ", batch_size)
993
+ # do_classifier_free_guidance = guidance_scale > 0.0
994
+ # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
995
+
996
+ # 2. Encode input prompt
997
+ prompt_embeds = self._encode_prompt(
998
+ prompt,
999
+ num_images_per_prompt,
1000
+ prompt_embeds=prompt_embeds,
1001
+ )
1002
+ #print("After Step 2: prompt embeds is ", prompt_embeds)
1003
+ #print("After Step 2: scheduler is ", scheduler )
1004
+ # 3. Prepare timesteps
1005
+ scheduler.set_timesteps(num_inference_steps, original_inference_steps=lcm_origin_steps)
1006
+ timesteps = scheduler.timesteps
1007
+
1008
+ #print("After Step 3: timesteps is ", timesteps)
1009
+
1010
+ # 4. Prepare latent variable
1011
+ num_channels_latents = 4
1012
+ latents = self.prepare_latents(
1013
+ batch_size * num_images_per_prompt,
1014
+ num_channels_latents,
1015
+ height,
1016
+ width,
1017
+ prompt_embeds.dtype,
1018
+ latents,
1019
+ )
1020
+ latents = latents * scheduler.init_noise_sigma
1021
+
1022
+ #print("After Step 4: ")
1023
+ bs = batch_size * num_images_per_prompt
1024
+
1025
+ # 5. Get Guidance Scale Embedding
1026
+ w = torch.tensor(guidance_scale).repeat(bs)
1027
+ w_embedding = self.get_w_embedding(w, embedding_dim=256)
1028
+ #print("After Step 5: ")
1029
+ # 6. LCM MultiStep Sampling Loop:
1030
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1031
+ for i, t in enumerate(timesteps):
1032
+ if callback:
1033
+ callback(i+1, callback_userdata)
1034
+
1035
+ ts = torch.full((bs,), t, dtype=torch.long)
1036
+
1037
+ # model prediction (v-prediction, eps, x)
1038
+ model_pred = self.unet([latents, ts, prompt_embeds, w_embedding],share_inputs=True, share_outputs=True)[0]
1039
+
1040
+ # compute the previous noisy sample x_t -> x_t-1
1041
+ latents, denoised = scheduler.step(
1042
+ torch.from_numpy(model_pred), t, latents, return_dict=False
1043
+ )
1044
+ progress_bar.update()
1045
+
1046
+ #print("After Step 6: ")
1047
+
1048
+ vae_start = time.time()
1049
+
1050
+ if not output_type == "latent":
1051
+ image = torch.from_numpy(self.vae_decoder(denoised / 0.18215, share_inputs=True, share_outputs=True)[0])
1052
+ else:
1053
+ image = denoised
1054
+
1055
+ print("Decoder Ended: ", time.time() - vae_start)
1056
+ #post_start = time.time()
1057
+
1058
+ #if has_nsfw_concept is None:
1059
+ do_denormalize = [True] * image.shape[0]
1060
+ #else:
1061
+ # do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1062
+
1063
+ #print ("After do_denormalize: image is ", image)
1064
+
1065
+ image = self.image_processor.postprocess(
1066
+ image, output_type=output_type, do_denormalize=do_denormalize
1067
+ )
1068
+
1069
+ return image[0]
1070
+
1071
+ class LatentConsistencyEngineAdvanced(DiffusionPipeline):
1072
+ def __init__(
1073
+ self,
1074
+ model="SimianLuo/LCM_Dreamshaper_v7",
1075
+ tokenizer="openai/clip-vit-large-patch14",
1076
+ device=["CPU", "CPU", "CPU"],
1077
+ ):
1078
+ super().__init__()
1079
+ try:
1080
+ self.tokenizer = CLIPTokenizer.from_pretrained(model, local_files_only=True)
1081
+ except:
1082
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
1083
+ self.tokenizer.save_pretrained(model)
1084
+
1085
+ self.core = Core()
1086
+ self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')}) # adding caching to reduce init time
1087
+ #try_enable_npu_turbo(device, self.core)
1088
+
1089
+
1090
+ with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
1091
+ text_future = executor.submit(self.load_model, model, "text_encoder", device[0])
1092
+ unet_future = executor.submit(self.load_model, model, "unet", device[1])
1093
+ vae_de_future = executor.submit(self.load_model, model, "vae_decoder", device[2])
1094
+ vae_encoder_future = executor.submit(self.load_model, model, "vae_encoder", device[2])
1095
+
1096
+
1097
+ print("Text Device:", device[0])
1098
+ self.text_encoder = text_future.result()
1099
+ self._text_encoder_output = self.text_encoder.output(0)
1100
+
1101
+ print("Unet Device:", device[1])
1102
+ self.unet = unet_future.result()
1103
+ self._unet_output = self.unet.output(0)
1104
+ self.infer_request = self.unet.create_infer_request()
1105
+
1106
+ print(f"VAE Device: {device[2]}")
1107
+ self.vae_decoder = vae_de_future.result()
1108
+ self.vae_encoder = vae_encoder_future.result()
1109
+ self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder else None
1110
+
1111
+ self.infer_request_vae = self.vae_decoder.create_infer_request()
1112
+ self.safety_checker = None #pipe.safety_checker
1113
+ self.feature_extractor = None #pipe.feature_extractor
1114
+ self.vae_scale_factor = 2 ** 3
1115
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
1116
+
1117
+ def load_model(self, model, model_name, device):
1118
+ print(f"Compiling the {model_name} to {device} ...")
1119
+ return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
1120
+
1121
+ def get_timesteps(self, num_inference_steps:int, strength:float, scheduler):
1122
+ """
1123
+ Helper function for getting scheduler timesteps for generation
1124
+ In case of image-to-image generation, it updates number of steps according to strength
1125
+
1126
+ Parameters:
1127
+ num_inference_steps (int):
1128
+ number of inference steps for generation
1129
+ strength (float):
1130
+ value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
1131
+ Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
1132
+ """
1133
+ # get the original timestep using init_timestep
1134
+
1135
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
1136
+
1137
+ t_start = max(num_inference_steps - init_timestep, 0)
1138
+ timesteps = scheduler.timesteps[t_start:]
1139
+
1140
+ return timesteps, num_inference_steps - t_start
1141
+
1142
+ def _encode_prompt(
1143
+ self,
1144
+ prompt,
1145
+ num_images_per_prompt,
1146
+ prompt_embeds: None,
1147
+ ):
1148
+ r"""
1149
+ Encodes the prompt into text encoder hidden states.
1150
+ Args:
1151
+ prompt (`str` or `List[str]`, *optional*):
1152
+ prompt to be encoded
1153
+ num_images_per_prompt (`int`):
1154
+ number of images that should be generated per prompt
1155
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1156
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1157
+ provided, text embeddings will be generated from `prompt` input argument.
1158
+ """
1159
+
1160
+ if prompt_embeds is None:
1161
+
1162
+ text_inputs = self.tokenizer(
1163
+ prompt,
1164
+ padding="max_length",
1165
+ max_length=self.tokenizer.model_max_length,
1166
+ truncation=True,
1167
+ return_tensors="pt",
1168
+ )
1169
+ text_input_ids = text_inputs.input_ids
1170
+ untruncated_ids = self.tokenizer(
1171
+ prompt, padding="longest", return_tensors="pt"
1172
+ ).input_ids
1173
+
1174
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
1175
+ -1
1176
+ ] and not torch.equal(text_input_ids, untruncated_ids):
1177
+ removed_text = self.tokenizer.batch_decode(
1178
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
1179
+ )
1180
+ logger.warning(
1181
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
1182
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
1183
+ )
1184
+
1185
+ prompt_embeds = self.text_encoder(text_input_ids, share_inputs=True, share_outputs=True)
1186
+ prompt_embeds = torch.from_numpy(prompt_embeds[0])
1187
+
1188
+ bs_embed, seq_len, _ = prompt_embeds.shape
1189
+ # duplicate text embeddings for each generation per prompt
1190
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
1191
+ prompt_embeds = prompt_embeds.view(
1192
+ bs_embed * num_images_per_prompt, seq_len, -1
1193
+ )
1194
+
1195
+ # Don't need to get uncond prompt embedding because of LCM Guided Distillation
1196
+ return prompt_embeds
1197
+
1198
+ def run_safety_checker(self, image, dtype):
1199
+ if self.safety_checker is None:
1200
+ has_nsfw_concept = None
1201
+ else:
1202
+ if torch.is_tensor(image):
1203
+ feature_extractor_input = self.image_processor.postprocess(
1204
+ image, output_type="pil"
1205
+ )
1206
+ else:
1207
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
1208
+ safety_checker_input = self.feature_extractor(
1209
+ feature_extractor_input, return_tensors="pt"
1210
+ )
1211
+ image, has_nsfw_concept = self.safety_checker(
1212
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
1213
+ )
1214
+ return image, has_nsfw_concep
1215
+
1216
+ def prepare_latents(
1217
+ self,image,timestep,batch_size, num_channels_latents, height, width, dtype, scheduler,latents=None,
1218
+ ):
1219
+ shape = (
1220
+ batch_size,
1221
+ num_channels_latents,
1222
+ height // self.vae_scale_factor,
1223
+ width // self.vae_scale_factor,
1224
+ )
1225
+ if image:
1226
+ #latents_shape = (1, 4, 512, 512 // 8)
1227
+ #input_image, meta = preprocess(image,512,512)
1228
+ latents_shape = (1, 4, 512 // 8, 512 // 8)
1229
+ noise = np.random.randn(*latents_shape).astype(np.float32)
1230
+ input_image,meta = preprocess(image,512,512)
1231
+ moments = self.vae_encoder(input_image)[self._vae_e_output]
1232
+ mean, logvar = np.split(moments, 2, axis=1)
1233
+ std = np.exp(logvar * 0.5)
1234
+ latents = (mean + std * np.random.randn(*mean.shape)) * 0.18215
1235
+ noise = torch.randn(shape, dtype=dtype)
1236
+ #latents = scheduler.add_noise(init_latents, noise, timestep)
1237
+ latents = scheduler.add_noise(torch.from_numpy(latents), noise, timestep)
1238
+
1239
+ else:
1240
+ latents = torch.randn(shape, dtype=dtype)
1241
+ # scale the initial noise by the standard deviation required by the scheduler
1242
+ return latents
1243
+
1244
+ def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
1245
+ """
1246
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
1247
+ Args:
1248
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
1249
+ embedding_dim: int: dimension of the embeddings to generate
1250
+ dtype: data type of the generated embeddings
1251
+ Returns:
1252
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
1253
+ """
1254
+ assert len(w.shape) == 1
1255
+ w = w * 1000.0
1256
+
1257
+ half_dim = embedding_dim // 2
1258
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
1259
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
1260
+ emb = w.to(dtype)[:, None] * emb[None, :]
1261
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
1262
+ if embedding_dim % 2 == 1: # zero pad
1263
+ emb = torch.nn.functional.pad(emb, (0, 1))
1264
+ assert emb.shape == (w.shape[0], embedding_dim)
1265
+ return emb
1266
+
1267
+ @torch.no_grad()
1268
+ def __call__(
1269
+ self,
1270
+ prompt: Union[str, List[str]] = None,
1271
+ init_image: Optional[PIL.Image.Image] = None,
1272
+ strength: Optional[float] = 0.8,
1273
+ height: Optional[int] = 512,
1274
+ width: Optional[int] = 512,
1275
+ guidance_scale: float = 7.5,
1276
+ scheduler = None,
1277
+ num_images_per_prompt: Optional[int] = 1,
1278
+ latents: Optional[torch.FloatTensor] = None,
1279
+ num_inference_steps: int = 4,
1280
+ lcm_origin_steps: int = 50,
1281
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1282
+ output_type: Optional[str] = "pil",
1283
+ return_dict: bool = True,
1284
+ model: Optional[Dict[str, any]] = None,
1285
+ seed: Optional[int] = 1234567,
1286
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1287
+ callback = None,
1288
+ callback_userdata = None
1289
+ ):
1290
+
1291
+ # 1. Define call parameters
1292
+ if prompt is not None and isinstance(prompt, str):
1293
+ batch_size = 1
1294
+ elif prompt is not None and isinstance(prompt, list):
1295
+ batch_size = len(prompt)
1296
+ else:
1297
+ batch_size = prompt_embeds.shape[0]
1298
+
1299
+ if seed is not None:
1300
+ torch.manual_seed(seed)
1301
+
1302
+ #print("After Step 1: batch size is ", batch_size)
1303
+ # do_classifier_free_guidance = guidance_scale > 0.0
1304
+ # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
1305
+
1306
+ # 2. Encode input prompt
1307
+ prompt_embeds = self._encode_prompt(
1308
+ prompt,
1309
+ num_images_per_prompt,
1310
+ prompt_embeds=prompt_embeds,
1311
+ )
1312
+ #print("After Step 2: prompt embeds is ", prompt_embeds)
1313
+ #print("After Step 2: scheduler is ", scheduler )
1314
+ # 3. Prepare timesteps
1315
+ #scheduler.set_timesteps(num_inference_steps, original_inference_steps=lcm_origin_steps)
1316
+ latent_timestep = None
1317
+ if init_image:
1318
+ scheduler.set_timesteps(num_inference_steps, original_inference_steps=lcm_origin_steps)
1319
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
1320
+ latent_timestep = timesteps[:1]
1321
+ else:
1322
+ scheduler.set_timesteps(num_inference_steps, original_inference_steps=lcm_origin_steps)
1323
+ timesteps = scheduler.timesteps
1324
+ #timesteps = scheduler.timesteps
1325
+ #latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1326
+ #print("timesteps: ", latent_timestep)
1327
+
1328
+ #print("After Step 3: timesteps is ", timesteps)
1329
+
1330
+ # 4. Prepare latent variable
1331
+ num_channels_latents = 4
1332
+ latents = self.prepare_latents(
1333
+ init_image,
1334
+ latent_timestep,
1335
+ batch_size * num_images_per_prompt,
1336
+ num_channels_latents,
1337
+ height,
1338
+ width,
1339
+ prompt_embeds.dtype,
1340
+ scheduler,
1341
+ latents,
1342
+ )
1343
+
1344
+ latents = latents * scheduler.init_noise_sigma
1345
+
1346
+ #print("After Step 4: ")
1347
+ bs = batch_size * num_images_per_prompt
1348
+
1349
+ # 5. Get Guidance Scale Embedding
1350
+ w = torch.tensor(guidance_scale).repeat(bs)
1351
+ w_embedding = self.get_w_embedding(w, embedding_dim=256)
1352
+ #print("After Step 5: ")
1353
+ # 6. LCM MultiStep Sampling Loop:
1354
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1355
+ for i, t in enumerate(timesteps):
1356
+ if callback:
1357
+ callback(i+1, callback_userdata)
1358
+
1359
+ ts = torch.full((bs,), t, dtype=torch.long)
1360
+
1361
+ # model prediction (v-prediction, eps, x)
1362
+ model_pred = self.unet([latents, ts, prompt_embeds, w_embedding],share_inputs=True, share_outputs=True)[0]
1363
+
1364
+ # compute the previous noisy sample x_t -> x_t-1
1365
+ latents, denoised = scheduler.step(
1366
+ torch.from_numpy(model_pred), t, latents, return_dict=False
1367
+ )
1368
+ progress_bar.update()
1369
+
1370
+ #print("After Step 6: ")
1371
+
1372
+ vae_start = time.time()
1373
+
1374
+ if not output_type == "latent":
1375
+ image = torch.from_numpy(self.vae_decoder(denoised / 0.18215, share_inputs=True, share_outputs=True)[0])
1376
+ else:
1377
+ image = denoised
1378
+
1379
+ print("Decoder Ended: ", time.time() - vae_start)
1380
+ #post_start = time.time()
1381
+
1382
+ #if has_nsfw_concept is None:
1383
+ do_denormalize = [True] * image.shape[0]
1384
+ #else:
1385
+ # do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1386
+
1387
+ #print ("After do_denormalize: image is ", image)
1388
+
1389
+ image = self.image_processor.postprocess(
1390
+ image, output_type=output_type, do_denormalize=do_denormalize
1391
+ )
1392
+
1393
+ return image[0]
1394
+
1395
+ class StableDiffusionEngineReferenceOnly(DiffusionPipeline):
1396
+ def __init__(
1397
+ self,
1398
+ #scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
1399
+ model="bes-dev/stable-diffusion-v1-4-openvino",
1400
+ tokenizer="openai/clip-vit-large-patch14",
1401
+ device=["CPU","CPU","CPU"]
1402
+ ):
1403
+ #self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
1404
+ try:
1405
+ self.tokenizer = CLIPTokenizer.from_pretrained(model,local_files_only=True)
1406
+ except:
1407
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
1408
+ self.tokenizer.save_pretrained(model)
1409
+
1410
+ #self.scheduler = scheduler
1411
+ # models
1412
+
1413
+ self.core = Core()
1414
+ self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')}) #adding caching to reduce init time
1415
+ # text features
1416
+
1417
+ print("Text Device:",device[0])
1418
+ self.text_encoder = self.core.compile_model(os.path.join(model, "text_encoder.xml"), device[0])
1419
+
1420
+ self._text_encoder_output = self.text_encoder.output(0)
1421
+
1422
+ # diffusion
1423
+ print("unet_w Device:",device[1])
1424
+ self.unet_w = self.core.compile_model(os.path.join(model, "unet_reference_write.xml"), device[1])
1425
+ self._unet_w_output = self.unet_w.output(0)
1426
+ self.latent_shape = tuple(self.unet_w.inputs[0].shape)[1:]
1427
+
1428
+ print("unet_r Device:",device[1])
1429
+ self.unet_r = self.core.compile_model(os.path.join(model, "unet_reference_read.xml"), device[1])
1430
+ self._unet_r_output = self.unet_r.output(0)
1431
+ # decoder
1432
+ print("Vae Device:",device[2])
1433
+
1434
+ self.vae_decoder = self.core.compile_model(os.path.join(model, "vae_decoder.xml"), device[2])
1435
+
1436
+ # encoder
1437
+
1438
+ self.vae_encoder = self.core.compile_model(os.path.join(model, "vae_encoder.xml"), device[2])
1439
+
1440
+ self.init_image_shape = tuple(self.vae_encoder.inputs[0].shape)[2:]
1441
+
1442
+ self._vae_d_output = self.vae_decoder.output(0)
1443
+ self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder is not None else None
1444
+
1445
+ self.height = self.unet_w.input(0).shape[2] * 8
1446
+ self.width = self.unet_w.input(0).shape[3] * 8
1447
+
1448
+
1449
+
1450
+ def __call__(
1451
+ self,
1452
+ prompt,
1453
+ image = None,
1454
+ negative_prompt=None,
1455
+ scheduler=None,
1456
+ strength = 1.0,
1457
+ num_inference_steps = 32,
1458
+ guidance_scale = 7.5,
1459
+ eta = 0.0,
1460
+ create_gif = False,
1461
+ model = None,
1462
+ callback = None,
1463
+ callback_userdata = None
1464
+ ):
1465
+ # extract condition
1466
+ text_input = self.tokenizer(
1467
+ prompt,
1468
+ padding="max_length",
1469
+ max_length=self.tokenizer.model_max_length,
1470
+ truncation=True,
1471
+ return_tensors="np",
1472
+ )
1473
+ text_embeddings = self.text_encoder(text_input.input_ids)[self._text_encoder_output]
1474
+
1475
+
1476
+ # do classifier free guidance
1477
+ do_classifier_free_guidance = guidance_scale > 1.0
1478
+ if do_classifier_free_guidance:
1479
+
1480
+ if negative_prompt is None:
1481
+ uncond_tokens = [""]
1482
+ elif isinstance(negative_prompt, str):
1483
+ uncond_tokens = [negative_prompt]
1484
+ else:
1485
+ uncond_tokens = negative_prompt
1486
+
1487
+ tokens_uncond = self.tokenizer(
1488
+ uncond_tokens,
1489
+ padding="max_length",
1490
+ max_length=self.tokenizer.model_max_length, #truncation=True,
1491
+ return_tensors="np"
1492
+ )
1493
+ uncond_embeddings = self.text_encoder(tokens_uncond.input_ids)[self._text_encoder_output]
1494
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
1495
+
1496
+ # set timesteps
1497
+ accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
1498
+ extra_set_kwargs = {}
1499
+
1500
+ if accepts_offset:
1501
+ extra_set_kwargs["offset"] = 1
1502
+
1503
+ scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
1504
+
1505
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
1506
+ latent_timestep = timesteps[:1]
1507
+
1508
+ ref_image = self.prepare_image(
1509
+ image=image,
1510
+ width=512,
1511
+ height=512,
1512
+ )
1513
+ # get the initial random noise unless the user supplied it
1514
+ latents, meta = self.prepare_latents(None, latent_timestep, scheduler)
1515
+ #ref_image_latents, _ = self.prepare_latents(init_image, latent_timestep, scheduler)
1516
+ ref_image_latents = self.ov_prepare_ref_latents(ref_image)
1517
+
1518
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
1519
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
1520
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
1521
+ # and should be between [0, 1]
1522
+ accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
1523
+ extra_step_kwargs = {}
1524
+ if accepts_eta:
1525
+ extra_step_kwargs["eta"] = eta
1526
+ if create_gif:
1527
+ frames = []
1528
+
1529
+ for i, t in enumerate(self.progress_bar(timesteps)):
1530
+ if callback:
1531
+ callback(i, callback_userdata)
1532
+
1533
+ # expand the latents if we are doing classifier free guidance
1534
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
1535
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
1536
+
1537
+ # ref only part
1538
+ noise = randn_tensor(
1539
+ ref_image_latents.shape
1540
+ )
1541
+
1542
+ ref_xt = scheduler.add_noise(
1543
+ torch.from_numpy(ref_image_latents),
1544
+ noise,
1545
+ t.reshape(
1546
+ 1,
1547
+ ),
1548
+ ).numpy()
1549
+ ref_xt = np.concatenate([ref_xt] * 2) if do_classifier_free_guidance else ref_xt
1550
+ ref_xt = scheduler.scale_model_input(ref_xt, t)
1551
+
1552
+ # MODE = "write"
1553
+ result_w_dict = self.unet_w([
1554
+ ref_xt,
1555
+ t,
1556
+ text_embeddings
1557
+ ])
1558
+ down_0_attn0 = result_w_dict["/unet/down_blocks.0/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1559
+ down_0_attn1 = result_w_dict["/unet/down_blocks.0/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1560
+ down_1_attn0 = result_w_dict["/unet/down_blocks.1/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1561
+ down_1_attn1 = result_w_dict["/unet/down_blocks.1/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1562
+ down_2_attn0 = result_w_dict["/unet/down_blocks.2/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1563
+ down_2_attn1 = result_w_dict["/unet/down_blocks.2/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1564
+ mid_attn0 = result_w_dict["/unet/mid_block/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1565
+ up_1_attn0 = result_w_dict["/unet/up_blocks.1/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1566
+ up_1_attn1 = result_w_dict["/unet/up_blocks.1/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1567
+ up_1_attn2 = result_w_dict["/unet/up_blocks.1/attentions.2/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1568
+ up_2_attn0 = result_w_dict["/unet/up_blocks.2/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1569
+ up_2_attn1 = result_w_dict["/unet/up_blocks.2/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1570
+ up_2_attn2 = result_w_dict["/unet/up_blocks.2/attentions.2/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1571
+ up_3_attn0 = result_w_dict["/unet/up_blocks.3/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1572
+ up_3_attn1 = result_w_dict["/unet/up_blocks.3/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1573
+ up_3_attn2 = result_w_dict["/unet/up_blocks.3/attentions.2/transformer_blocks.0/norm1/LayerNormalization_output_0"]
1574
+
1575
+ # MODE = "read"
1576
+ noise_pred = self.unet_r([
1577
+ latent_model_input, t, text_embeddings, down_0_attn0, down_0_attn1, down_1_attn0,
1578
+ down_1_attn1, down_2_attn0, down_2_attn1, mid_attn0, up_1_attn0, up_1_attn1, up_1_attn2,
1579
+ up_2_attn0, up_2_attn1, up_2_attn2, up_3_attn0, up_3_attn1, up_3_attn2
1580
+ ])[0]
1581
+
1582
+ # perform guidance
1583
+ if do_classifier_free_guidance:
1584
+ noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
1585
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1586
+
1587
+ # compute the previous noisy sample x_t -> x_t-1
1588
+ latents = scheduler.step(torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs)["prev_sample"].numpy()
1589
+
1590
+ if create_gif:
1591
+ frames.append(latents)
1592
+
1593
+ if callback:
1594
+ callback(num_inference_steps, callback_userdata)
1595
+
1596
+ # scale and decode the image latents with vae
1597
+
1598
+ image = self.vae_decoder(latents)[self._vae_d_output]
1599
+
1600
+ image = self.postprocess_image(image, meta)
1601
+
1602
+ if create_gif:
1603
+ gif_folder=os.path.join(model,"../../../gif")
1604
+ if not os.path.exists(gif_folder):
1605
+ os.makedirs(gif_folder)
1606
+ for i in range(0,len(frames)):
1607
+ image = self.vae_decoder(frames[i])[self._vae_d_output]
1608
+ image = self.postprocess_image(image, meta)
1609
+ output = gif_folder + "/" + str(i).zfill(3) +".png"
1610
+ cv2.imwrite(output, image)
1611
+ with open(os.path.join(gif_folder, "prompt.json"), "w") as file:
1612
+ json.dump({"prompt": prompt}, file)
1613
+ frames_image = [Image.open(image) for image in glob.glob(f"{gif_folder}/*.png")]
1614
+ frame_one = frames_image[0]
1615
+ gif_file=os.path.join(gif_folder,"stable_diffusion.gif")
1616
+ frame_one.save(gif_file, format="GIF", append_images=frames_image, save_all=True, duration=100, loop=0)
1617
+
1618
+ return image
1619
+
1620
+ def ov_prepare_ref_latents(self, refimage, vae_scaling_factor=0.18215):
1621
+ #refimage = refimage.to(device=device, dtype=dtype)
1622
+
1623
+ # encode the mask image into latents space so we can concatenate it to the latents
1624
+ moments = self.vae_encoder(refimage)[0]
1625
+ mean, logvar = np.split(moments, 2, axis=1)
1626
+ std = np.exp(logvar * 0.5)
1627
+ ref_image_latents = (mean + std * np.random.randn(*mean.shape))
1628
+ ref_image_latents = vae_scaling_factor * ref_image_latents
1629
+ #ref_image_latents = scheduler.add_noise(torch.from_numpy(ref_image_latents), torch.from_numpy(noise), latent_timestep).numpy()
1630
+
1631
+ # aligning device to prevent device errors when concating it with the latent model input
1632
+ #ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
1633
+ return ref_image_latents
1634
+
1635
+ def prepare_latents(self, image:PIL.Image.Image = None, latent_timestep:torch.Tensor = None, scheduler = LMSDiscreteScheduler):
1636
+ """
1637
+ Function for getting initial latents for starting generation
1638
+
1639
+ Parameters:
1640
+ image (PIL.Image.Image, *optional*, None):
1641
+ Input image for generation, if not provided randon noise will be used as starting point
1642
+ latent_timestep (torch.Tensor, *optional*, None):
1643
+ Predicted by scheduler initial step for image generation, required for latent image mixing with nosie
1644
+ Returns:
1645
+ latents (np.ndarray):
1646
+ Image encoded in latent space
1647
+ """
1648
+ latents_shape = (1, 4, self.height // 8, self.width // 8)
1649
+
1650
+ noise = np.random.randn(*latents_shape).astype(np.float32)
1651
+ if image is None:
1652
+ #print("Image is NONE")
1653
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
1654
+ if isinstance(scheduler, LMSDiscreteScheduler):
1655
+
1656
+ noise = noise * scheduler.sigmas[0].numpy()
1657
+ return noise, {}
1658
+ elif isinstance(scheduler, EulerDiscreteScheduler):
1659
+
1660
+ noise = noise * scheduler.sigmas.max().numpy()
1661
+ return noise, {}
1662
+ else:
1663
+ return noise, {}
1664
+ input_image, meta = preprocess(image,self.height,self.width)
1665
+
1666
+ moments = self.vae_encoder(input_image)[self._vae_e_output]
1667
+
1668
+ mean, logvar = np.split(moments, 2, axis=1)
1669
+
1670
+ std = np.exp(logvar * 0.5)
1671
+ latents = (mean + std * np.random.randn(*mean.shape)) * 0.18215
1672
+
1673
+
1674
+ latents = scheduler.add_noise(torch.from_numpy(latents), torch.from_numpy(noise), latent_timestep).numpy()
1675
+ return latents, meta
1676
+
1677
+ def postprocess_image(self, image:np.ndarray, meta:Dict):
1678
+ """
1679
+ Postprocessing for decoded image. Takes generated image decoded by VAE decoder, unpad it to initila image size (if required),
1680
+ normalize and convert to [0, 255] pixels range. Optionally, convertes it from np.ndarray to PIL.Image format
1681
+
1682
+ Parameters:
1683
+ image (np.ndarray):
1684
+ Generated image
1685
+ meta (Dict):
1686
+ Metadata obtained on latents preparing step, can be empty
1687
+ output_type (str, *optional*, pil):
1688
+ Output format for result, can be pil or numpy
1689
+ Returns:
1690
+ image (List of np.ndarray or PIL.Image.Image):
1691
+ Postprocessed images
1692
+
1693
+ if "src_height" in meta:
1694
+ orig_height, orig_width = meta["src_height"], meta["src_width"]
1695
+ image = [cv2.resize(img, (orig_width, orig_height))
1696
+ for img in image]
1697
+
1698
+ return image
1699
+ """
1700
+ if "padding" in meta:
1701
+ pad = meta["padding"]
1702
+ (_, end_h), (_, end_w) = pad[1:3]
1703
+ h, w = image.shape[2:]
1704
+ #print("image shape",image.shape[2:])
1705
+ unpad_h = h - end_h
1706
+ unpad_w = w - end_w
1707
+ image = image[:, :, :unpad_h, :unpad_w]
1708
+ image = np.clip(image / 2 + 0.5, 0, 1)
1709
+ image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
1710
+
1711
+
1712
+
1713
+ if "src_height" in meta:
1714
+ orig_height, orig_width = meta["src_height"], meta["src_width"]
1715
+ image = cv2.resize(image, (orig_width, orig_height))
1716
+
1717
+ return image
1718
+
1719
+
1720
+ #image = (image / 2 + 0.5).clip(0, 1)
1721
+ #image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
1722
+
1723
+
1724
+ def get_timesteps(self, num_inference_steps:int, strength:float, scheduler):
1725
+ """
1726
+ Helper function for getting scheduler timesteps for generation
1727
+ In case of image-to-image generation, it updates number of steps according to strength
1728
+
1729
+ Parameters:
1730
+ num_inference_steps (int):
1731
+ number of inference steps for generation
1732
+ strength (float):
1733
+ value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
1734
+ Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
1735
+ """
1736
+ # get the original timestep using init_timestep
1737
+
1738
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
1739
+
1740
+ t_start = max(num_inference_steps - init_timestep, 0)
1741
+ timesteps = scheduler.timesteps[t_start:]
1742
+
1743
+ return timesteps, num_inference_steps - t_start
1744
+ def prepare_image(
1745
+ self,
1746
+ image,
1747
+ width,
1748
+ height,
1749
+ do_classifier_free_guidance=False,
1750
+ guess_mode=False,
1751
+ ):
1752
+ if not isinstance(image, np.ndarray):
1753
+ if isinstance(image, PIL.Image.Image):
1754
+ image = [image]
1755
+
1756
+ if isinstance(image[0], PIL.Image.Image):
1757
+ images = []
1758
+
1759
+ for image_ in image:
1760
+ image_ = image_.convert("RGB")
1761
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
1762
+ image_ = np.array(image_)
1763
+ image_ = image_[None, :]
1764
+ images.append(image_)
1765
+
1766
+ image = images
1767
+
1768
+ image = np.concatenate(image, axis=0)
1769
+ image = np.array(image).astype(np.float32) / 255.0
1770
+ image = (image - 0.5) / 0.5
1771
+ image = image.transpose(0, 3, 1, 2)
1772
+ elif isinstance(image[0], np.ndarray):
1773
+ image = np.concatenate(image, dim=0)
1774
+
1775
+ if do_classifier_free_guidance and not guess_mode:
1776
+ image = np.concatenate([image] * 2)
1777
+
1778
+ return image
1779
+
1780
+ def print_npu_turbo_art():
1781
+ random_number = random.randint(1, 3)
1782
+
1783
+ if random_number == 1:
1784
+ print(" ")
1785
+ print(" ___ ___ ___ ___ ___ ___ ")
1786
+ print(" /\ \ /\ \ /\ \ /\ \ /\ \ _____ /\ \ ")
1787
+ print(" \:\ \ /::\ \ \:\ \ ___ \:\ \ /::\ \ /::\ \ /::\ \ ")
1788
+ print(" \:\ \ /:/\:\__\ \:\ \ /\__\ \:\ \ /:/\:\__\ /:/\:\ \ /:/\:\ \ ")
1789
+ print(" _____\:\ \ /:/ /:/ / ___ \:\ \ /:/ / ___ \:\ \ /:/ /:/ / /:/ /::\__\ /:/ \:\ \ ")
1790
+ print(" /::::::::\__\ /:/_/:/ / /\ \ \:\__\ /:/__/ /\ \ \:\__\ /:/_/:/__/___ /:/_/:/\:|__| /:/__/ \:\__\ ")
1791
+ print(" \:\~~\~~\/__/ \:\/:/ / \:\ \ /:/ / /::\ \ \:\ \ /:/ / \:\/:::::/ / \:\/:/ /:/ / \:\ \ /:/ / ")
1792
+ print(" \:\ \ \::/__/ \:\ /:/ / /:/\:\ \ \:\ /:/ / \::/~~/~~~~ \::/_/:/ / \:\ /:/ / ")
1793
+ print(" \:\ \ \:\ \ \:\/:/ / \/__\:\ \ \:\/:/ / \:\~~\ \:\/:/ / \:\/:/ / ")
1794
+ print(" \:\__\ \:\__\ \::/ / \:\__\ \::/ / \:\__\ \::/ / \::/ / ")
1795
+ print(" \/__/ \/__/ \/__/ \/__/ \/__/ \/__/ \/__/ \/__/ ")
1796
+ print(" ")
1797
+ elif random_number == 2:
1798
+ print(" _ _ ____ _ _ _____ _ _ ____ ____ ___ ")
1799
+ print("| \ | | | _ \ | | | | |_ _| | | | | | _ \ | __ ) / _ \ ")
1800
+ print("| \| | | |_) | | | | | | | | | | | | |_) | | _ \ | | | |")
1801
+ print("| |\ | | __/ | |_| | | | | |_| | | _ < | |_) | | |_| |")
1802
+ print("|_| \_| |_| \___/ |_| \___/ |_| \_\ |____/ \___/ ")
1803
+ print(" ")
1804
+ else:
1805
+ print("")
1806
+ print(" ) ( ( ) ")
1807
+ print(" ( /( )\ ) * ) )\ ) ( ( /( ")
1808
+ print(" )\()) (()/( ( ` ) /( ( (()/( ( )\ )\()) ")
1809
+ print("((_)\ /(_)) )\ ( )(_)) )\ /(_)) )((_) ((_)\ ")
1810
+ print(" _((_) (_)) _ ((_) (_(_()) _ ((_) (_)) ((_)_ ((_) ")
1811
+ print("| \| | | _ \ | | | | |_ _| | | | | | _ \ | _ ) / _ \ ")
1812
+ print("| .` | | _/ | |_| | | | | |_| | | / | _ \ | (_) | ")
1813
+ print("|_|\_| |_| \___/ |_| \___/ |_|_\ |___/ \___/ ")
1814
+ print(" ")
1815
+
1816
+
1817
+
backend/pipelines/lcm.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import LCM_DEFAULT_MODEL
2
+ from diffusers import (
3
+ DiffusionPipeline,
4
+ AutoencoderTiny,
5
+ UNet2DConditionModel,
6
+ LCMScheduler,
7
+ StableDiffusionPipeline,
8
+ )
9
+ import torch
10
+ from backend.tiny_decoder import get_tiny_decoder_vae_model
11
+ from typing import Any
12
+ from diffusers import (
13
+ LCMScheduler,
14
+ StableDiffusionImg2ImgPipeline,
15
+ StableDiffusionXLImg2ImgPipeline,
16
+ AutoPipelineForText2Image,
17
+ AutoPipelineForImage2Image,
18
+ StableDiffusionControlNetPipeline,
19
+ )
20
+ import pathlib
21
+
22
+
23
+ def _get_lcm_pipeline_from_base_model(
24
+ lcm_model_id: str,
25
+ base_model_id: str,
26
+ use_local_model: bool,
27
+ ):
28
+ pipeline = None
29
+ unet = UNet2DConditionModel.from_pretrained(
30
+ lcm_model_id,
31
+ torch_dtype=torch.float32,
32
+ local_files_only=use_local_model,
33
+ resume_download=True,
34
+ )
35
+ pipeline = DiffusionPipeline.from_pretrained(
36
+ base_model_id,
37
+ unet=unet,
38
+ torch_dtype=torch.float32,
39
+ local_files_only=use_local_model,
40
+ resume_download=True,
41
+ )
42
+ pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
43
+ return pipeline
44
+
45
+
46
+ def load_taesd(
47
+ pipeline: Any,
48
+ use_local_model: bool = False,
49
+ torch_data_type: torch.dtype = torch.float32,
50
+ ):
51
+ vae_model = get_tiny_decoder_vae_model(pipeline.__class__.__name__)
52
+ pipeline.vae = AutoencoderTiny.from_pretrained(
53
+ vae_model,
54
+ torch_dtype=torch_data_type,
55
+ local_files_only=use_local_model,
56
+ )
57
+
58
+
59
+ def get_lcm_model_pipeline(
60
+ model_id: str = LCM_DEFAULT_MODEL,
61
+ use_local_model: bool = False,
62
+ pipeline_args={},
63
+ ):
64
+ pipeline = None
65
+ if model_id == "latent-consistency/lcm-sdxl":
66
+ pipeline = _get_lcm_pipeline_from_base_model(
67
+ model_id,
68
+ "stabilityai/stable-diffusion-xl-base-1.0",
69
+ use_local_model,
70
+ )
71
+
72
+ elif model_id == "latent-consistency/lcm-ssd-1b":
73
+ pipeline = _get_lcm_pipeline_from_base_model(
74
+ model_id,
75
+ "segmind/SSD-1B",
76
+ use_local_model,
77
+ )
78
+ elif pathlib.Path(model_id).suffix == ".safetensors":
79
+ # When loading a .safetensors model, the pipeline has to be created
80
+ # with StableDiffusionPipeline() since it's the only class that
81
+ # defines the method from_single_file()
82
+ dummy_pipeline = StableDiffusionPipeline.from_single_file(
83
+ model_id,
84
+ safety_checker=None,
85
+ run_safety_checker=False,
86
+ load_safety_checker=False,
87
+ local_files_only=use_local_model,
88
+ use_safetensors=True,
89
+ )
90
+ if 'lcm' in model_id.lower():
91
+ dummy_pipeline.scheduler = LCMScheduler.from_config(dummy_pipeline.scheduler.config)
92
+
93
+ pipeline = AutoPipelineForText2Image.from_pipe(
94
+ dummy_pipeline,
95
+ **pipeline_args,
96
+ )
97
+ del dummy_pipeline
98
+ else:
99
+ # pipeline = DiffusionPipeline.from_pretrained(
100
+ pipeline = AutoPipelineForText2Image.from_pretrained(
101
+ model_id,
102
+ local_files_only=use_local_model,
103
+ **pipeline_args,
104
+ )
105
+
106
+ return pipeline
107
+
108
+
109
+ def get_image_to_image_pipeline(pipeline: Any) -> Any:
110
+ components = pipeline.components
111
+ pipeline_class = pipeline.__class__.__name__
112
+ if (
113
+ pipeline_class == "LatentConsistencyModelPipeline"
114
+ or pipeline_class == "StableDiffusionPipeline"
115
+ ):
116
+ return StableDiffusionImg2ImgPipeline(**components)
117
+ elif pipeline_class == "StableDiffusionControlNetPipeline":
118
+ return AutoPipelineForImage2Image.from_pipe(pipeline)
119
+ elif pipeline_class == "StableDiffusionXLPipeline":
120
+ return StableDiffusionXLImg2ImgPipeline(**components)
121
+ else:
122
+ raise Exception(f"Unknown pipeline {pipeline_class}")
backend/pipelines/lcm_lora.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from os import path
3
+
4
+ import torch
5
+ from diffusers import (
6
+ AutoPipelineForText2Image,
7
+ LCMScheduler,
8
+ StableDiffusionPipeline,
9
+ )
10
+
11
+
12
+ def load_lcm_weights(
13
+ pipeline,
14
+ use_local_model,
15
+ lcm_lora_id,
16
+ ):
17
+ kwargs = {
18
+ "local_files_only": use_local_model,
19
+ "weight_name": "pytorch_lora_weights.safetensors",
20
+ }
21
+ pipeline.load_lora_weights(
22
+ lcm_lora_id,
23
+ **kwargs,
24
+ adapter_name="lcm",
25
+ )
26
+
27
+
28
+ def get_lcm_lora_pipeline(
29
+ base_model_id: str,
30
+ lcm_lora_id: str,
31
+ use_local_model: bool,
32
+ torch_data_type: torch.dtype,
33
+ pipeline_args={},
34
+ ):
35
+ if pathlib.Path(base_model_id).suffix == ".safetensors":
36
+ # SD 1.5 models only
37
+ # When loading a .safetensors model, the pipeline has to be created
38
+ # with StableDiffusionPipeline() since it's the only class that
39
+ # defines the method from_single_file(); afterwards a new pipeline
40
+ # is created using AutoPipelineForText2Image() for ControlNet
41
+ # support, in case ControlNet is enabled
42
+ if not path.exists(base_model_id):
43
+ raise FileNotFoundError(
44
+ f"Model file not found,Please check your model path: {base_model_id}"
45
+ )
46
+ print("Using single file Safetensors model (Supported models - SD 1.5 models)")
47
+
48
+ dummy_pipeline = StableDiffusionPipeline.from_single_file(
49
+ base_model_id,
50
+ torch_dtype=torch_data_type,
51
+ safety_checker=None,
52
+ local_files_only=use_local_model,
53
+ use_safetensors=True,
54
+ )
55
+ pipeline = AutoPipelineForText2Image.from_pipe(
56
+ dummy_pipeline,
57
+ **pipeline_args,
58
+ )
59
+ del dummy_pipeline
60
+ else:
61
+ pipeline = AutoPipelineForText2Image.from_pretrained(
62
+ base_model_id,
63
+ torch_dtype=torch_data_type,
64
+ local_files_only=use_local_model,
65
+ **pipeline_args,
66
+ )
67
+
68
+ load_lcm_weights(
69
+ pipeline,
70
+ use_local_model,
71
+ lcm_lora_id,
72
+ )
73
+ # Always fuse LCM-LoRA
74
+ # pipeline.fuse_lora()
75
+
76
+ if "lcm" in lcm_lora_id.lower() or "hypersd" in lcm_lora_id.lower():
77
+ print("LCM LoRA model detected so using recommended LCMScheduler")
78
+ pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
79
+
80
+ # pipeline.unet.to(memory_format=torch.channels_last)
81
+ return pipeline
backend/tiny_decoder.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import (
2
+ TAESD_MODEL,
3
+ TAESDXL_MODEL,
4
+ TAESD_MODEL_OPENVINO,
5
+ TAESDXL_MODEL_OPENVINO,
6
+ )
7
+
8
+
9
+ def get_tiny_decoder_vae_model(pipeline_class) -> str:
10
+ print(f"Pipeline class : {pipeline_class}")
11
+ if (
12
+ pipeline_class == "LatentConsistencyModelPipeline"
13
+ or pipeline_class == "StableDiffusionPipeline"
14
+ or pipeline_class == "StableDiffusionImg2ImgPipeline"
15
+ or pipeline_class == "StableDiffusionControlNetPipeline"
16
+ or pipeline_class == "StableDiffusionControlNetImg2ImgPipeline"
17
+ ):
18
+ return TAESD_MODEL
19
+ elif (
20
+ pipeline_class == "StableDiffusionXLPipeline"
21
+ or pipeline_class == "StableDiffusionXLImg2ImgPipeline"
22
+ ):
23
+ return TAESDXL_MODEL
24
+ elif (
25
+ pipeline_class == "OVStableDiffusionPipeline"
26
+ or pipeline_class == "OVStableDiffusionImg2ImgPipeline"
27
+ ):
28
+ return TAESD_MODEL_OPENVINO
29
+ elif pipeline_class == "OVStableDiffusionXLPipeline":
30
+ return TAESDXL_MODEL_OPENVINO
31
+ else:
32
+ raise Exception("No valid pipeline class found!")
backend/upscale/aura_sr.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AuraSR: GAN-based Super-Resolution for real-world, a reproduction of the GigaGAN* paper. Implementation is
2
+ # based on the unofficial lucidrains/gigagan-pytorch repository. Heavily modified from there.
3
+ #
4
+ # https://mingukkang.github.io/GigaGAN/
5
+ from math import log2, ceil
6
+ from functools import partial
7
+ from typing import Any, Optional, List, Iterable
8
+
9
+ import torch
10
+ from torchvision import transforms
11
+ from PIL import Image
12
+ from torch import nn, einsum, Tensor
13
+ import torch.nn.functional as F
14
+
15
+ from einops import rearrange, repeat, reduce
16
+ from einops.layers.torch import Rearrange
17
+ from torchvision.utils import save_image
18
+ import math
19
+
20
+
21
+ def get_same_padding(size, kernel, dilation, stride):
22
+ return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
23
+
24
+
25
+ class AdaptiveConv2DMod(nn.Module):
26
+ def __init__(
27
+ self,
28
+ dim,
29
+ dim_out,
30
+ kernel,
31
+ *,
32
+ demod=True,
33
+ stride=1,
34
+ dilation=1,
35
+ eps=1e-8,
36
+ num_conv_kernels=1, # set this to be greater than 1 for adaptive
37
+ ):
38
+ super().__init__()
39
+ self.eps = eps
40
+
41
+ self.dim_out = dim_out
42
+
43
+ self.kernel = kernel
44
+ self.stride = stride
45
+ self.dilation = dilation
46
+ self.adaptive = num_conv_kernels > 1
47
+
48
+ self.weights = nn.Parameter(
49
+ torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel))
50
+ )
51
+
52
+ self.demod = demod
53
+
54
+ nn.init.kaiming_normal_(
55
+ self.weights, a=0, mode="fan_in", nonlinearity="leaky_relu"
56
+ )
57
+
58
+ def forward(
59
+ self, fmap, mod: Optional[Tensor] = None, kernel_mod: Optional[Tensor] = None
60
+ ):
61
+ """
62
+ notation
63
+
64
+ b - batch
65
+ n - convs
66
+ o - output
67
+ i - input
68
+ k - kernel
69
+ """
70
+
71
+ b, h = fmap.shape[0], fmap.shape[-2]
72
+
73
+ # account for feature map that has been expanded by the scale in the first dimension
74
+ # due to multiscale inputs and outputs
75
+
76
+ if mod.shape[0] != b:
77
+ mod = repeat(mod, "b ... -> (s b) ...", s=b // mod.shape[0])
78
+
79
+ if exists(kernel_mod):
80
+ kernel_mod_has_el = kernel_mod.numel() > 0
81
+
82
+ assert self.adaptive or not kernel_mod_has_el
83
+
84
+ if kernel_mod_has_el and kernel_mod.shape[0] != b:
85
+ kernel_mod = repeat(
86
+ kernel_mod, "b ... -> (s b) ...", s=b // kernel_mod.shape[0]
87
+ )
88
+
89
+ # prepare weights for modulation
90
+
91
+ weights = self.weights
92
+
93
+ if self.adaptive:
94
+ weights = repeat(weights, "... -> b ...", b=b)
95
+
96
+ # determine an adaptive weight and 'select' the kernel to use with softmax
97
+
98
+ assert exists(kernel_mod) and kernel_mod.numel() > 0
99
+
100
+ kernel_attn = kernel_mod.softmax(dim=-1)
101
+ kernel_attn = rearrange(kernel_attn, "b n -> b n 1 1 1 1")
102
+
103
+ weights = reduce(weights * kernel_attn, "b n ... -> b ...", "sum")
104
+
105
+ # do the modulation, demodulation, as done in stylegan2
106
+
107
+ mod = rearrange(mod, "b i -> b 1 i 1 1")
108
+
109
+ weights = weights * (mod + 1)
110
+
111
+ if self.demod:
112
+ inv_norm = (
113
+ reduce(weights**2, "b o i k1 k2 -> b o 1 1 1", "sum")
114
+ .clamp(min=self.eps)
115
+ .rsqrt()
116
+ )
117
+ weights = weights * inv_norm
118
+
119
+ fmap = rearrange(fmap, "b c h w -> 1 (b c) h w")
120
+
121
+ weights = rearrange(weights, "b o ... -> (b o) ...")
122
+
123
+ padding = get_same_padding(h, self.kernel, self.dilation, self.stride)
124
+ fmap = F.conv2d(fmap, weights, padding=padding, groups=b)
125
+
126
+ return rearrange(fmap, "1 (b o) ... -> b o ...", b=b)
127
+
128
+
129
+ class Attend(nn.Module):
130
+ def __init__(self, dropout=0.0, flash=False):
131
+ super().__init__()
132
+ self.dropout = dropout
133
+ self.attn_dropout = nn.Dropout(dropout)
134
+ self.scale = nn.Parameter(torch.randn(1))
135
+ self.flash = flash
136
+
137
+ def flash_attn(self, q, k, v):
138
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
139
+ out = F.scaled_dot_product_attention(
140
+ q, k, v, dropout_p=self.dropout if self.training else 0.0
141
+ )
142
+ return out
143
+
144
+ def forward(self, q, k, v):
145
+ if self.flash:
146
+ return self.flash_attn(q, k, v)
147
+
148
+ scale = q.shape[-1] ** -0.5
149
+
150
+ # similarity
151
+ sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
152
+
153
+ # attention
154
+ attn = sim.softmax(dim=-1)
155
+ attn = self.attn_dropout(attn)
156
+
157
+ # aggregate values
158
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
159
+
160
+ return out
161
+
162
+
163
+ def exists(x):
164
+ return x is not None
165
+
166
+
167
+ def default(val, d):
168
+ if exists(val):
169
+ return val
170
+ return d() if callable(d) else d
171
+
172
+
173
+ def cast_tuple(t, length=1):
174
+ if isinstance(t, tuple):
175
+ return t
176
+ return (t,) * length
177
+
178
+
179
+ def identity(t, *args, **kwargs):
180
+ return t
181
+
182
+
183
+ def is_power_of_two(n):
184
+ return log2(n).is_integer()
185
+
186
+
187
+ def null_iterator():
188
+ while True:
189
+ yield None
190
+
191
+
192
+ def Downsample(dim, dim_out=None):
193
+ return nn.Sequential(
194
+ Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
195
+ nn.Conv2d(dim * 4, default(dim_out, dim), 1),
196
+ )
197
+
198
+
199
+ class RMSNorm(nn.Module):
200
+ def __init__(self, dim):
201
+ super().__init__()
202
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
203
+ self.eps = 1e-4
204
+
205
+ def forward(self, x):
206
+ return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)
207
+
208
+
209
+ # building block modules
210
+
211
+
212
+ class Block(nn.Module):
213
+ def __init__(self, dim, dim_out, groups=8, num_conv_kernels=0):
214
+ super().__init__()
215
+ self.proj = AdaptiveConv2DMod(
216
+ dim, dim_out, kernel=3, num_conv_kernels=num_conv_kernels
217
+ )
218
+ self.kernel = 3
219
+ self.dilation = 1
220
+ self.stride = 1
221
+
222
+ self.act = nn.SiLU()
223
+
224
+ def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
225
+ conv_mods_iter = default(conv_mods_iter, null_iterator())
226
+
227
+ x = self.proj(x, mod=next(conv_mods_iter), kernel_mod=next(conv_mods_iter))
228
+
229
+ x = self.act(x)
230
+ return x
231
+
232
+
233
+ class ResnetBlock(nn.Module):
234
+ def __init__(
235
+ self, dim, dim_out, *, groups=8, num_conv_kernels=0, style_dims: List = []
236
+ ):
237
+ super().__init__()
238
+ style_dims.extend([dim, num_conv_kernels, dim_out, num_conv_kernels])
239
+
240
+ self.block1 = Block(
241
+ dim, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
242
+ )
243
+ self.block2 = Block(
244
+ dim_out, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
245
+ )
246
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
247
+
248
+ def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
249
+ h = self.block1(x, conv_mods_iter=conv_mods_iter)
250
+ h = self.block2(h, conv_mods_iter=conv_mods_iter)
251
+
252
+ return h + self.res_conv(x)
253
+
254
+
255
+ class LinearAttention(nn.Module):
256
+ def __init__(self, dim, heads=4, dim_head=32):
257
+ super().__init__()
258
+ self.scale = dim_head**-0.5
259
+ self.heads = heads
260
+ hidden_dim = dim_head * heads
261
+
262
+ self.norm = RMSNorm(dim)
263
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
264
+
265
+ self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim))
266
+
267
+ def forward(self, x):
268
+ b, c, h, w = x.shape
269
+
270
+ x = self.norm(x)
271
+
272
+ qkv = self.to_qkv(x).chunk(3, dim=1)
273
+ q, k, v = map(
274
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
275
+ )
276
+
277
+ q = q.softmax(dim=-2)
278
+ k = k.softmax(dim=-1)
279
+
280
+ q = q * self.scale
281
+
282
+ context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
283
+
284
+ out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
285
+ out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
286
+ return self.to_out(out)
287
+
288
+
289
+ class Attention(nn.Module):
290
+ def __init__(self, dim, heads=4, dim_head=32, flash=False):
291
+ super().__init__()
292
+ self.heads = heads
293
+ hidden_dim = dim_head * heads
294
+
295
+ self.norm = RMSNorm(dim)
296
+
297
+ self.attend = Attend(flash=flash)
298
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
299
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
300
+
301
+ def forward(self, x):
302
+ b, c, h, w = x.shape
303
+ x = self.norm(x)
304
+ qkv = self.to_qkv(x).chunk(3, dim=1)
305
+
306
+ q, k, v = map(
307
+ lambda t: rearrange(t, "b (h c) x y -> b h (x y) c", h=self.heads), qkv
308
+ )
309
+
310
+ out = self.attend(q, k, v)
311
+ out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
312
+
313
+ return self.to_out(out)
314
+
315
+
316
+ # feedforward
317
+ def FeedForward(dim, mult=4):
318
+ return nn.Sequential(
319
+ RMSNorm(dim),
320
+ nn.Conv2d(dim, dim * mult, 1),
321
+ nn.GELU(),
322
+ nn.Conv2d(dim * mult, dim, 1),
323
+ )
324
+
325
+
326
+ # transformers
327
+ class Transformer(nn.Module):
328
+ def __init__(self, dim, dim_head=64, heads=8, depth=1, flash_attn=True, ff_mult=4):
329
+ super().__init__()
330
+ self.layers = nn.ModuleList([])
331
+
332
+ for _ in range(depth):
333
+ self.layers.append(
334
+ nn.ModuleList(
335
+ [
336
+ Attention(
337
+ dim=dim, dim_head=dim_head, heads=heads, flash=flash_attn
338
+ ),
339
+ FeedForward(dim=dim, mult=ff_mult),
340
+ ]
341
+ )
342
+ )
343
+
344
+ def forward(self, x):
345
+ for attn, ff in self.layers:
346
+ x = attn(x) + x
347
+ x = ff(x) + x
348
+
349
+ return x
350
+
351
+
352
+ class LinearTransformer(nn.Module):
353
+ def __init__(self, dim, dim_head=64, heads=8, depth=1, ff_mult=4):
354
+ super().__init__()
355
+ self.layers = nn.ModuleList([])
356
+
357
+ for _ in range(depth):
358
+ self.layers.append(
359
+ nn.ModuleList(
360
+ [
361
+ LinearAttention(dim=dim, dim_head=dim_head, heads=heads),
362
+ FeedForward(dim=dim, mult=ff_mult),
363
+ ]
364
+ )
365
+ )
366
+
367
+ def forward(self, x):
368
+ for attn, ff in self.layers:
369
+ x = attn(x) + x
370
+ x = ff(x) + x
371
+
372
+ return x
373
+
374
+
375
+ class NearestNeighborhoodUpsample(nn.Module):
376
+ def __init__(self, dim, dim_out=None):
377
+ super().__init__()
378
+ dim_out = default(dim_out, dim)
379
+ self.conv = nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1)
380
+
381
+ def forward(self, x):
382
+
383
+ if x.shape[0] >= 64:
384
+ x = x.contiguous()
385
+
386
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
387
+ x = self.conv(x)
388
+
389
+ return x
390
+
391
+
392
+ class EqualLinear(nn.Module):
393
+ def __init__(self, dim, dim_out, lr_mul=1, bias=True):
394
+ super().__init__()
395
+ self.weight = nn.Parameter(torch.randn(dim_out, dim))
396
+ if bias:
397
+ self.bias = nn.Parameter(torch.zeros(dim_out))
398
+
399
+ self.lr_mul = lr_mul
400
+
401
+ def forward(self, input):
402
+ return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
403
+
404
+
405
+ class StyleGanNetwork(nn.Module):
406
+ def __init__(self, dim_in=128, dim_out=512, depth=8, lr_mul=0.1, dim_text_latent=0):
407
+ super().__init__()
408
+ self.dim_in = dim_in
409
+ self.dim_out = dim_out
410
+ self.dim_text_latent = dim_text_latent
411
+
412
+ layers = []
413
+ for i in range(depth):
414
+ is_first = i == 0
415
+
416
+ if is_first:
417
+ dim_in_layer = dim_in + dim_text_latent
418
+ else:
419
+ dim_in_layer = dim_out
420
+
421
+ dim_out_layer = dim_out
422
+
423
+ layers.extend(
424
+ [EqualLinear(dim_in_layer, dim_out_layer, lr_mul), nn.LeakyReLU(0.2)]
425
+ )
426
+
427
+ self.net = nn.Sequential(*layers)
428
+
429
+ def forward(self, x, text_latent=None):
430
+ x = F.normalize(x, dim=1)
431
+ if self.dim_text_latent > 0:
432
+ assert exists(text_latent)
433
+ x = torch.cat((x, text_latent), dim=-1)
434
+ return self.net(x)
435
+
436
+
437
+ class UnetUpsampler(torch.nn.Module):
438
+
439
+ def __init__(
440
+ self,
441
+ dim: int,
442
+ *,
443
+ image_size: int,
444
+ input_image_size: int,
445
+ init_dim: Optional[int] = None,
446
+ out_dim: Optional[int] = None,
447
+ style_network: Optional[dict] = None,
448
+ up_dim_mults: tuple = (1, 2, 4, 8, 16),
449
+ down_dim_mults: tuple = (4, 8, 16),
450
+ channels: int = 3,
451
+ resnet_block_groups: int = 8,
452
+ full_attn: tuple = (False, False, False, True, True),
453
+ flash_attn: bool = True,
454
+ self_attn_dim_head: int = 64,
455
+ self_attn_heads: int = 8,
456
+ attn_depths: tuple = (2, 2, 2, 2, 4),
457
+ mid_attn_depth: int = 4,
458
+ num_conv_kernels: int = 4,
459
+ resize_mode: str = "bilinear",
460
+ unconditional: bool = True,
461
+ skip_connect_scale: Optional[float] = None,
462
+ ):
463
+ super().__init__()
464
+ self.style_network = style_network = StyleGanNetwork(**style_network)
465
+ self.unconditional = unconditional
466
+ assert not (
467
+ unconditional
468
+ and exists(style_network)
469
+ and style_network.dim_text_latent > 0
470
+ )
471
+
472
+ assert is_power_of_two(image_size) and is_power_of_two(
473
+ input_image_size
474
+ ), "both output image size and input image size must be power of 2"
475
+ assert (
476
+ input_image_size < image_size
477
+ ), "input image size must be smaller than the output image size, thus upsampling"
478
+
479
+ self.image_size = image_size
480
+ self.input_image_size = input_image_size
481
+
482
+ style_embed_split_dims = []
483
+
484
+ self.channels = channels
485
+ input_channels = channels
486
+
487
+ init_dim = default(init_dim, dim)
488
+
489
+ up_dims = [init_dim, *map(lambda m: dim * m, up_dim_mults)]
490
+ init_down_dim = up_dims[len(up_dim_mults) - len(down_dim_mults)]
491
+ down_dims = [init_down_dim, *map(lambda m: dim * m, down_dim_mults)]
492
+ self.init_conv = nn.Conv2d(input_channels, init_down_dim, 7, padding=3)
493
+
494
+ up_in_out = list(zip(up_dims[:-1], up_dims[1:]))
495
+ down_in_out = list(zip(down_dims[:-1], down_dims[1:]))
496
+
497
+ block_klass = partial(
498
+ ResnetBlock,
499
+ groups=resnet_block_groups,
500
+ num_conv_kernels=num_conv_kernels,
501
+ style_dims=style_embed_split_dims,
502
+ )
503
+
504
+ FullAttention = partial(Transformer, flash_attn=flash_attn)
505
+ *_, mid_dim = up_dims
506
+
507
+ self.skip_connect_scale = default(skip_connect_scale, 2**-0.5)
508
+
509
+ self.downs = nn.ModuleList([])
510
+ self.ups = nn.ModuleList([])
511
+
512
+ block_count = 6
513
+
514
+ for ind, (
515
+ (dim_in, dim_out),
516
+ layer_full_attn,
517
+ layer_attn_depth,
518
+ ) in enumerate(zip(down_in_out, full_attn, attn_depths)):
519
+ attn_klass = FullAttention if layer_full_attn else LinearTransformer
520
+
521
+ blocks = []
522
+ for i in range(block_count):
523
+ blocks.append(block_klass(dim_in, dim_in))
524
+
525
+ self.downs.append(
526
+ nn.ModuleList(
527
+ [
528
+ nn.ModuleList(blocks),
529
+ nn.ModuleList(
530
+ [
531
+ (
532
+ attn_klass(
533
+ dim_in,
534
+ dim_head=self_attn_dim_head,
535
+ heads=self_attn_heads,
536
+ depth=layer_attn_depth,
537
+ )
538
+ if layer_full_attn
539
+ else None
540
+ ),
541
+ nn.Conv2d(
542
+ dim_in, dim_out, kernel_size=3, stride=2, padding=1
543
+ ),
544
+ ]
545
+ ),
546
+ ]
547
+ )
548
+ )
549
+
550
+ self.mid_block1 = block_klass(mid_dim, mid_dim)
551
+ self.mid_attn = FullAttention(
552
+ mid_dim,
553
+ dim_head=self_attn_dim_head,
554
+ heads=self_attn_heads,
555
+ depth=mid_attn_depth,
556
+ )
557
+ self.mid_block2 = block_klass(mid_dim, mid_dim)
558
+
559
+ *_, last_dim = up_dims
560
+
561
+ for ind, (
562
+ (dim_in, dim_out),
563
+ layer_full_attn,
564
+ layer_attn_depth,
565
+ ) in enumerate(
566
+ zip(
567
+ reversed(up_in_out),
568
+ reversed(full_attn),
569
+ reversed(attn_depths),
570
+ )
571
+ ):
572
+ attn_klass = FullAttention if layer_full_attn else LinearTransformer
573
+
574
+ blocks = []
575
+ input_dim = dim_in * 2 if ind < len(down_in_out) else dim_in
576
+ for i in range(block_count):
577
+ blocks.append(block_klass(input_dim, dim_in))
578
+
579
+ self.ups.append(
580
+ nn.ModuleList(
581
+ [
582
+ nn.ModuleList(blocks),
583
+ nn.ModuleList(
584
+ [
585
+ NearestNeighborhoodUpsample(
586
+ last_dim if ind == 0 else dim_out,
587
+ dim_in,
588
+ ),
589
+ (
590
+ attn_klass(
591
+ dim_in,
592
+ dim_head=self_attn_dim_head,
593
+ heads=self_attn_heads,
594
+ depth=layer_attn_depth,
595
+ )
596
+ if layer_full_attn
597
+ else None
598
+ ),
599
+ ]
600
+ ),
601
+ ]
602
+ )
603
+ )
604
+
605
+ self.out_dim = default(out_dim, channels)
606
+ self.final_res_block = block_klass(dim, dim)
607
+ self.final_to_rgb = nn.Conv2d(dim, channels, 1)
608
+ self.resize_mode = resize_mode
609
+ self.style_to_conv_modulations = nn.Linear(
610
+ style_network.dim_out, sum(style_embed_split_dims)
611
+ )
612
+ self.style_embed_split_dims = style_embed_split_dims
613
+
614
+ @property
615
+ def allowable_rgb_resolutions(self):
616
+ input_res_base = int(log2(self.input_image_size))
617
+ output_res_base = int(log2(self.image_size))
618
+ allowed_rgb_res_base = list(range(input_res_base, output_res_base))
619
+ return [*map(lambda p: 2**p, allowed_rgb_res_base)]
620
+
621
+ @property
622
+ def device(self):
623
+ return next(self.parameters()).device
624
+
625
+ @property
626
+ def total_params(self):
627
+ return sum([p.numel() for p in self.parameters()])
628
+
629
+ def resize_image_to(self, x, size):
630
+ return F.interpolate(x, (size, size), mode=self.resize_mode)
631
+
632
+ def forward(
633
+ self,
634
+ lowres_image: torch.Tensor,
635
+ styles: Optional[torch.Tensor] = None,
636
+ noise: Optional[torch.Tensor] = None,
637
+ global_text_tokens: Optional[torch.Tensor] = None,
638
+ return_all_rgbs: bool = False,
639
+ ):
640
+ x = lowres_image
641
+
642
+ noise_scale = 0.001 # Adjust the scale of the noise as needed
643
+ noise_aug = torch.randn_like(x) * noise_scale
644
+ x = x + noise_aug
645
+ x = x.clamp(0, 1)
646
+
647
+ shape = x.shape
648
+ batch_size = shape[0]
649
+
650
+ assert shape[-2:] == ((self.input_image_size,) * 2)
651
+
652
+ # styles
653
+ if not exists(styles):
654
+ assert exists(self.style_network)
655
+
656
+ noise = default(
657
+ noise,
658
+ torch.randn(
659
+ (batch_size, self.style_network.dim_in), device=self.device
660
+ ),
661
+ )
662
+ styles = self.style_network(noise, global_text_tokens)
663
+
664
+ # project styles to conv modulations
665
+ conv_mods = self.style_to_conv_modulations(styles)
666
+ conv_mods = conv_mods.split(self.style_embed_split_dims, dim=-1)
667
+ conv_mods = iter(conv_mods)
668
+
669
+ x = self.init_conv(x)
670
+
671
+ h = []
672
+ for blocks, (attn, downsample) in self.downs:
673
+ for block in blocks:
674
+ x = block(x, conv_mods_iter=conv_mods)
675
+ h.append(x)
676
+
677
+ if attn is not None:
678
+ x = attn(x)
679
+
680
+ x = downsample(x)
681
+
682
+ x = self.mid_block1(x, conv_mods_iter=conv_mods)
683
+ x = self.mid_attn(x)
684
+ x = self.mid_block2(x, conv_mods_iter=conv_mods)
685
+
686
+ for (
687
+ blocks,
688
+ (
689
+ upsample,
690
+ attn,
691
+ ),
692
+ ) in self.ups:
693
+ x = upsample(x)
694
+ for block in blocks:
695
+ if h != []:
696
+ res = h.pop()
697
+ res = res * self.skip_connect_scale
698
+ x = torch.cat((x, res), dim=1)
699
+
700
+ x = block(x, conv_mods_iter=conv_mods)
701
+
702
+ if attn is not None:
703
+ x = attn(x)
704
+
705
+ x = self.final_res_block(x, conv_mods_iter=conv_mods)
706
+ rgb = self.final_to_rgb(x)
707
+
708
+ if not return_all_rgbs:
709
+ return rgb
710
+
711
+ return rgb, []
712
+
713
+
714
+ def tile_image(image, chunk_size=64):
715
+ c, h, w = image.shape
716
+ h_chunks = ceil(h / chunk_size)
717
+ w_chunks = ceil(w / chunk_size)
718
+ tiles = []
719
+ for i in range(h_chunks):
720
+ for j in range(w_chunks):
721
+ tile = image[
722
+ :,
723
+ i * chunk_size : (i + 1) * chunk_size,
724
+ j * chunk_size : (j + 1) * chunk_size,
725
+ ]
726
+ tiles.append(tile)
727
+ return tiles, h_chunks, w_chunks
728
+
729
+
730
+ # This helps create a checkboard pattern with some edge blending
731
+ def create_checkerboard_weights(tile_size):
732
+ x = torch.linspace(-1, 1, tile_size)
733
+ y = torch.linspace(-1, 1, tile_size)
734
+
735
+ x, y = torch.meshgrid(x, y, indexing="ij")
736
+ d = torch.sqrt(x * x + y * y)
737
+ sigma, mu = 0.5, 0.0
738
+ weights = torch.exp(-((d - mu) ** 2 / (2.0 * sigma**2)))
739
+
740
+ # saturate the values to sure get high weights in the center
741
+ weights = weights**8
742
+
743
+ return weights / weights.max() # Normalize to [0, 1]
744
+
745
+
746
+ def repeat_weights(weights, image_size):
747
+ tile_size = weights.shape[0]
748
+ repeats = (
749
+ math.ceil(image_size[0] / tile_size),
750
+ math.ceil(image_size[1] / tile_size),
751
+ )
752
+ return weights.repeat(repeats)[: image_size[0], : image_size[1]]
753
+
754
+
755
+ def create_offset_weights(weights, image_size):
756
+ tile_size = weights.shape[0]
757
+ offset = tile_size // 2
758
+ full_weights = repeat_weights(
759
+ weights, (image_size[0] + offset, image_size[1] + offset)
760
+ )
761
+ return full_weights[offset:, offset:]
762
+
763
+
764
+ def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64):
765
+ # Determine the shape of the output tensor
766
+ c = tiles[0].shape[0]
767
+ h = h_chunks * chunk_size
768
+ w = w_chunks * chunk_size
769
+
770
+ # Create an empty tensor to hold the merged image
771
+ merged = torch.zeros((c, h, w), dtype=tiles[0].dtype)
772
+
773
+ # Iterate over the tiles and place them in the correct position
774
+ for idx, tile in enumerate(tiles):
775
+ i = idx // w_chunks
776
+ j = idx % w_chunks
777
+
778
+ h_start = i * chunk_size
779
+ w_start = j * chunk_size
780
+
781
+ tile_h, tile_w = tile.shape[1:]
782
+ merged[:, h_start : h_start + tile_h, w_start : w_start + tile_w] = tile
783
+
784
+ return merged
785
+
786
+
787
+ class AuraSR:
788
+ def __init__(self, config: dict[str, Any], device: str = "cuda"):
789
+ self.upsampler = UnetUpsampler(**config).to(device)
790
+ self.input_image_size = config["input_image_size"]
791
+
792
+ @classmethod
793
+ def from_pretrained(
794
+ cls,
795
+ model_id: str = "fal-ai/AuraSR",
796
+ use_safetensors: bool = True,
797
+ device: str = "cuda",
798
+ ):
799
+ import json
800
+ import torch
801
+ from pathlib import Path
802
+ from huggingface_hub import snapshot_download
803
+
804
+ # Check if model_id is a local file
805
+ if Path(model_id).is_file():
806
+ local_file = Path(model_id)
807
+ if local_file.suffix == ".safetensors":
808
+ use_safetensors = True
809
+ elif local_file.suffix == ".ckpt":
810
+ use_safetensors = False
811
+ else:
812
+ raise ValueError(
813
+ f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files."
814
+ )
815
+
816
+ # For local files, we need to provide the config separately
817
+ config_path = local_file.with_name("config.json")
818
+ if not config_path.exists():
819
+ raise FileNotFoundError(
820
+ f"Config file not found: {config_path}. "
821
+ f"When loading from a local file, ensure that 'config.json' "
822
+ f"is present in the same directory as '{local_file.name}'. "
823
+ f"If you're trying to load a model from Hugging Face, "
824
+ f"please provide the model ID instead of a file path."
825
+ )
826
+
827
+ config = json.loads(config_path.read_text())
828
+ hf_model_path = local_file.parent
829
+ else:
830
+ hf_model_path = Path(
831
+ snapshot_download(model_id, ignore_patterns=["*.ckpt"])
832
+ )
833
+ config = json.loads((hf_model_path / "config.json").read_text())
834
+
835
+ model = cls(config, device)
836
+
837
+ if use_safetensors:
838
+ try:
839
+ from safetensors.torch import load_file
840
+
841
+ checkpoint = load_file(
842
+ hf_model_path / "model.safetensors"
843
+ if not Path(model_id).is_file()
844
+ else model_id
845
+ )
846
+ except ImportError:
847
+ raise ImportError(
848
+ "The safetensors library is not installed. "
849
+ "Please install it with `pip install safetensors` "
850
+ "or use `use_safetensors=False` to load the model with PyTorch."
851
+ )
852
+ else:
853
+ checkpoint = torch.load(
854
+ hf_model_path / "model.ckpt"
855
+ if not Path(model_id).is_file()
856
+ else model_id
857
+ )
858
+
859
+ model.upsampler.load_state_dict(checkpoint, strict=True)
860
+ return model
861
+
862
+ @torch.no_grad()
863
+ def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image:
864
+ tensor_transform = transforms.ToTensor()
865
+ device = self.upsampler.device
866
+
867
+ image_tensor = tensor_transform(image).unsqueeze(0)
868
+ _, _, h, w = image_tensor.shape
869
+ pad_h = (
870
+ self.input_image_size - h % self.input_image_size
871
+ ) % self.input_image_size
872
+ pad_w = (
873
+ self.input_image_size - w % self.input_image_size
874
+ ) % self.input_image_size
875
+
876
+ # Pad the image
877
+ image_tensor = torch.nn.functional.pad(
878
+ image_tensor, (0, pad_w, 0, pad_h), mode="reflect"
879
+ ).squeeze(0)
880
+ tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size)
881
+
882
+ # Batch processing of tiles
883
+ num_tiles = len(tiles)
884
+ batches = [
885
+ tiles[i : i + max_batch_size] for i in range(0, num_tiles, max_batch_size)
886
+ ]
887
+ reconstructed_tiles = []
888
+
889
+ for batch in batches:
890
+ model_input = torch.stack(batch).to(device)
891
+ generator_output = self.upsampler(
892
+ lowres_image=model_input,
893
+ noise=torch.randn(model_input.shape[0], 128, device=device),
894
+ )
895
+ reconstructed_tiles.extend(
896
+ list(generator_output.clamp_(0, 1).detach().cpu())
897
+ )
898
+
899
+ merged_tensor = merge_tiles(
900
+ reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4
901
+ )
902
+ unpadded = merged_tensor[:, : h * 4, : w * 4]
903
+
904
+ to_pil = transforms.ToPILImage()
905
+ return to_pil(unpadded)
906
+
907
+ # Tiled 4x upscaling with overlapping tiles to reduce seam artifacts
908
+ # weights options are 'checkboard' and 'constant'
909
+ @torch.no_grad()
910
+ def upscale_4x_overlapped(self, image, max_batch_size=8, weight_type="checkboard"):
911
+ tensor_transform = transforms.ToTensor()
912
+ device = self.upsampler.device
913
+
914
+ image_tensor = tensor_transform(image).unsqueeze(0)
915
+ _, _, h, w = image_tensor.shape
916
+
917
+ # Calculate paddings
918
+ pad_h = (
919
+ self.input_image_size - h % self.input_image_size
920
+ ) % self.input_image_size
921
+ pad_w = (
922
+ self.input_image_size - w % self.input_image_size
923
+ ) % self.input_image_size
924
+
925
+ # Pad the image
926
+ image_tensor = torch.nn.functional.pad(
927
+ image_tensor, (0, pad_w, 0, pad_h), mode="reflect"
928
+ ).squeeze(0)
929
+
930
+ # Function to process tiles
931
+ def process_tiles(tiles, h_chunks, w_chunks):
932
+ num_tiles = len(tiles)
933
+ batches = [
934
+ tiles[i : i + max_batch_size]
935
+ for i in range(0, num_tiles, max_batch_size)
936
+ ]
937
+ reconstructed_tiles = []
938
+
939
+ for batch in batches:
940
+ model_input = torch.stack(batch).to(device)
941
+ generator_output = self.upsampler(
942
+ lowres_image=model_input,
943
+ noise=torch.randn(model_input.shape[0], 128, device=device),
944
+ )
945
+ reconstructed_tiles.extend(
946
+ list(generator_output.clamp_(0, 1).detach().cpu())
947
+ )
948
+
949
+ return merge_tiles(
950
+ reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4
951
+ )
952
+
953
+ # First pass
954
+ tiles1, h_chunks1, w_chunks1 = tile_image(image_tensor, self.input_image_size)
955
+ result1 = process_tiles(tiles1, h_chunks1, w_chunks1)
956
+
957
+ # Second pass with offset
958
+ offset = self.input_image_size // 2
959
+ image_tensor_offset = torch.nn.functional.pad(
960
+ image_tensor, (offset, offset, offset, offset), mode="reflect"
961
+ ).squeeze(0)
962
+
963
+ tiles2, h_chunks2, w_chunks2 = tile_image(
964
+ image_tensor_offset, self.input_image_size
965
+ )
966
+ result2 = process_tiles(tiles2, h_chunks2, w_chunks2)
967
+
968
+ # unpad
969
+ offset_4x = offset * 4
970
+ result2_interior = result2[:, offset_4x:-offset_4x, offset_4x:-offset_4x]
971
+
972
+ if weight_type == "checkboard":
973
+ weight_tile = create_checkerboard_weights(self.input_image_size * 4)
974
+
975
+ weight_shape = result2_interior.shape[1:]
976
+ weights_1 = create_offset_weights(weight_tile, weight_shape)
977
+ weights_2 = repeat_weights(weight_tile, weight_shape)
978
+
979
+ normalizer = weights_1 + weights_2
980
+ weights_1 = weights_1 / normalizer
981
+ weights_2 = weights_2 / normalizer
982
+
983
+ weights_1 = weights_1.unsqueeze(0).repeat(3, 1, 1)
984
+ weights_2 = weights_2.unsqueeze(0).repeat(3, 1, 1)
985
+ elif weight_type == "constant":
986
+ weights_1 = torch.ones_like(result2_interior) * 0.5
987
+ weights_2 = weights_1
988
+ else:
989
+ raise ValueError(
990
+ "weight_type should be either 'gaussian' or 'constant' but got",
991
+ weight_type,
992
+ )
993
+
994
+ result1 = result1 * weights_2
995
+ result2 = result2_interior * weights_1
996
+
997
+ # Average the overlapping region
998
+ result1 = result1 + result2
999
+
1000
+ # Remove padding
1001
+ unpadded = result1[:, : h * 4, : w * 4]
1002
+
1003
+ to_pil = transforms.ToPILImage()
1004
+ return to_pil(unpadded)
backend/upscale/aura_sr_upscale.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from backend.upscale.aura_sr import AuraSR
2
+ from PIL import Image
3
+
4
+
5
+ def upscale_aura_sr(image_path: str):
6
+
7
+ aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2", device="cpu")
8
+ image_in = Image.open(image_path) # .resize((256, 256))
9
+ return aura_sr.upscale_4x(image_in)
backend/upscale/edsr_upscale_onnx.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime
3
+ from huggingface_hub import hf_hub_download
4
+ from PIL import Image
5
+
6
+
7
+ def upscale_edsr_2x(image_path: str):
8
+ input_image = Image.open(image_path).convert("RGB")
9
+ input_image = np.array(input_image).astype("float32")
10
+ input_image = np.transpose(input_image, (2, 0, 1))
11
+ img_arr = np.expand_dims(input_image, axis=0)
12
+
13
+ if np.max(img_arr) > 256: # 16-bit image
14
+ max_range = 65535
15
+ else:
16
+ max_range = 255.0
17
+ img = img_arr / max_range
18
+
19
+ model_path = hf_hub_download(
20
+ repo_id="rupeshs/edsr-onnx",
21
+ filename="edsr_onnxsim_2x.onnx",
22
+ )
23
+ sess = onnxruntime.InferenceSession(model_path)
24
+
25
+ input_name = sess.get_inputs()[0].name
26
+ output_name = sess.get_outputs()[0].name
27
+ output = sess.run(
28
+ [output_name],
29
+ {input_name: img},
30
+ )[0]
31
+
32
+ result = output.squeeze()
33
+ result = result.clip(0, 1)
34
+ image_array = np.transpose(result, (1, 2, 0))
35
+ image_array = np.uint8(image_array * 255)
36
+ upscaled_image = Image.fromarray(image_array)
37
+ return upscaled_image
backend/upscale/tiled_upscale.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import math
3
+ import logging
4
+ from PIL import Image, ImageDraw, ImageFilter
5
+ from backend.models.lcmdiffusion_setting import DiffusionTask
6
+ from context import Context
7
+ from constants import DEVICE
8
+
9
+
10
+ def generate_upscaled_image(
11
+ config,
12
+ input_path=None,
13
+ strength=0.3,
14
+ scale_factor=2.0,
15
+ tile_overlap=16,
16
+ upscale_settings=None,
17
+ context: Context = None,
18
+ output_path=None,
19
+ image_format="PNG",
20
+ ):
21
+ if config == None or (
22
+ input_path == None or input_path == "" and upscale_settings == None
23
+ ):
24
+ logging.error("Wrong arguments in tiled upscale function call!")
25
+ return
26
+
27
+ # Use the upscale_settings dict if provided; otherwise, build the
28
+ # upscale_settings dict using the function arguments and default values
29
+ if upscale_settings == None:
30
+ upscale_settings = {
31
+ "source_file": input_path,
32
+ "target_file": None,
33
+ "output_format": image_format,
34
+ "strength": strength,
35
+ "scale_factor": scale_factor,
36
+ "prompt": config.lcm_diffusion_setting.prompt,
37
+ "tile_overlap": tile_overlap,
38
+ "tile_size": 256,
39
+ "tiles": [],
40
+ }
41
+ source_image = Image.open(input_path) # PIL image
42
+ else:
43
+ source_image = Image.open(upscale_settings["source_file"])
44
+
45
+ upscale_settings["source_image"] = source_image
46
+
47
+ if upscale_settings["target_file"]:
48
+ result = Image.open(upscale_settings["target_file"])
49
+ else:
50
+ result = Image.new(
51
+ mode="RGBA",
52
+ size=(
53
+ source_image.size[0] * int(upscale_settings["scale_factor"]),
54
+ source_image.size[1] * int(upscale_settings["scale_factor"]),
55
+ ),
56
+ color=(0, 0, 0, 0),
57
+ )
58
+ upscale_settings["target_image"] = result
59
+
60
+ # If the custom tile definition array 'tiles' is empty, proceed with the
61
+ # default tiled upscale task by defining all the possible image tiles; note
62
+ # that the actual tile size is 'tile_size' + 'tile_overlap' and the target
63
+ # image width and height are no longer constrained to multiples of 256 but
64
+ # are instead multiples of the actual tile size
65
+ if len(upscale_settings["tiles"]) == 0:
66
+ tile_size = upscale_settings["tile_size"]
67
+ scale_factor = upscale_settings["scale_factor"]
68
+ tile_overlap = upscale_settings["tile_overlap"]
69
+ total_cols = math.ceil(
70
+ source_image.size[0] / tile_size
71
+ ) # Image width / tile size
72
+ total_rows = math.ceil(
73
+ source_image.size[1] / tile_size
74
+ ) # Image height / tile size
75
+ for y in range(0, total_rows):
76
+ y_offset = tile_overlap if y > 0 else 0 # Tile mask offset
77
+ for x in range(0, total_cols):
78
+ x_offset = tile_overlap if x > 0 else 0 # Tile mask offset
79
+ x1 = x * tile_size
80
+ y1 = y * tile_size
81
+ w = tile_size + (tile_overlap if x < total_cols - 1 else 0)
82
+ h = tile_size + (tile_overlap if y < total_rows - 1 else 0)
83
+ mask_box = ( # Default tile mask box definition
84
+ x_offset,
85
+ y_offset,
86
+ int(w * scale_factor),
87
+ int(h * scale_factor),
88
+ )
89
+ upscale_settings["tiles"].append(
90
+ {
91
+ "x": x1,
92
+ "y": y1,
93
+ "w": w,
94
+ "h": h,
95
+ "mask_box": mask_box,
96
+ "prompt": upscale_settings["prompt"], # Use top level prompt if available
97
+ "scale_factor": scale_factor,
98
+ }
99
+ )
100
+
101
+ # Generate the output image tiles
102
+ for i in range(0, len(upscale_settings["tiles"])):
103
+ generate_upscaled_tile(
104
+ config,
105
+ i,
106
+ upscale_settings,
107
+ context=context,
108
+ )
109
+
110
+ # Save completed upscaled image
111
+ if upscale_settings["output_format"].upper() == "JPEG":
112
+ result_rgb = result.convert("RGB")
113
+ result.close()
114
+ result = result_rgb
115
+ result.save(output_path)
116
+ result.close()
117
+ source_image.close()
118
+ return
119
+
120
+
121
+ def get_current_tile(
122
+ config,
123
+ context,
124
+ strength,
125
+ ):
126
+ config.lcm_diffusion_setting.strength = strength
127
+ config.lcm_diffusion_setting.diffusion_task = DiffusionTask.image_to_image.value
128
+ if (
129
+ config.lcm_diffusion_setting.use_tiny_auto_encoder
130
+ and config.lcm_diffusion_setting.use_openvino
131
+ ):
132
+ config.lcm_diffusion_setting.use_tiny_auto_encoder = False
133
+ current_tile = context.generate_text_to_image(
134
+ settings=config,
135
+ reshape=True,
136
+ device=DEVICE,
137
+ save_config=False,
138
+ )[0]
139
+ return current_tile
140
+
141
+
142
+ # Generates a single tile from the source image as defined in the
143
+ # upscale_settings["tiles"] array with the corresponding index and pastes the
144
+ # generated tile into the target image using the corresponding mask and scale
145
+ # factor; note that scale factor for the target image and the individual tiles
146
+ # can be different, this function will adjust scale factors as needed
147
+ def generate_upscaled_tile(
148
+ config,
149
+ index,
150
+ upscale_settings,
151
+ context: Context = None,
152
+ ):
153
+ if config == None or upscale_settings == None:
154
+ logging.error("Wrong arguments in tile creation function call!")
155
+ return
156
+
157
+ x = upscale_settings["tiles"][index]["x"]
158
+ y = upscale_settings["tiles"][index]["y"]
159
+ w = upscale_settings["tiles"][index]["w"]
160
+ h = upscale_settings["tiles"][index]["h"]
161
+ tile_prompt = upscale_settings["tiles"][index]["prompt"]
162
+ scale_factor = upscale_settings["scale_factor"]
163
+ tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
164
+ target_width = int(w * tile_scale_factor)
165
+ target_height = int(h * tile_scale_factor)
166
+ strength = upscale_settings["strength"]
167
+ source_image = upscale_settings["source_image"]
168
+ target_image = upscale_settings["target_image"]
169
+ mask_image = generate_tile_mask(config, index, upscale_settings)
170
+
171
+ config.lcm_diffusion_setting.number_of_images = 1
172
+ config.lcm_diffusion_setting.prompt = tile_prompt
173
+ config.lcm_diffusion_setting.image_width = target_width
174
+ config.lcm_diffusion_setting.image_height = target_height
175
+ config.lcm_diffusion_setting.init_image = source_image.crop((x, y, x + w, y + h))
176
+
177
+ current_tile = None
178
+ print(f"[SD Upscale] Generating tile {index + 1}/{len(upscale_settings['tiles'])} ")
179
+ if tile_prompt == None or tile_prompt == "":
180
+ config.lcm_diffusion_setting.prompt = ""
181
+ config.lcm_diffusion_setting.negative_prompt = ""
182
+ current_tile = get_current_tile(config, context, strength)
183
+ else:
184
+ # Attempt to use img2img with low denoising strength to
185
+ # generate the tiles with the extra aid of a prompt
186
+ # context = get_context(InterfaceType.CLI)
187
+ current_tile = get_current_tile(config, context, strength)
188
+
189
+ if math.isclose(scale_factor, tile_scale_factor):
190
+ target_image.paste(
191
+ current_tile, (int(x * scale_factor), int(y * scale_factor)), mask_image
192
+ )
193
+ else:
194
+ target_image.paste(
195
+ current_tile.resize((int(w * scale_factor), int(h * scale_factor))),
196
+ (int(x * scale_factor), int(y * scale_factor)),
197
+ mask_image.resize((int(w * scale_factor), int(h * scale_factor))),
198
+ )
199
+ mask_image.close()
200
+ current_tile.close()
201
+ config.lcm_diffusion_setting.init_image.close()
202
+
203
+
204
+ # Generate tile mask using the box definition in the upscale_settings["tiles"]
205
+ # array with the corresponding index; note that tile masks for the default
206
+ # tiled upscale task can be reused but that would complicate the code, so
207
+ # new tile masks are instead created for each tile
208
+ def generate_tile_mask(
209
+ config,
210
+ index,
211
+ upscale_settings,
212
+ ):
213
+ scale_factor = upscale_settings["scale_factor"]
214
+ tile_overlap = upscale_settings["tile_overlap"]
215
+ tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
216
+ w = int(upscale_settings["tiles"][index]["w"] * tile_scale_factor)
217
+ h = int(upscale_settings["tiles"][index]["h"] * tile_scale_factor)
218
+ # The Stable Diffusion pipeline automatically adjusts the output size
219
+ # to multiples of 8 pixels; the mask must be created with the same
220
+ # size as the output tile
221
+ w = w - (w % 8)
222
+ h = h - (h % 8)
223
+ mask_box = upscale_settings["tiles"][index]["mask_box"]
224
+ if mask_box == None:
225
+ # Build a default solid mask with soft/transparent edges
226
+ mask_box = (
227
+ tile_overlap,
228
+ tile_overlap,
229
+ w - tile_overlap,
230
+ h - tile_overlap,
231
+ )
232
+ mask_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0, 0))
233
+ mask_draw = ImageDraw.Draw(mask_image)
234
+ mask_draw.rectangle(tuple(mask_box), fill=(0, 0, 0))
235
+ mask_blur = mask_image.filter(ImageFilter.BoxBlur(tile_overlap - 1))
236
+ mask_image.close()
237
+ return mask_blur
backend/upscale/upscaler.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.models.lcmdiffusion_setting import DiffusionTask
2
+ from backend.models.upscale import UpscaleMode
3
+ from backend.upscale.edsr_upscale_onnx import upscale_edsr_2x
4
+ from backend.upscale.aura_sr_upscale import upscale_aura_sr
5
+ from backend.upscale.tiled_upscale import generate_upscaled_image
6
+ from context import Context
7
+ from PIL import Image
8
+ from state import get_settings
9
+
10
+
11
+ config = get_settings()
12
+
13
+
14
+ def upscale_image(
15
+ context: Context,
16
+ src_image_path: str,
17
+ dst_image_path: str,
18
+ scale_factor: int = 2,
19
+ upscale_mode: UpscaleMode = UpscaleMode.normal.value,
20
+ strength: float = 0.1,
21
+ ):
22
+ if upscale_mode == UpscaleMode.normal.value:
23
+ upscaled_img = upscale_edsr_2x(src_image_path)
24
+ upscaled_img.save(dst_image_path)
25
+ print(f"Upscaled image saved {dst_image_path}")
26
+ elif upscale_mode == UpscaleMode.aura_sr.value:
27
+ upscaled_img = upscale_aura_sr(src_image_path)
28
+ upscaled_img.save(dst_image_path)
29
+ print(f"Upscaled image saved {dst_image_path}")
30
+ else:
31
+ config.settings.lcm_diffusion_setting.strength = (
32
+ 0.3 if config.settings.lcm_diffusion_setting.use_openvino else strength
33
+ )
34
+ config.settings.lcm_diffusion_setting.diffusion_task = (
35
+ DiffusionTask.image_to_image.value
36
+ )
37
+
38
+ generate_upscaled_image(
39
+ config.settings,
40
+ src_image_path,
41
+ config.settings.lcm_diffusion_setting.strength,
42
+ upscale_settings=None,
43
+ context=context,
44
+ tile_overlap=(
45
+ 32 if config.settings.lcm_diffusion_setting.use_openvino else 16
46
+ ),
47
+ output_path=dst_image_path,
48
+ image_format=config.settings.generated_images.format,
49
+ )
50
+ print(f"Upscaled image saved {dst_image_path}")
51
+
52
+ return [Image.open(dst_image_path)]
configs/lcm-lora-models.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ latent-consistency/lcm-lora-sdv1-5
2
+ latent-consistency/lcm-lora-sdxl
3
+ latent-consistency/lcm-lora-ssd-1b
4
+ rupeshs/hypersd-sd1-5-1-step-lora
configs/lcm-models.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ stabilityai/sd-turbo
2
+ rupeshs/sdxs-512-0.9-orig-vae
3
+ rupeshs/hyper-sd-sdxl-1-step
4
+ rupeshs/SDXL-Lightning-2steps
5
+ stabilityai/sdxl-turbo
6
+ SimianLuo/LCM_Dreamshaper_v7
7
+ latent-consistency/lcm-sdxl
8
+ latent-consistency/lcm-ssd-1b
configs/openvino-lcm-models.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ rupeshs/sd-turbo-openvino
2
+ rupeshs/sdxs-512-0.9-openvino
3
+ rupeshs/hyper-sd-sdxl-1-step-openvino-int8
4
+ rupeshs/SDXL-Lightning-2steps-openvino-int8
5
+ rupeshs/sdxl-turbo-openvino-int8
6
+ rupeshs/LCM-dreamshaper-v7-openvino
7
+ Disty0/LCM_SoteMix
8
+ rupeshs/FLUX.1-schnell-openvino-int4
9
+ rupeshs/sd15-lcm-square-openvino-int8
configs/stable-diffusion-models.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Lykon/dreamshaper-8
2
+ Fictiverse/Stable_Diffusion_PaperCut_Model
3
+ stabilityai/stable-diffusion-xl-base-1.0
4
+ runwayml/stable-diffusion-v1-5
5
+ segmind/SSD-1B
6
+ stablediffusionapi/anything-v5
7
+ prompthero/openjourney-v4
constants.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import environ, cpu_count
2
+
3
+ cpu_cores = cpu_count()
4
+ cpus = cpu_cores // 2 if cpu_cores else 0
5
+ APP_VERSION = "v1.0.0 beta 200"
6
+ LCM_DEFAULT_MODEL = "stabilityai/sd-turbo"
7
+ LCM_DEFAULT_MODEL_OPENVINO = "rupeshs/sd-turbo-openvino"
8
+ APP_NAME = "FastSD CPU"
9
+ APP_SETTINGS_FILE = "settings.yaml"
10
+ RESULTS_DIRECTORY = "results"
11
+ CONFIG_DIRECTORY = "configs"
12
+ DEVICE = environ.get("DEVICE", "cpu")
13
+ SD_MODELS_FILE = "stable-diffusion-models.txt"
14
+ LCM_LORA_MODELS_FILE = "lcm-lora-models.txt"
15
+ OPENVINO_LCM_MODELS_FILE = "openvino-lcm-models.txt"
16
+ TAESD_MODEL = "madebyollin/taesd"
17
+ TAESDXL_MODEL = "madebyollin/taesdxl"
18
+ TAESD_MODEL_OPENVINO = "deinferno/taesd-openvino"
19
+ LCM_MODELS_FILE = "lcm-models.txt"
20
+ TAESDXL_MODEL_OPENVINO = "rupeshs/taesdxl-openvino"
21
+ LORA_DIRECTORY = "lora_models"
22
+ CONTROLNET_DIRECTORY = "controlnet_models"
23
+ MODELS_DIRECTORY = "models"
24
+ GGUF_THREADS = environ.get("GGUF_THREADS", cpus)
25
+ TAEF1_MODEL_OPENVINO = "rupeshs/taef1-openvino"
context.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ from app_settings import Settings
3
+ from models.interface_types import InterfaceType
4
+ from backend.models.lcmdiffusion_setting import DiffusionTask
5
+ from backend.lcm_text_to_image import LCMTextToImage
6
+ from time import perf_counter
7
+ from backend.image_saver import ImageSaver
8
+ from pprint import pprint
9
+
10
+
11
+ class Context:
12
+ def __init__(
13
+ self,
14
+ interface_type: InterfaceType,
15
+ device="cpu",
16
+ ):
17
+ self.interface_type = interface_type.value
18
+ self.lcm_text_to_image = LCMTextToImage(device)
19
+ self._latency = 0
20
+
21
+ @property
22
+ def latency(self):
23
+ return self._latency
24
+
25
+ def generate_text_to_image(
26
+ self,
27
+ settings: Settings,
28
+ reshape: bool = False,
29
+ device: str = "cpu",
30
+ save_config=True,
31
+ ) -> Any:
32
+ if (
33
+ settings.lcm_diffusion_setting.use_tiny_auto_encoder
34
+ and settings.lcm_diffusion_setting.use_openvino
35
+ ):
36
+ print(
37
+ "WARNING: Tiny AutoEncoder is not supported in Image to image mode (OpenVINO)"
38
+ )
39
+ tick = perf_counter()
40
+ from state import get_settings
41
+
42
+ if (
43
+ settings.lcm_diffusion_setting.diffusion_task
44
+ == DiffusionTask.text_to_image.value
45
+ ):
46
+ settings.lcm_diffusion_setting.init_image = None
47
+
48
+ if save_config:
49
+ get_settings().save()
50
+
51
+ pprint(settings.lcm_diffusion_setting.model_dump())
52
+ if not settings.lcm_diffusion_setting.lcm_lora:
53
+ return None
54
+ self.lcm_text_to_image.init(
55
+ device,
56
+ settings.lcm_diffusion_setting,
57
+ )
58
+ images = self.lcm_text_to_image.generate(
59
+ settings.lcm_diffusion_setting,
60
+ reshape,
61
+ )
62
+ elapsed = perf_counter() - tick
63
+ self._latency = elapsed
64
+ print(f"Latency : {elapsed:.2f} seconds")
65
+ if settings.lcm_diffusion_setting.controlnet:
66
+ if settings.lcm_diffusion_setting.controlnet.enabled:
67
+ images.append(settings.lcm_diffusion_setting.controlnet._control_image)
68
+ return images
69
+
70
+
71
+ def save_images(
72
+ self,
73
+ images: Any,
74
+ settings: Settings,
75
+ ) -> list[str]:
76
+ saved_images = []
77
+ if images and settings.generated_images.save_image:
78
+ saved_images = ImageSaver.save_images(
79
+ settings.generated_images.path,
80
+ images=images,
81
+ lcm_diffusion_setting=settings.lcm_diffusion_setting,
82
+ format=settings.generated_images.format,
83
+ jpeg_quality=settings.generated_images.save_image_quality,
84
+ )
85
+ return saved_images
frontend/cli_interactive.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path
2
+ from PIL import Image
3
+ from typing import Any
4
+
5
+ from constants import DEVICE
6
+ from paths import FastStableDiffusionPaths
7
+ from backend.upscale.upscaler import upscale_image
8
+ from backend.upscale.tiled_upscale import generate_upscaled_image
9
+ from frontend.webui.image_variations_ui import generate_image_variations
10
+ from backend.lora import (
11
+ get_active_lora_weights,
12
+ update_lora_weights,
13
+ load_lora_weight,
14
+ )
15
+ from backend.models.lcmdiffusion_setting import (
16
+ DiffusionTask,
17
+ ControlNetSetting,
18
+ )
19
+
20
+
21
+ _batch_count = 1
22
+ _edit_lora_settings = False
23
+
24
+
25
+ def user_value(
26
+ value_type: type,
27
+ message: str,
28
+ default_value: Any,
29
+ ) -> Any:
30
+ try:
31
+ value = value_type(input(message))
32
+ except:
33
+ value = default_value
34
+ return value
35
+
36
+
37
+ def interactive_mode(
38
+ config,
39
+ context,
40
+ ):
41
+ print("=============================================")
42
+ print("Welcome to FastSD CPU Interactive CLI")
43
+ print("=============================================")
44
+ while True:
45
+ print("> 1. Text to Image")
46
+ print("> 2. Image to Image")
47
+ print("> 3. Image Variations")
48
+ print("> 4. EDSR Upscale")
49
+ print("> 5. SD Upscale")
50
+ print("> 6. Edit default generation settings")
51
+ print("> 7. Edit LoRA settings")
52
+ print("> 8. Edit ControlNet settings")
53
+ print("> 9. Edit negative prompt")
54
+ print("> 10. Quit")
55
+ option = user_value(
56
+ int,
57
+ "Enter a Diffusion Task number (1): ",
58
+ 1,
59
+ )
60
+ if option not in range(1, 11):
61
+ print("Wrong Diffusion Task number!")
62
+ exit()
63
+
64
+ if option == 1:
65
+ interactive_txt2img(
66
+ config,
67
+ context,
68
+ )
69
+ elif option == 2:
70
+ interactive_img2img(
71
+ config,
72
+ context,
73
+ )
74
+ elif option == 3:
75
+ interactive_variations(
76
+ config,
77
+ context,
78
+ )
79
+ elif option == 4:
80
+ interactive_edsr(
81
+ config,
82
+ context,
83
+ )
84
+ elif option == 5:
85
+ interactive_sdupscale(
86
+ config,
87
+ context,
88
+ )
89
+ elif option == 6:
90
+ interactive_settings(
91
+ config,
92
+ context,
93
+ )
94
+ elif option == 7:
95
+ interactive_lora(
96
+ config,
97
+ context,
98
+ True,
99
+ )
100
+ elif option == 8:
101
+ interactive_controlnet(
102
+ config,
103
+ context,
104
+ True,
105
+ )
106
+ elif option == 9:
107
+ interactive_negative(
108
+ config,
109
+ context,
110
+ )
111
+ elif option == 10:
112
+ exit()
113
+
114
+
115
+ def interactive_negative(
116
+ config,
117
+ context,
118
+ ):
119
+ settings = config.lcm_diffusion_setting
120
+ print(f"Current negative prompt: '{settings.negative_prompt}'")
121
+ user_input = input("Write a negative prompt (set guidance > 1.0): ")
122
+ if user_input == "":
123
+ return
124
+ else:
125
+ settings.negative_prompt = user_input
126
+
127
+
128
+ def interactive_controlnet(
129
+ config,
130
+ context,
131
+ menu_flag=False,
132
+ ):
133
+ """
134
+ @param menu_flag: Indicates whether this function was called from the main
135
+ interactive CLI menu; _True_ if called from the main menu, _False_ otherwise
136
+ """
137
+ settings = config.lcm_diffusion_setting
138
+ if not settings.controlnet:
139
+ settings.controlnet = ControlNetSetting()
140
+
141
+ current_enabled = settings.controlnet.enabled
142
+ current_adapter_path = settings.controlnet.adapter_path
143
+ current_conditioning_scale = settings.controlnet.conditioning_scale
144
+ current_control_image = settings.controlnet._control_image
145
+
146
+ option = input("Enable ControlNet? (y/N): ")
147
+ settings.controlnet.enabled = True if option.upper() == "Y" else False
148
+ if settings.controlnet.enabled:
149
+ option = input(
150
+ f"Enter ControlNet adapter path ({settings.controlnet.adapter_path}): "
151
+ )
152
+ if option != "":
153
+ settings.controlnet.adapter_path = option
154
+ settings.controlnet.conditioning_scale = user_value(
155
+ float,
156
+ f"Enter ControlNet conditioning scale ({settings.controlnet.conditioning_scale}): ",
157
+ settings.controlnet.conditioning_scale,
158
+ )
159
+ option = input(
160
+ f"Enter ControlNet control image path (Leave empty to reuse current): "
161
+ )
162
+ if option != "":
163
+ try:
164
+ new_image = Image.open(option)
165
+ settings.controlnet._control_image = new_image
166
+ except (AttributeError, FileNotFoundError) as e:
167
+ settings.controlnet._control_image = None
168
+ if (
169
+ not settings.controlnet.adapter_path
170
+ or not path.exists(settings.controlnet.adapter_path)
171
+ or not settings.controlnet._control_image
172
+ ):
173
+ print("Invalid ControlNet settings! Disabling ControlNet")
174
+ settings.controlnet.enabled = False
175
+
176
+ if (
177
+ settings.controlnet.enabled != current_enabled
178
+ or settings.controlnet.adapter_path != current_adapter_path
179
+ ):
180
+ settings.rebuild_pipeline = True
181
+
182
+
183
+ def interactive_lora(
184
+ config,
185
+ context,
186
+ menu_flag=False,
187
+ ):
188
+ """
189
+ @param menu_flag: Indicates whether this function was called from the main
190
+ interactive CLI menu; _True_ if called from the main menu, _False_ otherwise
191
+ """
192
+ if context == None or context.lcm_text_to_image.pipeline == None:
193
+ print("Diffusion pipeline not initialized, please run a generation task first!")
194
+ return
195
+
196
+ print("> 1. Change LoRA weights")
197
+ print("> 2. Load new LoRA model")
198
+ option = user_value(
199
+ int,
200
+ "Enter a LoRA option (1): ",
201
+ 1,
202
+ )
203
+ if option not in range(1, 3):
204
+ print("Wrong LoRA option!")
205
+ return
206
+
207
+ if option == 1:
208
+ update_weights = []
209
+ active_weights = get_active_lora_weights()
210
+ for lora in active_weights:
211
+ weight = user_value(
212
+ float,
213
+ f"Enter a new LoRA weight for {lora[0]} ({lora[1]}): ",
214
+ lora[1],
215
+ )
216
+ update_weights.append(
217
+ (
218
+ lora[0],
219
+ weight,
220
+ )
221
+ )
222
+ if len(update_weights) > 0:
223
+ update_lora_weights(
224
+ context.lcm_text_to_image.pipeline,
225
+ config.lcm_diffusion_setting,
226
+ update_weights,
227
+ )
228
+ elif option == 2:
229
+ # Load a new LoRA
230
+ settings = config.lcm_diffusion_setting
231
+ settings.lora.fuse = False
232
+ settings.lora.enabled = False
233
+ settings.lora.path = input("Enter LoRA model path: ")
234
+ settings.lora.weight = user_value(
235
+ float,
236
+ "Enter a LoRA weight (0.5): ",
237
+ 0.5,
238
+ )
239
+ if not path.exists(settings.lora.path):
240
+ print("Invalid LoRA model path!")
241
+ return
242
+ settings.lora.enabled = True
243
+ load_lora_weight(context.lcm_text_to_image.pipeline, settings)
244
+
245
+ if menu_flag:
246
+ global _edit_lora_settings
247
+ _edit_lora_settings = False
248
+ option = input("Edit LoRA settings after every generation? (y/N): ")
249
+ if option.upper() == "Y":
250
+ _edit_lora_settings = True
251
+
252
+
253
+ def interactive_settings(
254
+ config,
255
+ context,
256
+ ):
257
+ global _batch_count
258
+ settings = config.lcm_diffusion_setting
259
+ print("Enter generation settings (leave empty to use current value)")
260
+ print("> 1. Use LCM")
261
+ print("> 2. Use LCM-Lora")
262
+ print("> 3. Use OpenVINO")
263
+ option = user_value(
264
+ int,
265
+ "Select inference model option (1): ",
266
+ 1,
267
+ )
268
+ if option not in range(1, 4):
269
+ print("Wrong inference model option! Falling back to defaults")
270
+ return
271
+
272
+ settings.use_lcm_lora = False
273
+ settings.use_openvino = False
274
+ if option == 1:
275
+ lcm_model_id = input(f"Enter LCM model ID ({settings.lcm_model_id}): ")
276
+ if lcm_model_id != "":
277
+ settings.lcm_model_id = lcm_model_id
278
+ elif option == 2:
279
+ settings.use_lcm_lora = True
280
+ lcm_lora_id = input(
281
+ f"Enter LCM-Lora model ID ({settings.lcm_lora.lcm_lora_id}): "
282
+ )
283
+ if lcm_lora_id != "":
284
+ settings.lcm_lora.lcm_lora_id = lcm_lora_id
285
+ base_model_id = input(
286
+ f"Enter Base model ID ({settings.lcm_lora.base_model_id}): "
287
+ )
288
+ if base_model_id != "":
289
+ settings.lcm_lora.base_model_id = base_model_id
290
+ elif option == 3:
291
+ settings.use_openvino = True
292
+ openvino_lcm_model_id = input(
293
+ f"Enter OpenVINO model ID ({settings.openvino_lcm_model_id}): "
294
+ )
295
+ if openvino_lcm_model_id != "":
296
+ settings.openvino_lcm_model_id = openvino_lcm_model_id
297
+
298
+ settings.use_offline_model = True
299
+ settings.use_tiny_auto_encoder = True
300
+ option = input("Work offline? (Y/n): ")
301
+ if option.upper() == "N":
302
+ settings.use_offline_model = False
303
+ option = input("Use Tiny Auto Encoder? (Y/n): ")
304
+ if option.upper() == "N":
305
+ settings.use_tiny_auto_encoder = False
306
+
307
+ settings.image_width = user_value(
308
+ int,
309
+ f"Image width ({settings.image_width}): ",
310
+ settings.image_width,
311
+ )
312
+ settings.image_height = user_value(
313
+ int,
314
+ f"Image height ({settings.image_height}): ",
315
+ settings.image_height,
316
+ )
317
+ settings.inference_steps = user_value(
318
+ int,
319
+ f"Inference steps ({settings.inference_steps}): ",
320
+ settings.inference_steps,
321
+ )
322
+ settings.guidance_scale = user_value(
323
+ float,
324
+ f"Guidance scale ({settings.guidance_scale}): ",
325
+ settings.guidance_scale,
326
+ )
327
+ settings.number_of_images = user_value(
328
+ int,
329
+ f"Number of images per batch ({settings.number_of_images}): ",
330
+ settings.number_of_images,
331
+ )
332
+ _batch_count = user_value(
333
+ int,
334
+ f"Batch count ({_batch_count}): ",
335
+ _batch_count,
336
+ )
337
+ # output_format = user_value(int, f"Output format (PNG)", 1)
338
+ print(config.lcm_diffusion_setting)
339
+
340
+
341
+ def interactive_txt2img(
342
+ config,
343
+ context,
344
+ ):
345
+ global _batch_count
346
+ config.lcm_diffusion_setting.diffusion_task = DiffusionTask.text_to_image.value
347
+ user_input = input("Write a prompt (write 'exit' to quit): ")
348
+ while True:
349
+ if user_input == "exit":
350
+ return
351
+ elif user_input == "":
352
+ user_input = config.lcm_diffusion_setting.prompt
353
+ config.lcm_diffusion_setting.prompt = user_input
354
+ for _ in range(0, _batch_count):
355
+ images = context.generate_text_to_image(
356
+ settings=config,
357
+ device=DEVICE,
358
+ )
359
+ context.save_images(
360
+ images,
361
+ config,
362
+ )
363
+ if _edit_lora_settings:
364
+ interactive_lora(
365
+ config,
366
+ context,
367
+ )
368
+ user_input = input("Write a prompt: ")
369
+
370
+
371
+ def interactive_img2img(
372
+ config,
373
+ context,
374
+ ):
375
+ global _batch_count
376
+ settings = config.lcm_diffusion_setting
377
+ settings.diffusion_task = DiffusionTask.image_to_image.value
378
+ steps = settings.inference_steps
379
+ source_path = input("Image path: ")
380
+ if source_path == "":
381
+ print("Error : You need to provide a file in img2img mode")
382
+ return
383
+ settings.strength = user_value(
384
+ float,
385
+ f"img2img strength ({settings.strength}): ",
386
+ settings.strength,
387
+ )
388
+ settings.inference_steps = int(steps / settings.strength + 1)
389
+ user_input = input("Write a prompt (write 'exit' to quit): ")
390
+ while True:
391
+ if user_input == "exit":
392
+ settings.inference_steps = steps
393
+ return
394
+ settings.init_image = Image.open(source_path)
395
+ settings.prompt = user_input
396
+ for _ in range(0, _batch_count):
397
+ images = context.generate_text_to_image(
398
+ settings=config,
399
+ device=DEVICE,
400
+ )
401
+ context.save_images(
402
+ images,
403
+ config,
404
+ )
405
+ new_path = input(f"Image path ({source_path}): ")
406
+ if new_path != "":
407
+ source_path = new_path
408
+ settings.strength = user_value(
409
+ float,
410
+ f"img2img strength ({settings.strength}): ",
411
+ settings.strength,
412
+ )
413
+ if _edit_lora_settings:
414
+ interactive_lora(
415
+ config,
416
+ context,
417
+ )
418
+ settings.inference_steps = int(steps / settings.strength + 1)
419
+ user_input = input("Write a prompt: ")
420
+
421
+
422
+ def interactive_variations(
423
+ config,
424
+ context,
425
+ ):
426
+ global _batch_count
427
+ settings = config.lcm_diffusion_setting
428
+ settings.diffusion_task = DiffusionTask.image_to_image.value
429
+ steps = settings.inference_steps
430
+ source_path = input("Image path: ")
431
+ if source_path == "":
432
+ print("Error : You need to provide a file in Image variations mode")
433
+ return
434
+ settings.strength = user_value(
435
+ float,
436
+ f"Image variations strength ({settings.strength}): ",
437
+ settings.strength,
438
+ )
439
+ settings.inference_steps = int(steps / settings.strength + 1)
440
+ while True:
441
+ settings.init_image = Image.open(source_path)
442
+ settings.prompt = ""
443
+ for i in range(0, _batch_count):
444
+ generate_image_variations(
445
+ settings.init_image,
446
+ settings.strength,
447
+ )
448
+ if _edit_lora_settings:
449
+ interactive_lora(
450
+ config,
451
+ context,
452
+ )
453
+ user_input = input("Continue in Image variations mode? (Y/n): ")
454
+ if user_input.upper() == "N":
455
+ settings.inference_steps = steps
456
+ return
457
+ new_path = input(f"Image path ({source_path}): ")
458
+ if new_path != "":
459
+ source_path = new_path
460
+ settings.strength = user_value(
461
+ float,
462
+ f"Image variations strength ({settings.strength}): ",
463
+ settings.strength,
464
+ )
465
+ settings.inference_steps = int(steps / settings.strength + 1)
466
+
467
+
468
+ def interactive_edsr(
469
+ config,
470
+ context,
471
+ ):
472
+ source_path = input("Image path: ")
473
+ if source_path == "":
474
+ print("Error : You need to provide a file in EDSR mode")
475
+ return
476
+ while True:
477
+ output_path = FastStableDiffusionPaths.get_upscale_filepath(
478
+ source_path,
479
+ 2,
480
+ config.generated_images.format,
481
+ )
482
+ result = upscale_image(
483
+ context,
484
+ source_path,
485
+ output_path,
486
+ 2,
487
+ )
488
+ user_input = input("Continue in EDSR upscale mode? (Y/n): ")
489
+ if user_input.upper() == "N":
490
+ return
491
+ new_path = input(f"Image path ({source_path}): ")
492
+ if new_path != "":
493
+ source_path = new_path
494
+
495
+
496
+ def interactive_sdupscale_settings(config):
497
+ steps = config.lcm_diffusion_setting.inference_steps
498
+ custom_settings = {}
499
+ print("> 1. Upscale whole image")
500
+ print("> 2. Define custom tiles (advanced)")
501
+ option = user_value(
502
+ int,
503
+ "Select an SD Upscale option (1): ",
504
+ 1,
505
+ )
506
+ if option not in range(1, 3):
507
+ print("Wrong SD Upscale option!")
508
+ return
509
+
510
+ # custom_settings["source_file"] = args.file
511
+ custom_settings["source_file"] = ""
512
+ new_path = input(f"Input image path ({custom_settings['source_file']}): ")
513
+ if new_path != "":
514
+ custom_settings["source_file"] = new_path
515
+ if custom_settings["source_file"] == "":
516
+ print("Error : You need to provide a file in SD Upscale mode")
517
+ return
518
+ custom_settings["target_file"] = None
519
+ if option == 2:
520
+ custom_settings["target_file"] = input("Image to patch: ")
521
+ if custom_settings["target_file"] == "":
522
+ print("No target file provided, upscaling whole input image instead!")
523
+ custom_settings["target_file"] = None
524
+ option = 1
525
+ custom_settings["output_format"] = config.generated_images.format
526
+ custom_settings["strength"] = user_value(
527
+ float,
528
+ f"SD Upscale strength ({config.lcm_diffusion_setting.strength}): ",
529
+ config.lcm_diffusion_setting.strength,
530
+ )
531
+ config.lcm_diffusion_setting.inference_steps = int(
532
+ steps / custom_settings["strength"] + 1
533
+ )
534
+ if option == 1:
535
+ custom_settings["scale_factor"] = user_value(
536
+ float,
537
+ f"Scale factor (2.0): ",
538
+ 2.0,
539
+ )
540
+ custom_settings["tile_size"] = user_value(
541
+ int,
542
+ f"Split input image into tiles of the following size, in pixels (256): ",
543
+ 256,
544
+ )
545
+ custom_settings["tile_overlap"] = user_value(
546
+ int,
547
+ f"Tile overlap, in pixels (16): ",
548
+ 16,
549
+ )
550
+ elif option == 2:
551
+ custom_settings["scale_factor"] = user_value(
552
+ float,
553
+ "Input image to Image-to-patch scale_factor (2.0): ",
554
+ 2.0,
555
+ )
556
+ custom_settings["tile_size"] = 256
557
+ custom_settings["tile_overlap"] = 16
558
+ custom_settings["prompt"] = input(
559
+ "Write a prompt describing the input image (optional): "
560
+ )
561
+ custom_settings["tiles"] = []
562
+ if option == 2:
563
+ add_tile = True
564
+ while add_tile:
565
+ print("=== Define custom SD Upscale tile ===")
566
+ tile_x = user_value(
567
+ int,
568
+ "Enter tile's X position: ",
569
+ 0,
570
+ )
571
+ tile_y = user_value(
572
+ int,
573
+ "Enter tile's Y position: ",
574
+ 0,
575
+ )
576
+ tile_w = user_value(
577
+ int,
578
+ "Enter tile's width (256): ",
579
+ 256,
580
+ )
581
+ tile_h = user_value(
582
+ int,
583
+ "Enter tile's height (256): ",
584
+ 256,
585
+ )
586
+ tile_scale = user_value(
587
+ float,
588
+ "Enter tile's scale factor (2.0): ",
589
+ 2.0,
590
+ )
591
+ tile_prompt = input("Enter tile's prompt (optional): ")
592
+ custom_settings["tiles"].append(
593
+ {
594
+ "x": tile_x,
595
+ "y": tile_y,
596
+ "w": tile_w,
597
+ "h": tile_h,
598
+ "mask_box": None,
599
+ "prompt": tile_prompt,
600
+ "scale_factor": tile_scale,
601
+ }
602
+ )
603
+ tile_option = input("Do you want to define another tile? (y/N): ")
604
+ if tile_option == "" or tile_option.upper() == "N":
605
+ add_tile = False
606
+
607
+ return custom_settings
608
+
609
+
610
+ def interactive_sdupscale(
611
+ config,
612
+ context,
613
+ ):
614
+ settings = config.lcm_diffusion_setting
615
+ settings.diffusion_task = DiffusionTask.image_to_image.value
616
+ settings.init_image = ""
617
+ source_path = ""
618
+ steps = settings.inference_steps
619
+
620
+ while True:
621
+ custom_upscale_settings = None
622
+ option = input("Edit custom SD Upscale settings? (y/N): ")
623
+ if option.upper() == "Y":
624
+ config.lcm_diffusion_setting.inference_steps = steps
625
+ custom_upscale_settings = interactive_sdupscale_settings(config)
626
+ if not custom_upscale_settings:
627
+ return
628
+ source_path = custom_upscale_settings["source_file"]
629
+ else:
630
+ new_path = input(f"Image path ({source_path}): ")
631
+ if new_path != "":
632
+ source_path = new_path
633
+ if source_path == "":
634
+ print("Error : You need to provide a file in SD Upscale mode")
635
+ return
636
+ settings.strength = user_value(
637
+ float,
638
+ f"SD Upscale strength ({settings.strength}): ",
639
+ settings.strength,
640
+ )
641
+ settings.inference_steps = int(steps / settings.strength + 1)
642
+
643
+ output_path = FastStableDiffusionPaths.get_upscale_filepath(
644
+ source_path,
645
+ 2,
646
+ config.generated_images.format,
647
+ )
648
+ generate_upscaled_image(
649
+ config,
650
+ source_path,
651
+ settings.strength,
652
+ upscale_settings=custom_upscale_settings,
653
+ context=context,
654
+ tile_overlap=32 if settings.use_openvino else 16,
655
+ output_path=output_path,
656
+ image_format=config.generated_images.format,
657
+ )
658
+ user_input = input("Continue in SD Upscale mode? (Y/n): ")
659
+ if user_input.upper() == "N":
660
+ settings.inference_steps = steps
661
+ return