File size: 2,351 Bytes
9a5dc44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
diff --git a/pipeline.py b/pipeline.py
index 7c41e04..842c5b4 100644
--- a/pipeline.py
+++ b/pipeline.py
@@ -27,7 +27,7 @@ def load_learned_concepts(pipe, root_folder="selected_outputs/", num_scales=10):
     for exp_name in os.listdir(root_folder):
         # get everything up to the first numeric
         pure_names.append(exp_name)
-        encoder = torch.load(os.path.join(root_folder, exp_name, "text_encoder/pytorch_model.bin"))
+        encoder = torch.load(os.path.join(root_folder, exp_name, "text_encoder/pytorch_model.bin"), map_location=pipe.device)
         embeddings = encoder["text_model.embeddings.token_embedding.weight"]
         param_value = embeddings[-10:]
 
@@ -36,23 +36,23 @@ def load_learned_concepts(pipe, root_folder="selected_outputs/", num_scales=10):
             string_name = f"<{exp_name}|{t}|>"
             tokens_to_add.append(string_name)
             string_to_param_dict[string_name] = torch.nn.Parameter(param_value[t].unsqueeze(0).repeat([num_scales, 1]))
-            
+
             # Fully Resolution: use appropriate time embedding for the whole generation time.
             string_name = f"<{exp_name}[{t}]>"
             tokens_to_add.append(string_name)
             repeats = t + 1
             rep_param = param_value[t].unsqueeze(0).repeat([repeats, 1])
             left = param_value[rep_param.shape[0]:]
-            new_param = torch.cat([rep_param, left]) 
+            new_param = torch.cat([rep_param, left])
             string_to_param_dict[string_name] = torch.nn.Parameter(new_param)
 
             # Semi Resolution: use appropriate time embedding up to a certain time and then no conditioning.
             string_name = f"<{exp_name}({t})>"
             tokens_to_add.append(string_name)
-            null_embedding = torch.zeros((param_value.shape[1],), device=param_value.device, dtype=param_value.dtype)
+            null_embedding = torch.zeros((param_value.shape[1],), device=pipe.device, dtype=param_value.dtype)
             rep_param = null_embedding.unsqueeze(0).repeat([t + 1, 1])
             left = param_value[rep_param.shape[0]:]
-            new_param = torch.cat([rep_param, left]) 
+            new_param = torch.cat([rep_param, left])
             string_to_param_dict[string_name] = torch.nn.Parameter(new_param)
 
     pipe.tokenizer.add_tokens(tokens_to_add)