# login as a privileged user. import os, json HF_TOKEN = os.environ.get("HF_TOKEN") from huggingface_hub import login, hf_hub_download login(token=HF_TOKEN) from threading import Thread from typing import Iterator import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import pyvene as pv MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) DESCRIPTION = """\ # Model Steering with Supervised Dictionary Learning (SDL) ### What's Model Steering with SDL? This is a demo of model steering with Supervised Dictionary Learning (SDL) using AxBench-ReFT-r1-16K, which hosts steering vectors for 16K concepts. We evaluate various steering methods, including ReFT-r1, a novel weakly-supervised dictionary learning method. ReFT-r1 demonstrates competitive steering capabilities compared to finetuning and prompting baselines. """ LICENSE = """

--- This demo is governed by the original license and acceptable use policy of the model it is derived from. Please refer to the specific licensing and use policy of the underlying model. """ def load_jsonl(jsonl_path): jsonl_data = [] with open(jsonl_path, 'r') as f: for line in f: data = json.loads(line) jsonl_data += [data] return jsonl_data class Steer(pv.SourcelessIntervention): """Steer model via activation addition""" def __init__(self, **kwargs): super().__init__(**kwargs, keep_last_dim=True) self.proj = torch.nn.Linear( self.embed_dim, kwargs["latent_dim"], bias=False) def forward(self, base, source=None, subspaces=None): steering_vec = torch.tensor(subspaces["mag"]) * \ self.proj.weight[subspaces["idx"]].unsqueeze(dim=0) return base + steering_vec if not torch.cuda.is_available(): DESCRIPTION += "\n

Running on CPU 🥶 This demo does not work on CPU.

" if torch.cuda.is_available(): # load the LLM model_id = "google/gemma-2-2b-it" model = AutoModelForCausalLM.from_pretrained( model_id, device_map="cuda", torch_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_id) # load the dictionary path_to_params = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt", force_download=False) path_to_md = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl", force_download=False) params = torch.load(path_to_params).cuda() md = load_jsonl(path_to_md) id_to_concept = {item["id"]: item["concept"] for item in md} concept_list = [item["concept"] for item in md] steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1]) steer.proj.weight.data = params.float() # Mount the encoder to the model pv_model = pv.IntervenableModel({ "component": f"model.layers[20].output", "intervention": steer}, model=model) terminators = [ tokenizer.eos_token_id, ] @spaces.GPU def generate( message: str, chat_history: list[tuple[str, str]], max_new_tokens: int = 1024, ) -> Iterator[str]: # tokenize and prepare the input prompt = torch.tensor([tokenizer.apply_chat_template( [{"role": "user", "content": message}], tokenize=True, add_generation_prompt=True)]).cuda() input_ids = prompt["input_ids"] attention_mask = prompt["attention_mask"] if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "base": {"input_ids": input_ids, "attention_mask": attention_mask}, "unit_locations": None, "max_new_tokens": max_new_tokens, "intervene_on_prompt": True, "subspaces": [{"idx": 1795, "mag": 150.0}], "streamer": streamer, "eos_token_id": terminators, "early_stopping": True, "do_sample": True } t = Thread(target=pv_model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) chat_interface = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ) ], stop_btn=None, title="Model Steering with ReFT-r1 (16K concepts)", ) with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") chat_interface.render() gr.Markdown(LICENSE) if __name__ == "__main__": demo.queue(max_size=20).launch()