Ryukijano commited on
Commit
16c45c8
·
verified ·
1 Parent(s): b2b60ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -68
app.py CHANGED
@@ -67,74 +67,17 @@ def generate_image(
67
 
68
  start_time = time.time()
69
 
70
- # Initialize static inputs for CUDA graph
71
- static_latents = torch.randn(
72
- (1, 4, height // 8, width // 8), dtype=dtype, device="cuda"
 
 
 
 
 
 
 
73
  )
74
- static_prompt_embeds = torch.randn(
75
- (2, 77, 768), dtype=dtype, device="cuda"
76
- ) # Adjust dimensions as needed
77
- static_pooled_prompt_embeds = torch.randn(
78
- (2, 768), dtype=dtype, device="cuda"
79
- ) # Adjust dimensions as needed
80
- static_text_ids = torch.tensor([[[1, 2, 3]]], dtype=torch.int32, device="cuda")
81
- static_latent_image_ids = torch.tensor([1], dtype=torch.int64, device="cuda")
82
- static_timestep = torch.tensor([999], dtype=dtype, device="cuda")
83
-
84
- # Warmup
85
- s = torch.cuda.Stream()
86
- s.wait_stream(torch.cuda.current_stream())
87
- with torch.cuda.stream(s):
88
- for _ in range(3):
89
- _ = pipe.transformer(
90
- hidden_states=static_latents,
91
- timestep=static_timestep / 1000,
92
- guidance=None,
93
- pooled_projections=static_pooled_prompt_embeds,
94
- encoder_hidden_states=static_prompt_embeds,
95
- txt_ids=static_text_ids,
96
- img_ids=static_latent_image_ids,
97
- return_dict=False,
98
- )
99
- torch.cuda.current_stream().wait_stream(s)
100
-
101
- # Capture CUDA Graph
102
- g = torch.cuda.CUDAGraph()
103
- with torch.cuda.graph(g):
104
- static_noise_pred = pipe.transformer(
105
- hidden_states=static_latents,
106
- timestep=static_timestep / 1000,
107
- guidance=None,
108
- pooled_projections=static_pooled_prompt_embeds,
109
- encoder_hidden_states=static_prompt_embeds,
110
- txt_ids=static_text_ids,
111
- img_ids=static_latent_image_ids,
112
- return_dict=False,
113
- )[0]
114
- static_latents_out = pipe.scheduler.step(
115
- static_noise_pred, static_timestep, static_latents, return_dict=False
116
- )[0]
117
- static_output = pipe._decode_latents_to_image(
118
- static_latents_out, height, width, "pil"
119
- )
120
-
121
- # Graph-based generation function
122
- def generate_with_graph(
123
- latents,
124
- prompt_embeds,
125
- pooled_prompt_embeds,
126
- text_ids,
127
- latent_image_ids,
128
- timestep,
129
- ):
130
- static_latents.copy_(latents)
131
- static_prompt_embeds.copy_(prompt_embeds)
132
- static_pooled_prompt_embeds.copy_(pooled_prompt_embeds)
133
- static_text_ids.copy_(text_ids)
134
- static_latent_image_ids.copy_(latent_image_ids)
135
- static_timestep.copy_(timestep)
136
- g.replay()
137
- return static_output
138
 
139
  # Only generate the last image in the sequence
140
  img = pipe.generate_images(
@@ -143,7 +86,9 @@ def generate_image(
143
  height=height,
144
  num_inference_steps=num_inference_steps,
145
  generator=generator,
146
- generate_with_graph=generate_with_graph,
 
 
147
  )
148
  latency = f"Latency: {(time.time()-start_time):.2f} seconds"
149
  return img, seed, latency
 
67
 
68
  start_time = time.time()
69
 
70
+ # Dynamically determine shapes based on input width/height
71
+ latents_shape = (1, 4, height // 8, width // 8)
72
+ prompt_embeds_shape = (
73
+ 1,
74
+ pipe.transformer.text_encoder.config.max_position_embeddings,
75
+ pipe.transformer.text_encoder.config.hidden_size,
76
+ )
77
+ pooled_prompt_embeds_shape = (
78
+ 1,
79
+ pipe.transformer.text_encoder.config.hidden_size,
80
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # Only generate the last image in the sequence
83
  img = pipe.generate_images(
 
86
  height=height,
87
  num_inference_steps=num_inference_steps,
88
  generator=generator,
89
+ latents_shape=latents_shape,
90
+ prompt_embeds_shape=prompt_embeds_shape,
91
+ pooled_prompt_embeds_shape=pooled_prompt_embeds_shape
92
  )
93
  latency = f"Latency: {(time.time()-start_time):.2f} seconds"
94
  return img, seed, latency