RanM commited on
Commit
b0bcf89
·
verified ·
1 Parent(s): 66e43e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -15
app.py CHANGED
@@ -1,62 +1,88 @@
1
  import gradio as gr
2
  from diffusers import AutoPipelineForText2Image
3
- from transformers import AutoTokenizer
 
4
  from PIL import Image
5
  import asyncio
 
 
6
 
 
7
  class SchedulerWrapper:
8
  def __init__(self, scheduler):
9
  self.scheduler = scheduler
10
-
11
- def __getattr__(self, name):
12
- return getattr(self.scheduler, name)
13
-
 
 
 
 
 
 
 
14
  @property
15
  def timesteps(self):
16
  return self.scheduler.timesteps
17
 
18
- def set_timesteps(self, timesteps):
19
- self.scheduler.set_timesteps(timesteps)
20
 
21
- # Load the model and tokenizer
22
- tokenizer = AutoTokenizer.from_pretrained("stabilityai/sdxl-turbo")
23
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
24
 
25
- # Wrap the scheduler
26
- scheduler = model.scheduler
27
  wrapped_scheduler = SchedulerWrapper(scheduler)
28
  model.scheduler = wrapped_scheduler
29
 
 
30
  async def generate_image(prompt):
31
  try:
32
- num_inference_steps = 5
 
 
33
  output = await asyncio.to_thread(
34
  model,
35
  prompt=prompt,
36
  num_inference_steps=num_inference_steps,
37
- guidance_scale=0.0,
38
- output_type="pil"
39
  )
 
 
40
  if output.images:
41
  return output.images[0]
42
  else:
43
  raise Exception("No images returned by the model.")
44
  except Exception as e:
45
  print(f"Error generating image: {e}")
46
- return None
 
47
 
 
48
  async def inference(sentence_mapping, character_dict, selected_style):
49
  images = []
 
50
  prompts = []
 
 
51
  for paragraph_number, sentences in sentence_mapping.items():
52
  combined_sentence = " ".join(sentences)
53
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
54
  prompts.append(prompt)
 
 
 
55
  tasks = [generate_image(prompt) for prompt in prompts]
56
  images = await asyncio.gather(*tasks)
 
 
57
  images = [image for image in images if image is not None]
 
58
  return images
59
 
 
60
  gradio_interface = gr.Interface(
61
  fn=inference,
62
  inputs=[
@@ -67,5 +93,6 @@ gradio_interface = gr.Interface(
67
  outputs=gr.Gallery(label="Generated Images")
68
  )
69
 
 
70
  if __name__ == "__main__":
71
  gradio_interface.launch()
 
1
  import gradio as gr
2
  from diffusers import AutoPipelineForText2Image
3
+ from diffusers.schedulers import DPMSolverMultistepScheduler
4
+ from generate_propmts import generate_prompt # Assuming you have this module
5
  from PIL import Image
6
  import asyncio
7
+ import threading
8
+ import traceback
9
 
10
+ # Define the SchedulerWrapper class
11
  class SchedulerWrapper:
12
  def __init__(self, scheduler):
13
  self.scheduler = scheduler
14
+ self._step = threading.local()
15
+ self._step.step = 0
16
+
17
+ def step(self, *args, **kwargs):
18
+ try:
19
+ self._step.step += 1
20
+ return self.scheduler.step(*args, **kwargs)
21
+ except IndexError:
22
+ self._step.step = 0
23
+ return self.scheduler.step(*args, **kwargs)
24
+
25
  @property
26
  def timesteps(self):
27
  return self.scheduler.timesteps
28
 
29
+ def set_timesteps(self, *args, **kwargs):
30
+ return self.scheduler.set_timesteps(*args, **kwargs)
31
 
32
+ # Load the model and wrap the scheduler
 
33
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
34
 
35
+ scheduler = DPMSolverMultistepScheduler.from_config(model.scheduler.config)
 
36
  wrapped_scheduler = SchedulerWrapper(scheduler)
37
  model.scheduler = wrapped_scheduler
38
 
39
+ # Define the image generation function
40
  async def generate_image(prompt):
41
  try:
42
+ num_inference_steps = 5 # Adjust this value as needed
43
+
44
+ # Use the model to generate an image
45
  output = await asyncio.to_thread(
46
  model,
47
  prompt=prompt,
48
  num_inference_steps=num_inference_steps,
49
+ guidance_scale=0.0, # Typical value for guidance scale in image generation
50
+ output_type="pil" # Directly get PIL Image objects
51
  )
52
+
53
+ # Check for output validity and return
54
  if output.images:
55
  return output.images[0]
56
  else:
57
  raise Exception("No images returned by the model.")
58
  except Exception as e:
59
  print(f"Error generating image: {e}")
60
+ traceback.print_exc()
61
+ return None # Return None on error to handle it gracefully in the UI
62
 
63
+ # Define the inference function
64
  async def inference(sentence_mapping, character_dict, selected_style):
65
  images = []
66
+ print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
67
  prompts = []
68
+
69
+ # Generate prompts for each paragraph
70
  for paragraph_number, sentences in sentence_mapping.items():
71
  combined_sentence = " ".join(sentences)
72
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
73
  prompts.append(prompt)
74
+ print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
75
+
76
+ # Use asyncio.gather to run generate_image in parallel
77
  tasks = [generate_image(prompt) for prompt in prompts]
78
  images = await asyncio.gather(*tasks)
79
+
80
+ # Filter out None values
81
  images = [image for image in images if image is not None]
82
+
83
  return images
84
 
85
+ # Define the Gradio interface
86
  gradio_interface = gr.Interface(
87
  fn=inference,
88
  inputs=[
 
93
  outputs=gr.Gallery(label="Generated Images")
94
  )
95
 
96
+ # Run the Gradio app
97
  if __name__ == "__main__":
98
  gradio_interface.launch()