RanM commited on
Commit
690f094
·
verified ·
1 Parent(s): e4c2663

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -48
app.py CHANGED
@@ -1,7 +1,8 @@
1
- import asyncio
2
  import gradio as gr
3
  from diffusers import AutoPipelineForText2Image
4
- from generate_prompts import generate_prompt
 
 
5
 
6
  # Load the model once outside of the function
7
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
@@ -9,61 +10,69 @@ model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
9
  async def generate_image(prompt, prompt_name):
10
  try:
11
  print(f"Generating image for {prompt_name}")
12
- output = await model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
13
- image = output.images[0]
14
- img_bytes = image.tobytes()
15
- print(f"Image bytes length for {prompt_name}: {len(img_bytes)}")
16
- return img_bytes
 
 
 
 
 
 
 
 
 
 
 
17
  except Exception as e:
18
  print(f"Error generating image for {prompt_name}: {e}")
19
  return None
20
 
21
- async def queue_image_calls(prompts):
22
- tasks = [generate_image(prompts[i], f"Prompt {i}") for i in range(len(prompts))]
23
- responses = await asyncio.gather(*tasks)
24
- return responses
25
 
26
- def async_image_generation(prompts):
27
- try:
28
- loop = asyncio.get_running_loop()
29
- except RuntimeError:
30
- loop = asyncio.new_event_loop()
31
- asyncio.set_event_loop(loop)
32
- results = loop.run_until_complete(queue_image_calls(prompts))
33
- return results
34
 
35
- def gradio_interface(sentence_mapping, character_dict, selected_style):
36
- prompts = generate_prompt(sentence_mapping, character_dict, selected_style)
37
- image_bytes_list = async_image_generation(prompts)
38
- outputs = [gr.Image.update(value=img_bytes) if img_bytes else gr.Image.update(value=None) for img_bytes in image_bytes_list]
39
- return outputs
40
 
41
- # Gradio Interface
42
- def update_images(sentence_mapping, character_dict, selected_style):
43
- prompts = generate_prompt(sentence_mapping, character_dict, selected_style)
44
- image_bytes_list = async_image_generation(prompts)
45
- return image_bytes_list
 
46
 
47
- with gr.Blocks() as demo:
48
- sentence_mapping_input = gr.Textbox(label="Sentence Mapping")
49
- character_dict_input = gr.Textbox(label="Character Dictionary")
50
- selected_style_input = gr.Textbox(label="Selected Style")
51
-
52
- output_images = gr.Gallery(label="Generated Images")
53
 
54
- def generate_and_update_images(sentence_mapping, character_dict, selected_style):
55
- image_bytes_list = update_images(sentence_mapping, character_dict, selected_style)
56
- return [gr.Image.update(value=img_bytes) if img_bytes else gr.Image.update(value=None) for img_bytes in image_bytes_list]
 
57
 
58
- sentence_mapping_input.change(fn=generate_and_update_images,
59
- inputs=[sentence_mapping_input, character_dict_input, selected_style_input],
60
- outputs=output_images)
61
- character_dict_input.change(fn=generate_and_update_images,
62
- inputs=[sentence_mapping_input, character_dict_input, selected_style_input],
63
- outputs=output_images)
64
- selected_style_input.change(fn=generate_and_update_images,
65
- inputs=[sentence_mapping_input, character_dict_input, selected_style_input],
66
- outputs=output_images)
 
 
67
 
68
  if __name__ == "__main__":
69
- demo.launch()
 
 
1
  import gradio as gr
2
  from diffusers import AutoPipelineForText2Image
3
+ from io import BytesIO
4
+ import asyncio
5
+ from generate_propmts import generate_prompt
6
 
7
  # Load the model once outside of the function
8
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
 
10
  async def generate_image(prompt, prompt_name):
11
  try:
12
  print(f"Generating image for {prompt_name}")
13
+ output = await asyncio.to_thread(model, prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
14
+
15
+ # Check if the model returned images
16
+ if isinstance(output.images, list) and len(output.images) > 0:
17
+ image = output.images[0]
18
+ buffered = BytesIO()
19
+ try:
20
+ image.save(buffered, format="JPEG")
21
+ image_bytes = buffered.getvalue()
22
+ print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
23
+ return image_bytes
24
+ except Exception as e:
25
+ print(f"Error saving image for {prompt_name}: {e}")
26
+ return None
27
+ else:
28
+ raise Exception(f"No images returned by the model for {prompt_name}.")
29
  except Exception as e:
30
  print(f"Error generating image for {prompt_name}: {e}")
31
  return None
32
 
33
+ async def process_prompt(sentence_mapping, character_dict, selected_style):
34
+ images = {}
35
+ print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
36
+ prompts = []
37
 
38
+ # Generate prompts for each paragraph
39
+ for paragraph_number, sentences in sentence_mapping.items():
40
+ combined_sentence = " ".join(sentences)
41
+ prompt = generate_prompt(combined_sentence, character_dict, selected_style)
42
+ prompts.append((paragraph_number, prompt))
43
+ print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
 
 
44
 
45
+ print(f'prompts: {prompts}')
46
+ # Create tasks for all prompts and run them concurrently
47
+ tasks = [generate_image(prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
48
+ print(f'tasks: {tasks}')
49
+ results = await asyncio.gather(*tasks)
50
 
51
+ # Map results back to paragraphs
52
+ for i, (paragraph_number, _) in enumerate(prompts):
53
+ if i < len(results):
54
+ images[paragraph_number] = results[i]
55
+ else:
56
+ print(f"Error: No result for paragraph {paragraph_number}")
57
 
58
+ return images
 
 
 
 
 
59
 
60
+ # Helper function to generate a prompt based on the input
61
+ def generate_prompt(combined_sentence, character_dict, selected_style):
62
+ characters = " ".join([" ".join(character) if isinstance(character, list) else character for character in character_dict.values()])
63
+ return f"Make an illustration in {selected_style} style from: {characters}. {combined_sentence}"
64
 
65
+ # Gradio interface with high concurrency limit
66
+ gradio_interface = gr.Interface(
67
+ fn=process_prompt,
68
+ inputs=[
69
+ gr.JSON(label="Sentence Mapping"),
70
+ gr.JSON(label="Character Dict"),
71
+ gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
72
+ ],
73
+ outputs="json",
74
+ concurrency_limit=20 # Set a high concurrency limit
75
+ ).queue(default_concurrency_limit=20)
76
 
77
  if __name__ == "__main__":
78
+ gradio_interface.launch()