Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -34,7 +34,7 @@ def get_lora_sd_pipeline(
|
|
34 |
if base_model_name_or_path is None:
|
35 |
raise ValueError("Please specify the base model name or path")
|
36 |
|
37 |
-
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype).to(device)
|
38 |
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
|
39 |
|
40 |
if os.path.exists(text_encoder_sub_dir):
|
@@ -42,9 +42,9 @@ def get_lora_sd_pipeline(
|
|
42 |
pipe.text_encoder, text_encoder_sub_dir
|
43 |
)
|
44 |
|
45 |
-
if dtype in (torch.float16, torch.bfloat16):
|
46 |
-
|
47 |
-
|
48 |
|
49 |
pipe.to(device)
|
50 |
return pipe
|
@@ -90,10 +90,8 @@ def infer(
|
|
90 |
progress=gr.Progress(track_tqdm=True),
|
91 |
):
|
92 |
generator = torch.Generator(device).manual_seed(seed)
|
93 |
-
pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
|
94 |
-
pipe = pipe.to(device)
|
95 |
# pipe.fuse_lora(lora_scale=lora_scale)
|
96 |
-
pipe.safety_checker = None
|
97 |
# prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
|
98 |
# negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
|
99 |
|
@@ -157,12 +155,12 @@ with gr.Blocks(css=css, fill_height=True) as demo:
|
|
157 |
value=7.0, # Replace with defaults that work for your model
|
158 |
)
|
159 |
with gr.Row():
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
)
|
167 |
|
168 |
num_inference_steps = gr.Slider(
|
|
|
34 |
if base_model_name_or_path is None:
|
35 |
raise ValueError("Please specify the base model name or path")
|
36 |
|
37 |
+
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype, safety_checker=None).to(device)
|
38 |
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
|
39 |
|
40 |
if os.path.exists(text_encoder_sub_dir):
|
|
|
42 |
pipe.text_encoder, text_encoder_sub_dir
|
43 |
)
|
44 |
|
45 |
+
# if dtype in (torch.float16, torch.bfloat16):
|
46 |
+
# pipe.unet.half()
|
47 |
+
# pipe.text_encoder.half()
|
48 |
|
49 |
pipe.to(device)
|
50 |
return pipe
|
|
|
90 |
progress=gr.Progress(track_tqdm=True),
|
91 |
):
|
92 |
generator = torch.Generator(device).manual_seed(seed)
|
93 |
+
pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
|
|
|
94 |
# pipe.fuse_lora(lora_scale=lora_scale)
|
|
|
95 |
# prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
|
96 |
# negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
|
97 |
|
|
|
155 |
value=7.0, # Replace with defaults that work for your model
|
156 |
)
|
157 |
with gr.Row():
|
158 |
+
# lora_scale = gr.Slider(
|
159 |
+
# label="LoRA scale",
|
160 |
+
# minimum=0.0,
|
161 |
+
# maximum=1.0,
|
162 |
+
# step=0.1,
|
163 |
+
# value=1.0,
|
164 |
)
|
165 |
|
166 |
num_inference_steps = gr.Slider(
|