patrickvonplaten commited on
Commit
dfaea05
·
1 Parent(s): d44229e
Files changed (1) hide show
  1. app.py +42 -267
app.py CHANGED
@@ -1,53 +1,11 @@
1
- from diffusers import (
2
- StableDiffusionPipeline,
3
- StableDiffusionImg2ImgPipeline,
4
- DPMSolverMultistepScheduler,
5
- )
6
  import gradio as gr
7
  import torch
8
- from PIL import Image
9
  import time
10
  import psutil
11
- import random
12
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
13
 
14
 
15
  start_time = time.time()
16
- current_steps = 25
17
-
18
- SAFETY_CHECKER = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16)
19
-
20
-
21
- class Model:
22
- def __init__(self, name, path=""):
23
- self.name = name
24
- self.path = path
25
-
26
- if path != "":
27
- self.pipe_t2i = StableDiffusionPipeline.from_pretrained(
28
- path, torch_dtype=torch.float16, safety_checker=SAFETY_CHECKER
29
- )
30
- self.pipe_t2i.scheduler = DPMSolverMultistepScheduler.from_config(
31
- self.pipe_t2i.scheduler.config
32
- )
33
- self.pipe_i2i = StableDiffusionImg2ImgPipeline(**self.pipe_t2i.components)
34
- else:
35
- self.pipe_t2i = None
36
- self.pipe_i2i = None
37
-
38
-
39
- models = [
40
- Model("Protogen v2.2 (Anime)", "darkstorm2150/Protogen_v2.2_Official_Release"),
41
- Model("Protogen x3.4 (Photorealism)", "darkstorm2150/Protogen_x3.4_Official_Release"),
42
- Model("Protogen x5.3 (Photorealism)", "darkstorm2150/Protogen_x5.3_Official_Release"),
43
- Model("Protogen x5.8 Rebuilt (Scifi+Anime)", "darkstorm2150/Protogen_x5.8_Official_Release"),
44
- Model("Protogen Dragon (RPG Model)", "darkstorm2150/Protogen_Dragon_Official_Release"),
45
- Model("Protogen Nova", "darkstorm2150/Protogen_Nova_Official_Release"),
46
- Model("Protogen Eclipse", "darkstorm2150/Protogen_Eclipse_Official_Release"),
47
- Model("Protogen Infinity", "darkstorm2150/Protogen_Infinity_Official_Release"),
48
- ]
49
-
50
- MODELS = {m.name: m for m in models}
51
 
52
  device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
53
 
@@ -62,263 +20,80 @@ def error_str(error, title="Error"):
62
 
63
 
64
  def inference(
65
- model_name,
 
66
  prompt,
67
- guidance,
68
- steps,
69
- n_images=1,
70
- width=512,
71
- height=512,
72
- seed=0,
73
- img=None,
74
- strength=0.5,
75
- neg_prompt="",
76
  ):
77
 
78
  print(psutil.virtual_memory()) # print memory usage
79
 
80
- if seed == 0:
81
- seed = random.randint(0, 2147483647)
82
-
83
- generator = torch.Generator("cuda").manual_seed(seed)
84
-
85
- try:
86
- if img is not None:
87
- return (
88
- img_to_img(
89
- model_name,
90
- prompt,
91
- n_images,
92
- neg_prompt,
93
- img,
94
- strength,
95
- guidance,
96
- steps,
97
- width,
98
- height,
99
- generator,
100
- seed,
101
- ),
102
- f"Done. Seed: {seed}",
103
- )
104
- else:
105
- return (
106
- txt_to_img(
107
- model_name,
108
- prompt,
109
- n_images,
110
- neg_prompt,
111
- guidance,
112
- steps,
113
- width,
114
- height,
115
- generator,
116
- seed,
117
- ),
118
- f"Done. Seed: {seed}",
119
- )
120
- except Exception as e:
121
- return None, error_str(e)
122
-
123
-
124
- def txt_to_img(
125
- model_name,
126
- prompt,
127
- n_images,
128
- neg_prompt,
129
- guidance,
130
- steps,
131
- width,
132
- height,
133
- generator,
134
- seed,
135
- ):
136
- pipe = MODELS[model_name].pipe_t2i
137
-
138
- if torch.cuda.is_available():
139
- pipe = pipe.to("cuda")
140
- pipe.enable_xformers_memory_efficient_attention()
141
-
142
- result = pipe(
143
- prompt,
144
- negative_prompt=neg_prompt,
145
- num_images_per_prompt=n_images,
146
- num_inference_steps=int(steps),
147
- guidance_scale=guidance,
148
- width=width,
149
- height=height,
150
- generator=generator,
151
- )
152
-
153
- pipe.to("cpu")
154
-
155
- return replace_nsfw_images(result)
156
-
157
-
158
- def img_to_img(
159
- model_name,
160
- prompt,
161
- n_images,
162
- neg_prompt,
163
- img,
164
- strength,
165
- guidance,
166
- steps,
167
- width,
168
- height,
169
- generator,
170
- seed,
171
- ):
172
- pipe = MODELS[model_name].pipe_i2i
173
-
174
- if torch.cuda.is_available():
175
- pipe = pipe.to("cuda")
176
- pipe.enable_xformers_memory_efficient_attention()
177
 
178
- ratio = min(height / img.height, width / img.width)
179
- img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS)
180
 
181
- result = pipe(
182
- prompt,
183
- negative_prompt=neg_prompt,
184
- num_images_per_prompt=n_images,
185
- image=img,
186
- num_inference_steps=int(steps),
187
- strength=strength,
188
- guidance_scale=guidance,
189
- generator=generator,
190
- )
191
-
192
- pipe.to("cpu")
193
-
194
- return replace_nsfw_images(result)
195
 
 
 
 
196
 
197
- def replace_nsfw_images(results):
198
- for i in range(len(results.images)):
199
- if results.nsfw_content_detected[i]:
200
- results.images[i] = Image.open("nsfw.png")
201
- return results.images
202
 
203
 
204
  with gr.Blocks(css="style.css") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
205
  with gr.Row():
206
 
207
  with gr.Column(scale=55):
208
  with gr.Group():
209
- prompt = gr.Textbox(
210
  label="Repo id on Hub",
211
  placeholder="Path to model, e.g. CompVis/stable-diffusion-v1-4",
212
  )
213
- with gr.Box(visible=False) as custom_model_group:
214
- custom_model_path = gr.Textbox(
215
- label="Custom model path",
216
- placeholder="Path to model, e.g. darkstorm2150/Protogen_x3.4_Official_Release",
217
- interactive=True,
218
- )
219
- gr.HTML(
220
- "<div><font size='2'>Custom models have to be downloaded first, so give it some time.</font></div>"
221
- )
222
-
223
- with gr.Row():
224
- prompt = gr.Textbox(
225
- label="Prompt",
226
- show_label=False,
227
- max_lines=2,
228
- placeholder="Enter prompt.",
229
- ).style(container=False)
230
- generate = gr.Button(value="Generate").style(
231
- rounded=(False, True, True, False)
232
- )
233
-
234
- # image_out = gr.Image(height=512)
235
  gallery = gr.Gallery(
236
  label="Generated images", show_label=False, elem_id="gallery"
237
  ).style(grid=[2], height="auto")
238
 
239
- state_info = gr.Textbox(label="State", show_label=False, max_lines=2).style(
240
- container=False
241
- )
242
  error_output = gr.Markdown()
243
 
244
- with gr.Column(scale=45):
245
- with gr.Tab("Options"):
246
- with gr.Group():
247
- neg_prompt = gr.Textbox(
248
- label="Negative prompt",
249
- placeholder="What to exclude from the image",
250
- )
251
-
252
- n_images = gr.Slider(
253
- label="Images", value=1, minimum=1, maximum=4, step=1
254
- )
255
-
256
- with gr.Row():
257
- guidance = gr.Slider(
258
- label="Guidance scale", value=7.5, maximum=15
259
- )
260
- steps = gr.Slider(
261
- label="Steps",
262
- value=current_steps,
263
- minimum=2,
264
- maximum=75,
265
- step=1,
266
- )
267
-
268
- with gr.Row():
269
- width = gr.Slider(
270
- label="Width", value=512, minimum=64, maximum=1024, step=8
271
- )
272
- height = gr.Slider(
273
- label="Height", value=512, minimum=64, maximum=1024, step=8
274
- )
275
-
276
- seed = gr.Slider(
277
- 0, 2147483647, label="Seed (0 = random)", value=0, step=1
278
- )
279
-
280
- with gr.Tab("Image to image"):
281
- with gr.Group():
282
- image = gr.Image(
283
- label="Image", height=256, tool="editor", type="pil"
284
- )
285
- strength = gr.Slider(
286
- label="Transformation strength",
287
- minimum=0,
288
- maximum=1,
289
- step=0.01,
290
- value=0.5,
291
- )
292
 
293
  inputs = [
294
- model_name,
 
295
  prompt,
296
- guidance,
297
- steps,
298
- n_images,
299
- width,
300
- height,
301
- seed,
302
- image,
303
- strength,
304
- neg_prompt,
305
  ]
306
  outputs = [gallery, error_output]
307
  prompt.submit(inference, inputs=inputs, outputs=outputs)
308
  generate.click(inference, inputs=inputs, outputs=outputs)
309
 
310
- gr.HTML(
311
- """
312
- <div style="border-top: 1px solid #303030;">
313
- <br>
314
- <p>Models by <a href="https://huggingface.co/darkstorm2150">@darkstorm2150</a> and others. ❤️</p>
315
- <p>This space uses the <a href="https://github.com/LuChengTHU/dpm-solver">DPM-Solver++</a> sampler by <a href="https://arxiv.org/abs/2206.00927">Cheng Lu, et al.</a>.</p>
316
- <p>Space by: Darkstorm (Victor Espinoza)<br>
317
- <a href="https://www.instagram.com/officialvictorespinoza/">Instagram</a>
318
- </div>
319
- """
320
- )
321
-
322
  print(f"Space built in {time.time() - start_time:.2f} seconds")
323
 
324
  demo.queue(concurrency_count=1)
 
1
+ from diffusers import DiffusionPipeline
 
 
 
 
2
  import gradio as gr
3
  import torch
 
4
  import time
5
  import psutil
 
 
6
 
7
 
8
  start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
11
 
 
20
 
21
 
22
  def inference(
23
+ repo_id,
24
+ pr,
25
  prompt,
 
 
 
 
 
 
 
 
 
26
  ):
27
 
28
  print(psutil.virtual_memory()) # print memory usage
29
 
30
+ seed = 0
31
+ torch_device = "cuda" if "GPU" in device else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ generator = torch.Generator(torch_device).manual_seed(seed)
 
34
 
35
+ dtype = torch.float16 if torch_device == "cuda" else torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ try:
38
+ pipe = DiffusionPipeline.from_pretrained(repo_id, revision=pr, torch_dtype=dtype)
39
+ pipe.to(torch_device)
40
 
41
+ return pipe(prompt, generator=generator, num_inference_steps=25).images
42
+ except Exception as e:
43
+ url = f"https://huggingface.co/{repo_id}/discussions/{pr.split('/')[-1]}"
44
+ message = f"There is a problem with your diffusers weights of the PR: {url}. Error message: \n"
45
+ return None, error_str(message + e)
46
 
47
 
48
  with gr.Blocks(css="style.css") as demo:
49
+ gr.HTML(
50
+ f"""
51
+ <div class="diffusion">
52
+ <p>
53
+ Space to test whether `diffusers` PRs work.
54
+ </p>
55
+ <p>
56
+ Running on <b>{device}</b>
57
+ </p>
58
+ </div>
59
+ """
60
+ )
61
  with gr.Row():
62
 
63
  with gr.Column(scale=55):
64
  with gr.Group():
65
+ repo_id = gr.Textbox(
66
  label="Repo id on Hub",
67
  placeholder="Path to model, e.g. CompVis/stable-diffusion-v1-4",
68
  )
69
+ pr = gr.Textbox(
70
+ label="PR branch",
71
+ placeholder="PR branch that should be checked, e.g. refs/pr/171",
72
+ )
73
+ prompt = gr.Textbox(
74
+ label="Prompt",
75
+ default="An astronaut riding a horse on Mars.",
76
+ placeholder="Enter prompt.",
77
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  gallery = gr.Gallery(
79
  label="Generated images", show_label=False, elem_id="gallery"
80
  ).style(grid=[2], height="auto")
81
 
 
 
 
82
  error_output = gr.Markdown()
83
 
84
+ generate = gr.Button(value="Generate").style(
85
+ rounded=(False, True, True, False)
86
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  inputs = [
89
+ repo_id,
90
+ pr,
91
  prompt,
 
 
 
 
 
 
 
 
 
92
  ]
93
  outputs = [gallery, error_output]
94
  prompt.submit(inference, inputs=inputs, outputs=outputs)
95
  generate.click(inference, inputs=inputs, outputs=outputs)
96
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  print(f"Space built in {time.time() - start_time:.2f} seconds")
98
 
99
  demo.queue(concurrency_count=1)