Spaces:
Runtime error
Runtime error
Initial implementation
Browse files
app.py
CHANGED
@@ -1,7 +1,228 @@
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
demo.launch()
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
from urllib.parse import urlparse
|
4 |
+
|
5 |
import gradio as gr
|
6 |
+
import requests
|
7 |
+
import spaces
|
8 |
+
import torch
|
9 |
+
from diffusers import AutoencoderKL, StableDiffusionImg2ImgPipeline
|
10 |
+
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
11 |
+
download_from_original_stable_diffusion_ckpt,
|
12 |
+
)
|
13 |
+
from loguru import logger
|
14 |
+
from PIL import Image
|
15 |
+
from slugify import slugify
|
16 |
+
from tqdm import tqdm
|
17 |
+
from tqdm.contrib.concurrent import thread_map
|
18 |
+
|
19 |
+
SUPPORTED_MODELS = [
|
20 |
+
"https://civitai.com/models/4384/dreamshaper",
|
21 |
+
"https://civitai.com/models/44960/mpixel",
|
22 |
+
"https://civitai.com/models/92444/lelo-lego-lora-for-xl-and-sd15",
|
23 |
+
"https://civitai.com/models/120298/chinese-landscape-art",
|
24 |
+
"https://civitai.com/models/150986/blueprintify-sd-xl-10",
|
25 |
+
"https://civitai.com/models/257749/pony-diffusion-v6-xl",
|
26 |
+
]
|
27 |
+
DEFAULT_MODEL = "https://civitai.com/models/4384/dreamshaper"
|
28 |
+
|
29 |
+
model_url = os.environ.get("MODEL_URL", DEFAULT_MODEL)
|
30 |
+
gpu_duration = int(os.environ.get("GPU_DURATION", 60))
|
31 |
+
|
32 |
+
|
33 |
+
logger.debug(f"Loading model info for: {model_url}")
|
34 |
+
|
35 |
+
model_id = int(urlparse(model_url).path.split("/")[2])
|
36 |
+
r = requests.get(f"https://civitai.com/api/v1/models/{model_id}")
|
37 |
+
try:
|
38 |
+
r.raise_for_status()
|
39 |
+
except requests.HTTPError as e:
|
40 |
+
raise requests.HTTPError(
|
41 |
+
r.text.strip(), request=e.request, response=e.response
|
42 |
+
) from e
|
43 |
+
|
44 |
+
model = r.json()
|
45 |
+
|
46 |
+
logger.debug(f"Model info: {model}")
|
47 |
+
|
48 |
+
model_version = model["modelVersions"][0]
|
49 |
+
|
50 |
+
assert len(model_version["files"]) <= 2
|
51 |
+
assert len({file["type"] for file in model_version["files"]}) == len(
|
52 |
+
model_version["files"]
|
53 |
+
)
|
54 |
+
assert all(file["type"] in ["Model", "VAE"] for file in model_version["files"])
|
55 |
+
assert all(
|
56 |
+
file["metadata"]["format"] in ["SafeTensor"] for file in model_version["files"]
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def download(file: str, url: str):
|
61 |
+
if os.path.exists(file):
|
62 |
+
return
|
63 |
+
|
64 |
+
r = requests.get(url, stream=True)
|
65 |
+
r.raise_for_status()
|
66 |
+
|
67 |
+
temp_file = f"/tmp/{file}"
|
68 |
+
with tqdm(
|
69 |
+
desc=file, total=int(r.headers["content-length"]), unit="B", unit_scale=True
|
70 |
+
) as pbar, open(temp_file, "wb") as f:
|
71 |
+
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
72 |
+
f.write(chunk)
|
73 |
+
pbar.update(len(chunk))
|
74 |
+
|
75 |
+
shutil.move(temp_file, file)
|
76 |
+
|
77 |
+
|
78 |
+
model_name = model["name"]
|
79 |
+
|
80 |
+
|
81 |
+
def get_file_name(file_type):
|
82 |
+
return f"{slugify(model_name)}.{slugify(file_type)}.safetensors"
|
83 |
+
|
84 |
+
|
85 |
+
for _ in thread_map(
|
86 |
+
lambda file: download(get_file_name(file["type"]), file["downloadUrl"]),
|
87 |
+
model_version["files"],
|
88 |
+
):
|
89 |
+
pass
|
90 |
+
|
91 |
+
|
92 |
+
pipe_args = {}
|
93 |
+
if os.path.exists(get_file_name("VAE")):
|
94 |
+
logger.debug(f"Loading VAE")
|
95 |
+
|
96 |
+
pipe_args["vae"] = AutoencoderKL.from_single_file(
|
97 |
+
get_file_name("VAE"),
|
98 |
+
torch_dtype=torch.float16,
|
99 |
+
use_safetensors=True,
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
logger.debug(f"Loading pipeline")
|
104 |
+
|
105 |
+
pipe = download_from_original_stable_diffusion_ckpt(
|
106 |
+
checkpoint_path_or_dict=get_file_name("Model"),
|
107 |
+
from_safetensors=True,
|
108 |
+
pipeline_class=StableDiffusionImg2ImgPipeline,
|
109 |
+
load_safety_checker=False,
|
110 |
+
**pipe_args,
|
111 |
+
)
|
112 |
+
|
113 |
+
pipe = pipe.to("cuda")
|
114 |
+
|
115 |
+
|
116 |
+
@logger.catch(reraise=True)
|
117 |
+
@spaces.GPU(duration=gpu_duration)
|
118 |
+
def infer(
|
119 |
+
prompt: str,
|
120 |
+
init_image: Image.Image,
|
121 |
+
negative_prompt: str | None,
|
122 |
+
strength: float,
|
123 |
+
num_inference_steps: int,
|
124 |
+
guidance_scale: float,
|
125 |
+
progress=gr.Progress(track_tqdm=True),
|
126 |
+
):
|
127 |
+
logger.info(f"Starting image generation: {dict(prompt=prompt, image=init_image)}")
|
128 |
+
|
129 |
+
# Downscale the image
|
130 |
+
init_image.thumbnail((1024, 1024))
|
131 |
+
|
132 |
+
additional_args = {
|
133 |
+
k: v
|
134 |
+
for k, v in dict(
|
135 |
+
strength=strength,
|
136 |
+
num_inference_steps=num_inference_steps,
|
137 |
+
guidance_scale=guidance_scale,
|
138 |
+
).items()
|
139 |
+
if v
|
140 |
+
}
|
141 |
+
|
142 |
+
logger.debug(f"Generating image: {dict(prompt=prompt, **additional_args)}")
|
143 |
+
|
144 |
+
images = pipe(
|
145 |
+
prompt=prompt,
|
146 |
+
image=init_image,
|
147 |
+
negative_prompt=negative_prompt,
|
148 |
+
**additional_args,
|
149 |
+
).images
|
150 |
+
return images[0]
|
151 |
+
|
152 |
+
|
153 |
+
css = """
|
154 |
+
@media (max-width: 1280px) {
|
155 |
+
#images-container {
|
156 |
+
flex-direction: column;
|
157 |
+
}
|
158 |
+
}
|
159 |
+
"""
|
160 |
+
|
161 |
+
with gr.Blocks(css=css) as demo:
|
162 |
+
with gr.Column():
|
163 |
+
gr.Markdown("# Image-to-Image")
|
164 |
+
gr.Markdown(f"## Model: [{model_url}]({model_name})")
|
165 |
+
|
166 |
+
with gr.Row():
|
167 |
+
prompt = gr.Text(
|
168 |
+
label="Prompt",
|
169 |
+
show_label=False,
|
170 |
+
max_lines=1,
|
171 |
+
placeholder="Enter your prompt",
|
172 |
+
container=False,
|
173 |
+
)
|
174 |
+
|
175 |
+
run_button = gr.Button("Run", scale=0, variant="primary")
|
176 |
+
|
177 |
+
with gr.Row(elem_id="images-container"):
|
178 |
+
init_image = gr.Image(label="Initial image", type="pil")
|
179 |
+
|
180 |
+
result = gr.Image(label="Result")
|
181 |
+
|
182 |
+
with gr.Accordion("Advanced Settings", open=False):
|
183 |
+
negative_prompt = gr.Text(
|
184 |
+
label="Negative prompt",
|
185 |
+
max_lines=1,
|
186 |
+
placeholder="Enter a negative prompt",
|
187 |
+
)
|
188 |
+
|
189 |
+
with gr.Row():
|
190 |
+
strength = gr.Slider(
|
191 |
+
label="Strength",
|
192 |
+
minimum=0.0,
|
193 |
+
maximum=1.0,
|
194 |
+
step=0.01,
|
195 |
+
value=0.0,
|
196 |
+
)
|
197 |
+
|
198 |
+
num_inference_steps = gr.Slider(
|
199 |
+
label="Number of inference steps",
|
200 |
+
minimum=0,
|
201 |
+
maximum=100,
|
202 |
+
step=1,
|
203 |
+
value=0,
|
204 |
+
)
|
205 |
|
206 |
+
guidance_scale = gr.Slider(
|
207 |
+
label="Guidance scale",
|
208 |
+
minimum=0.0,
|
209 |
+
maximum=100.0,
|
210 |
+
step=0.1,
|
211 |
+
value=0.0,
|
212 |
+
)
|
213 |
+
gr.on(
|
214 |
+
triggers=[run_button.click, prompt.submit],
|
215 |
+
fn=infer,
|
216 |
+
inputs=[
|
217 |
+
prompt,
|
218 |
+
init_image,
|
219 |
+
negative_prompt,
|
220 |
+
strength,
|
221 |
+
num_inference_steps,
|
222 |
+
guidance_scale,
|
223 |
+
],
|
224 |
+
outputs=[result],
|
225 |
+
)
|
226 |
|
227 |
+
if __name__ == "__main__":
|
228 |
+
demo.launch()
|