RanM commited on
Commit
6b1b953
·
verified ·
1 Parent(s): 53490b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -2,28 +2,29 @@ import os
2
  import asyncio
3
  from concurrent.futures import ProcessPoolExecutor
4
  from io import BytesIO
 
5
  from diffusers import StableDiffusionPipeline
6
  import gradio as gr
7
  from generate_prompts import generate_prompt
8
 
9
  # Load the model once at the start
10
  print("Loading the Stable Diffusion model...")
11
- model = StableDiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo")
12
- print("Model loaded successfully.")
13
-
14
- def truncate_prompt(prompt, max_length=77):
15
- tokens = prompt.split()
16
- if len(tokens) > max_length:
17
- prompt = " ".join(tokens[:max_length])
18
- return prompt
19
 
20
  def generate_image(prompt, prompt_name):
21
  try:
 
 
 
22
  print(f"Generating image for {prompt_name} with prompt: {prompt}")
23
- truncated_prompt = truncate_prompt(prompt)
24
- output = model(prompt=truncated_prompt, num_inference_steps=1, guidance_scale=0.0)
25
  print(f"Model output for {prompt_name}: {output}")
26
-
27
  if output and hasattr(output, 'images') and output.images:
28
  print(f"Image generated for {prompt_name}")
29
  image = output.images[0]
@@ -40,6 +41,7 @@ def generate_image(prompt, prompt_name):
40
 
41
  async def queue_api_calls(sentence_mapping, character_dict, selected_style):
42
  print("Starting to queue API calls...")
 
43
  prompts = []
44
  for paragraph_number, sentences in sentence_mapping.items():
45
  combined_sentence = " ".join(sentences)
 
2
  import asyncio
3
  from concurrent.futures import ProcessPoolExecutor
4
  from io import BytesIO
5
+ from PIL import Image
6
  from diffusers import StableDiffusionPipeline
7
  import gradio as gr
8
  from generate_prompts import generate_prompt
9
 
10
  # Load the model once at the start
11
  print("Loading the Stable Diffusion model...")
12
+ try:
13
+ model = StableDiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo")
14
+ print("Model loaded successfully.")
15
+ except Exception as e:
16
+ print(f"Error loading model: {e}")
17
+ model = None
 
 
18
 
19
  def generate_image(prompt, prompt_name):
20
  try:
21
+ if model is None:
22
+ raise ValueError("Model not loaded properly.")
23
+
24
  print(f"Generating image for {prompt_name} with prompt: {prompt}")
25
+ output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
 
26
  print(f"Model output for {prompt_name}: {output}")
27
+
28
  if output and hasattr(output, 'images') and output.images:
29
  print(f"Image generated for {prompt_name}")
30
  image = output.images[0]
 
41
 
42
  async def queue_api_calls(sentence_mapping, character_dict, selected_style):
43
  print("Starting to queue API calls...")
44
+ print(f'sentence_mapping"{sentence_mapping}')
45
  prompts = []
46
  for paragraph_number, sentences in sentence_mapping.items():
47
  combined_sentence = " ".join(sentences)