frankaging commited on
Commit
bb5c56b
·
1 Parent(s): 7f3db95
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -34,8 +34,11 @@ class Steer(pv.SourcelessIntervention):
34
  def forward(self, base, source=None, subspaces=None):
35
  if subspaces is None:
36
  return base
37
- steering_vec = torch.tensor(subspaces["mag"]) * \
38
- self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
 
 
 
39
  return base + steering_vec
40
 
41
  # Check GPU
@@ -91,9 +94,9 @@ def generate(
91
 
92
  # build list of messages
93
  messages = []
94
- # for user_msg, model_msg in recent_history:
95
- # messages.append({"role": "user", "content": user_msg})
96
- # messages.append({"role": "model", "content": model_msg})
97
  messages.append({"role": "user", "content": message})
98
 
99
  input_ids = torch.tensor([tokenizer.apply_chat_template(
@@ -113,8 +116,8 @@ def generate(
113
  "intervene_on_prompt": True,
114
  "subspaces": [
115
  {
116
- "idx": int(subspaces_list[0]["idx"]),
117
- "mag": int(subspaces_list[0]["internal_mag"])
118
  }
119
  ] if subspaces_list else None,
120
  "streamer": streamer,
 
34
  def forward(self, base, source=None, subspaces=None):
35
  if subspaces is None:
36
  return base
37
+ steering_vec = []
38
+ for idx, mag in zip(subspaces["idx"], subspaces["mag"]):
39
+ steering_vec.append(self.proj.weight[idx].unsqueeze(dim=0))
40
+ steering_vec = torch.cat(steering_vec, dim=0).mean(dim=0)
41
+ steering_vec = mag * steering_vec
42
  return base + steering_vec
43
 
44
  # Check GPU
 
94
 
95
  # build list of messages
96
  messages = []
97
+ for user_msg, model_msg in recent_history:
98
+ messages.append({"role": "user", "content": user_msg})
99
+ messages.append({"role": "model", "content": model_msg})
100
  messages.append({"role": "user", "content": message})
101
 
102
  input_ids = torch.tensor([tokenizer.apply_chat_template(
 
116
  "intervene_on_prompt": True,
117
  "subspaces": [
118
  {
119
+ "idx": [int(sl["idx"]) for sl in subspaces_list],
120
+ "mag": [int(sl["internal_mag"]) for sl in subspaces_list]
121
  }
122
  ] if subspaces_list else None,
123
  "streamer": streamer,