Spaces:
Running
on
Zero
Running
on
Zero
diff --git a/pipeline.py b/pipeline.py | |
index 7c41e04..842c5b4 100644 | |
--- a/pipeline.py | |
+++ b/pipeline.py | |
@@ -27,7 +27,7 @@ def load_learned_concepts(pipe, root_folder="selected_outputs/", num_scales=10): | |
for exp_name in os.listdir(root_folder): | |
# get everything up to the first numeric | |
pure_names.append(exp_name) | |
- encoder = torch.load(os.path.join(root_folder, exp_name, "text_encoder/pytorch_model.bin")) | |
+ encoder = torch.load(os.path.join(root_folder, exp_name, "text_encoder/pytorch_model.bin"), map_location=pipe.device) | |
embeddings = encoder["text_model.embeddings.token_embedding.weight"] | |
param_value = embeddings[-10:] | |
@@ -36,23 +36,23 @@ def load_learned_concepts(pipe, root_folder="selected_outputs/", num_scales=10): | |
string_name = f"<{exp_name}|{t}|>" | |
tokens_to_add.append(string_name) | |
string_to_param_dict[string_name] = torch.nn.Parameter(param_value[t].unsqueeze(0).repeat([num_scales, 1])) | |
- | |
+ | |
# Fully Resolution: use appropriate time embedding for the whole generation time. | |
string_name = f"<{exp_name}[{t}]>" | |
tokens_to_add.append(string_name) | |
repeats = t + 1 | |
rep_param = param_value[t].unsqueeze(0).repeat([repeats, 1]) | |
left = param_value[rep_param.shape[0]:] | |
- new_param = torch.cat([rep_param, left]) | |
+ new_param = torch.cat([rep_param, left]) | |
string_to_param_dict[string_name] = torch.nn.Parameter(new_param) | |
# Semi Resolution: use appropriate time embedding up to a certain time and then no conditioning. | |
string_name = f"<{exp_name}({t})>" | |
tokens_to_add.append(string_name) | |
- null_embedding = torch.zeros((param_value.shape[1],), device=param_value.device, dtype=param_value.dtype) | |
+ null_embedding = torch.zeros((param_value.shape[1],), device=pipe.device, dtype=param_value.dtype) | |
rep_param = null_embedding.unsqueeze(0).repeat([t + 1, 1]) | |
left = param_value[rep_param.shape[0]:] | |
- new_param = torch.cat([rep_param, left]) | |
+ new_param = torch.cat([rep_param, left]) | |
string_to_param_dict[string_name] = torch.nn.Parameter(new_param) | |
pipe.tokenizer.add_tokens(tokens_to_add) | |