RanM commited on
Commit
f45116f
·
verified ·
1 Parent(s): 6450ed3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -9
app.py CHANGED
@@ -1,20 +1,43 @@
1
  import gradio as gr
2
- import torch
3
  from diffusers import AutoPipelineForText2Image
4
  from io import BytesIO
5
  from generate_propmts import generate_prompt
6
- import asyncio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Load the model once outside of the function
9
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
10
 
 
 
 
11
  async def generate_image(prompt):
12
  try:
 
 
 
 
13
  # Generate an image based on the prompt
14
  output = await asyncio.to_thread(model, prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
15
- print(f"Model output: {output}")
16
-
17
- # Check if the model returned images
18
  if isinstance(output.images, list) and len(output.images) > 0:
19
  image = output.images[0]
20
  buffered = BytesIO()
@@ -23,13 +46,15 @@ async def generate_image(prompt):
23
  return image_bytes
24
  else:
25
  raise Exception("No images returned by the model.")
 
 
 
26
  except Exception as e:
27
  print(f"Error generating image: {e}")
28
  return None
29
 
30
  async def process_prompt(sentence_mapping, character_dict, selected_style):
31
  images = {}
32
- print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
33
  prompts = []
34
 
35
  # Generate prompts for each paragraph
@@ -37,7 +62,6 @@ async def process_prompt(sentence_mapping, character_dict, selected_style):
37
  combined_sentence = " ".join(sentences)
38
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
39
  prompts.append((paragraph_number, prompt))
40
- print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
41
 
42
  # Create tasks for all prompts and run them concurrently
43
  tasks = [generate_image(prompt) for _, prompt in prompts]
@@ -45,11 +69,12 @@ async def process_prompt(sentence_mapping, character_dict, selected_style):
45
 
46
  # Map results back to paragraphs
47
  for i, (paragraph_number, _) in enumerate(prompts):
48
- images[paragraph_number] = results[i]
 
49
 
50
  return images
51
 
52
- # Gradio interface with high concurrency limit
53
  gradio_interface = gr.Interface(
54
  fn=process_prompt,
55
  inputs=[
 
1
  import gradio as gr
2
+ import asyncio
3
  from diffusers import AutoPipelineForText2Image
4
  from io import BytesIO
5
  from generate_propmts import generate_prompt
6
+ import threading
7
+
8
+ # Define the Scheduler class
9
+ class Scheduler:
10
+ def __init__(self):
11
+ self._step = threading.local()
12
+ self._step.step = None
13
+
14
+ def _init_step_index(self):
15
+ self._step.step = 0
16
+
17
+ @property
18
+ def step(self):
19
+ return self._step.step
20
+
21
+ def step_process(self):
22
+ if self._step.step is None:
23
+ self._init_step_index()
24
+ self._step.step += 1
25
 
26
  # Load the model once outside of the function
27
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
28
 
29
+ # Create a Scheduler instance
30
+ scheduler = Scheduler()
31
+
32
  async def generate_image(prompt):
33
  try:
34
+ # Update the scheduler step
35
+ scheduler.step_process()
36
+ print(f"Current step: {scheduler.step}")
37
+
38
  # Generate an image based on the prompt
39
  output = await asyncio.to_thread(model, prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
40
+
 
 
41
  if isinstance(output.images, list) and len(output.images) > 0:
42
  image = output.images[0]
43
  buffered = BytesIO()
 
46
  return image_bytes
47
  else:
48
  raise Exception("No images returned by the model.")
49
+ except IndexError as e:
50
+ print(f"IndexError: {e}")
51
+ return None
52
  except Exception as e:
53
  print(f"Error generating image: {e}")
54
  return None
55
 
56
  async def process_prompt(sentence_mapping, character_dict, selected_style):
57
  images = {}
 
58
  prompts = []
59
 
60
  # Generate prompts for each paragraph
 
62
  combined_sentence = " ".join(sentences)
63
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
64
  prompts.append((paragraph_number, prompt))
 
65
 
66
  # Create tasks for all prompts and run them concurrently
67
  tasks = [generate_image(prompt) for _, prompt in prompts]
 
69
 
70
  # Map results back to paragraphs
71
  for i, (paragraph_number, _) in enumerate(prompts):
72
+ if i < len(results):
73
+ images[paragraph_number] = results[i]
74
 
75
  return images
76
 
77
+ # Define Gradio interface with high concurrency limit
78
  gradio_interface = gr.Interface(
79
  fn=process_prompt,
80
  inputs=[