Spaces:
Runtime error
Runtime error
LORA support
Browse files
app.py
CHANGED
@@ -6,7 +6,11 @@ import gradio as gr
|
|
6 |
import requests
|
7 |
import spaces
|
8 |
import torch
|
9 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
10 |
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
11 |
download_from_original_stable_diffusion_ckpt,
|
12 |
)
|
@@ -99,16 +103,38 @@ if os.path.exists(get_file_name("VAE")):
|
|
99 |
use_safetensors=True,
|
100 |
)
|
101 |
|
|
|
102 |
|
103 |
-
|
|
|
104 |
|
105 |
-
pipe = download_from_original_stable_diffusion_ckpt(
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
pipe = pipe.to("cuda")
|
114 |
|
|
|
6 |
import requests
|
7 |
import spaces
|
8 |
import torch
|
9 |
+
from diffusers import (
|
10 |
+
AutoencoderKL,
|
11 |
+
AutoPipelineForImage2Image,
|
12 |
+
StableDiffusionImg2ImgPipeline,
|
13 |
+
)
|
14 |
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
15 |
download_from_original_stable_diffusion_ckpt,
|
16 |
)
|
|
|
103 |
use_safetensors=True,
|
104 |
)
|
105 |
|
106 |
+
model_type = model["type"]
|
107 |
|
108 |
+
if model_type == "Checkpoint":
|
109 |
+
logger.debug(f"Loading pipeline for checkpoint")
|
110 |
|
111 |
+
pipe = download_from_original_stable_diffusion_ckpt(
|
112 |
+
checkpoint_path_or_dict=get_file_name("Model"),
|
113 |
+
from_safetensors=True,
|
114 |
+
pipeline_class=StableDiffusionImg2ImgPipeline,
|
115 |
+
load_safety_checker=False,
|
116 |
+
**pipe_args,
|
117 |
+
)
|
118 |
+
elif model_type == "LORA":
|
119 |
+
logger.debug(f"Loading pipeline for LORA")
|
120 |
+
|
121 |
+
base_model = model_version["baseModel"]
|
122 |
+
|
123 |
+
if base_model == "SD 1.5":
|
124 |
+
pipe = AutoPipelineForImage2Image.from_pretrained(
|
125 |
+
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
126 |
+
safety_checker=None,
|
127 |
+
requires_safety_checker=False,
|
128 |
+
torch_dtype=torch.float16,
|
129 |
+
use_safetensors=True,
|
130 |
+
variant="fp16",
|
131 |
+
)
|
132 |
+
else:
|
133 |
+
raise ValueError(f"Unsupported base model: {base_model}")
|
134 |
+
|
135 |
+
pipe.load_lora_weights(get_file_name("Model"), adapter_name=slugify(model_name))
|
136 |
+
else:
|
137 |
+
raise ValueError(f"Unsupported model type: {model_type}")
|
138 |
|
139 |
pipe = pipe.to("cuda")
|
140 |
|