RanM commited on
Commit
4294a68
·
verified ·
1 Parent(s): bb032a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -29
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import os
2
- import asyncio
3
  from io import BytesIO
4
  from PIL import Image
5
- from diffusers import AutoPipelineForText2Image
6
  import gradio as gr
 
 
7
 
 
8
  print("Loading the Stable Diffusion model...")
9
  try:
10
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
@@ -13,56 +15,64 @@ except Exception as e:
13
  print(f"Error loading model: {e}")
14
  model = None
15
 
16
- def generate_image(prompt, prompt_name):
17
  try:
18
  if model is None:
19
  raise ValueError("Model not loaded properly.")
20
 
21
- print(f"Generating image for {prompt_name} with prompt: {prompt}")
22
- output = model(prompt=prompt, num_inference_steps=50, guidance_scale=7.5)
23
- print(f"Model output for {prompt_name}: {output}")
24
 
25
  if output is None:
26
- raise ValueError(f"Model returned None for {prompt_name}")
27
 
28
  if hasattr(output, 'images') and output.images:
29
- print(f"Image generated for {prompt_name}")
30
  image = output.images[0]
31
  buffered = BytesIO()
32
- image.save(buffered, format="PNG")
33
  image_bytes = buffered.getvalue()
34
- return image_bytes
 
35
  else:
36
- print(f"No images found in model output for {prompt_name}")
37
- raise ValueError(f"No images found in model output for {prompt_name}")
38
  except Exception as e:
39
- print(f"An error occurred while generating image for {prompt_name}: {e}")
40
- return None
41
 
42
- def process_prompt(sentence_mapping, character_dict, selected_style):
43
- print("Processing prompt...")
44
- print(f"Sentence Mapping: {sentence_mapping}")
45
- print(f"Character Dict: {character_dict}")
46
- print(f"Selected Style: {selected_style}")
47
-
48
- prompt_results = {}
49
- for paragraph_number, sentences in sentence_mapping.items():
50
- combined_sentence = " ".join(sentences)
51
- prompt = f"Make an illustration in {selected_style} style from: {combined_sentence}"
52
- image_bytes = generate_image(prompt, f"Prompt {paragraph_number}")
53
- prompt_results[paragraph_number] = image_bytes
54
 
55
- return prompt_results
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  gradio_interface = gr.Interface(
58
- fn=process_prompt,
59
  inputs=[
60
  gr.JSON(label="Sentence Mapping"),
61
  gr.JSON(label="Character Dict"),
62
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
63
  ],
64
  outputs="json"
65
- ).queue(concurrency_limit=10)
66
 
67
  if __name__ == "__main__":
68
  print("Launching Gradio interface...")
 
1
  import os
 
2
  from io import BytesIO
3
  from PIL import Image
4
+ from transformers import AutoPipelineForText2Image
5
  import gradio as gr
6
+ from generate_prompts import generate_prompt
7
+ import base64
8
 
9
+ # Load the model once at the start
10
  print("Loading the Stable Diffusion model...")
11
  try:
12
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
 
15
  print(f"Error loading model: {e}")
16
  model = None
17
 
18
+ def generate_image(prompt):
19
  try:
20
  if model is None:
21
  raise ValueError("Model not loaded properly.")
22
 
23
+ print(f"Generating image with prompt: {prompt}")
24
+ output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
25
+ print(f"Model output: {output}")
26
 
27
  if output is None:
28
+ raise ValueError("Model returned None")
29
 
30
  if hasattr(output, 'images') and output.images:
31
+ print(f"Image generated")
32
  image = output.images[0]
33
  buffered = BytesIO()
34
+ image.save(buffered, format="JPEG")
35
  image_bytes = buffered.getvalue()
36
+ img_str = base64.b64encode(image_bytes).decode("utf-8")
37
+ return img_str, None
38
  else:
39
+ print(f"No images found in model output")
40
+ raise ValueError("No images found in model output")
41
  except Exception as e:
42
+ print(f"An error occurred while generating image: {e}")
43
+ return None, str(e)
44
 
45
+ def inference(sentence_mapping, character_dict, selected_style):
46
+ try:
47
+ print(f"Received sentence_mapping: {sentence_mapping}")
48
+ print(f"Received character_dict: {character_dict}")
49
+ print(f"Received selected_style: {selected_style}")
50
+
51
+ if sentence_mapping is None or character_dict is None or selected_style is None:
52
+ return {"error": "One or more inputs are None"}
 
 
 
 
53
 
54
+ images = {}
55
+ for paragraph_number, sentences in sentence_mapping.items():
56
+ combined_sentence = " ".join(sentences)
57
+ prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
58
+ img_str, error = generate_image(prompt)
59
+ if error:
60
+ images[paragraph_number] = f"Error: {error}"
61
+ else:
62
+ images[paragraph_number] = img_str
63
+ return images
64
+ except Exception as e:
65
+ return {"error": str(e)}
66
 
67
  gradio_interface = gr.Interface(
68
+ fn=inference,
69
  inputs=[
70
  gr.JSON(label="Sentence Mapping"),
71
  gr.JSON(label="Character Dict"),
72
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
73
  ],
74
  outputs="json"
75
+ )
76
 
77
  if __name__ == "__main__":
78
  print("Launching Gradio interface...")