wanghuging commited on
Commit
a66bb06
·
1 Parent(s): 6c593b0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -0
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image
2
+ import torch
3
+ import os
4
+
5
+ try:
6
+ import intel_extension_for_pytorch as ipex
7
+ except:
8
+ pass
9
+
10
+ from PIL import Image
11
+ import numpy as np
12
+ import gradio as gr
13
+ import psutil
14
+ import time
15
+ import math
16
+
17
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
18
+ TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
19
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
20
+ # check if MPS is available OSX only M1/M2/M3 chips
21
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
22
+ xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
23
+ device = torch.device(
24
+ "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
25
+ )
26
+ torch_device = device
27
+ torch_dtype = torch.float16
28
+
29
+ print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
30
+ print(f"TORCH_COMPILE: {TORCH_COMPILE}")
31
+ print(f"device: {device}")
32
+
33
+ if mps_available:
34
+ device = torch.device("mps")
35
+ torch_device = "cpu"
36
+ torch_dtype = torch.float32
37
+
38
+ if SAFETY_CHECKER == "True":
39
+ i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
40
+ "stabilityai/sdxl-turbo",
41
+ torch_dtype=torch_dtype,
42
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
43
+ )
44
+ t2i_pipe = AutoPipelineForText2Image.from_pretrained(
45
+ "stabilityai/sdxl-turbo",
46
+ torch_dtype=torch_dtype,
47
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
48
+ )
49
+ else:
50
+ i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
51
+ "stabilityai/sdxl-turbo",
52
+ safety_checker=None,
53
+ torch_dtype=torch_dtype,
54
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
55
+ )
56
+ t2i_pipe = AutoPipelineForText2Image.from_pretrained(
57
+ "stabilityai/sdxl-turbo",
58
+ safety_checker=None,
59
+ torch_dtype=torch_dtype,
60
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
61
+ )
62
+
63
+
64
+ t2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
65
+ t2i_pipe.set_progress_bar_config(disable=True)
66
+ i2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
67
+ i2i_pipe.set_progress_bar_config(disable=True)
68
+
69
+
70
+ def resize_crop(image, size=512):
71
+ image = image.convert("RGB")
72
+ w, h = image.size
73
+ image = image.resize((size, int(size * (h / w))), Image.BICUBIC)
74
+ return image
75
+
76
+
77
+ async def predict(init_image, prompt, strength, steps, seed=1231231):
78
+ if init_image is not None:
79
+ init_image = resize_crop(init_image)
80
+ generator = torch.manual_seed(seed)
81
+ last_time = time.time()
82
+
83
+ if int(steps * strength) < 1:
84
+ steps = math.ceil(1 / max(0.10, strength))
85
+
86
+ results = i2i_pipe(
87
+ prompt=prompt,
88
+ image=init_image,
89
+ generator=generator,
90
+ num_inference_steps=steps,
91
+ guidance_scale=0.0,
92
+ strength=strength,
93
+ width=512,
94
+ height=512,
95
+ output_type="pil",
96
+ )
97
+ else:
98
+ generator = torch.manual_seed(seed)
99
+ last_time = time.time()
100
+ results = t2i_pipe(
101
+ prompt=prompt,
102
+ generator=generator,
103
+ num_inference_steps=steps,
104
+ guidance_scale=0.0,
105
+ width=512,
106
+ height=512,
107
+ output_type="pil",
108
+ )
109
+ print(f"Pipe took {time.time() - last_time} seconds")
110
+ nsfw_content_detected = (
111
+ results.nsfw_content_detected[0]
112
+ if "nsfw_content_detected" in results
113
+ else False
114
+ )
115
+ if nsfw_content_detected:
116
+ gr.Warning("NSFW content detected.")
117
+ return Image.new("RGB", (512, 512))
118
+ return results.images[0]
119
+
120
+
121
+ css = """
122
+ #container{
123
+ margin: 0 auto;
124
+ max-width: 80rem;
125
+ }
126
+ #intro{
127
+ max-width: 100%;
128
+ text-align: center;
129
+ margin: 0 auto;
130
+ }
131
+ """
132
+ with gr.Blocks(css=css) as demo:
133
+ init_image_state = gr.State()
134
+ with gr.Column(elem_id="container"):
135
+ gr.Markdown(
136
+ """# SDXL Turbo Image to Image/Text to Image
137
+ ## Unofficial Demo
138
+ SDXL Turbo model can generate high quality images in a single pass read more on [stability.ai post](https://stability.ai/news/stability-ai-sdxl-turbo).
139
+ **Model**: https://huggingface.co/stabilityai/sdxl-turbo
140
+ """,
141
+ elem_id="intro",
142
+ )
143
+ with gr.Row():
144
+ prompt = gr.Textbox(
145
+ placeholder="Insert your prompt here:",
146
+ scale=5,
147
+ container=False,
148
+ )
149
+ generate_bt = gr.Button("Generate", scale=1)
150
+ with gr.Row():
151
+ with gr.Column():
152
+ image_input = gr.Image(
153
+ sources=["upload", "webcam", "clipboard"],
154
+ label="Webcam",
155
+ type="pil",
156
+ )
157
+ with gr.Column():
158
+ image = gr.Image(type="filepath")
159
+ with gr.Accordion("Advanced options", open=False):
160
+ strength = gr.Slider(
161
+ label="Strength",
162
+ value=0.7,
163
+ minimum=0.0,
164
+ maximum=1.0,
165
+ step=0.001,
166
+ )
167
+ steps = gr.Slider(
168
+ label="Steps", value=2, minimum=1, maximum=10, step=1
169
+ )
170
+ seed = gr.Slider(
171
+ randomize=True,
172
+ minimum=0,
173
+ maximum=12013012031030,
174
+ label="Seed",
175
+ step=1,
176
+ )
177
+
178
+ with gr.Accordion("Run with diffusers"):
179
+ gr.Markdown(
180
+ """## Running SDXL Turbo with `diffusers`
181
+ ```bash
182
+ pip install diffusers==0.23.1
183
+ ```
184
+ ```py
185
+ from diffusers import DiffusionPipeline
186
+ pipe = DiffusionPipeline.from_pretrained(
187
+ "stabilityai/sdxl-turbo"
188
+ ).to("cuda")
189
+ results = pipe(
190
+ prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe",
191
+ num_inference_steps=1,
192
+ guidance_scale=0.0,
193
+ )
194
+ imga = results.images[0]
195
+ imga.save("image.png")
196
+ ```
197
+ """
198
+ )
199
+
200
+ inputs = [image_input, prompt, strength, steps, seed]
201
+ generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
202
+ prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
203
+ steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
204
+ seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
205
+ strength.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
206
+ image_input.change(
207
+ fn=lambda x: x,
208
+ inputs=image_input,
209
+ outputs=init_image_state,
210
+ show_progress=False,
211
+ queue=False,
212
+ )
213
+
214
+ demo.queue()
215
+ demo.launch()