Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -18,59 +18,6 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
18 |
MAX_IMAGE_SIZE = 1024
|
19 |
|
20 |
|
21 |
-
def get_lora_sd_pipeline(
|
22 |
-
ckpt_dir='./output',
|
23 |
-
base_model_name_or_path=model_id_default,
|
24 |
-
dtype=torch_dtype,
|
25 |
-
device=device
|
26 |
-
):
|
27 |
-
unet_sub_dir = os.path.join(ckpt_dir, "unet")
|
28 |
-
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
|
29 |
-
|
30 |
-
if base_model_name_or_path is None:
|
31 |
-
raise ValueError("Please specify the base model name or path")
|
32 |
-
|
33 |
-
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path,
|
34 |
-
torch_dtype=dtype,
|
35 |
-
safety_checker=None).to(device)
|
36 |
-
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
|
37 |
-
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
|
38 |
-
|
39 |
-
if dtype in (torch.float16, torch.bfloat16):
|
40 |
-
pipe.unet.half()
|
41 |
-
pipe.text_encoder.half()
|
42 |
-
|
43 |
-
pipe.to(device)
|
44 |
-
|
45 |
-
return pipe
|
46 |
-
|
47 |
-
|
48 |
-
# def encode_prompt(prompt, tokenizer, text_encoder):
|
49 |
-
# text_inputs = tokenizer(
|
50 |
-
# prompt,
|
51 |
-
# padding="max_length",
|
52 |
-
# max_length=tokenizer.model_max_length,
|
53 |
-
# return_tensors="pt",
|
54 |
-
# )
|
55 |
-
# with torch.no_grad():
|
56 |
-
# if len(text_inputs.input_ids[0]) < tokenizer.model_max_length:
|
57 |
-
# prompt_embeds = text_encoder(text_inputs.input_ids.to(text_encoder.device))[0]
|
58 |
-
# else:
|
59 |
-
# embeds = []
|
60 |
-
# start = 0
|
61 |
-
# while start < tokenizer.model_max_length:
|
62 |
-
# end = start + tokenizer.model_max_length
|
63 |
-
# part_of_text_inputs = text_inputs.input_ids[0][start:end]
|
64 |
-
# if len(part_of_text_inputs) < tokenizer.model_max_length:
|
65 |
-
# part_of_text_inputs = torch.cat([part_of_text_inputs, torch.tensor([tokenizer.pad_token_id] * (tokenizer.model_max_length - len(part_of_text_inputs)))])
|
66 |
-
# embeds.append(text_encoder(part_of_text_inputs.to(text_encoder.device).unsqueeze(0))[0])
|
67 |
-
# start += int((8/
|
68 |
-
|
69 |
-
# 11)*tokenizer.model_max_length)
|
70 |
-
# prompt_embeds = torch.mean(torch.stack(embeds, dim=0), dim=0)
|
71 |
-
# return prompt_embeds
|
72 |
-
|
73 |
-
|
74 |
# @spaces.GPU #[uncomment to use ZeroGPU]
|
75 |
def infer(
|
76 |
prompt,
|
@@ -85,10 +32,26 @@ def infer(
|
|
85 |
progress=gr.Progress(track_tqdm=True),
|
86 |
):
|
87 |
generator = torch.Generator(device).manual_seed(seed)
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
pipe.fuse_lora(lora_scale=lora_scale)
|
90 |
-
|
91 |
-
# negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
|
92 |
|
93 |
image = pipe(
|
94 |
prompt=prompt,
|
|
|
18 |
MAX_IMAGE_SIZE = 1024
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
# @spaces.GPU #[uncomment to use ZeroGPU]
|
22 |
def infer(
|
23 |
prompt,
|
|
|
32 |
progress=gr.Progress(track_tqdm=True),
|
33 |
):
|
34 |
generator = torch.Generator(device).manual_seed(seed)
|
35 |
+
|
36 |
+
ckpt_dir='./output'
|
37 |
+
unet_sub_dir = os.path.join(ckpt_dir, "unet")
|
38 |
+
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
|
39 |
+
|
40 |
+
if model_id is None:
|
41 |
+
raise ValueError("Please specify the base model name or path")
|
42 |
+
|
43 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id,
|
44 |
+
torch_dtype=dtype,
|
45 |
+
safety_checker=None).to(device)
|
46 |
+
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
|
47 |
+
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
|
48 |
+
|
49 |
+
if dtype in (torch.float16, torch.bfloat16):
|
50 |
+
pipe.unet.half()
|
51 |
+
pipe.text_encoder.half()
|
52 |
+
|
53 |
pipe.fuse_lora(lora_scale=lora_scale)
|
54 |
+
pipe.to(device)
|
|
|
55 |
|
56 |
image = pipe(
|
57 |
prompt=prompt,
|