Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,120 Bytes
0e90065 330e95b 0e90065 330e95b 0e90065 330e95b 0e90065 330e95b 0e90065 330e95b 0e90065 330e95b 0e90065 330e95b 0e90065 330e95b 0e90065 330e95b 0e90065 fcb8864 0e90065 330e95b 0e90065 330e95b 0e90065 330e95b 0e90065 330e95b 0e90065 330e95b 0e90065 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# login as a privileged user.
import os, json
HF_TOKEN = os.environ.get("HF_TOKEN")
from huggingface_hub import login, hf_hub_download
login(token=HF_TOKEN)
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import pyvene as pv
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# Model Steering with Supervised Dictionary Learning (SDL)
### What's Model Steering with SDL?
This is a demo of model steering with Supervised Dictionary Learning (SDL) using AxBench-ReFT-r1-16K, which hosts steering vectors for 16K concepts. We evaluate various steering methods, including ReFT-r1, a novel weakly-supervised dictionary learning method. ReFT-r1 demonstrates competitive steering capabilities compared to finetuning and prompting baselines.
"""
LICENSE = """
<p/>
---
This demo is governed by the original license and acceptable use policy of the model it is derived from. Please refer to the specific licensing and use policy of the underlying model.
"""
def load_jsonl(jsonl_path):
jsonl_data = []
with open(jsonl_path, 'r') as f:
for line in f:
data = json.loads(line)
jsonl_data += [data]
return jsonl_data
class Steer(pv.SourcelessIntervention):
"""Steer model via activation addition"""
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
self.proj = torch.nn.Linear(
self.embed_dim, kwargs["latent_dim"], bias=False)
def forward(self, base, source=None, subspaces=None):
steering_vec = torch.tensor(subspaces["mag"]) * \
self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
return base + steering_vec
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
# load the LLM
model_id = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="cuda", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# load the dictionary
path_to_params = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt", force_download=False)
path_to_md = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl", force_download=False)
params = torch.load(path_to_params).cuda()
md = load_jsonl(path_to_md)
id_to_concept = {item["id"]: item["concept"] for item in md}
concept_list = [item["concept"] for item in md]
steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1])
steer.proj.weight.data = params.float()
# Mount the encoder to the model
pv_model = pv.IntervenableModel({
"component": f"model.layers[20].output",
"intervention": steer}, model=model)
terminators = [
tokenizer.eos_token_id,
]
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
) -> Iterator[str]:
# tokenize and prepare the input
prompt = torch.tensor([tokenizer.apply_chat_template(
[{"role": "user", "content": message}], tokenize=True, add_generation_prompt=True)]).cuda()
input_ids = prompt["input_ids"]
attention_mask = prompt["attention_mask"]
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"base": {"input_ids": input_ids, "attention_mask": attention_mask},
"unit_locations": None,
"max_new_tokens": max_new_tokens,
"intervene_on_prompt": True,
"subspaces": [{"idx": 1795, "mag": 150.0}],
"streamer": streamer,
"eos_token_id": terminators,
"early_stopping": True,
"do_sample": True
}
t = Thread(target=pv_model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
)
],
stop_btn=None,
title="Model Steering with ReFT-r1 (16K concepts)",
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|