RanM commited on
Commit
a9b8939
·
verified ·
1 Parent(s): 1a20e42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -13
app.py CHANGED
@@ -1,30 +1,26 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import DiffusionPipeline, AutoPipelineForText2Image
4
  import base64
5
  from io import BytesIO
6
 
 
 
7
 
8
-
9
- def text_to_image_model():
10
- # AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
11
- # AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
12
- return AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
13
-
14
- # Generate image from prompt using AmusedPipeline
15
  def generate_image(prompt):
16
  try:
17
- pipe = text_to_image_model()
18
- image = pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0.0).images[0]
19
  return image, None
20
  except Exception as e:
21
  return None, str(e)
22
 
23
  def inference(prompt):
24
- print(f"Received prompt: {prompt}") # Debugging statement
 
25
  image, error = generate_image(prompt)
26
  if error:
27
- print(f"Error generating image: {error}") # Debugging statement
 
28
  return "Error: " + error
29
 
30
  buffered = BytesIO()
@@ -39,4 +35,4 @@ gradio_interface = gr.Interface(
39
  )
40
 
41
  if __name__ == "__main__":
42
- gradio_interface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import AutoPipelineForText2Image
4
  import base64
5
  from io import BytesIO
6
 
7
+ # Load the model once outside of the function
8
+ model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
9
 
 
 
 
 
 
 
 
10
  def generate_image(prompt):
11
  try:
12
+ image = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0).images[0]
 
13
  return image, None
14
  except Exception as e:
15
  return None, str(e)
16
 
17
  def inference(prompt):
18
+ print(f"Received prompt: {prompt}")
19
+ # Debugging statement
20
  image, error = generate_image(prompt)
21
  if error:
22
+ print(f"Error generating image: {error}")
23
+ # Debugging statement
24
  return "Error: " + error
25
 
26
  buffered = BytesIO()
 
35
  )
36
 
37
  if __name__ == "__main__":
38
+ gradio_interface.launch()