dgoot commited on
Commit
dcf713b
·
1 Parent(s): 4590b87

Initial implementation

Browse files
Files changed (1) hide show
  1. app.py +225 -4
app.py CHANGED
@@ -1,7 +1,228 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import os
2
+ import shutil
3
+ from urllib.parse import urlparse
4
+
5
  import gradio as gr
6
+ import requests
7
+ import spaces
8
+ import torch
9
+ from diffusers import AutoencoderKL, StableDiffusionImg2ImgPipeline
10
+ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
11
+ download_from_original_stable_diffusion_ckpt,
12
+ )
13
+ from loguru import logger
14
+ from PIL import Image
15
+ from slugify import slugify
16
+ from tqdm import tqdm
17
+ from tqdm.contrib.concurrent import thread_map
18
+
19
+ SUPPORTED_MODELS = [
20
+ "https://civitai.com/models/4384/dreamshaper",
21
+ "https://civitai.com/models/44960/mpixel",
22
+ "https://civitai.com/models/92444/lelo-lego-lora-for-xl-and-sd15",
23
+ "https://civitai.com/models/120298/chinese-landscape-art",
24
+ "https://civitai.com/models/150986/blueprintify-sd-xl-10",
25
+ "https://civitai.com/models/257749/pony-diffusion-v6-xl",
26
+ ]
27
+ DEFAULT_MODEL = "https://civitai.com/models/4384/dreamshaper"
28
+
29
+ model_url = os.environ.get("MODEL_URL", DEFAULT_MODEL)
30
+ gpu_duration = int(os.environ.get("GPU_DURATION", 60))
31
+
32
+
33
+ logger.debug(f"Loading model info for: {model_url}")
34
+
35
+ model_id = int(urlparse(model_url).path.split("/")[2])
36
+ r = requests.get(f"https://civitai.com/api/v1/models/{model_id}")
37
+ try:
38
+ r.raise_for_status()
39
+ except requests.HTTPError as e:
40
+ raise requests.HTTPError(
41
+ r.text.strip(), request=e.request, response=e.response
42
+ ) from e
43
+
44
+ model = r.json()
45
+
46
+ logger.debug(f"Model info: {model}")
47
+
48
+ model_version = model["modelVersions"][0]
49
+
50
+ assert len(model_version["files"]) <= 2
51
+ assert len({file["type"] for file in model_version["files"]}) == len(
52
+ model_version["files"]
53
+ )
54
+ assert all(file["type"] in ["Model", "VAE"] for file in model_version["files"])
55
+ assert all(
56
+ file["metadata"]["format"] in ["SafeTensor"] for file in model_version["files"]
57
+ )
58
+
59
+
60
+ def download(file: str, url: str):
61
+ if os.path.exists(file):
62
+ return
63
+
64
+ r = requests.get(url, stream=True)
65
+ r.raise_for_status()
66
+
67
+ temp_file = f"/tmp/{file}"
68
+ with tqdm(
69
+ desc=file, total=int(r.headers["content-length"]), unit="B", unit_scale=True
70
+ ) as pbar, open(temp_file, "wb") as f:
71
+ for chunk in r.iter_content(chunk_size=1024 * 1024):
72
+ f.write(chunk)
73
+ pbar.update(len(chunk))
74
+
75
+ shutil.move(temp_file, file)
76
+
77
+
78
+ model_name = model["name"]
79
+
80
+
81
+ def get_file_name(file_type):
82
+ return f"{slugify(model_name)}.{slugify(file_type)}.safetensors"
83
+
84
+
85
+ for _ in thread_map(
86
+ lambda file: download(get_file_name(file["type"]), file["downloadUrl"]),
87
+ model_version["files"],
88
+ ):
89
+ pass
90
+
91
+
92
+ pipe_args = {}
93
+ if os.path.exists(get_file_name("VAE")):
94
+ logger.debug(f"Loading VAE")
95
+
96
+ pipe_args["vae"] = AutoencoderKL.from_single_file(
97
+ get_file_name("VAE"),
98
+ torch_dtype=torch.float16,
99
+ use_safetensors=True,
100
+ )
101
+
102
+
103
+ logger.debug(f"Loading pipeline")
104
+
105
+ pipe = download_from_original_stable_diffusion_ckpt(
106
+ checkpoint_path_or_dict=get_file_name("Model"),
107
+ from_safetensors=True,
108
+ pipeline_class=StableDiffusionImg2ImgPipeline,
109
+ load_safety_checker=False,
110
+ **pipe_args,
111
+ )
112
+
113
+ pipe = pipe.to("cuda")
114
+
115
+
116
+ @logger.catch(reraise=True)
117
+ @spaces.GPU(duration=gpu_duration)
118
+ def infer(
119
+ prompt: str,
120
+ init_image: Image.Image,
121
+ negative_prompt: str | None,
122
+ strength: float,
123
+ num_inference_steps: int,
124
+ guidance_scale: float,
125
+ progress=gr.Progress(track_tqdm=True),
126
+ ):
127
+ logger.info(f"Starting image generation: {dict(prompt=prompt, image=init_image)}")
128
+
129
+ # Downscale the image
130
+ init_image.thumbnail((1024, 1024))
131
+
132
+ additional_args = {
133
+ k: v
134
+ for k, v in dict(
135
+ strength=strength,
136
+ num_inference_steps=num_inference_steps,
137
+ guidance_scale=guidance_scale,
138
+ ).items()
139
+ if v
140
+ }
141
+
142
+ logger.debug(f"Generating image: {dict(prompt=prompt, **additional_args)}")
143
+
144
+ images = pipe(
145
+ prompt=prompt,
146
+ image=init_image,
147
+ negative_prompt=negative_prompt,
148
+ **additional_args,
149
+ ).images
150
+ return images[0]
151
+
152
+
153
+ css = """
154
+ @media (max-width: 1280px) {
155
+ #images-container {
156
+ flex-direction: column;
157
+ }
158
+ }
159
+ """
160
+
161
+ with gr.Blocks(css=css) as demo:
162
+ with gr.Column():
163
+ gr.Markdown("# Image-to-Image")
164
+ gr.Markdown(f"## Model: [{model_url}]({model_name})")
165
+
166
+ with gr.Row():
167
+ prompt = gr.Text(
168
+ label="Prompt",
169
+ show_label=False,
170
+ max_lines=1,
171
+ placeholder="Enter your prompt",
172
+ container=False,
173
+ )
174
+
175
+ run_button = gr.Button("Run", scale=0, variant="primary")
176
+
177
+ with gr.Row(elem_id="images-container"):
178
+ init_image = gr.Image(label="Initial image", type="pil")
179
+
180
+ result = gr.Image(label="Result")
181
+
182
+ with gr.Accordion("Advanced Settings", open=False):
183
+ negative_prompt = gr.Text(
184
+ label="Negative prompt",
185
+ max_lines=1,
186
+ placeholder="Enter a negative prompt",
187
+ )
188
+
189
+ with gr.Row():
190
+ strength = gr.Slider(
191
+ label="Strength",
192
+ minimum=0.0,
193
+ maximum=1.0,
194
+ step=0.01,
195
+ value=0.0,
196
+ )
197
+
198
+ num_inference_steps = gr.Slider(
199
+ label="Number of inference steps",
200
+ minimum=0,
201
+ maximum=100,
202
+ step=1,
203
+ value=0,
204
+ )
205
 
206
+ guidance_scale = gr.Slider(
207
+ label="Guidance scale",
208
+ minimum=0.0,
209
+ maximum=100.0,
210
+ step=0.1,
211
+ value=0.0,
212
+ )
213
+ gr.on(
214
+ triggers=[run_button.click, prompt.submit],
215
+ fn=infer,
216
+ inputs=[
217
+ prompt,
218
+ init_image,
219
+ negative_prompt,
220
+ strength,
221
+ num_inference_steps,
222
+ guidance_scale,
223
+ ],
224
+ outputs=[result],
225
+ )
226
 
227
+ if __name__ == "__main__":
228
+ demo.launch()