RanM commited on
Commit
eb48f29
·
verified ·
1 Parent(s): 9cd3a95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -46
app.py CHANGED
@@ -1,50 +1,40 @@
1
  import os
2
  import asyncio
3
- import time
4
- from generate_prompts import generate_prompt
5
- from diffusers import AutoPipelineForText2Image
6
  from io import BytesIO
 
7
  import gradio as gr
8
- import ray
9
-
10
- ray.init()
11
 
12
- @ray.remote
13
- class ModelActor:
14
- def __init__(self):
15
- """
16
- Initializes the ModelActor class and loads the text-to-image model.
17
- """
18
- self.model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
19
 
20
- async def generate_image(self, prompt, prompt_name):
21
- """
22
- Generates an image based on the provided prompt.
23
- Parameters:
24
- - prompt (str): The input text for image generation.
25
- - prompt_name (str): A name for the prompt, used for logging.
26
- Returns:
27
- bytes: The generated image data in bytes format, or None if generation fails.
28
- """
29
- start_time = time.time()
30
- process_id = os.getpid()
31
- try:
32
- output = await self.model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
33
- if isinstance(output.images, list) and len(output.images) > 0:
34
- image = output.images[0]
35
- buffered = BytesIO()
36
- image.save(buffered, format="JPEG")
37
- image_bytes = buffered.getvalue()
38
- end_time = time.time()
39
- return image_bytes
40
- else:
41
- return None
42
- except Exception as e:
43
  return None
 
 
 
44
 
45
  async def queue_api_calls(sentence_mapping, character_dict, selected_style):
46
  """
47
- Generates images for all provided prompts in parallel using Ray actors.
48
  Parameters:
49
  - sentence_mapping (dict): Mapping between paragraph numbers and sentences.
50
  - character_dict (dict): Dictionary mapping characters to their descriptions.
@@ -58,14 +48,8 @@ async def queue_api_calls(sentence_mapping, character_dict, selected_style):
58
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
59
  prompts.append((paragraph_number, prompt))
60
 
61
- num_prompts = len(prompts)
62
- num_actors = min(num_prompts, 20)
63
- model_actors = [ModelActor.remote() for _ in range(num_actors)]
64
- tasks = [model_actors[i % num_actors].generate_image.remote(prompt, f"Prompt {paragraph_number}") for i, (paragraph_number, prompt) in enumerate(prompts)]
65
-
66
- # Convert ray.get(task) to awaitable coroutines
67
- async_tasks = [asyncio.wrap_future(ray.get(task)) for task in tasks]
68
- responses = await asyncio.gather(*async_tasks)
69
 
70
  images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
71
  return images
@@ -93,7 +77,7 @@ gradio_interface = gr.Interface(
93
  fn=process_prompt,
94
  inputs=[gr.JSON(label="Sentence Mapping"), gr.JSON(label="Character Dict"), gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")],
95
  outputs="json"
96
- ).queue(default_concurrency_limit=20) # Set concurrency limit to match the number of model actors
97
 
98
  if __name__ == "__main__":
99
  gradio_interface.launch()
 
1
  import os
2
  import asyncio
 
 
 
3
  from io import BytesIO
4
+ from diffusers import AutoPipelineForText2Image
5
  import gradio as gr
6
+ from generate_prompts import generate_prompt
 
 
7
 
8
+ # Initialize model
9
+ model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
 
 
 
 
 
10
 
11
+ async def generate_image(prompt, prompt_name):
12
+ """
13
+ Generates an image based on the provided prompt.
14
+ Parameters:
15
+ - prompt (str): The input text for image generation.
16
+ - prompt_name (str): A name for the prompt, used for logging.
17
+ Returns:
18
+ bytes: The generated image data in bytes format, or None if generation fails.
19
+ """
20
+ try:
21
+ print(f"Generating image for {prompt_name}")
22
+ output = await model(prompt=prompt, num_inference_steps=50, guidance_scale=7.5)
23
+ if isinstance(output.images, list) and len(output.images) > 0:
24
+ image = output.images[0]
25
+ buffered = BytesIO()
26
+ image.save(buffered, format="JPEG")
27
+ image_bytes = buffered.getvalue()
28
+ return image_bytes
29
+ else:
 
 
 
 
30
  return None
31
+ except Exception as e:
32
+ print(f"An error occurred while generating image for {prompt_name}: {e}")
33
+ return None
34
 
35
  async def queue_api_calls(sentence_mapping, character_dict, selected_style):
36
  """
37
+ Generates images for all provided prompts in parallel using asyncio.
38
  Parameters:
39
  - sentence_mapping (dict): Mapping between paragraph numbers and sentences.
40
  - character_dict (dict): Dictionary mapping characters to their descriptions.
 
48
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
49
  prompts.append((paragraph_number, prompt))
50
 
51
+ tasks = [generate_image(prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
52
+ responses = await asyncio.gather(*tasks)
 
 
 
 
 
 
53
 
54
  images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
55
  return images
 
77
  fn=process_prompt,
78
  inputs=[gr.JSON(label="Sentence Mapping"), gr.JSON(label="Character Dict"), gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")],
79
  outputs="json"
80
+ ).queue(default_concurrency_limit=20) # Set concurrency limit if needed
81
 
82
  if __name__ == "__main__":
83
  gradio_interface.launch()