RanM commited on
Commit
c14304d
·
verified ·
1 Parent(s): 5c6cd7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -5,17 +5,16 @@ from diffusers import AutoPipelineForText2Image
5
  from io import BytesIO
6
  import gradio as gr
7
  from concurrent.futures import ProcessPoolExecutor
8
- import multiprocessing
9
 
10
- # Load the model once outside of the function and share across processes
11
  print("Loading the model...")
12
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
13
  print("Model loaded successfully.")
14
 
15
- def generate_image(prompt, prompt_name, num_inference_steps=1):
16
  try:
17
  print(f"Generating response for {prompt_name} with prompt: {prompt}")
18
- output = model(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=0.0)
19
  print(f"Output for {prompt_name}: {output}")
20
 
21
  # Check if the model returned images
@@ -44,12 +43,12 @@ async def queue_api_calls(sentence_mapping, character_dict, selected_style):
44
  for paragraph_number, sentences in sentence_mapping.items():
45
  combined_sentence = " ".join(sentences)
46
  print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
47
- prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
48
  prompts.append((paragraph_number, prompt))
49
  print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
50
 
51
- # Set max_workers to the minimum of the number of prompts and available CPU cores
52
- max_workers = min(len(prompts), multiprocessing.cpu_count())
53
 
54
  # Generate images for each prompt in parallel using multiprocessing
55
  with ProcessPoolExecutor(max_workers=max_workers) as executor:
@@ -66,12 +65,15 @@ async def queue_api_calls(sentence_mapping, character_dict, selected_style):
66
  def process_prompt(sentence_mapping, character_dict, selected_style):
67
  print(f"process_prompt called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
68
  try:
 
69
  loop = asyncio.get_running_loop()
70
  except RuntimeError:
 
71
  loop = asyncio.new_event_loop()
72
  asyncio.set_event_loop(loop)
73
  print("Event loop created.")
74
 
 
75
  cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
76
  print(f"process_prompt completed with return value: {cmpt_return}")
77
  return cmpt_return
@@ -89,5 +91,5 @@ gradio_interface = gr.Interface(
89
 
90
  if __name__ == "__main__":
91
  print("Launching Gradio interface...")
92
- gradio_interface.launch(server_name="0.0.0.0", server_port=7860, debug=True)
93
  print("Gradio interface launched.")
 
5
  from io import BytesIO
6
  import gradio as gr
7
  from concurrent.futures import ProcessPoolExecutor
 
8
 
9
+ # Load the model once outside of the function
10
  print("Loading the model...")
11
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
12
  print("Model loaded successfully.")
13
 
14
+ def generate_image(prompt, prompt_name):
15
  try:
16
  print(f"Generating response for {prompt_name} with prompt: {prompt}")
17
+ output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
18
  print(f"Output for {prompt_name}: {output}")
19
 
20
  # Check if the model returned images
 
43
  for paragraph_number, sentences in sentence_mapping.items():
44
  combined_sentence = " ".join(sentences)
45
  print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
46
+ prompt = generate_prompt(combined_sentence, character_dict, selected_style)
47
  prompts.append((paragraph_number, prompt))
48
  print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
49
 
50
+ # Set max_workers to the total number of prompts
51
+ max_workers = len(prompts)
52
 
53
  # Generate images for each prompt in parallel using multiprocessing
54
  with ProcessPoolExecutor(max_workers=max_workers) as executor:
 
65
  def process_prompt(sentence_mapping, character_dict, selected_style):
66
  print(f"process_prompt called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
67
  try:
68
+ # See if there is a loop already running. If there is, reuse it.
69
  loop = asyncio.get_running_loop()
70
  except RuntimeError:
71
+ # Create new event loop if one is not running
72
  loop = asyncio.new_event_loop()
73
  asyncio.set_event_loop(loop)
74
  print("Event loop created.")
75
 
76
+ # This sends the prompts to function that sets up the async calls. Once all the calls to the API complete, it returns a list of the gr.Textbox with value= set.
77
  cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
78
  print(f"process_prompt completed with return value: {cmpt_return}")
79
  return cmpt_return
 
91
 
92
  if __name__ == "__main__":
93
  print("Launching Gradio interface...")
94
+ gradio_interface.launch()
95
  print("Gradio interface launched.")