RanM commited on
Commit
9872917
·
verified ·
1 Parent(s): f45116f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -33
app.py CHANGED
@@ -1,43 +1,19 @@
1
  import gradio as gr
2
- import asyncio
3
  from diffusers import AutoPipelineForText2Image
4
  from io import BytesIO
5
- from generate_propmts import generate_prompt
6
- import threading
7
-
8
- # Define the Scheduler class
9
- class Scheduler:
10
- def __init__(self):
11
- self._step = threading.local()
12
- self._step.step = None
13
-
14
- def _init_step_index(self):
15
- self._step.step = 0
16
-
17
- @property
18
- def step(self):
19
- return self._step.step
20
-
21
- def step_process(self):
22
- if self._step.step is None:
23
- self._init_step_index()
24
- self._step.step += 1
25
 
26
  # Load the model once outside of the function
27
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
28
 
29
- # Create a Scheduler instance
30
- scheduler = Scheduler()
31
-
32
  async def generate_image(prompt):
33
  try:
34
- # Update the scheduler step
35
- scheduler.step_process()
36
- print(f"Current step: {scheduler.step}")
37
-
38
  # Generate an image based on the prompt
39
  output = await asyncio.to_thread(model, prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
40
-
 
 
41
  if isinstance(output.images, list) and len(output.images) > 0:
42
  image = output.images[0]
43
  buffered = BytesIO()
@@ -46,15 +22,13 @@ async def generate_image(prompt):
46
  return image_bytes
47
  else:
48
  raise Exception("No images returned by the model.")
49
- except IndexError as e:
50
- print(f"IndexError: {e}")
51
- return None
52
  except Exception as e:
53
  print(f"Error generating image: {e}")
54
  return None
55
 
56
  async def process_prompt(sentence_mapping, character_dict, selected_style):
57
  images = {}
 
58
  prompts = []
59
 
60
  # Generate prompts for each paragraph
@@ -62,6 +36,7 @@ async def process_prompt(sentence_mapping, character_dict, selected_style):
62
  combined_sentence = " ".join(sentences)
63
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
64
  prompts.append((paragraph_number, prompt))
 
65
 
66
  # Create tasks for all prompts and run them concurrently
67
  tasks = [generate_image(prompt) for _, prompt in prompts]
@@ -74,7 +49,7 @@ async def process_prompt(sentence_mapping, character_dict, selected_style):
74
 
75
  return images
76
 
77
- # Define Gradio interface with high concurrency limit
78
  gradio_interface = gr.Interface(
79
  fn=process_prompt,
80
  inputs=[
 
1
  import gradio as gr
2
+ import torch
3
  from diffusers import AutoPipelineForText2Image
4
  from io import BytesIO
5
+ import asyncio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Load the model once outside of the function
8
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
9
 
 
 
 
10
  async def generate_image(prompt):
11
  try:
 
 
 
 
12
  # Generate an image based on the prompt
13
  output = await asyncio.to_thread(model, prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
14
+ print(f"Model output: {output}")
15
+
16
+ # Check if the model returned images
17
  if isinstance(output.images, list) and len(output.images) > 0:
18
  image = output.images[0]
19
  buffered = BytesIO()
 
22
  return image_bytes
23
  else:
24
  raise Exception("No images returned by the model.")
 
 
 
25
  except Exception as e:
26
  print(f"Error generating image: {e}")
27
  return None
28
 
29
  async def process_prompt(sentence_mapping, character_dict, selected_style):
30
  images = {}
31
+ print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
32
  prompts = []
33
 
34
  # Generate prompts for each paragraph
 
36
  combined_sentence = " ".join(sentences)
37
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
38
  prompts.append((paragraph_number, prompt))
39
+ print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
40
 
41
  # Create tasks for all prompts and run them concurrently
42
  tasks = [generate_image(prompt) for _, prompt in prompts]
 
49
 
50
  return images
51
 
52
+ # Gradio interface with high concurrency limit
53
  gradio_interface = gr.Interface(
54
  fn=process_prompt,
55
  inputs=[