radames commited on
Commit
e0273d5
·
1 Parent(s): 2bca65e

Delete server/wrapper.py

Browse files
Files changed (1) hide show
  1. server/wrapper.py +0 -529
server/wrapper.py DELETED
@@ -1,529 +0,0 @@
1
- import gc
2
- import os
3
- import traceback
4
- from typing import List, Literal, Optional, Union
5
-
6
- import numpy as np
7
- import torch
8
- from diffusers import AutoencoderTiny, StableDiffusionPipeline
9
- from PIL import Image
10
- from polygraphy import cuda
11
-
12
- from streamdiffusion import StreamDiffusion
13
- from streamdiffusion.image_utils import postprocess_image
14
-
15
- torch.set_grad_enabled(False)
16
- torch.backends.cuda.matmul.allow_tf32 = True
17
- torch.backends.cudnn.allow_tf32 = True
18
-
19
-
20
- class StreamDiffusionWrapper:
21
- def __init__(
22
- self,
23
- model_id: str,
24
- t_index_list: List[int],
25
- mode: Literal["img2img", "txt2img"] = "img2img",
26
- output_type: Literal["pil", "pt", "np", "latent"] = "pil",
27
- lcm_lora_id: Optional[str] = None,
28
- vae_id: Optional[str] = None,
29
- device: Literal["cpu", "cuda"] = "cuda",
30
- dtype: torch.dtype = torch.float16,
31
- frame_buffer_size: int = 1,
32
- width: int = 512,
33
- height: int = 512,
34
- warmup: int = 10,
35
- acceleration: Literal["none", "xformers", "sfast", "tensorrt"] = "xformers",
36
- is_drawing: bool = True,
37
- device_ids: Optional[List[int]] = None,
38
- use_lcm_lora: bool = True,
39
- use_tiny_vae: bool = True,
40
- enable_similar_image_filter: bool = False,
41
- similar_image_filter_threshold: float = 0.98,
42
- use_denoising_batch: bool = True,
43
- cfg_type: Literal["none", "full", "self", "initialize"] = "none",
44
- use_safety_checker: bool = False,
45
- ):
46
- if mode == "txt2img":
47
- if cfg_type != "none":
48
- raise ValueError(
49
- f"txt2img mode accepts only cfg_type = 'none', but got {cfg_type}"
50
- )
51
- if use_denoising_batch and frame_buffer_size > 1:
52
- raise ValueError(
53
- "txt2img mode cannot use denoising batch with frame_buffer_size > 1."
54
- )
55
-
56
- if mode == "img2img":
57
- if not use_denoising_batch:
58
- raise NotImplementedError(
59
- "img2img mode must use denoising batch for now."
60
- )
61
-
62
- self.sd_turbo = "turbo" in model_id
63
- self.device = device
64
- self.dtype = dtype
65
- self.width = width
66
- self.height = height
67
- self.mode = mode
68
- self.output_type = output_type
69
- self.frame_buffer_size = frame_buffer_size
70
- self.batch_size = (
71
- len(t_index_list) * frame_buffer_size
72
- if use_denoising_batch
73
- else frame_buffer_size
74
- )
75
-
76
- self.use_denoising_batch = use_denoising_batch
77
- self.use_safety_checker = use_safety_checker
78
-
79
- self.stream = self._load_model(
80
- model_id=model_id,
81
- lcm_lora_id=lcm_lora_id,
82
- vae_id=vae_id,
83
- t_index_list=t_index_list,
84
- acceleration=acceleration,
85
- warmup=warmup,
86
- is_drawing=is_drawing,
87
- use_lcm_lora=use_lcm_lora,
88
- use_tiny_vae=use_tiny_vae,
89
- cfg_type=cfg_type,
90
- )
91
-
92
- if device_ids is not None:
93
- self.stream.unet = torch.nn.DataParallel(
94
- self.stream.unet, device_ids=device_ids
95
- )
96
-
97
- if enable_similar_image_filter:
98
- self.stream.enable_similar_image_filter(similar_image_filter_threshold)
99
-
100
- def prepare(
101
- self,
102
- prompt: str,
103
- negative_prompt: str = "",
104
- num_inference_steps: int = 50,
105
- guidance_scale: float = 1.2,
106
- delta: float = 1.0,
107
- ) -> None:
108
- """
109
- Prepares the model for inference.
110
-
111
- Parameters
112
- ----------
113
- prompt : str
114
- The prompt to generate images from.
115
- num_inference_steps : int, optional
116
- The number of inference steps to perform, by default 50.
117
- """
118
- self.stream.prepare(
119
- prompt,
120
- negative_prompt,
121
- num_inference_steps=num_inference_steps,
122
- guidance_scale=guidance_scale,
123
- delta=delta,
124
- )
125
-
126
- def __call__(
127
- self,
128
- image: Optional[Union[str, Image.Image, torch.Tensor]] = None,
129
- prompt: Optional[str] = None,
130
- ) -> Union[Image.Image, List[Image.Image]]:
131
- """
132
- Performs img2img or txt2img based on the mode.
133
-
134
- Parameters
135
- ----------
136
- image : Optional[Union[str, Image.Image, torch.Tensor]]
137
- The image to generate from.
138
- prompt : Optional[str]
139
- The prompt to generate images from.
140
-
141
- Returns
142
- -------
143
- Union[Image.Image, List[Image.Image]]
144
- The generated image.
145
- """
146
- if self.mode == "img2img":
147
- return self.img2img(image)
148
- else:
149
- return self.txt2img(prompt)
150
-
151
- def txt2img(
152
- self, prompt: str
153
- ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
154
- """
155
- Performs txt2img.
156
-
157
- Parameters
158
- ----------
159
- prompt : str
160
- The prompt to generate images from.
161
-
162
- Returns
163
- -------
164
- Union[Image.Image, List[Image.Image]]
165
- The generated image.
166
- """
167
- self.stream.update_prompt(prompt)
168
-
169
- if self.sd_turbo:
170
- image_tensor = self.stream.txt2img_sd_turbo(self.batch_size)
171
- else:
172
- image_tensor = self.stream.txt2img(self.frame_buffer_size)
173
- image = self.postprocess_image(image_tensor, output_type=self.output_type)
174
-
175
- if self.use_safety_checker:
176
- safety_checker_input = self.feature_extractor(
177
- image, return_tensors="pt"
178
- ).to(self.device)
179
- _, has_nsfw_concept = self.safety_checker(
180
- images=image_tensor.to(self.dtype),
181
- clip_input=safety_checker_input.pixel_values.to(self.dtype),
182
- )
183
- image = self.nsfw_fallback_img if has_nsfw_concept[0] else image
184
-
185
- return image
186
-
187
- def img2img(
188
- self, image: Union[str, Image.Image, torch.Tensor]
189
- ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
190
- """
191
- Performs img2img.
192
-
193
- Parameters
194
- ----------
195
- image : Union[str, Image.Image, torch.Tensor]
196
- The image to generate from.
197
-
198
- Returns
199
- -------
200
- Image.Image
201
- The generated image.
202
- """
203
- if isinstance(image, str) or isinstance(image, Image.Image):
204
- image = self.preprocess_image(image)
205
-
206
- image_tensor = self.stream(image)
207
- return self.postprocess_image(image_tensor, output_type=self.output_type)
208
-
209
- def preprocess_image(self, image: Union[str, Image.Image]) -> torch.Tensor:
210
- """
211
- Preprocesses the image.
212
-
213
- Parameters
214
- ----------
215
- image : Union[str, Image.Image, torch.Tensor]
216
- The image to preprocess.
217
-
218
- Returns
219
- -------
220
- torch.Tensor
221
- The preprocessed image.
222
- """
223
- if isinstance(image, str):
224
- image = Image.open(image).convert("RGB").resize((self.width, self.height))
225
- if isinstance(image, Image.Image):
226
- image = image.convert("RGB").resize((self.width, self.height))
227
-
228
- return self.stream.image_processor.preprocess(
229
- image, self.height, self.width
230
- ).to(device=self.device, dtype=self.dtype)
231
-
232
- def postprocess_image(
233
- self, image_tensor: torch.Tensor, output_type: str = "pil"
234
- ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
235
- """
236
- Postprocesses the image.
237
-
238
- Parameters
239
- ----------
240
- image_tensor : torch.Tensor
241
- The image tensor to postprocess.
242
-
243
- Returns
244
- -------
245
- Union[Image.Image, List[Image.Image]]
246
- The postprocessed image.
247
- """
248
- if self.frame_buffer_size > 1:
249
- return postprocess_image(image_tensor.cpu(), output_type=output_type)
250
- else:
251
- return postprocess_image(image_tensor.cpu(), output_type=output_type)[0]
252
-
253
- def _load_model(
254
- self,
255
- model_id: str,
256
- t_index_list: List[int],
257
- lcm_lora_id: Optional[str] = None,
258
- vae_id: Optional[str] = None,
259
- acceleration: Literal["none", "sfast", "tensorrt"] = "tensorrt",
260
- is_drawing: bool = True,
261
- warmup: int = 10,
262
- use_lcm_lora: bool = True,
263
- use_tiny_vae: bool = True,
264
- cfg_type: Literal["none", "full", "self", "initialize"] = "self",
265
- ):
266
- """
267
- Loads the model.
268
-
269
- This method does the following:
270
-
271
- 1. Loads the model from the model_id.
272
- 2. Loads and fuses the LCM-LoRA model from the lcm_lora_id if needed.
273
- 3. Loads the VAE model from the vae_id if needed.
274
- 4. Enables acceleration if needed.
275
- 5. Prepares the model for inference.
276
- 6. Warms up the model.
277
-
278
- Parameters
279
- ----------
280
- model_id : str
281
- The model id to load.
282
- t_index_list : List[int]
283
- The t_index_list to use for inference.
284
- lcm_lora_id : Optional[str], optional
285
- The lcm_lora_id to load, by default None.
286
- vae_id : Optional[str], optional
287
- The vae_id to load, by default None.
288
- acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional
289
- The acceleration method to use, by default "tensorrt".
290
- warmup : int, optional
291
- The number of warmup steps to perform, by default 10.
292
- is_drawing : bool, optional
293
- Whether to draw the image or not, by default True.
294
- use_lcm_lora : bool, optional
295
- Whether to use LCM-LoRA or not, by default True.
296
- use_tiny_vae : bool, optional
297
- Whether to use TinyVAE or not, by default True.
298
- cfg_type : Literal["none", "full", "self", "initialize"], optional
299
- The cfg_type to use, by default "self".
300
- """
301
-
302
- try: # Load from local directory
303
- pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
304
- model_id,
305
- ).to(device=self.device, dtype=self.dtype)
306
-
307
- except ValueError: # Load from huggingface
308
- pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file(
309
- model_id
310
- ).to(device=self.device, dtype=self.dtype)
311
- except Exception: # No model found
312
- traceback.print_exc()
313
- print("Model load has failed. Doesn't exist.")
314
- exit()
315
-
316
- if self.use_safety_checker:
317
- from transformers import CLIPFeatureExtractor
318
- from diffusers.pipelines.stable_diffusion.safety_checker import (
319
- StableDiffusionSafetyChecker,
320
- )
321
-
322
- self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
323
- "CompVis/stable-diffusion-safety-checker"
324
- ).to(pipe.device)
325
- self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
326
- "openai/clip-vit-base-patch32"
327
- )
328
- self.nsfw_fallback_img = Image.new("RGB", (512, 512), (0, 0, 0))
329
-
330
- stream = StreamDiffusion(
331
- pipe=pipe,
332
- t_index_list=t_index_list,
333
- torch_dtype=self.dtype,
334
- width=self.width,
335
- height=self.height,
336
- is_drawing=is_drawing,
337
- frame_buffer_size=self.frame_buffer_size,
338
- use_denoising_batch=self.use_denoising_batch,
339
- cfg_type=cfg_type,
340
- )
341
- if not self.sd_turbo:
342
- if use_lcm_lora:
343
- if lcm_lora_id is not None:
344
- stream.load_lcm_lora(
345
- pretrained_model_name_or_path_or_dict=lcm_lora_id
346
- )
347
- else:
348
- stream.load_lcm_lora()
349
- stream.fuse_lora()
350
-
351
- if use_tiny_vae:
352
- if vae_id is not None:
353
- stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(
354
- device=pipe.device, dtype=pipe.dtype
355
- )
356
- else:
357
- stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(
358
- device=pipe.device, dtype=pipe.dtype
359
- )
360
-
361
- try:
362
- if acceleration == "xformers":
363
- stream.pipe.enable_xformers_memory_efficient_attention()
364
- if acceleration == "tensorrt":
365
- from streamdiffusion.acceleration.tensorrt import (
366
- TorchVAEEncoder,
367
- compile_unet,
368
- compile_vae_decoder,
369
- compile_vae_encoder,
370
- )
371
- from streamdiffusion.acceleration.tensorrt.engine import (
372
- AutoencoderKLEngine,
373
- UNet2DConditionModelEngine,
374
- )
375
- from streamdiffusion.acceleration.tensorrt.models import (
376
- VAE,
377
- UNet,
378
- VAEEncoder,
379
- )
380
-
381
- def create_prefix(
382
- max_batch_size: int,
383
- min_batch_size: int,
384
- ):
385
- return f"{model_id}--lcm_lora-{use_tiny_vae}--tiny_vae-{use_lcm_lora}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}"
386
-
387
- engine_dir = os.path.join("engines")
388
- unet_path = os.path.join(
389
- engine_dir,
390
- create_prefix(
391
- stream.trt_unet_batch_size, stream.trt_unet_batch_size
392
- ),
393
- "unet.engine",
394
- )
395
- vae_encoder_path = os.path.join(
396
- engine_dir,
397
- create_prefix(
398
- self.batch_size
399
- if self.mode == "txt2img"
400
- else stream.frame_bff_size,
401
- self.batch_size
402
- if self.mode == "txt2img"
403
- else stream.frame_bff_size,
404
- ),
405
- "vae_encoder.engine",
406
- )
407
- vae_decoder_path = os.path.join(
408
- engine_dir,
409
- create_prefix(
410
- self.batch_size
411
- if self.mode == "txt2img"
412
- else stream.frame_bff_size,
413
- self.batch_size
414
- if self.mode == "txt2img"
415
- else stream.frame_bff_size,
416
- ),
417
- "vae_decoder.engine",
418
- )
419
-
420
- if not os.path.exists(unet_path):
421
- os.makedirs(os.path.dirname(unet_path), exist_ok=True)
422
- unet_model = UNet(
423
- fp16=True,
424
- device=stream.device,
425
- max_batch_size=stream.trt_unet_batch_size,
426
- min_batch_size=stream.trt_unet_batch_size,
427
- embedding_dim=stream.text_encoder.config.hidden_size,
428
- unet_dim=stream.unet.config.in_channels,
429
- )
430
- compile_unet(
431
- stream.unet,
432
- unet_model,
433
- unet_path + ".onnx",
434
- unet_path + ".opt.onnx",
435
- unet_path,
436
- opt_batch_size=stream.trt_unet_batch_size,
437
- )
438
-
439
- if not os.path.exists(vae_decoder_path):
440
- os.makedirs(os.path.dirname(vae_decoder_path), exist_ok=True)
441
- stream.vae.forward = stream.vae.decode
442
- vae_decoder_model = VAE(
443
- device=stream.device,
444
- max_batch_size=self.batch_size
445
- if self.mode == "txt2img"
446
- else stream.frame_bff_size,
447
- min_batch_size=self.batch_size
448
- if self.mode == "txt2img"
449
- else stream.frame_bff_size,
450
- )
451
- compile_vae_decoder(
452
- stream.vae,
453
- vae_decoder_model,
454
- vae_decoder_path + ".onnx",
455
- vae_decoder_path + ".opt.onnx",
456
- vae_decoder_path,
457
- opt_batch_size=self.batch_size
458
- if self.mode == "txt2img"
459
- else stream.frame_bff_size,
460
- )
461
- delattr(stream.vae, "forward")
462
-
463
- if not os.path.exists(vae_encoder_path):
464
- os.makedirs(os.path.dirname(vae_encoder_path), exist_ok=True)
465
- vae_encoder = TorchVAEEncoder(stream.vae).to(torch.device("cuda"))
466
- vae_encoder_model = VAEEncoder(
467
- device=stream.device,
468
- max_batch_size=self.batch_size
469
- if self.mode == "txt2img"
470
- else stream.frame_bff_size,
471
- min_batch_size=self.batch_size
472
- if self.mode == "txt2img"
473
- else stream.frame_bff_size,
474
- )
475
- compile_vae_encoder(
476
- vae_encoder,
477
- vae_encoder_model,
478
- vae_encoder_path + ".onnx",
479
- vae_encoder_path + ".opt.onnx",
480
- vae_encoder_path,
481
- opt_batch_size=self.batch_size
482
- if self.mode == "txt2img"
483
- else stream.frame_bff_size,
484
- )
485
-
486
- cuda_steram = cuda.Stream()
487
-
488
- vae_config = stream.vae.config
489
- vae_dtype = stream.vae.dtype
490
-
491
- stream.unet = UNet2DConditionModelEngine(
492
- unet_path, cuda_steram, use_cuda_graph=False
493
- )
494
- stream.vae = AutoencoderKLEngine(
495
- vae_encoder_path,
496
- vae_decoder_path,
497
- cuda_steram,
498
- stream.pipe.vae_scale_factor,
499
- use_cuda_graph=False,
500
- )
501
- setattr(stream.vae, "config", vae_config)
502
- setattr(stream.vae, "dtype", vae_dtype)
503
-
504
- gc.collect()
505
- torch.cuda.empty_cache()
506
-
507
- print("TensorRT acceleration enabled.")
508
- if acceleration == "sfast":
509
- from streamdiffusion.acceleration.sfast import (
510
- accelerate_with_stable_fast,
511
- )
512
-
513
- stream = accelerate_with_stable_fast(stream)
514
- print("StableFast acceleration enabled.")
515
- except Exception:
516
- traceback.print_exc()
517
- print("Acceleration has failed. Falling back to normal mode.")
518
-
519
- stream.prepare(
520
- "",
521
- "",
522
- num_inference_steps=50,
523
- guidance_scale=1.1
524
- if stream.cfg_type in ["full", "self", "initialize"]
525
- else 1.0,
526
- generator=torch.manual_seed(2),
527
- )
528
-
529
- return stream