Spaces:
Running
on
Zero
Running
on
Zero
# login as a privileged user. | |
import os, json | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
from huggingface_hub import login, hf_hub_download | |
login(token=HF_TOKEN) | |
from threading import Thread | |
from typing import Iterator | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import pyvene as pv | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
DESCRIPTION = """\ | |
# Model Steering with Supervised Dictionary Learning (SDL) | |
### What's Model Steering with SDL? | |
This is a demo of model steering with Supervised Dictionary Learning (SDL) using AxBench-ReFT-r1-16K, which hosts steering vectors for 16K concepts. We evaluate various steering methods, including ReFT-r1, a novel weakly-supervised dictionary learning method. ReFT-r1 demonstrates competitive steering capabilities compared to finetuning and prompting baselines. | |
""" | |
LICENSE = """ | |
<p/> | |
--- | |
This demo is governed by the original license and acceptable use policy of the model it is derived from. Please refer to the specific licensing and use policy of the underlying model. | |
""" | |
def load_jsonl(jsonl_path): | |
jsonl_data = [] | |
with open(jsonl_path, 'r') as f: | |
for line in f: | |
data = json.loads(line) | |
jsonl_data += [data] | |
return jsonl_data | |
class Steer(pv.SourcelessIntervention): | |
"""Steer model via activation addition""" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs, keep_last_dim=True) | |
self.proj = torch.nn.Linear( | |
self.embed_dim, kwargs["latent_dim"], bias=False) | |
def forward(self, base, source=None, subspaces=None): | |
steering_vec = torch.tensor(subspaces["mag"]) * \ | |
self.proj.weight[subspaces["idx"]].unsqueeze(dim=0) | |
return base + steering_vec | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
if torch.cuda.is_available(): | |
# load the LLM | |
model_id = "google/gemma-2-2b-it" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, device_map="cuda", torch_dtype=torch.bfloat16 | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# load the dictionary | |
path_to_params = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt", force_download=False) | |
path_to_md = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl", force_download=False) | |
params = torch.load(path_to_params).cuda() | |
md = load_jsonl(path_to_md) | |
id_to_concept = {item["id"]: item["concept"] for item in md} | |
concept_list = [item["concept"] for item in md] | |
steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1]) | |
steer.proj.weight.data = params.float() | |
# Mount the encoder to the model | |
pv_model = pv.IntervenableModel({ | |
"component": f"model.layers[20].output", | |
"intervention": steer}, model=model) | |
terminators = [ | |
tokenizer.eos_token_id, | |
] | |
def generate( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
max_new_tokens: int = 1024, | |
) -> Iterator[str]: | |
# tokenize and prepare the input | |
prompt = torch.tensor([tokenizer.apply_chat_template( | |
[{"role": "user", "content": message}], tokenize=True, add_generation_prompt=True)]).cuda() | |
input_ids = prompt["input_ids"] | |
attention_mask = prompt["attention_mask"] | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = { | |
"base": {"input_ids": input_ids, "attention_mask": attention_mask}, | |
"unit_locations": None, | |
"max_new_tokens": max_new_tokens, | |
"intervene_on_prompt": True, | |
"subspaces": [{"idx": 1795, "mag": 150.0}], | |
"streamer": streamer, | |
"eos_token_id": terminators, | |
"early_stopping": True, | |
"do_sample": True | |
} | |
t = Thread(target=pv_model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
) | |
], | |
stop_btn=None, | |
title="Model Steering with ReFT-r1 (16K concepts)", | |
) | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") | |
chat_interface.render() | |
gr.Markdown(LICENSE) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |