dgoot's picture
Fixed model link
ae93796 verified
raw
history blame
6.03 kB
import os
import shutil
from urllib.parse import urlparse
import gradio as gr
import requests
import spaces
import torch
from diffusers import AutoencoderKL, StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from loguru import logger
from PIL import Image
from slugify import slugify
from tqdm import tqdm
from tqdm.contrib.concurrent import thread_map
SUPPORTED_MODELS = [
"https://civitai.com/models/4384/dreamshaper",
"https://civitai.com/models/44960/mpixel",
"https://civitai.com/models/92444/lelo-lego-lora-for-xl-and-sd15",
"https://civitai.com/models/120298/chinese-landscape-art",
"https://civitai.com/models/150986/blueprintify-sd-xl-10",
"https://civitai.com/models/257749/pony-diffusion-v6-xl",
]
DEFAULT_MODEL = "https://civitai.com/models/4384/dreamshaper"
model_url = os.environ.get("MODEL_URL", DEFAULT_MODEL)
gpu_duration = int(os.environ.get("GPU_DURATION", 60))
logger.debug(f"Loading model info for: {model_url}")
model_id = int(urlparse(model_url).path.split("/")[2])
r = requests.get(f"https://civitai.com/api/v1/models/{model_id}")
try:
r.raise_for_status()
except requests.HTTPError as e:
raise requests.HTTPError(
r.text.strip(), request=e.request, response=e.response
) from e
model = r.json()
logger.debug(f"Model info: {model}")
model_version = model["modelVersions"][0]
assert len(model_version["files"]) <= 2
assert len({file["type"] for file in model_version["files"]}) == len(
model_version["files"]
)
assert all(file["type"] in ["Model", "VAE"] for file in model_version["files"])
assert all(
file["metadata"]["format"] in ["SafeTensor"] for file in model_version["files"]
)
def download(file: str, url: str):
if os.path.exists(file):
return
r = requests.get(url, stream=True)
r.raise_for_status()
temp_file = f"/tmp/{file}"
with tqdm(
desc=file, total=int(r.headers["content-length"]), unit="B", unit_scale=True
) as pbar, open(temp_file, "wb") as f:
for chunk in r.iter_content(chunk_size=1024 * 1024):
f.write(chunk)
pbar.update(len(chunk))
shutil.move(temp_file, file)
model_name = model["name"]
def get_file_name(file_type):
return f"{slugify(model_name)}.{slugify(file_type)}.safetensors"
for _ in thread_map(
lambda file: download(get_file_name(file["type"]), file["downloadUrl"]),
model_version["files"],
):
pass
pipe_args = {}
if os.path.exists(get_file_name("VAE")):
logger.debug(f"Loading VAE")
pipe_args["vae"] = AutoencoderKL.from_single_file(
get_file_name("VAE"),
torch_dtype=torch.float16,
use_safetensors=True,
)
logger.debug(f"Loading pipeline")
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=get_file_name("Model"),
from_safetensors=True,
pipeline_class=StableDiffusionImg2ImgPipeline,
load_safety_checker=False,
**pipe_args,
)
pipe = pipe.to("cuda")
@logger.catch(reraise=True)
@spaces.GPU(duration=gpu_duration)
def infer(
prompt: str,
init_image: Image.Image,
negative_prompt: str | None,
strength: float,
num_inference_steps: int,
guidance_scale: float,
progress=gr.Progress(track_tqdm=True),
):
logger.info(f"Starting image generation: {dict(prompt=prompt, image=init_image)}")
# Downscale the image
init_image.thumbnail((1024, 1024))
additional_args = {
k: v
for k, v in dict(
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).items()
if v
}
logger.debug(f"Generating image: {dict(prompt=prompt, **additional_args)}")
images = pipe(
prompt=prompt,
image=init_image,
negative_prompt=negative_prompt,
**additional_args,
).images
return images[0]
css = """
@media (max-width: 1280px) {
#images-container {
flex-direction: column;
}
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column():
gr.Markdown("# Image-to-Image")
gr.Markdown(f"## Model: [{model_name}]({model_url})")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
with gr.Row(elem_id="images-container"):
init_image = gr.Image(label="Initial image", type="pil")
result = gr.Image(label="Result")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
with gr.Row():
strength = gr.Slider(
label="Strength",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=0,
maximum=100,
step=1,
value=0,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=100.0,
step=0.1,
value=0.0,
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
init_image,
negative_prompt,
strength,
num_inference_steps,
guidance_scale,
],
outputs=[result],
)
if __name__ == "__main__":
demo.launch()