frankaging
switch to pyvene
330e95b
raw
history blame
5.12 kB
# login as a privileged user.
import os, json
HF_TOKEN = os.environ.get("HF_TOKEN")
from huggingface_hub import login, hf_hub_download
login(token=HF_TOKEN)
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import pyvene as pv
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# Model Steering with Supervised Dictionary Learning (SDL)
### What's Model Steering with SDL?
This is a demo of model steering with Supervised Dictionary Learning (SDL) using AxBench-ReFT-r1-16K, which hosts steering vectors for 16K concepts. We evaluate various steering methods, including ReFT-r1, a novel weakly-supervised dictionary learning method. ReFT-r1 demonstrates competitive steering capabilities compared to finetuning and prompting baselines.
"""
LICENSE = """
<p/>
---
This demo is governed by the original license and acceptable use policy of the model it is derived from. Please refer to the specific licensing and use policy of the underlying model.
"""
def load_jsonl(jsonl_path):
jsonl_data = []
with open(jsonl_path, 'r') as f:
for line in f:
data = json.loads(line)
jsonl_data += [data]
return jsonl_data
class Steer(pv.SourcelessIntervention):
"""Steer model via activation addition"""
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
self.proj = torch.nn.Linear(
self.embed_dim, kwargs["latent_dim"], bias=False)
def forward(self, base, source=None, subspaces=None):
steering_vec = torch.tensor(subspaces["mag"]) * \
self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
return base + steering_vec
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
# load the LLM
model_id = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="cuda", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# load the dictionary
path_to_params = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt", force_download=False)
path_to_md = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl", force_download=False)
params = torch.load(path_to_params).cuda()
md = load_jsonl(path_to_md)
id_to_concept = {item["id"]: item["concept"] for item in md}
concept_list = [item["concept"] for item in md]
steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1])
steer.proj.weight.data = params.float()
# Mount the encoder to the model
pv_model = pv.IntervenableModel({
"component": f"model.layers[20].output",
"intervention": steer}, model=model)
terminators = [
tokenizer.eos_token_id,
]
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
) -> Iterator[str]:
# tokenize and prepare the input
prompt = torch.tensor([tokenizer.apply_chat_template(
[{"role": "user", "content": message}], tokenize=True, add_generation_prompt=True)]).cuda()
input_ids = prompt["input_ids"]
attention_mask = prompt["attention_mask"]
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"base": {"input_ids": input_ids, "attention_mask": attention_mask},
"unit_locations": None,
"max_new_tokens": max_new_tokens,
"intervene_on_prompt": True,
"subspaces": [{"idx": 1795, "mag": 150.0}],
"streamer": streamer,
"eos_token_id": terminators,
"early_stopping": True,
"do_sample": True
}
t = Thread(target=pv_model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
)
],
stop_btn=None,
title="Model Steering with ReFT-r1 (16K concepts)",
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch()