RanM commited on
Commit
c513221
·
verified ·
1 Parent(s): e647062

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -34
app.py CHANGED
@@ -1,23 +1,19 @@
1
- import asyncio
 
2
  from generate_prompts import generate_prompt
3
  from diffusers import AutoPipelineForText2Image
4
  from io import BytesIO
5
  import gradio as gr
6
- import anyio
7
 
8
- # Asynchronously load the model once outside of the function
9
- model = None
10
-
11
- async def load_model():
12
  global model
13
  print("Loading the model...")
14
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
15
  print("Model loaded successfully.")
16
 
17
- # Run the model loading
18
- asyncio.run(load_model())
19
-
20
- async def generate_image(prompt, prompt_name):
21
  try:
22
  print(f"Generating response for {prompt_name} with prompt: {prompt}")
23
  output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
@@ -31,50 +27,40 @@ async def generate_image(prompt, prompt_name):
31
  image.save(buffered, format="JPEG")
32
  image_bytes = buffered.getvalue()
33
  print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
34
- return image_bytes
35
  except Exception as e:
36
  print(f"Error saving image for {prompt_name}: {e}")
37
- return None
38
  else:
39
  raise Exception(f"No images returned by the model for {prompt_name}.")
40
  except Exception as e:
41
  print(f"Error generating image for {prompt_name}: {e}")
42
- return None
43
 
44
- async def queue_api_calls(sentence_mapping, character_dict, selected_style):
45
- print(f"queue_api_calls invoked with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
 
46
  prompts = []
47
-
48
- # Generate prompts for each paragraph
49
  for paragraph_number, sentences in sentence_mapping.items():
50
  combined_sentence = " ".join(sentences)
51
  print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
52
- prompt = generate_prompt(combined_sentence, character_dict, selected_style) # Correct prompt generation
53
  prompts.append((paragraph_number, prompt))
54
  print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
55
 
56
- # Generate images for each prompt in parallel
57
- tasks = [generate_image(prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
58
- print("Tasks created for image generation.")
59
- responses = await asyncio.gather(*tasks)
60
- print("Responses received from image generation tasks.")
61
 
62
- images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
63
  print(f"Images generated: {images}")
64
  return images
65
 
66
  def process_prompt(sentence_mapping, character_dict, selected_style):
67
- print(f"process_prompt called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
68
-
69
- async def run_async():
70
- async with anyio.create_task_group() as task_group:
71
- return await task_group.start(queue_api_calls, sentence_mapping, character_dict, selected_style)
72
-
73
- cmpt_return = anyio.run(run_async)
74
- print(f"process_prompt completed with return value: {cmpt_return}")
75
- return cmpt_return
76
 
77
- # Gradio interface with high concurrency limit
78
  gradio_interface = gr.Interface(
79
  fn=process_prompt,
80
  inputs=[
 
1
+ import os
2
+ import multiprocessing
3
  from generate_prompts import generate_prompt
4
  from diffusers import AutoPipelineForText2Image
5
  from io import BytesIO
6
  import gradio as gr
7
+ import json
8
 
9
+ # Define a function to initialize the model. This will be called in each process.
10
+ def initialize_model():
 
 
11
  global model
12
  print("Loading the model...")
13
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
14
  print("Model loaded successfully.")
15
 
16
+ def generate_image(prompt, prompt_name):
 
 
 
17
  try:
18
  print(f"Generating response for {prompt_name} with prompt: {prompt}")
19
  output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
 
27
  image.save(buffered, format="JPEG")
28
  image_bytes = buffered.getvalue()
29
  print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
30
+ return prompt_name, image_bytes
31
  except Exception as e:
32
  print(f"Error saving image for {prompt_name}: {e}")
33
+ return prompt_name, None
34
  else:
35
  raise Exception(f"No images returned by the model for {prompt_name}.")
36
  except Exception as e:
37
  print(f"Error generating image for {prompt_name}: {e}")
38
+ return prompt_name, None
39
 
40
+ def process_prompts(sentence_mapping, character_dict, selected_style):
41
+ print(f"process_prompts called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
42
+
43
  prompts = []
 
 
44
  for paragraph_number, sentences in sentence_mapping.items():
45
  combined_sentence = " ".join(sentences)
46
  print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
47
+ prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
48
  prompts.append((paragraph_number, prompt))
49
  print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
50
 
51
+ with multiprocessing.Pool(initializer=initialize_model) as pool:
52
+ tasks = [(prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
53
+ results = pool.starmap(generate_image, tasks)
 
 
54
 
55
+ images = {prompt_name: image for prompt_name, image in results}
56
  print(f"Images generated: {images}")
57
  return images
58
 
59
  def process_prompt(sentence_mapping, character_dict, selected_style):
60
+ sentence_mapping = json.loads(sentence_mapping)
61
+ character_dict = json.loads(character_dict)
62
+ return process_prompts(sentence_mapping, character_dict, selected_style)
 
 
 
 
 
 
63
 
 
64
  gradio_interface = gr.Interface(
65
  fn=process_prompt,
66
  inputs=[