hysts's picture
hysts HF staff
Add files
9a5dc44
raw
history blame
2.35 kB
diff --git a/pipeline.py b/pipeline.py
index 7c41e04..842c5b4 100644
--- a/pipeline.py
+++ b/pipeline.py
@@ -27,7 +27,7 @@ def load_learned_concepts(pipe, root_folder="selected_outputs/", num_scales=10):
for exp_name in os.listdir(root_folder):
# get everything up to the first numeric
pure_names.append(exp_name)
- encoder = torch.load(os.path.join(root_folder, exp_name, "text_encoder/pytorch_model.bin"))
+ encoder = torch.load(os.path.join(root_folder, exp_name, "text_encoder/pytorch_model.bin"), map_location=pipe.device)
embeddings = encoder["text_model.embeddings.token_embedding.weight"]
param_value = embeddings[-10:]
@@ -36,23 +36,23 @@ def load_learned_concepts(pipe, root_folder="selected_outputs/", num_scales=10):
string_name = f"<{exp_name}|{t}|>"
tokens_to_add.append(string_name)
string_to_param_dict[string_name] = torch.nn.Parameter(param_value[t].unsqueeze(0).repeat([num_scales, 1]))
-
+
# Fully Resolution: use appropriate time embedding for the whole generation time.
string_name = f"<{exp_name}[{t}]>"
tokens_to_add.append(string_name)
repeats = t + 1
rep_param = param_value[t].unsqueeze(0).repeat([repeats, 1])
left = param_value[rep_param.shape[0]:]
- new_param = torch.cat([rep_param, left])
+ new_param = torch.cat([rep_param, left])
string_to_param_dict[string_name] = torch.nn.Parameter(new_param)
# Semi Resolution: use appropriate time embedding up to a certain time and then no conditioning.
string_name = f"<{exp_name}({t})>"
tokens_to_add.append(string_name)
- null_embedding = torch.zeros((param_value.shape[1],), device=param_value.device, dtype=param_value.dtype)
+ null_embedding = torch.zeros((param_value.shape[1],), device=pipe.device, dtype=param_value.dtype)
rep_param = null_embedding.unsqueeze(0).repeat([t + 1, 1])
left = param_value[rep_param.shape[0]:]
- new_param = torch.cat([rep_param, left])
+ new_param = torch.cat([rep_param, left])
string_to_param_dict[string_name] = torch.nn.Parameter(new_param)
pipe.tokenizer.add_tokens(tokens_to_add)