RanM commited on
Commit
f6b8b7e
·
verified ·
1 Parent(s): e3fc553

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -50
app.py CHANGED
@@ -1,66 +1,28 @@
1
- import threading
2
  from diffusers import AutoPipelineForText2Image
 
3
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
  from PIL import Image
5
  import traceback
6
 
7
- class Scheduler:
8
- def __init__(self):
9
- self._step = threading.local()
10
- self._step.step = None
11
-
12
- def _init_step_index(self):
13
- self._step.step = 0
14
-
15
- @property
16
- def step_index(self):
17
- if self._step.step is None:
18
- self._init_step_index()
19
- return self._step.step
20
-
21
- @step_index.setter
22
- def step_index(self, value):
23
- self._step.step = value
24
-
25
- def step_process(self, noise_pred, t, latents, **extra_step_kwargs):
26
- try:
27
- sigma_to = self.sigmas[self.step_index + 1]
28
- self.step_index += 1
29
- # Process the step (pseudocode)
30
- # latents = process_latents(noise_pred, t, latents, sigma_to, **extra_step_kwargs)
31
- return latents
32
- except IndexError as e:
33
- print(f"Index error during step processing: {e}")
34
- traceback.print_exc()
35
- return latents
36
-
37
- # Mocking a model class for demonstration purposes
38
- class MockModel:
39
- def __init__(self):
40
- self.scheduler = Scheduler()
41
-
42
- def __call__(self, prompt, num_inference_steps, guidance_scale):
43
- # Simulate the inference steps
44
- latents = None
45
- for t in range(num_inference_steps):
46
- noise_pred = None # Replace with actual noise prediction
47
- latents = self.scheduler.step_process(noise_pred, t, latents)
48
- return {"images": [Image.new("RGB", (512, 512))]} # Return a dummy image for now
49
-
50
- # Load the actual model
51
- model = MockModel()
52
 
53
  def generate_image(prompt):
54
  try:
55
- output = model(prompt=prompt, num_inference_steps=3, guidance_scale=0.0)
56
  print(f"Model output: {output}")
57
 
58
  # Check if the model returned images
59
- if isinstance(output['images'], list) and len(output['images']) > 0:
60
- return output['images'][0]
61
  else:
62
  raise Exception("No images returned by the model.")
63
 
 
 
 
 
64
  except Exception as e:
65
  print(f"Error generating image: {e}")
66
  traceback.print_exc()
@@ -71,6 +33,7 @@ def inference(sentence_mapping, character_dict, selected_style):
71
  print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
72
  prompts = []
73
 
 
74
  for paragraph_number, sentences in sentence_mapping.items():
75
  combined_sentence = " ".join(sentences)
76
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
@@ -99,7 +62,7 @@ gradio_interface = gr.Interface(
99
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
100
  ],
101
  outputs=gr.Gallery(label="Generated Images")
102
- )
103
 
104
  if __name__ == "__main__":
105
  gradio_interface.launch()
 
1
+ import gradio as gr
2
  from diffusers import AutoPipelineForText2Image
3
+ from generate_propmts import generate_prompt
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
  from PIL import Image
6
  import traceback
7
 
8
+ # Load the model once outside of the function
9
+ model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def generate_image(prompt):
12
  try:
13
+ output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
14
  print(f"Model output: {output}")
15
 
16
  # Check if the model returned images
17
+ if isinstance(output.images, list) and len(output.images) > 0:
18
+ return output.images[0]
19
  else:
20
  raise Exception("No images returned by the model.")
21
 
22
+ except IndexError as e:
23
+ print(f"Index error during image generation: {e}")
24
+ traceback.print_exc()
25
+ return None
26
  except Exception as e:
27
  print(f"Error generating image: {e}")
28
  traceback.print_exc()
 
33
  print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
34
  prompts = []
35
 
36
+ # Generate prompts for each paragraph
37
  for paragraph_number, sentences in sentence_mapping.items():
38
  combined_sentence = " ".join(sentences)
39
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
 
62
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
63
  ],
64
  outputs=gr.Gallery(label="Generated Images")
65
+ .queue(default_concurrency_limit=5)
66
 
67
  if __name__ == "__main__":
68
  gradio_interface.launch()