ktrndy commited on
Commit
afb02b7
·
verified ·
1 Parent(s): af954e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -34,7 +34,7 @@ def get_lora_sd_pipeline(
34
  if base_model_name_or_path is None:
35
  raise ValueError("Please specify the base model name or path")
36
 
37
- pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype).to(device)
38
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
39
 
40
  if os.path.exists(text_encoder_sub_dir):
@@ -42,9 +42,9 @@ def get_lora_sd_pipeline(
42
  pipe.text_encoder, text_encoder_sub_dir
43
  )
44
 
45
- if dtype in (torch.float16, torch.bfloat16):
46
- pipe.unet.half()
47
- pipe.text_encoder.half()
48
 
49
  pipe.to(device)
50
  return pipe
@@ -90,10 +90,8 @@ def infer(
90
  progress=gr.Progress(track_tqdm=True),
91
  ):
92
  generator = torch.Generator(device).manual_seed(seed)
93
- pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
94
- pipe = pipe.to(device)
95
  # pipe.fuse_lora(lora_scale=lora_scale)
96
- pipe.safety_checker = None
97
  # prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
98
  # negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
99
 
@@ -157,12 +155,12 @@ with gr.Blocks(css=css, fill_height=True) as demo:
157
  value=7.0, # Replace with defaults that work for your model
158
  )
159
  with gr.Row():
160
- lora_scale = gr.Slider(
161
- label="LoRA scale",
162
- minimum=0.0,
163
- maximum=1.0,
164
- step=0.1,
165
- value=1.0,
166
  )
167
 
168
  num_inference_steps = gr.Slider(
 
34
  if base_model_name_or_path is None:
35
  raise ValueError("Please specify the base model name or path")
36
 
37
+ pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype, safety_checker=None).to(device)
38
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
39
 
40
  if os.path.exists(text_encoder_sub_dir):
 
42
  pipe.text_encoder, text_encoder_sub_dir
43
  )
44
 
45
+ # if dtype in (torch.float16, torch.bfloat16):
46
+ # pipe.unet.half()
47
+ # pipe.text_encoder.half()
48
 
49
  pipe.to(device)
50
  return pipe
 
90
  progress=gr.Progress(track_tqdm=True),
91
  ):
92
  generator = torch.Generator(device).manual_seed(seed)
93
+ pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
 
94
  # pipe.fuse_lora(lora_scale=lora_scale)
 
95
  # prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
96
  # negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
97
 
 
155
  value=7.0, # Replace with defaults that work for your model
156
  )
157
  with gr.Row():
158
+ # lora_scale = gr.Slider(
159
+ # label="LoRA scale",
160
+ # minimum=0.0,
161
+ # maximum=1.0,
162
+ # step=0.1,
163
+ # value=1.0,
164
  )
165
 
166
  num_inference_steps = gr.Slider(