RanM commited on
Commit
bc0d978
·
verified ·
1 Parent(s): f607069

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -63
app.py CHANGED
@@ -1,76 +1,55 @@
1
- import gradio as gr
2
- from diffusers import AutoPipelineForText2Image
3
- from io import BytesIO
4
  import asyncio
5
- from generate_propmts import generate_prompt
 
 
 
6
 
7
- # Load the model once outside of the function
8
- model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
9
 
10
  async def generate_image(prompt, prompt_name):
11
  try:
12
  print(f"Generating image for {prompt_name}")
13
- output = await asyncio.to_thread(model, prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
14
-
15
- # Check if the model returned images
16
- if isinstance(output.images, list) and len(output.images) > 0:
17
- image = output.images[0]
18
- buffered = BytesIO()
19
- try:
20
- image.save(buffered, format="JPEG")
21
- image_bytes = buffered.getvalue()
22
- print(f"Image bytes length for {prompt_name}: {len(image_bytes)}")
23
- return image_bytes
24
- except Exception as e:
25
- print(f"Error saving image for {prompt_name}: {e}")
26
- return None
27
- else:
28
- raise Exception(f"No images returned by the model for {prompt_name}.")
29
  except Exception as e:
30
  print(f"Error generating image for {prompt_name}: {e}")
31
  return None
32
 
33
- async def process_prompt(sentence_mapping, character_dict, selected_style):
34
- images = {}
35
- print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
36
- prompts = []
37
-
38
- # Generate prompts for each paragraph
39
- for paragraph_number, sentences in sentence_mapping.items():
40
- combined_sentence = " ".join(sentences)
41
- prompt = generate_prompt(combined_sentence, character_dict, selected_style)
42
- prompts.append((paragraph_number, prompt))
43
- print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
44
-
45
- # Create tasks for all prompts and run them concurrently
46
- tasks = [generate_image(prompt, f"Prompt {paragraph_number}") for paragraph_number, prompt in prompts]
47
- results = await asyncio.gather(*tasks)
48
 
49
- # Map results back to paragraphs
50
- for i, (paragraph_number, _) in enumerate(prompts):
51
- if i < len(results):
52
- images[paragraph_number] = results[i]
53
- else:
54
- print(f"Error: No result for paragraph {paragraph_number}")
55
-
56
- return images
57
-
58
- # Helper function to generate a prompt based on the input
59
- def generate_prompt(combined_sentence, character_dict, selected_style):
60
- characters = " ".join([" ".join(character) if isinstance(character, list) else character for character in character_dict.values()])
61
- return f"Make an illustration in {selected_style} style from: {characters}. {combined_sentence}"
62
-
63
- # Gradio interface with high concurrency limit
64
- gradio_interface = gr.Interface(
65
- fn=process_prompt,
66
- inputs=[
67
- gr.JSON(label="Sentence Mapping"),
68
- gr.JSON(label="Character Dict"),
69
- gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
70
- ],
71
- outputs="json",
72
- concurrency_limit=20 # Set a high concurrency limit
73
- ).queue(default_concurrency_limit=20)
 
 
74
 
75
  if __name__ == "__main__":
76
- gradio_interface.launch()
 
 
 
 
1
  import asyncio
2
+ import json
3
+ import gradio as gr
4
+ from diffusers import StableDiffusionPipeline
5
+ from generate_prompts import generate_prompt
6
 
7
+ # Load the model pipeline
8
+ pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo").to("cuda")
9
 
10
  async def generate_image(prompt, prompt_name):
11
  try:
12
  print(f"Generating image for {prompt_name}")
13
+ image = await pipeline(prompt).images[0]
14
+ img_bytes = image.tobytes()
15
+ print(f"Image bytes length for {prompt_name}: {len(img_bytes)}")
16
+ return img_bytes
 
 
 
 
 
 
 
 
 
 
 
 
17
  except Exception as e:
18
  print(f"Error generating image for {prompt_name}: {e}")
19
  return None
20
 
21
+ async def queue_image_calls(prompts):
22
+ tasks = [generate_image(prompts[i], f"Prompt {i}") for i in range(len(prompts))]
23
+ responses = await asyncio.gather(*tasks)
24
+ return responses
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ def async_image_generation(prompts):
27
+ try:
28
+ loop = asyncio.get_running_loop()
29
+ except RuntimeError:
30
+ loop = asyncio.new_event_loop()
31
+ asyncio.set_event_loop(loop)
32
+ results = loop.run_until_complete(queue_image_calls(prompts))
33
+ return results
34
+
35
+ def gradio_interface(sentence_mapping, character_dict, selected_style):
36
+ prompts = generate_prompt(sentence_mapping, character_dict, selected_style)
37
+ image_bytes_list = async_image_generation(prompts)
38
+ outputs = [gr.Image.update(value=img_bytes) if img_bytes else gr.Image.update(value=None) for img_bytes in image_bytes_list]
39
+ return outputs
40
+
41
+ # Gradio Interface
42
+ with gr.Blocks() as demo:
43
+ with gr.Row():
44
+ with gr.Column():
45
+ sentence_mapping_input = gr.Textbox(label="Sentence Mapping")
46
+ character_dict_input = gr.Textbox(label="Character Dictionary")
47
+ selected_style_input = gr.Textbox(label="Selected Style")
48
+ submit_btn = gr.Button(value='Submit')
49
+ prompt_responses = [gr.Image(label=f"Prompt {i} Response") for i in range(4)]
50
+ submit_btn.click(fn=gradio_interface,
51
+ inputs=[sentence_mapping_input, character_dict_input, selected_style_input],
52
+ outputs=prompt_responses)
53
 
54
  if __name__ == "__main__":
55
+ demo.launch()