dgoot commited on
Commit
36e7224
·
1 Parent(s): 4adfa9c

SDXL base models and model versions

Browse files
Files changed (1) hide show
  1. app.py +34 -15
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import shutil
3
- from urllib.parse import urlparse
4
 
5
  import gradio as gr
6
  import requests
@@ -10,6 +10,7 @@ from diffusers import (
10
  AutoencoderKL,
11
  AutoPipelineForImage2Image,
12
  StableDiffusionImg2ImgPipeline,
 
13
  )
14
  from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
15
  download_from_original_stable_diffusion_ckpt,
@@ -36,7 +37,16 @@ gpu_duration = int(os.environ.get("GPU_DURATION", 60))
36
 
37
  logger.debug(f"Loading model info for: {model_url}")
38
 
39
- model_id = int(urlparse(model_url).path.split("/")[2])
 
 
 
 
 
 
 
 
 
40
  r = requests.get(f"https://civitai.com/api/v1/models/{model_id}")
41
  try:
42
  r.raise_for_status()
@@ -49,7 +59,11 @@ model = r.json()
49
 
50
  logger.debug(f"Model info: {model}")
51
 
52
- model_version = model["modelVersions"][0]
 
 
 
 
53
 
54
  assert len(model_version["files"]) <= 2
55
  assert len({file["type"] for file in model_version["files"]}) == len(
@@ -92,26 +106,31 @@ for _ in thread_map(
92
  ):
93
  pass
94
 
95
-
96
- pipe_args = {}
97
- if os.path.exists(get_file_name("VAE")):
98
- logger.debug(f"Loading VAE")
99
-
100
- pipe_args["vae"] = AutoencoderKL.from_single_file(
101
- get_file_name("VAE"),
102
- torch_dtype=torch.float16,
103
- use_safetensors=True,
104
- )
105
-
106
  model_type = model["type"]
107
 
108
  if model_type == "Checkpoint":
109
  logger.debug(f"Loading pipeline for checkpoint")
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  pipe = download_from_original_stable_diffusion_ckpt(
112
  checkpoint_path_or_dict=get_file_name("Model"),
113
  from_safetensors=True,
114
- pipeline_class=StableDiffusionImg2ImgPipeline,
115
  load_safety_checker=False,
116
  **pipe_args,
117
  )
 
1
  import os
2
  import shutil
3
+ from urllib.parse import parse_qs, urlparse
4
 
5
  import gradio as gr
6
  import requests
 
10
  AutoencoderKL,
11
  AutoPipelineForImage2Image,
12
  StableDiffusionImg2ImgPipeline,
13
+ StableDiffusionXLImg2ImgPipeline,
14
  )
15
  from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
16
  download_from_original_stable_diffusion_ckpt,
 
37
 
38
  logger.debug(f"Loading model info for: {model_url}")
39
 
40
+ model_url_parsed = urlparse(model_url)
41
+
42
+ model_id = int(model_url_parsed.path.split("/")[2])
43
+
44
+ model_version_id = parse_qs(model_url_parsed.query).get("modelVersionId")
45
+ if model_version_id is not None:
46
+ model_version_id = int(model_version_id[0])
47
+
48
+ logger.debug(f"Model version id: {model_version_id}")
49
+
50
  r = requests.get(f"https://civitai.com/api/v1/models/{model_id}")
51
  try:
52
  r.raise_for_status()
 
59
 
60
  logger.debug(f"Model info: {model}")
61
 
62
+ model_version = (
63
+ model["modelVersions"][0]
64
+ if model_version_id is None
65
+ else next(mv for mv in model["modelVersions"] if mv["id"] == model_version_id)
66
+ )
67
 
68
  assert len(model_version["files"]) <= 2
69
  assert len({file["type"] for file in model_version["files"]}) == len(
 
106
  ):
107
  pass
108
 
 
 
 
 
 
 
 
 
 
 
 
109
  model_type = model["type"]
110
 
111
  if model_type == "Checkpoint":
112
  logger.debug(f"Loading pipeline for checkpoint")
113
 
114
+ pipe_args = {}
115
+ if os.path.exists(get_file_name("VAE")):
116
+ logger.debug(f"Loading VAE")
117
+
118
+ pipe_args["vae"] = AutoencoderKL.from_single_file(
119
+ get_file_name("VAE"),
120
+ torch_dtype=torch.float16,
121
+ use_safetensors=True,
122
+ )
123
+
124
+ base_model = model_version["baseModel"]
125
+ if base_model == "SD 1.5":
126
+ pipeline_class = StableDiffusionImg2ImgPipeline
127
+ elif base_model == "SDXL 1.0":
128
+ pipeline_class = StableDiffusionXLImg2ImgPipeline
129
+
130
  pipe = download_from_original_stable_diffusion_ckpt(
131
  checkpoint_path_or_dict=get_file_name("Model"),
132
  from_safetensors=True,
133
+ pipeline_class=pipeline_class,
134
  load_safety_checker=False,
135
  **pipe_args,
136
  )