RanM commited on
Commit
5c3986b
·
verified ·
1 Parent(s): 6706e0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -32
app.py CHANGED
@@ -1,56 +1,45 @@
1
  import os
2
  import asyncio
3
- import concurrent.futures
4
  from io import BytesIO
5
  from diffusers import StableDiffusionPipeline
6
  import gradio as gr
7
  from generate_prompts import generate_prompt
8
 
9
- # Initialize model globally
 
10
  model = StableDiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo")
 
11
 
12
  def generate_image(prompt, prompt_name):
13
- """
14
- Generates an image based on the provided prompt.
15
- Parameters:
16
- - prompt (str): The input text for image generation.
17
- - prompt_name (str): A name for the prompt, used for logging.
18
- Returns:
19
- bytes: The generated image data in bytes format, or None if generation fails.
20
- """
21
  try:
22
  print(f"Generating image for {prompt_name}")
23
- output = model(prompt=prompt, num_inference_steps=50, guidance_scale=7.5)
24
- if isinstance(output.images, list) and len(output.images) > 0:
 
25
  image = output.images[0]
26
  buffered = BytesIO()
27
  image.save(buffered, format="JPEG")
28
  image_bytes = buffered.getvalue()
29
  return image_bytes
30
  else:
 
31
  return None
32
  except Exception as e:
33
  print(f"An error occurred while generating image for {prompt_name}: {e}")
34
  return None
35
 
36
  async def queue_api_calls(sentence_mapping, character_dict, selected_style):
37
- """
38
- Generates images for all provided prompts in parallel using ProcessPoolExecutor.
39
- Parameters:
40
- - sentence_mapping (dict): Mapping between paragraph numbers and sentences.
41
- - character_dict (dict): Dictionary mapping characters to their descriptions.
42
- - selected_style (str): Selected illustration style.
43
- Returns:
44
- dict: A dictionary where keys are paragraph numbers and values are image data in bytes format.
45
- """
46
  prompts = []
47
  for paragraph_number, sentences in sentence_mapping.items():
48
  combined_sentence = " ".join(sentences)
49
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
50
  prompts.append((paragraph_number, prompt))
 
51
 
52
  loop = asyncio.get_running_loop()
53
- with concurrent.futures.ProcessPoolExecutor() as pool:
54
  tasks = [
55
  loop.run_in_executor(pool, generate_image, prompt, f"Prompt {paragraph_number}")
56
  for paragraph_number, prompt in prompts
@@ -58,32 +47,36 @@ async def queue_api_calls(sentence_mapping, character_dict, selected_style):
58
  responses = await asyncio.gather(*tasks)
59
 
60
  images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
 
61
  return images
62
 
63
  def process_prompt(sentence_mapping, character_dict, selected_style):
64
- """
65
- Processes the provided prompts and generates images.
66
- Parameters:
67
- - sentence_mapping (dict): Mapping between paragraph numbers and sentences.
68
- - character_dict (dict): Dictionary mapping characters to their descriptions.
69
- - selected_style (str): Selected illustration style.
70
- Returns:
71
- dict: A dictionary where keys are paragraph numbers and values are image data in bytes format.
72
- """
73
  try:
74
  loop = asyncio.get_running_loop()
 
75
  except RuntimeError:
76
  loop = asyncio.new_event_loop()
77
  asyncio.set_event_loop(loop)
 
78
 
79
  cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
 
80
  return cmpt_return
81
 
82
  gradio_interface = gr.Interface(
83
  fn=process_prompt,
84
- inputs=[gr.JSON(label="Sentence Mapping"), gr.JSON(label="Character Dict"), gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")],
 
 
 
 
85
  outputs="json"
86
  ).queue(default_concurrency_limit=20) # Set concurrency limit if needed
87
 
88
  if __name__ == "__main__":
 
89
  gradio_interface.launch()
 
1
  import os
2
  import asyncio
3
+ from concurrent.futures import ProcessPoolExecutor
4
  from io import BytesIO
5
  from diffusers import StableDiffusionPipeline
6
  import gradio as gr
7
  from generate_prompts import generate_prompt
8
 
9
+ # Load the model once at the start
10
+ print("Loading the Stable Diffusion model...")
11
  model = StableDiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo")
12
+ print("Model loaded successfully.")
13
 
14
  def generate_image(prompt, prompt_name):
 
 
 
 
 
 
 
 
15
  try:
16
  print(f"Generating image for {prompt_name}")
17
+ output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
18
+ if output and hasattr(output, 'images') and len(output.images) > 0:
19
+ print(f"Image generated for {prompt_name}")
20
  image = output.images[0]
21
  buffered = BytesIO()
22
  image.save(buffered, format="JPEG")
23
  image_bytes = buffered.getvalue()
24
  return image_bytes
25
  else:
26
+ print(f"No images found for {prompt_name}")
27
  return None
28
  except Exception as e:
29
  print(f"An error occurred while generating image for {prompt_name}: {e}")
30
  return None
31
 
32
  async def queue_api_calls(sentence_mapping, character_dict, selected_style):
33
+ print("Starting to queue API calls...")
 
 
 
 
 
 
 
 
34
  prompts = []
35
  for paragraph_number, sentences in sentence_mapping.items():
36
  combined_sentence = " ".join(sentences)
37
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
38
  prompts.append((paragraph_number, prompt))
39
+ print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
40
 
41
  loop = asyncio.get_running_loop()
42
+ with ProcessPoolExecutor() as pool:
43
  tasks = [
44
  loop.run_in_executor(pool, generate_image, prompt, f"Prompt {paragraph_number}")
45
  for paragraph_number, prompt in prompts
 
47
  responses = await asyncio.gather(*tasks)
48
 
49
  images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
50
+ print("Finished queuing API calls.")
51
  return images
52
 
53
  def process_prompt(sentence_mapping, character_dict, selected_style):
54
+ print("Processing prompt...")
55
+ print(f"Sentence Mapping: {sentence_mapping}")
56
+ print(f"Character Dict: {character_dict}")
57
+ print(f"Selected Style: {selected_style}")
 
 
 
 
 
58
  try:
59
  loop = asyncio.get_running_loop()
60
+ print("Using existing event loop.")
61
  except RuntimeError:
62
  loop = asyncio.new_event_loop()
63
  asyncio.set_event_loop(loop)
64
+ print("Created new event loop.")
65
 
66
  cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
67
+ print("Prompt processing complete.")
68
  return cmpt_return
69
 
70
  gradio_interface = gr.Interface(
71
  fn=process_prompt,
72
+ inputs=[
73
+ gr.JSON(label="Sentence Mapping"),
74
+ gr.JSON(label="Character Dict"),
75
+ gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
76
+ ],
77
  outputs="json"
78
  ).queue(default_concurrency_limit=20) # Set concurrency limit if needed
79
 
80
  if __name__ == "__main__":
81
+ print("Launching Gradio interface...")
82
  gradio_interface.launch()