RanM commited on
Commit
28413d5
·
verified ·
1 Parent(s): d253f4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -39
app.py CHANGED
@@ -12,86 +12,91 @@ ray.init()
12
  @ray.remote
13
  class ModelActor:
14
  def __init__(self):
15
- print("Loading the model...")
 
 
16
  self.model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
17
- print("Model loaded successfully.")
18
 
19
- def generate_image(self, prompt, prompt_name):
 
 
 
 
 
 
 
 
 
 
20
  start_time = time.time()
21
  process_id = os.getpid()
22
  try:
23
- print(f"[{process_id}] Generating response for {prompt_name} with prompt: {prompt}")
24
- output = self.model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
25
- print(f"[{process_id}] Output for {prompt_name}: {output}")
26
-
27
  if isinstance(output.images, list) and len(output.images) > 0:
28
  image = output.images[0]
29
  buffered = BytesIO()
30
- try:
31
- image.save(buffered, format="JPEG")
32
- image_bytes = buffered.getvalue()
33
- end_time = time.time()
34
- print(f"[{process_id}] Image bytes length for {prompt_name}: {len(image_bytes)}")
35
- print(f"[{process_id}] Time taken for {prompt_name}: {end_time - start_time} seconds")
36
- return image_bytes
37
- except Exception as e:
38
- print(f"[{process_id}] Error saving image for {prompt_name}: {e}")
39
- return None
40
  else:
41
- raise Exception(f"[{process_id}] No images returned by the model for {prompt_name}.")
42
  except Exception as e:
43
- print(f"[{process_id}] Error generating image for {prompt_name}: {e}")
44
  return None
45
 
46
  async def queue_api_calls(sentence_mapping, character_dict, selected_style):
47
- print(f"queue_api_calls invoked with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
48
- prompts = []
 
 
 
 
 
49
 
 
 
 
 
50
  for paragraph_number, sentences in sentence_mapping.items():
51
  combined_sentence = " ".join(sentences)
52
- print(f"combined_sentence for paragraph {paragraph_number}: {combined_sentence}")
53
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
54
  prompts.append((paragraph_number, prompt))
55
- print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
56
 
57
  num_prompts = len(prompts)
58
- num_actors = min(num_prompts, 20) # Limit to a maximum of 20 actors
59
  model_actors = [ModelActor.remote() for _ in range(num_actors)]
60
-
61
  tasks = [model_actors[i % num_actors].generate_image.remote(prompt, f"Prompt {paragraph_number}") for i, (paragraph_number, prompt) in enumerate(prompts)]
62
- print("Tasks created for image generation.")
63
 
64
  responses = await asyncio.gather(*[asyncio.to_thread(ray.get, task) for task in tasks])
65
- print("Responses received from image generation tasks.")
66
-
67
  images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
68
- print(f"Images generated: {images}")
69
  return images
70
 
71
  def process_prompt(sentence_mapping, character_dict, selected_style):
72
- print(f"process_prompt called with sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}")
 
 
 
 
 
 
 
 
 
 
73
  try:
74
  loop = asyncio.get_running_loop()
75
  except RuntimeError:
76
  loop = asyncio.new_event_loop()
77
  asyncio.set_event_loop(loop)
78
- print("Event loop created.")
79
 
80
  cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
81
- print(f"process_prompt completed with return value: {cmpt_return}")
82
  return cmpt_return
83
 
84
  gradio_interface = gr.Interface(
85
  fn=process_prompt,
86
- inputs=[
87
- gr.JSON(label="Sentence Mapping"),
88
- gr.JSON(label="Character Dict"),
89
- gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
90
- ],
91
  outputs="json"
92
  )
93
 
94
  if __name__ == "__main__":
95
- print("Launching Gradio interface...")
96
  gradio_interface.launch()
97
- print("Gradio interface launched.")
 
12
  @ray.remote
13
  class ModelActor:
14
  def __init__(self):
15
+ """
16
+ Initializes the ModelActor class and loads the text-to-image model.
17
+ """
18
  self.model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
 
19
 
20
+ async def generate_image(self, prompt, prompt_name):
21
+ """
22
+ Generates an image based on the provided prompt.
23
+
24
+ Parameters:
25
+ - prompt (str): The input text for image generation.
26
+ - prompt_name (str): A name for the prompt, used for logging.
27
+
28
+ Returns:
29
+ bytes: The generated image data in bytes format, or None if generation fails.
30
+ """
31
  start_time = time.time()
32
  process_id = os.getpid()
33
  try:
34
+ output = await self.model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
 
 
 
35
  if isinstance(output.images, list) and len(output.images) > 0:
36
  image = output.images[0]
37
  buffered = BytesIO()
38
+ image.save(buffered, format="JPEG")
39
+ image_bytes = buffered.getvalue()
40
+ end_time = time.time()
41
+ return image_bytes
 
 
 
 
 
 
42
  else:
43
+ return None
44
  except Exception as e:
 
45
  return None
46
 
47
  async def queue_api_calls(sentence_mapping, character_dict, selected_style):
48
+ """
49
+ Generates images for all provided prompts in parallel using Ray actors.
50
+
51
+ Parameters:
52
+ - sentence_mapping (dict): Mapping between paragraph numbers and sentences.
53
+ - character_dict (dict): Dictionary mapping characters to their descriptions.
54
+ - selected_style (str): Selected illustration style.
55
 
56
+ Returns:
57
+ dict: A dictionary where keys are paragraph numbers and values are image data in bytes format.
58
+ """
59
+ prompts = []
60
  for paragraph_number, sentences in sentence_mapping.items():
61
  combined_sentence = " ".join(sentences)
 
62
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
63
  prompts.append((paragraph_number, prompt))
 
64
 
65
  num_prompts = len(prompts)
66
+ num_actors = min(num_prompts, 20)
67
  model_actors = [ModelActor.remote() for _ in range(num_actors)]
 
68
  tasks = [model_actors[i % num_actors].generate_image.remote(prompt, f"Prompt {paragraph_number}") for i, (paragraph_number, prompt) in enumerate(prompts)]
 
69
 
70
  responses = await asyncio.gather(*[asyncio.to_thread(ray.get, task) for task in tasks])
 
 
71
  images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
 
72
  return images
73
 
74
  def process_prompt(sentence_mapping, character_dict, selected_style):
75
+ """
76
+ Processes the provided prompts and generates images.
77
+
78
+ Parameters:
79
+ - sentence_mapping (dict): Mapping between paragraph numbers and sentences.
80
+ - character_dict (dict): Dictionary mapping characters to their descriptions.
81
+ - selected_style (str): Selected illustration style.
82
+
83
+ Returns:
84
+ dict: A dictionary where keys are paragraph numbers and values are image data in bytes format.
85
+ """
86
  try:
87
  loop = asyncio.get_running_loop()
88
  except RuntimeError:
89
  loop = asyncio.new_event_loop()
90
  asyncio.set_event_loop(loop)
 
91
 
92
  cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
 
93
  return cmpt_return
94
 
95
  gradio_interface = gr.Interface(
96
  fn=process_prompt,
97
+ inputs=[gr.JSON(label="Sentence Mapping"), gr.JSON(label="Character Dict"), gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")],
 
 
 
 
98
  outputs="json"
99
  )
100
 
101
  if __name__ == "__main__":
 
102
  gradio_interface.launch()