Ryukijano commited on
Commit
89fffa1
·
verified ·
1 Parent(s): 70eea75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -53,7 +53,7 @@ torch.cuda.empty_cache()
53
 
54
  # Inference function
55
  @spaces.GPU(duration=25)
56
- def generate_image(
57
  prompt,
58
  seed=24,
59
  width=DEFAULT_WIDTH,
@@ -95,7 +95,6 @@ def generate_image(
95
  encoder_hidden_states=static_prompt_embeds,
96
  txt_ids=static_text_ids,
97
  img_ids=static_latent_image_ids,
98
- joint_attention_kwargs=pipe.joint_attention_kwargs,
99
  return_dict=False,
100
  )
101
  torch.cuda.current_stream().wait_stream(s)
@@ -111,7 +110,6 @@ def generate_image(
111
  encoder_hidden_states=static_prompt_embeds,
112
  txt_ids=static_text_ids,
113
  img_ids=static_latent_image_ids,
114
- joint_attention_kwargs=pipe.joint_attention_kwargs,
115
  return_dict=False,
116
  )[0]
117
  static_latents_out = pipe.scheduler.step(
@@ -122,7 +120,7 @@ def generate_image(
122
  )
123
 
124
  # Graph-based generation function
125
- def generate_with_graph(
126
  latents,
127
  prompt_embeds,
128
  pooled_prompt_embeds,
@@ -140,7 +138,7 @@ def generate_image(
140
  return static_output
141
 
142
  # Only generate the last image in the sequence
143
- img = pipe.generate_images(
144
  prompt=prompt,
145
  width=width,
146
  height=height,
@@ -251,7 +249,7 @@ with gr.Blocks() as demo:
251
  outputs=[result, seed, latency],
252
  show_progress="full",
253
  api_name="RealtimeFlux",
254
- queue=False,
255
  )
256
 
257
  def update_ui(realtime_enabled):
@@ -269,10 +267,9 @@ with gr.Blocks() as demo:
269
  )
270
 
271
  async def realtime_generation(*args):
 
272
  if args[0]: # If realtime is enabled
273
- loop = asyncio.get_event_loop()
274
- result = await loop.run_in_executor(None, next, generate_image(*args[1:]))
275
- return result
276
 
277
  prompt.submit(
278
  fn=generate_image,
 
53
 
54
  # Inference function
55
  @spaces.GPU(duration=25)
56
+ async def generate_image(
57
  prompt,
58
  seed=24,
59
  width=DEFAULT_WIDTH,
 
95
  encoder_hidden_states=static_prompt_embeds,
96
  txt_ids=static_text_ids,
97
  img_ids=static_latent_image_ids,
 
98
  return_dict=False,
99
  )
100
  torch.cuda.current_stream().wait_stream(s)
 
110
  encoder_hidden_states=static_prompt_embeds,
111
  txt_ids=static_text_ids,
112
  img_ids=static_latent_image_ids,
 
113
  return_dict=False,
114
  )[0]
115
  static_latents_out = pipe.scheduler.step(
 
120
  )
121
 
122
  # Graph-based generation function
123
+ async def generate_with_graph(
124
  latents,
125
  prompt_embeds,
126
  pooled_prompt_embeds,
 
138
  return static_output
139
 
140
  # Only generate the last image in the sequence
141
+ img = await pipe.generate_images(
142
  prompt=prompt,
143
  width=width,
144
  height=height,
 
249
  outputs=[result, seed, latency],
250
  show_progress="full",
251
  api_name="RealtimeFlux",
252
+ queue=False
253
  )
254
 
255
  def update_ui(realtime_enabled):
 
267
  )
268
 
269
  async def realtime_generation(*args):
270
+ print("realtime_generation")
271
  if args[0]: # If realtime is enabled
272
+ return await generate_image(*args[1:])
 
 
273
 
274
  prompt.submit(
275
  fn=generate_image,