ktrndy commited on
Commit
ba92903
·
verified ·
1 Parent(s): 39e252c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -5,7 +5,6 @@ import os
5
  import torch
6
  from diffusers import StableDiffusionPipeline
7
  from peft import PeftModel, LoraConfig
8
- from diffusers import DiffusionPipeline
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
@@ -27,26 +26,22 @@ def get_lora_sd_pipeline(
27
  ):
28
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
29
  text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
30
- if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
31
- config = LoraConfig.from_pretrained(text_encoder_sub_dir)
32
- base_model_name_or_path = config.base_model_name_or_path
33
 
34
  if base_model_name_or_path is None:
35
  raise ValueError("Please specify the base model name or path")
36
 
37
- pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype, safety_checker=None).to(device)
 
 
38
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
 
39
 
40
- if os.path.exists(text_encoder_sub_dir):
41
- pipe.text_encoder = PeftModel.from_pretrained(
42
- pipe.text_encoder, text_encoder_sub_dir
43
- )
44
-
45
- # if dtype in (torch.float16, torch.bfloat16):
46
- # pipe.unet.half()
47
- # pipe.text_encoder.half()
48
 
49
  pipe.to(device)
 
50
  return pipe
51
 
52
 
@@ -91,7 +86,7 @@ def infer(
91
  ):
92
  generator = torch.Generator(device).manual_seed(seed)
93
  pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
94
- # pipe.fuse_lora(lora_scale=lora_scale)
95
  # prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
96
  # negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
97
 
@@ -104,6 +99,7 @@ def infer(
104
  height=height,
105
  generator=generator,
106
  ).images[0]
 
107
 
108
  return image
109
 
 
5
  import torch
6
  from diffusers import StableDiffusionPipeline
7
  from peft import PeftModel, LoraConfig
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
 
26
  ):
27
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
28
  text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
 
 
 
29
 
30
  if base_model_name_or_path is None:
31
  raise ValueError("Please specify the base model name or path")
32
 
33
+ pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path,
34
+ torch_dtype=dtype,
35
+ safety_checker=None).to(device)
36
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
37
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
38
 
39
+ if dtype in (torch.float16, torch.bfloat16):
40
+ pipe.unet.half()
41
+ pipe.text_encoder.half()
 
 
 
 
 
42
 
43
  pipe.to(device)
44
+
45
  return pipe
46
 
47
 
 
86
  ):
87
  generator = torch.Generator(device).manual_seed(seed)
88
  pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
89
+ pipe.fuse_lora(lora_scale=lora_scale)
90
  # prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
91
  # negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
92
 
 
99
  height=height,
100
  generator=generator,
101
  ).images[0]
102
+ print(device, torch_dtype)
103
 
104
  return image
105