RanM commited on
Commit
5e2c7ed
·
verified ·
1 Parent(s): 9da79fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -24
app.py CHANGED
@@ -1,16 +1,24 @@
1
- import gradio as gr
2
- from diffusers import AutoPipelineForText2Image
3
- from io import BytesIO
4
  import asyncio
5
  from generate_prompts import generate_prompt
 
 
 
 
6
 
7
  # Load the model once outside of the function
8
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
9
 
 
 
 
 
 
 
10
  async def generate_image(prompt, prompt_name):
11
  try:
12
- print(f"Generating image for {prompt_name}")
13
- output = await asyncio.to_thread(model, prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
14
 
15
  # Check if the model returned images
16
  if isinstance(output.images, list) and len(output.images) > 0:
@@ -30,24 +38,24 @@ async def generate_image(prompt, prompt_name):
30
  print(f"Error generating image for {prompt_name}: {e}")
31
  return None
32
 
33
- async def process_prompt(sentence_mapping, character_dict, selected_style):
34
- images = {}
35
  print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
36
  prompts = []
37
-
38
  # Generate prompts for each paragraph
39
  for paragraph_number, sentences in sentence_mapping.items():
40
  combined_sentence = " ".join(sentences)
41
  prompt = generate_prompt(combined_sentence, character_dict, selected_style)
42
  prompts.append((paragraph_number, prompt))
43
  print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
44
-
45
- print(f'prompts: {prompts}')
46
- # Create tasks for all prompts and run them concurrently
47
  tasks = [generate_image(prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
48
- print(f'tasks: {tasks}')
49
- results = await asyncio.gather(*tasks)
50
-
 
 
 
51
  # Map results back to paragraphs
52
  for i, (paragraph_number, _) in enumerate(prompts):
53
  if i < len(results):
@@ -57,22 +65,28 @@ async def process_prompt(sentence_mapping, character_dict, selected_style):
57
 
58
  return images
59
 
60
- # Helper function to generate a prompt based on the input
61
- def generate_prompt(combined_sentence, character_dict, selected_style):
62
- characters = " ".join([" ".join(character) if isinstance(character, list) else character for character in character_dict.values()])
63
- return f"Make an illustration in {selected_style} style from: {characters}. {combined_sentence}"
 
 
 
 
 
 
 
 
64
 
65
  # Gradio interface with high concurrency limit
66
  gradio_interface = gr.Interface(
67
- fn=process_prompt,
68
  inputs=[
69
  gr.JSON(label="Sentence Mapping"),
70
  gr.JSON(label="Character Dict"),
71
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
72
  ],
73
- outputs="json",
74
- concurrency_limit=20 # Set a high concurrency limit
75
- ).queue(default_concurrency_limit=20)
76
-
77
  if __name__ == "__main__":
78
- gradio_interface.launch()
 
1
+ import os
 
 
2
  import asyncio
3
  from generate_prompts import generate_prompt
4
+ from diffusers import AutoPipelineForText2Image
5
+ from io import BytesIO
6
+ import json
7
+ import gradio as gr
8
 
9
  # Load the model once outside of the function
10
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
11
 
12
+
13
+ prompt1 = "write a 5 paragraph explanation of how to use python async and await. Return a JSON structure as follows {'prompt_name': 'prompt1','response': '[response]'}"
14
+ prompt2 = "write a 5 paragraph explanation of limitations for using asyncio.run(). Return a JSON structure as follows {'prompt_name': 'prompt2','response': '[response}'}"
15
+ prompt3 = "write a 5 paragraph explanation of how to use asyncio.get_running_loop(). Return a JSON structure as follows {'prompt_name': 'prompt3','response': '[response]'}"
16
+ prompt4 = "write a 5 paragraph explanation of how to use asyncio.gather(). Return a JSON structure as follows {'prompt_name': 'prompt4','response': '[response]'}"
17
+
18
  async def generate_image(prompt, prompt_name):
19
  try:
20
+ print(f"Generating response for {prompt_name}")
21
+ output = await model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
22
 
23
  # Check if the model returned images
24
  if isinstance(output.images, list) and len(output.images) > 0:
 
38
  print(f"Error generating image for {prompt_name}: {e}")
39
  return None
40
 
41
+ async def queue_api_calls(sentence_mapping, character_dict, selected_style):
 
42
  print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
43
  prompts = []
44
+
45
  # Generate prompts for each paragraph
46
  for paragraph_number, sentences in sentence_mapping.items():
47
  combined_sentence = " ".join(sentences)
48
  prompt = generate_prompt(combined_sentence, character_dict, selected_style)
49
  prompts.append((paragraph_number, prompt))
50
  print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
51
+
 
 
52
  tasks = [generate_image(prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
53
+ responses = await asyncio.gather(generate_image(*task))
54
+
55
+ #Note: Although the API calls get processed in async order, asyncio.gather and returns them in the request order
56
+ images = {}
57
+
58
+ # Iterate through each response
59
  # Map results back to paragraphs
60
  for i, (paragraph_number, _) in enumerate(prompts):
61
  if i < len(results):
 
65
 
66
  return images
67
 
68
+ def process_prompt(sentence_mapping, character_dict, selected_style):
69
+ try:
70
+ #see if there is a loop already running. If there is, reuse it.
71
+ loop = asyncio.get_running_loop()
72
+ except RuntimeError:
73
+ # Create new event loop if one is not running
74
+ loop = asyncio.new_event_loop()
75
+ asyncio.set_event_loop(loop)
76
+
77
+ #this sends the prompts to function that sets up the async calls. Once all the calls to the API complete, it returns a list of the gr.Textbox with value= set.
78
+ cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
79
+ return cmpt_return
80
 
81
  # Gradio interface with high concurrency limit
82
  gradio_interface = gr.Interface(
83
+ fn=process_prompt,
84
  inputs=[
85
  gr.JSON(label="Sentence Mapping"),
86
  gr.JSON(label="Character Dict"),
87
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
88
  ],
89
+ outputs="json")
90
+
 
 
91
  if __name__ == "__main__":
92
+ demo.launch()