File size: 5,120 Bytes
0e90065
330e95b
0e90065
 
330e95b
0e90065
 
 
 
 
 
 
 
 
330e95b
0e90065
 
 
 
 
 
 
 
330e95b
0e90065
330e95b
 
0e90065
 
 
 
 
 
330e95b
0e90065
 
330e95b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e90065
 
 
 
 
330e95b
 
0e90065
 
 
 
330e95b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e90065
fcb8864
 
 
 
0e90065
 
 
 
 
 
 
 
 
330e95b
 
0e90065
 
 
 
 
 
 
 
 
 
 
 
330e95b
0e90065
 
330e95b
0e90065
 
 
 
 
 
330e95b
0e90065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330e95b
0e90065
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# login as a privileged user.
import os, json
HF_TOKEN = os.environ.get("HF_TOKEN")

from huggingface_hub import login, hf_hub_download
login(token=HF_TOKEN)

from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import pyvene as pv


MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))


DESCRIPTION = """\
# Model Steering with Supervised Dictionary Learning (SDL)

### What's Model Steering with SDL?
This is a demo of model steering with Supervised Dictionary Learning (SDL) using AxBench-ReFT-r1-16K, which hosts steering vectors for 16K concepts. We evaluate various steering methods, including ReFT-r1, a novel weakly-supervised dictionary learning method. ReFT-r1 demonstrates competitive steering capabilities compared to finetuning and prompting baselines.
"""

LICENSE = """
<p/>

---
This demo is governed by the original license and acceptable use policy of the model it is derived from. Please refer to the specific licensing and use policy of the underlying model.
"""

def load_jsonl(jsonl_path):
    jsonl_data = []
    with open(jsonl_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            jsonl_data += [data]
    return jsonl_data


class Steer(pv.SourcelessIntervention):
    """Steer model via activation addition"""
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
                self.embed_dim, kwargs["latent_dim"], bias=False)
    def forward(self, base, source=None, subspaces=None):
        steering_vec = torch.tensor(subspaces["mag"]) * \
            self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
        return base + steering_vec


if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


if torch.cuda.is_available():
    # load the LLM
    model_id = "google/gemma-2-2b-it"
    model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="cuda", torch_dtype=torch.bfloat16
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # load the dictionary
    path_to_params = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt", force_download=False)
    path_to_md = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl", force_download=False)
    params = torch.load(path_to_params).cuda()
    md = load_jsonl(path_to_md)
    id_to_concept = {item["id"]: item["concept"] for item in md}
    concept_list = [item["concept"] for item in md]

    steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1])
    steer.proj.weight.data = params.float()

    # Mount the encoder to the model
    pv_model = pv.IntervenableModel({
        "component": f"model.layers[20].output",
        "intervention": steer}, model=model)

terminators = [
    tokenizer.eos_token_id,
]


@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
) -> Iterator[str]:

    # tokenize and prepare the input
    prompt = torch.tensor([tokenizer.apply_chat_template(
        [{"role": "user", "content": message}], tokenize=True, add_generation_prompt=True)]).cuda()
    
    input_ids = prompt["input_ids"]
    attention_mask = prompt["attention_mask"]
    
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = {
        "base": {"input_ids": input_ids, "attention_mask": attention_mask},
        "unit_locations": None,
        "max_new_tokens": max_new_tokens,
        "intervene_on_prompt": True,
        "subspaces": [{"idx": 1795, "mag": 150.0}],
        "streamer": streamer,
        "eos_token_id": terminators,
        "early_stopping": True,
        "do_sample": True
    }

    t = Thread(target=pv_model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        )
    ],
    stop_btn=None,
    title="Model Steering with ReFT-r1 (16K concepts)",
)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    chat_interface.render()
    gr.Markdown(LICENSE)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()