RanM commited on
Commit
66e43e3
·
verified ·
1 Parent(s): 849ef9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -38
app.py CHANGED
@@ -1,77 +1,60 @@
1
  import gradio as gr
2
  from diffusers import AutoPipelineForText2Image
3
- from generate_propmts import generate_prompt
4
  from PIL import Image
5
  import asyncio
6
- import threading
7
- import traceback
8
-
9
- # Load the model once outside of the function
10
- model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
11
 
12
  class SchedulerWrapper:
13
  def __init__(self, scheduler):
14
  self.scheduler = scheduler
15
- self._step = threading.local()
16
- self._step.step = 0
 
 
 
 
 
17
 
18
- def step(self, *args, **kwargs):
19
- try:
20
- self._step.step += 1
21
- return self.scheduler.step(*args, **kwargs)
22
- except IndexError:
23
- self._step.step = 0
24
- return self.scheduler.step(*args, **kwargs)
25
-
26
- def set_timesteps(self, *args, **kwargs):
27
- return self.scheduler.set_timesteps(*args, **kwargs)
28
 
29
  # Wrap the scheduler
30
- model.scheduler = SchedulerWrapper(model.scheduler)
 
 
31
 
32
  async def generate_image(prompt):
33
  try:
34
- # Set a higher value for num_inference_steps
35
- num_inference_steps = 5 # Adjust this value as needed
36
-
37
- # Use the model to generate an image
38
  output = await asyncio.to_thread(
39
  model,
40
  prompt=prompt,
41
  num_inference_steps=num_inference_steps,
42
- guidance_scale=0.0, # Typical value for guidance scale in image generation
43
- output_type="pil" # Directly get PIL Image objects
44
  )
45
-
46
- # Check for output validity and return
47
  if output.images:
48
  return output.images[0]
49
  else:
50
  raise Exception("No images returned by the model.")
51
  except Exception as e:
52
  print(f"Error generating image: {e}")
53
- traceback.print_exc()
54
- return None # Return None on error to handle it gracefully in the UI
55
 
56
  async def inference(sentence_mapping, character_dict, selected_style):
57
  images = []
58
- print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
59
  prompts = []
60
-
61
- # Generate prompts for each paragraph
62
  for paragraph_number, sentences in sentence_mapping.items():
63
  combined_sentence = " ".join(sentences)
64
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
65
  prompts.append(prompt)
66
- print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
67
-
68
- # Use asyncio.gather to run generate_image in parallel
69
  tasks = [generate_image(prompt) for prompt in prompts]
70
  images = await asyncio.gather(*tasks)
71
-
72
- # Filter out None values
73
  images = [image for image in images if image is not None]
74
-
75
  return images
76
 
77
  gradio_interface = gr.Interface(
 
1
  import gradio as gr
2
  from diffusers import AutoPipelineForText2Image
3
+ from transformers import AutoTokenizer
4
  from PIL import Image
5
  import asyncio
 
 
 
 
 
6
 
7
  class SchedulerWrapper:
8
  def __init__(self, scheduler):
9
  self.scheduler = scheduler
10
+
11
+ def __getattr__(self, name):
12
+ return getattr(self.scheduler, name)
13
+
14
+ @property
15
+ def timesteps(self):
16
+ return self.scheduler.timesteps
17
 
18
+ def set_timesteps(self, timesteps):
19
+ self.scheduler.set_timesteps(timesteps)
20
+
21
+ # Load the model and tokenizer
22
+ tokenizer = AutoTokenizer.from_pretrained("stabilityai/sdxl-turbo")
23
+ model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
 
 
 
 
24
 
25
  # Wrap the scheduler
26
+ scheduler = model.scheduler
27
+ wrapped_scheduler = SchedulerWrapper(scheduler)
28
+ model.scheduler = wrapped_scheduler
29
 
30
  async def generate_image(prompt):
31
  try:
32
+ num_inference_steps = 5
 
 
 
33
  output = await asyncio.to_thread(
34
  model,
35
  prompt=prompt,
36
  num_inference_steps=num_inference_steps,
37
+ guidance_scale=0.0,
38
+ output_type="pil"
39
  )
 
 
40
  if output.images:
41
  return output.images[0]
42
  else:
43
  raise Exception("No images returned by the model.")
44
  except Exception as e:
45
  print(f"Error generating image: {e}")
46
+ return None
 
47
 
48
  async def inference(sentence_mapping, character_dict, selected_style):
49
  images = []
 
50
  prompts = []
 
 
51
  for paragraph_number, sentences in sentence_mapping.items():
52
  combined_sentence = " ".join(sentences)
53
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
54
  prompts.append(prompt)
 
 
 
55
  tasks = [generate_image(prompt) for prompt in prompts]
56
  images = await asyncio.gather(*tasks)
 
 
57
  images = [image for image in images if image is not None]
 
58
  return images
59
 
60
  gradio_interface = gr.Interface(