RanM commited on
Commit
789e6b5
·
verified ·
1 Parent(s): cbd1935

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -66
app.py CHANGED
@@ -1,71 +1,36 @@
1
  import gradio as gr
 
2
  from diffusers import AutoPipelineForText2Image
3
- from diffusers.schedulers import DPMSolverMultistepScheduler
4
- from generate_propmts import generate_prompt # Assuming you have this module
5
- from PIL import Image
6
  import asyncio
7
- import threading
8
- import traceback
9
 
10
- # Define the SchedulerWrapper class
11
- class SchedulerWrapper:
12
- def __init__(self, scheduler):
13
- self.scheduler = scheduler
14
- self._step = threading.local()
15
- self._step.step = 0
16
-
17
- def __getattr__(self, name):
18
- return getattr(self.scheduler, name)
19
-
20
- def step(self, *args, **kwargs):
21
- try:
22
- self._step.step += 1
23
- return self.scheduler.step(*args, **kwargs)
24
- except IndexError:
25
- self._step.step = 0
26
- return self.scheduler.step(*args, **kwargs)
27
-
28
- @property
29
- def timesteps(self):
30
- return self.scheduler.timesteps
31
-
32
- def set_timesteps(self, *args, **kwargs):
33
- return self.scheduler.set_timesteps(*args, **kwargs)
34
-
35
- # Load the model and wrap the scheduler
36
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
37
 
38
- scheduler = DPMSolverMultistepScheduler.from_config(model.scheduler.config)
39
- wrapped_scheduler = SchedulerWrapper(scheduler)
40
- model.scheduler = wrapped_scheduler
41
-
42
- # Define the image generation function
43
  async def generate_image(prompt):
44
  try:
45
- num_inference_steps = 5 # Adjust this value as needed
46
-
47
- # Use the model to generate an image
48
- output = await asyncio.to_thread(
49
- model,
50
- prompt=prompt,
51
- num_inference_steps=num_inference_steps,
52
- guidance_scale=0.0, # Typical value for guidance scale in image generation
53
- output_type="pil" # Directly get PIL Image objects
54
- )
55
-
56
- # Check for output validity and return
57
- if output.images:
58
- return output.images[0]
59
  else:
60
  raise Exception("No images returned by the model.")
 
61
  except Exception as e:
62
  print(f"Error generating image: {e}")
63
- traceback.print_exc()
64
- return None # Return None on error to handle it gracefully in the UI
65
 
66
- # Define the inference function
67
- async def inference(sentence_mapping, character_dict, selected_style):
68
- images = []
69
  print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
70
  prompts = []
71
 
@@ -73,29 +38,36 @@ async def inference(sentence_mapping, character_dict, selected_style):
73
  for paragraph_number, sentences in sentence_mapping.items():
74
  combined_sentence = " ".join(sentences)
75
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
76
- prompts.append(prompt)
77
  print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
78
 
79
- # Use asyncio.gather to run generate_image in parallel
80
- tasks = [generate_image(prompt) for prompt in prompts]
81
- images = await asyncio.gather(*tasks)
82
-
83
- # Filter out None values
84
- images = [image for image in images if image is not None]
 
 
 
 
 
 
 
 
85
 
86
  return images
87
 
88
- # Define the Gradio interface
89
  gradio_interface = gr.Interface(
90
- fn=inference,
91
  inputs=[
92
  gr.JSON(label="Sentence Mapping"),
93
  gr.JSON(label="Character Dict"),
94
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
95
  ],
96
- outputs=gr.Gallery(label="Generated Images")
 
97
  )
98
 
99
- # Run the Gradio app
100
  if __name__ == "__main__":
101
  gradio_interface.launch()
 
1
  import gradio as gr
2
+ import torch
3
  from diffusers import AutoPipelineForText2Image
4
+ from io import BytesIO
5
+ from generate_propmts import generate_prompt
6
+ from concurrent.futures import ThreadPoolExecutor
7
  import asyncio
 
 
8
 
9
+ # Load the model once outside of the function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
11
 
 
 
 
 
 
12
  async def generate_image(prompt):
13
  try:
14
+ # Truncate prompt if necessary
15
+ output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
16
+ print(f"Model output: {output}")
17
+
18
+ # Check if the model returned images
19
+ if isinstance(output.images, list) and len(output.images) > 0:
20
+ image = output.images[0]
21
+ buffered = BytesIO()
22
+ image.save(buffered, format="JPEG")
23
+ image_bytes = buffered.getvalue()
24
+ return image_bytes
 
 
 
25
  else:
26
  raise Exception("No images returned by the model.")
27
+
28
  except Exception as e:
29
  print(f"Error generating image: {e}")
30
+ return None
 
31
 
32
+ async def process_prompt(sentence_mapping, character_dict, selected_style):
33
+ images = {}
 
34
  print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
35
  prompts = []
36
 
 
38
  for paragraph_number, sentences in sentence_mapping.items():
39
  combined_sentence = " ".join(sentences)
40
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
41
+ prompts.append((paragraph_number, prompt))
42
  print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
43
 
44
+ loop = asyncio.get_event_loop()
45
+ tasks = []
46
+
47
+ with ThreadPoolExecutor() as executor:
48
+ for paragraph_number, prompt in prompts:
49
+ tasks.append(loop.run_in_executor(executor, generate_image, prompt))
50
+
51
+ for paragraph_number, task in zip(sentence_mapping.keys(), await asyncio.gather(*tasks)):
52
+ try:
53
+ image = task
54
+ if image:
55
+ images[paragraph_number] = image
56
+ except Exception as e:
57
+ print(f"Error processing paragraph {paragraph_number}: {e}")
58
 
59
  return images
60
 
 
61
  gradio_interface = gr.Interface(
62
+ fn=process_prompt,
63
  inputs=[
64
  gr.JSON(label="Sentence Mapping"),
65
  gr.JSON(label="Character Dict"),
66
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
67
  ],
68
+ outputs="json",
69
+ concurrency_limit=10 # Allow up to 10 concurrent executions
70
  )
71
 
 
72
  if __name__ == "__main__":
73
  gradio_interface.launch()