sudo-soldier commited on
Commit
ca51646
Β·
verified Β·
1 Parent(s): 3c0e3c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -2,16 +2,23 @@ import gradio as gr
2
  from diffusers import StableDiffusionPipeline
3
  import torch
4
 
5
-
6
  pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
7
  pipe.to("cuda" if torch.cuda.is_available() else "cpu")
8
 
 
 
 
 
9
  def infer(prompt, guidance_scale, num_inference_steps):
10
  with torch.no_grad():
11
- image = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images[0]
12
- return image
13
-
 
 
14
 
 
15
  with gr.Blocks() as demo:
16
  gr.Markdown("πŸš€ Hyper-Sketch")
17
 
@@ -20,13 +27,13 @@ with gr.Blocks() as demo:
20
 
21
  with gr.Row():
22
  guidance_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.1, label="Guidance Scale")
23
- num_inference_steps = gr.Slider(10, 100, value=50, step=1, label="Inference Steps")
24
 
25
  generate_button = gr.Button("Generate Image")
26
  output_image = gr.Image(label="Generated Image", type="pil")
27
 
28
  generate_button.click(infer, inputs=[prompt, guidance_scale, num_inference_steps], outputs=[output_image])
29
 
30
-
31
  demo.launch(share=True)
32
 
 
 
2
  from diffusers import StableDiffusionPipeline
3
  import torch
4
 
5
+ # Load model
6
  pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
7
  pipe.to("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
+ # Disable safety checker (optional)
10
+ pipe.safety_checker = lambda images, **kwargs: (images, False)
11
+
12
+ # Define inference function
13
  def infer(prompt, guidance_scale, num_inference_steps):
14
  with torch.no_grad():
15
+ try:
16
+ image = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images[0]
17
+ return image
18
+ except Exception as e:
19
+ return f"Error: {str(e)}"
20
 
21
+ # Gradio UI
22
  with gr.Blocks() as demo:
23
  gr.Markdown("πŸš€ Hyper-Sketch")
24
 
 
27
 
28
  with gr.Row():
29
  guidance_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.1, label="Guidance Scale")
30
+ num_inference_steps = gr.Slider(10, 50, value=20, step=1, label="Inference Steps") # Lowered max steps
31
 
32
  generate_button = gr.Button("Generate Image")
33
  output_image = gr.Image(label="Generated Image", type="pil")
34
 
35
  generate_button.click(infer, inputs=[prompt, guidance_scale, num_inference_steps], outputs=[output_image])
36
 
 
37
  demo.launch(share=True)
38
 
39
+