Spaces:
Runtime error
Runtime error
frankaging
commited on
Commit
·
bb5c56b
1
Parent(s):
7f3db95
o1 impl
Browse files
app.py
CHANGED
|
@@ -34,8 +34,11 @@ class Steer(pv.SourcelessIntervention):
|
|
| 34 |
def forward(self, base, source=None, subspaces=None):
|
| 35 |
if subspaces is None:
|
| 36 |
return base
|
| 37 |
-
steering_vec =
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
return base + steering_vec
|
| 40 |
|
| 41 |
# Check GPU
|
|
@@ -91,9 +94,9 @@ def generate(
|
|
| 91 |
|
| 92 |
# build list of messages
|
| 93 |
messages = []
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
messages.append({"role": "user", "content": message})
|
| 98 |
|
| 99 |
input_ids = torch.tensor([tokenizer.apply_chat_template(
|
|
@@ -113,8 +116,8 @@ def generate(
|
|
| 113 |
"intervene_on_prompt": True,
|
| 114 |
"subspaces": [
|
| 115 |
{
|
| 116 |
-
"idx": int(
|
| 117 |
-
"mag": int(
|
| 118 |
}
|
| 119 |
] if subspaces_list else None,
|
| 120 |
"streamer": streamer,
|
|
|
|
| 34 |
def forward(self, base, source=None, subspaces=None):
|
| 35 |
if subspaces is None:
|
| 36 |
return base
|
| 37 |
+
steering_vec = []
|
| 38 |
+
for idx, mag in zip(subspaces["idx"], subspaces["mag"]):
|
| 39 |
+
steering_vec.append(self.proj.weight[idx].unsqueeze(dim=0))
|
| 40 |
+
steering_vec = torch.cat(steering_vec, dim=0).mean(dim=0)
|
| 41 |
+
steering_vec = mag * steering_vec
|
| 42 |
return base + steering_vec
|
| 43 |
|
| 44 |
# Check GPU
|
|
|
|
| 94 |
|
| 95 |
# build list of messages
|
| 96 |
messages = []
|
| 97 |
+
for user_msg, model_msg in recent_history:
|
| 98 |
+
messages.append({"role": "user", "content": user_msg})
|
| 99 |
+
messages.append({"role": "model", "content": model_msg})
|
| 100 |
messages.append({"role": "user", "content": message})
|
| 101 |
|
| 102 |
input_ids = torch.tensor([tokenizer.apply_chat_template(
|
|
|
|
| 116 |
"intervene_on_prompt": True,
|
| 117 |
"subspaces": [
|
| 118 |
{
|
| 119 |
+
"idx": [int(sl["idx"]) for sl in subspaces_list],
|
| 120 |
+
"mag": [int(sl["internal_mag"]) for sl in subspaces_list]
|
| 121 |
}
|
| 122 |
] if subspaces_list else None,
|
| 123 |
"streamer": streamer,
|