dgoot commited on
Commit
7dfe38c
·
1 Parent(s): ae93796

LORA support

Browse files
Files changed (1) hide show
  1. app.py +35 -9
app.py CHANGED
@@ -6,7 +6,11 @@ import gradio as gr
6
  import requests
7
  import spaces
8
  import torch
9
- from diffusers import AutoencoderKL, StableDiffusionImg2ImgPipeline
 
 
 
 
10
  from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
11
  download_from_original_stable_diffusion_ckpt,
12
  )
@@ -99,16 +103,38 @@ if os.path.exists(get_file_name("VAE")):
99
  use_safetensors=True,
100
  )
101
 
 
102
 
103
- logger.debug(f"Loading pipeline")
 
104
 
105
- pipe = download_from_original_stable_diffusion_ckpt(
106
- checkpoint_path_or_dict=get_file_name("Model"),
107
- from_safetensors=True,
108
- pipeline_class=StableDiffusionImg2ImgPipeline,
109
- load_safety_checker=False,
110
- **pipe_args,
111
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  pipe = pipe.to("cuda")
114
 
 
6
  import requests
7
  import spaces
8
  import torch
9
+ from diffusers import (
10
+ AutoencoderKL,
11
+ AutoPipelineForImage2Image,
12
+ StableDiffusionImg2ImgPipeline,
13
+ )
14
  from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
15
  download_from_original_stable_diffusion_ckpt,
16
  )
 
103
  use_safetensors=True,
104
  )
105
 
106
+ model_type = model["type"]
107
 
108
+ if model_type == "Checkpoint":
109
+ logger.debug(f"Loading pipeline for checkpoint")
110
 
111
+ pipe = download_from_original_stable_diffusion_ckpt(
112
+ checkpoint_path_or_dict=get_file_name("Model"),
113
+ from_safetensors=True,
114
+ pipeline_class=StableDiffusionImg2ImgPipeline,
115
+ load_safety_checker=False,
116
+ **pipe_args,
117
+ )
118
+ elif model_type == "LORA":
119
+ logger.debug(f"Loading pipeline for LORA")
120
+
121
+ base_model = model_version["baseModel"]
122
+
123
+ if base_model == "SD 1.5":
124
+ pipe = AutoPipelineForImage2Image.from_pretrained(
125
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
126
+ safety_checker=None,
127
+ requires_safety_checker=False,
128
+ torch_dtype=torch.float16,
129
+ use_safetensors=True,
130
+ variant="fp16",
131
+ )
132
+ else:
133
+ raise ValueError(f"Unsupported base model: {base_model}")
134
+
135
+ pipe.load_lora_weights(get_file_name("Model"), adapter_name=slugify(model_name))
136
+ else:
137
+ raise ValueError(f"Unsupported model type: {model_type}")
138
 
139
  pipe = pipe.to("cuda")
140