ktrndy commited on
Commit
17091ee
·
verified ·
1 Parent(s): 99ba384

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -56
app.py CHANGED
@@ -18,59 +18,6 @@ MAX_SEED = np.iinfo(np.int32).max
18
  MAX_IMAGE_SIZE = 1024
19
 
20
 
21
- def get_lora_sd_pipeline(
22
- ckpt_dir='./output',
23
- base_model_name_or_path=model_id_default,
24
- dtype=torch_dtype,
25
- device=device
26
- ):
27
- unet_sub_dir = os.path.join(ckpt_dir, "unet")
28
- text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
29
-
30
- if base_model_name_or_path is None:
31
- raise ValueError("Please specify the base model name or path")
32
-
33
- pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path,
34
- torch_dtype=dtype,
35
- safety_checker=None).to(device)
36
- pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
37
- pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
38
-
39
- if dtype in (torch.float16, torch.bfloat16):
40
- pipe.unet.half()
41
- pipe.text_encoder.half()
42
-
43
- pipe.to(device)
44
-
45
- return pipe
46
-
47
-
48
- # def encode_prompt(prompt, tokenizer, text_encoder):
49
- # text_inputs = tokenizer(
50
- # prompt,
51
- # padding="max_length",
52
- # max_length=tokenizer.model_max_length,
53
- # return_tensors="pt",
54
- # )
55
- # with torch.no_grad():
56
- # if len(text_inputs.input_ids[0]) < tokenizer.model_max_length:
57
- # prompt_embeds = text_encoder(text_inputs.input_ids.to(text_encoder.device))[0]
58
- # else:
59
- # embeds = []
60
- # start = 0
61
- # while start < tokenizer.model_max_length:
62
- # end = start + tokenizer.model_max_length
63
- # part_of_text_inputs = text_inputs.input_ids[0][start:end]
64
- # if len(part_of_text_inputs) < tokenizer.model_max_length:
65
- # part_of_text_inputs = torch.cat([part_of_text_inputs, torch.tensor([tokenizer.pad_token_id] * (tokenizer.model_max_length - len(part_of_text_inputs)))])
66
- # embeds.append(text_encoder(part_of_text_inputs.to(text_encoder.device).unsqueeze(0))[0])
67
- # start += int((8/
68
-
69
- # 11)*tokenizer.model_max_length)
70
- # prompt_embeds = torch.mean(torch.stack(embeds, dim=0), dim=0)
71
- # return prompt_embeds
72
-
73
-
74
  # @spaces.GPU #[uncomment to use ZeroGPU]
75
  def infer(
76
  prompt,
@@ -85,10 +32,26 @@ def infer(
85
  progress=gr.Progress(track_tqdm=True),
86
  ):
87
  generator = torch.Generator(device).manual_seed(seed)
88
- pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  pipe.fuse_lora(lora_scale=lora_scale)
90
- # prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
91
- # negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
92
 
93
  image = pipe(
94
  prompt=prompt,
 
18
  MAX_IMAGE_SIZE = 1024
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # @spaces.GPU #[uncomment to use ZeroGPU]
22
  def infer(
23
  prompt,
 
32
  progress=gr.Progress(track_tqdm=True),
33
  ):
34
  generator = torch.Generator(device).manual_seed(seed)
35
+
36
+ ckpt_dir='./output'
37
+ unet_sub_dir = os.path.join(ckpt_dir, "unet")
38
+ text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
39
+
40
+ if model_id is None:
41
+ raise ValueError("Please specify the base model name or path")
42
+
43
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,
44
+ torch_dtype=dtype,
45
+ safety_checker=None).to(device)
46
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
47
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
48
+
49
+ if dtype in (torch.float16, torch.bfloat16):
50
+ pipe.unet.half()
51
+ pipe.text_encoder.half()
52
+
53
  pipe.fuse_lora(lora_scale=lora_scale)
54
+ pipe.to(device)
 
55
 
56
  image = pipe(
57
  prompt=prompt,