Spaces:
Running
on
Zero
Running
on
Zero
import os, json | |
import torch | |
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from huggingface_hub import login, hf_hub_download | |
import pyvene as pv | |
from threading import Thread | |
from typing import Iterator | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
login(token=HF_TOKEN) | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
DESCRIPTION = """\ | |
# Model Steering with Supervised Dictionary Learning (SDL) | |
### What's Model Steering with SDL? | |
This is a demo of model steering with AxBench-ReFT-r1-16K, ... | |
""" | |
LICENSE = """ | |
<p/> | |
--- | |
Please refer to the specific licensing and use policy of the underlying model. | |
""" | |
def load_jsonl(jsonl_path): | |
jsonl_data = [] | |
with open(jsonl_path, 'r') as f: | |
for line in f: | |
data = json.loads(line) | |
jsonl_data.append(data) | |
return jsonl_data | |
class Steer(pv.SourcelessIntervention): | |
"""Steer model via activation addition""" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs, keep_last_dim=True) | |
self.proj = torch.nn.Linear(self.embed_dim, kwargs["latent_dim"], bias=False) | |
def forward(self, base, source=None, subspaces=None): | |
# subspaces is a list of dicts: each has {"idx": int, "mag": float} | |
steer_vec = base | |
if subspaces is not None: | |
for sp in subspaces: | |
idx = sp["idx"] | |
mag = sp["mag"] | |
# each idx is a row in self.proj.weight | |
steering_vec = mag * self.proj.weight[idx].unsqueeze(dim=0) | |
steer_vec = steer_vec + steering_vec | |
return steer_vec | |
# --------------------------------------------------- | |
# Load Model & Dictionary if GPU is available | |
# --------------------------------------------------- | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo won't perform well on CPU.</p>" | |
if torch.cuda.is_available(): | |
model_id = "google/gemma-2-2b-it" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, device_map="cuda", torch_dtype=torch.bfloat16 | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
path_to_params = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt") | |
path_to_md = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl") | |
params = torch.load(path_to_params).cuda() | |
md = load_jsonl(path_to_md) | |
concept_list = [item["concept"] for item in md] | |
concept_id_map = {item["concept"]: item["concept_id"] for item in md} | |
steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1]) | |
steer.proj.weight.data = params.float() | |
pv_model = pv.IntervenableModel( | |
{ | |
"component": f"model.layers[20].output", | |
"intervention": steer, | |
}, | |
model=model, | |
) | |
terminators = [tokenizer.eos_token_id] | |
# --------------------------------------------------------------------- | |
# The main generation function, limiting to last 3 conversation turns | |
# and then using apply_chat_template | |
# --------------------------------------------------------------------- | |
def generate( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
max_new_tokens: int, | |
subspaces_list: list[dict], | |
) -> Iterator[str]: | |
# Restrict to the last 3 turns only | |
start_idx = max(0, len(chat_history) - 3) | |
recent_history = chat_history[start_idx:] | |
# Build a list of messages | |
# each tuple is (user_message, assistant_message) | |
messages = [] | |
for user_msg, assistant_msg in recent_history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Now append the new user message | |
messages.append({"role": "user", "content": message}) | |
# Convert messages into model input tokens with a generation prompt | |
prompt = tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=True # appends a final "Assistant:" for the model to continue | |
) | |
# Retrieve input_ids and mask | |
input_ids = torch.tensor([prompt["input_ids"]]).cuda() | |
attention_mask = torch.tensor([prompt["attention_mask"]]).cuda() | |
# Possibly trim if over max length | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:] | |
yield "\n[Warning: Truncated conversation exceeds max allowed input tokens]\n" | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = { | |
"base": {"input_ids": input_ids, "attention_mask": attention_mask}, | |
"unit_locations": None, | |
"max_new_tokens": max_new_tokens, | |
"intervene_on_prompt": True, | |
"subspaces": subspaces_list, | |
"streamer": streamer, | |
"eos_token_id": terminators, | |
"early_stopping": True, | |
"do_sample": True | |
} | |
t = Thread(target=pv_model.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_text = [] | |
for token_str in streamer: | |
partial_text.append(token_str) | |
yield "".join(partial_text) | |
# -------------- | |
# UI Callbacks | |
# -------------- | |
def filter_concepts(search_text: str): | |
if not search_text.strip(): | |
return concept_list[:500] | |
filtered = [c for c in concept_list if search_text.lower() in c.lower()] | |
return filtered[:500] | |
def add_concept_to_list(selected_concept, magnitude, current_list): | |
"""When 'Add Concept' is clicked, add the chosen concept and magnitude to subspaces.""" | |
if not selected_concept: | |
return current_list, current_list, gr.update(choices=[str(x["idx"]) for x in current_list]) | |
concept_idx = concept_id_map[selected_concept] | |
new_entry = {"idx": concept_idx, "mag": magnitude} | |
updated_list = current_list + [new_entry] | |
remove_choices = [str(x["idx"]) for x in updated_list] | |
table_data = [[x['idx'], x['mag']] for x in updated_list] | |
return updated_list, table_data, gr.update(choices=remove_choices) | |
def remove_concept_from_list(rem_concept_idx_str, current_list): | |
"""Remove the chosen concept from the list. Index is a string from remove_dropdown.""" | |
if not rem_concept_idx_str: | |
return current_list, current_list, gr.update() | |
rem_idx = int(rem_concept_idx_str) | |
updated_list = [x for x in current_list if x["idx"] != rem_idx] | |
remove_choices = [str(x["idx"]) for x in updated_list] | |
table_data = [[x['idx'], x['mag']] for x in updated_list] | |
return updated_list, table_data, gr.update(choices=remove_choices) | |
def update_dropdown_choices(search_text): | |
filtered = filter_concepts(search_text) | |
return gr.update(choices=filtered) | |
# ------------------------- | |
# Build the Gradio Blocks | |
# ------------------------- | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") | |
selected_subspaces = gr.State([]) | |
with gr.Row(): | |
with gr.Column(): | |
# Searching / selecting a concept | |
search_box = gr.Textbox( | |
label="Search concepts", | |
placeholder="Type text to filter concepts (e.g. 'sports')" | |
) | |
concept_dropdown = gr.Dropdown( | |
label="Filtered Concepts", | |
choices=[], | |
multiselect=False | |
) | |
concept_magnitude = gr.Slider( | |
label="Magnitude", | |
minimum=-300.0, | |
maximum=300.0, | |
step=1.0, | |
value=150.0 | |
) | |
add_button = gr.Button("Add Concept") | |
# Removal | |
remove_dropdown = gr.Dropdown( | |
label="Remove from active list", | |
choices=[], | |
multiselect=False | |
) | |
remove_button = gr.Button("Remove Selected") | |
with gr.Column(): | |
# Display currently active subspaces | |
active_subspaces_table = gr.Dataframe( | |
headers=["idx", "magnitude"], | |
datatype=["number", "number"], | |
interactive=False, | |
label="Active Concept Subspaces" | |
) | |
# The Chat Interface | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
), | |
selected_subspaces | |
], | |
title="Model Steering with ReFT-r1 (16K concepts)", | |
) | |
gr.Markdown(LICENSE) | |
# Wire up events | |
search_box.change( | |
fn=update_dropdown_choices, | |
inputs=[search_box], | |
outputs=[concept_dropdown] | |
) | |
add_button.click( | |
fn=add_concept_to_list, | |
inputs=[concept_dropdown, concept_magnitude, selected_subspaces], | |
outputs=[selected_subspaces, active_subspaces_table, remove_dropdown], | |
) | |
remove_button.click( | |
fn=remove_concept_from_list, | |
inputs=[remove_dropdown, selected_subspaces], | |
outputs=[selected_subspaces, active_subspaces_table, remove_dropdown], | |
) | |
demo.queue(max_size=20).launch() | |