ktrndy commited on
Commit
7de50e2
·
verified ·
1 Parent(s): 55f37d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -10,10 +10,11 @@ from diffusers import DiffusionPipeline
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5" # Replace to the model you would like to use
12
 
13
- if torch.cuda.is_available():
14
- torch_dtype = torch.float16
15
- else:
16
- torch_dtype = torch.float32
 
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 1024
@@ -23,8 +24,7 @@ def get_lora_sd_pipeline(
23
  ckpt_dir='./output',
24
  base_model_name_or_path=model_id_default,
25
  dtype=torch_dtype,
26
- device=device,
27
- adapter_name="default"
28
  ):
29
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
30
  text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
@@ -36,8 +36,7 @@ def get_lora_sd_pipeline(
36
  raise ValueError("Please specify the base model name or path")
37
 
38
  pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype).to(device)
39
- pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
40
- pipe.unet.set_adapter(adapter_name)
41
 
42
  if os.path.exists(text_encoder_sub_dir):
43
  pipe.text_encoder = PeftModel.from_pretrained(
@@ -92,8 +91,7 @@ def infer(
92
  progress=gr.Progress(track_tqdm=True),
93
  ):
94
  generator = torch.Generator(device).manual_seed(seed)
95
- pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id,
96
- adapter_name="sticker_of_funny_cat_Pusheen")
97
  pipe = pipe.to(device)
98
  # pipe.fuse_lora(lora_scale=lora_scale)
99
  # prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5" # Replace to the model you would like to use
12
 
13
+ # if torch.cuda.is_available():
14
+ # torch_dtype = torch.float16
15
+ # else:
16
+ # torch_dtype = torch.float32
17
+ torch_dtype = torch.float32
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  MAX_IMAGE_SIZE = 1024
 
24
  ckpt_dir='./output',
25
  base_model_name_or_path=model_id_default,
26
  dtype=torch_dtype,
27
+ device=device
 
28
  ):
29
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
30
  text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
 
36
  raise ValueError("Please specify the base model name or path")
37
 
38
  pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype).to(device)
39
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
 
40
 
41
  if os.path.exists(text_encoder_sub_dir):
42
  pipe.text_encoder = PeftModel.from_pretrained(
 
91
  progress=gr.Progress(track_tqdm=True),
92
  ):
93
  generator = torch.Generator(device).manual_seed(seed)
94
+ pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
 
95
  pipe = pipe.to(device)
96
  # pipe.fuse_lora(lora_scale=lora_scale)
97
  # prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)