RanM commited on
Commit
081cd9c
·
verified ·
1 Parent(s): 7aae5d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -13
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  from diffusers import AutoPipelineForText2Image
3
  from generate_prompts import generate_prompt
@@ -5,10 +6,10 @@ from generate_prompts import generate_prompt
5
  # Load the model once outside of the function
6
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
7
 
8
- def generate_image(prompt, prompt_name):
9
  try:
10
  print(f"Generating image for {prompt_name}")
11
- output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
12
  image = output.images[0]
13
  img_bytes = image.tobytes()
14
  print(f"Image bytes length for {prompt_name}: {len(img_bytes)}")
@@ -17,24 +18,52 @@ def generate_image(prompt, prompt_name):
17
  print(f"Error generating image for {prompt_name}: {e}")
18
  return None
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def gradio_interface(sentence_mapping, character_dict, selected_style):
21
  prompts = generate_prompt(sentence_mapping, character_dict, selected_style)
22
- image_bytes_list = [generate_image(prompt, f"Prompt {i}") for i, prompt in enumerate(prompts)]
23
  outputs = [gr.Image.update(value=img_bytes) if img_bytes else gr.Image.update(value=None) for img_bytes in image_bytes_list]
24
  return outputs
25
 
26
  # Gradio Interface
 
 
 
 
 
27
  with gr.Blocks() as demo:
28
- with gr.Row():
29
- with gr.Column():
30
- sentence_mapping_input = gr.Textbox(label="Sentence Mapping")
31
- character_dict_input = gr.Textbox(label="Character Dictionary")
32
- selected_style_input = gr.Textbox(label="Selected Style")
33
- submit_btn = gr.Button(value='Submit')
34
- prompt_responses = [] # Empty list for dynamic addition of Image components
35
- submit_btn.click(fn=gradio_interface,
36
- inputs=[sentence_mapping_input, character_dict_input, selected_style_input],
37
- outputs=prompt_responses)
 
 
 
 
 
 
 
 
 
38
 
39
  if __name__ == "__main__":
40
  demo.launch()
 
1
+ import asyncio
2
  import gradio as gr
3
  from diffusers import AutoPipelineForText2Image
4
  from generate_prompts import generate_prompt
 
6
  # Load the model once outside of the function
7
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
8
 
9
+ async def generate_image(prompt, prompt_name):
10
  try:
11
  print(f"Generating image for {prompt_name}")
12
+ output = await model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
13
  image = output.images[0]
14
  img_bytes = image.tobytes()
15
  print(f"Image bytes length for {prompt_name}: {len(img_bytes)}")
 
18
  print(f"Error generating image for {prompt_name}: {e}")
19
  return None
20
 
21
+ async def queue_image_calls(prompts):
22
+ tasks = [generate_image(prompts[i], f"Prompt {i}") for i in range(len(prompts))]
23
+ responses = await asyncio.gather(*tasks)
24
+ return responses
25
+
26
+ def async_image_generation(prompts):
27
+ try:
28
+ loop = asyncio.get_running_loop()
29
+ except RuntimeError:
30
+ loop = asyncio.new_event_loop()
31
+ asyncio.set_event_loop(loop)
32
+ results = loop.run_until_complete(queue_image_calls(prompts))
33
+ return results
34
+
35
  def gradio_interface(sentence_mapping, character_dict, selected_style):
36
  prompts = generate_prompt(sentence_mapping, character_dict, selected_style)
37
+ image_bytes_list = async_image_generation(prompts)
38
  outputs = [gr.Image.update(value=img_bytes) if img_bytes else gr.Image.update(value=None) for img_bytes in image_bytes_list]
39
  return outputs
40
 
41
  # Gradio Interface
42
+ def update_images(sentence_mapping, character_dict, selected_style):
43
+ prompts = generate_prompt(sentence_mapping, character_dict, selected_style)
44
+ image_bytes_list = async_image_generation(prompts)
45
+ return image_bytes_list
46
+
47
  with gr.Blocks() as demo:
48
+ sentence_mapping_input = gr.Textbox(label="Sentence Mapping")
49
+ character_dict_input = gr.Textbox(label="Character Dictionary")
50
+ selected_style_input = gr.Textbox(label="Selected Style")
51
+
52
+ output_images = gr.Gallery(label="Generated Images").style(grid=[2], height=300)
53
+
54
+ def generate_and_update_images(sentence_mapping, character_dict, selected_style):
55
+ image_bytes_list = update_images(sentence_mapping, character_dict, selected_style)
56
+ return [gr.Image.update(value=img_bytes) if img_bytes else gr.Image.update(value=None) for img_bytes in image_bytes_list]
57
+
58
+ sentence_mapping_input.change(fn=generate_and_update_images,
59
+ inputs=[sentence_mapping_input, character_dict_input, selected_style_input],
60
+ outputs=output_images)
61
+ character_dict_input.change(fn=generate_and_update_images,
62
+ inputs=[sentence_mapping_input, character_dict_input, selected_style_input],
63
+ outputs=output_images)
64
+ selected_style_input.change(fn=generate_and_update_images,
65
+ inputs=[sentence_mapping_input, character_dict_input, selected_style_input],
66
+ outputs=output_images)
67
 
68
  if __name__ == "__main__":
69
  demo.launch()