Ryukijano commited on
Commit
7c212f5
·
verified ·
1 Parent(s): af5f1ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -8
app.py CHANGED
@@ -1,14 +1,15 @@
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
  import spaces
5
- import torch
6
  import time
7
  from diffusers import DiffusionPipeline, AutoencoderTiny
8
  from diffusers.models.attention_processor import AttnProcessor2_0
9
  from custom_pipeline import FluxWithCFGPipeline
10
-
11
- torch.backends.cuda.matmul.allow_tf32 = True
12
 
13
  # Constants
14
  MAX_SEED = np.iinfo(np.int32).max
@@ -29,6 +30,11 @@ pipe.set_adapters(["better"], adapter_weights=[1.0])
29
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
30
  pipe.unload_lora_weights()
31
 
 
 
 
 
 
32
  torch.cuda.empty_cache()
33
 
34
  # Inference function
@@ -40,13 +46,68 @@ def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT,
40
 
41
  start_time = time.time()
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # Only generate the last image in the sequence
44
  img = pipe.generate_images(
45
  prompt=prompt,
46
  width=width,
47
  height=height,
48
  num_inference_steps=num_inference_steps,
49
- generator=generator
 
50
  )
51
  latency = f"Latency: {(time.time()-start_time):.2f} seconds"
52
  return img, seed, latency
@@ -138,9 +199,11 @@ with gr.Blocks() as demo:
138
  concurrency_limit=None
139
  )
140
 
141
- def realtime_generation(*args):
142
  if args[0]: # If realtime is enabled
143
- return next(generate_image(*args[1:]))
 
 
144
 
145
  prompt.submit(
146
  fn=generate_image,
@@ -158,9 +221,9 @@ with gr.Blocks() as demo:
158
  outputs=[result, seed, latency],
159
  show_progress="hidden",
160
  trigger_mode="always_last",
161
- queue=False,
162
  concurrency_limit=None
163
  )
164
 
165
  # Launch the app
166
- demo.launch()
 
1
+ import torch
2
+ torch.backends.cuda.matmul.allow_tf32 = True
3
+ torch.backends.cudnn.allow_tf32 = True
4
  import gradio as gr
5
  import numpy as np
6
  import random
7
  import spaces
 
8
  import time
9
  from diffusers import DiffusionPipeline, AutoencoderTiny
10
  from diffusers.models.attention_processor import AttnProcessor2_0
11
  from custom_pipeline import FluxWithCFGPipeline
12
+ import asyncio
 
13
 
14
  # Constants
15
  MAX_SEED = np.iinfo(np.int32).max
 
30
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
31
  pipe.unload_lora_weights()
32
 
33
+ pipe.unet.to(memory_format=torch.channels_last)
34
+ pipe.vae.to(memory_format=torch.channels_last)
35
+
36
+ pipe.enable_xformers_memory_efficient_attention()
37
+
38
  torch.cuda.empty_cache()
39
 
40
  # Inference function
 
46
 
47
  start_time = time.time()
48
 
49
+ # Initialize static inputs for CUDA graph
50
+ static_latents = torch.randn((1, 4, height // 8, width // 8), dtype=dtype, device="cuda")
51
+ static_prompt_embeds = torch.randn((2, 77, 768), dtype=dtype, device="cuda") # Adjust dimensions as needed
52
+ static_pooled_prompt_embeds = torch.randn((2, 768), dtype=dtype, device="cuda") # Adjust dimensions as needed
53
+ static_text_ids = torch.tensor([[[1, 2, 3]]], dtype=torch.int32, device="cuda")
54
+ static_latent_image_ids = torch.tensor([1], dtype=torch.int64, device="cuda")
55
+ static_timestep = torch.tensor([999], dtype=dtype, device="cuda")
56
+
57
+ # Warmup
58
+ s = torch.cuda.Stream()
59
+ s.wait_stream(torch.cuda.current_stream())
60
+ with torch.cuda.stream(s):
61
+ for _ in range(3):
62
+ _ = pipe.transformer(
63
+ hidden_states=static_latents,
64
+ timestep=static_timestep / 1000,
65
+ guidance=None,
66
+ pooled_projections=static_pooled_prompt_embeds,
67
+ encoder_hidden_states=static_prompt_embeds,
68
+ txt_ids=static_text_ids,
69
+ img_ids=static_latent_image_ids,
70
+ joint_attention_kwargs=pipe.joint_attention_kwargs,
71
+ return_dict=False,
72
+ )
73
+ torch.cuda.current_stream().wait_stream(s)
74
+
75
+ # Capture CUDA Graph
76
+ g = torch.cuda.CUDAGraph()
77
+ with torch.cuda.graph(g):
78
+ static_noise_pred = pipe.transformer(
79
+ hidden_states=static_latents,
80
+ timestep=static_timestep / 1000,
81
+ guidance=None,
82
+ pooled_projections=static_pooled_prompt_embeds,
83
+ encoder_hidden_states=static_prompt_embeds,
84
+ txt_ids=static_text_ids,
85
+ img_ids=static_latent_image_ids,
86
+ joint_attention_kwargs=pipe.joint_attention_kwargs,
87
+ return_dict=False,
88
+ )[0]
89
+ static_latents_out = pipe.scheduler.step(static_noise_pred, static_timestep, static_latents, return_dict=False)[0]
90
+ static_output = pipe._decode_latents_to_image(static_latents_out, height, width, "pil")
91
+
92
+ # Graph-based generation function
93
+ def generate_with_graph(latents, prompt_embeds, pooled_prompt_embeds, text_ids, latent_image_ids, timestep):
94
+ static_latents.copy_(latents)
95
+ static_prompt_embeds.copy_(prompt_embeds)
96
+ static_pooled_prompt_embeds.copy_(pooled_prompt_embeds)
97
+ static_text_ids.copy_(text_ids)
98
+ static_latent_image_ids.copy_(latent_image_ids)
99
+ static_timestep.copy_(timestep)
100
+ g.replay()
101
+ return static_output
102
+
103
  # Only generate the last image in the sequence
104
  img = pipe.generate_images(
105
  prompt=prompt,
106
  width=width,
107
  height=height,
108
  num_inference_steps=num_inference_steps,
109
+ generator=generator,
110
+ generate_with_graph=generate_with_graph
111
  )
112
  latency = f"Latency: {(time.time()-start_time):.2f} seconds"
113
  return img, seed, latency
 
199
  concurrency_limit=None
200
  )
201
 
202
+ async def realtime_generation(*args):
203
  if args[0]: # If realtime is enabled
204
+ loop = asyncio.get_event_loop()
205
+ result = await loop.run_in_executor(None, next, generate_image(*args[1:]))
206
+ return result
207
 
208
  prompt.submit(
209
  fn=generate_image,
 
221
  outputs=[result, seed, latency],
222
  show_progress="hidden",
223
  trigger_mode="always_last",
224
+ queue=True,
225
  concurrency_limit=None
226
  )
227
 
228
  # Launch the app
229
+ demo.launch()