frankaging
hande made
7de2513
raw
history blame
7.95 kB
import os, json, random
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from huggingface_hub import login, hf_hub_download
import pyreft
import pyvene as pv
from threading import Thread
from typing import Iterator
HF_TOKEN = os.environ.get("HF_TOKEN")
login(token=HF_TOKEN)
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 256 # smaller default to save memory
MAX_INPUT_TOKEN_LENGTH = 4096
def load_jsonl(jsonl_path):
jsonl_data = []
with open(jsonl_path, 'r') as f:
for line in f:
data = json.loads(line)
jsonl_data.append(data)
return jsonl_data
class Steer(pv.SourcelessIntervention):
"""Steer model via activation addition"""
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
self.proj = torch.nn.Linear(
self.embed_dim, kwargs["latent_dim"], bias=False
)
def forward(self, base, source=None, subspaces=None):
if subspaces is None:
return base
steering_vec = []
avg_mag = sum(subspaces["mag"]) / len(subspaces["mag"])
for idx, mag in zip(subspaces["idx"], subspaces["mag"]):
steering_vec.append(self.proj.weight[idx].unsqueeze(dim=0))
steering_vec = torch.cat(steering_vec, dim=0).mean(dim=0)
steering_vec = avg_mag * steering_vec
return base + steering_vec
# Check GPU
if not torch.cuda.is_available():
print("Warning: Running on CPU, may be slow.")
# Load model & dictionary
model_id = "google/gemma-2-2b-it"
pv_model = None
tokenizer = None
concept_list = []
concept_id_map = {}
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="cuda", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Download dictionary
weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt")
meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl")
params = torch.load(weight_path).cuda()
md = load_jsonl(meta_path)
concept_list = [item["concept"] for item in md]
concept_id_map = {}
# the reason to reindex is because there is one concept that is missing.
concept_reindex = 0
for item in md:
concept_id_map[item["concept"]] = concept_reindex
concept_reindex += 1
steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1])
steer.proj.weight.data = params.float()
pv_model = pv.IntervenableModel({
"component": f"model.layers[20].output",
"intervention": steer}, model=model)
terminators = [tokenizer.eos_token_id] if tokenizer else []
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
subspaces_list: list[dict],
max_new_tokens: int=DEFAULT_MAX_NEW_TOKENS,
) -> Iterator[str]:
# limit to last 4 turns
start_idx = max(0, len(chat_history) - 4)
recent_history = chat_history[start_idx:]
# build list of messages
messages = []
for rh in recent_history:
messages.append({"role": rh["role"], "content": rh["content"]})
messages.append({"role": "user", "content": message})
input_ids = torch.tensor([tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True)]).cuda()
# trim if needed
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
yield "[Truncated prior text]\n"
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
print(subspaces_list)
generate_kwargs = {
"base": {"input_ids": input_ids},
"unit_locations": None,
"max_new_tokens": max_new_tokens,
"intervene_on_prompt": True,
"subspaces": [
{
"idx": [int(sl["idx"]) for sl in subspaces_list],
"mag": [int(sl["internal_mag"]) for sl in subspaces_list]
}
] if subspaces_list else None,
"streamer": streamer,
"do_sample": True
}
t = Thread(target=pv_model.generate, kwargs=generate_kwargs)
t.start()
partial_text = []
for token_str in streamer:
partial_text.append(token_str)
yield "".join(partial_text)
def _build_remove_choices(subspaces):
return [f"(+{x['display_mag']:.1f}*) {x['text']}" for x in subspaces]
def filter_concepts(search_text: str):
if not search_text.strip():
return concept_list[:500]
filtered = [c for c in concept_list if search_text.lower() in c.lower()]
return filtered[:500]
def add_concept_to_list(selected_concept, user_slider_val, current_list):
if not selected_concept:
return current_list, gr.update(choices=_build_remove_choices(current_list))
idx = concept_id_map[selected_concept]
internal_mag = user_slider_val * 50
new_entry = {
"text": selected_concept,
"idx": idx,
"display_mag": user_slider_val,
"internal_mag": internal_mag,
}
# Add to the beginning of the list
current_list = [new_entry]
return current_list
def update_dropdown_choices(search_text):
filtered = filter_concepts(search_text)
if not filtered:
return gr.update(choices=[], value=None, interactive=True)
# Automatically select the first matching concept
return gr.update(
choices=filtered,
value=filtered[0], # Select the first match
interactive=True
)
with gr.Blocks(fill_height=True) as demo:
# Remove default subspaces
selected_subspaces = gr.State([])
with gr.Row(min_height=700):
# Left side: bigger chat area
with gr.Column(scale=7):
chat_interface = gr.ChatInterface(
fn=generate,
title="Chat with a Concept Steering Model",
description="Steer responses by selecting concepts on the right →",
type="messages",
additional_inputs=[selected_subspaces],
fill_height=True
)
# Right side: concept management
with gr.Column(scale=3):
gr.Markdown("## Steer Model Responses")
gr.Markdown("Search and then select a concept to steer. The closest match will be automatically selected.")
# Concept Search and Selection
with gr.Group():
search_box = gr.Textbox(
label="Search Concepts",
placeholder="Find concepts to steer the model (e.g. 'time travel')",
lines=2,
)
concept_dropdown = gr.Dropdown(
label="Select a concept to steer the model",
interactive=True,
allow_custom_value=False
)
concept_magnitude = gr.Slider(
label="Steering Intensity",
minimum=-5,
maximum=5,
step=0.1,
value=3,
)
# Wire up events
# When search box changes, update dropdown AND trigger concept selection
search_box.change(
update_dropdown_choices,
[search_box],
[concept_dropdown]
).then( # Chain the events to automatically add the concept
add_concept_to_list,
[concept_dropdown, concept_magnitude, selected_subspaces],
[selected_subspaces]
)
concept_dropdown.select(
add_concept_to_list,
[concept_dropdown, concept_magnitude, selected_subspaces],
[selected_subspaces]
)
concept_magnitude.input(
add_concept_to_list,
[concept_dropdown, concept_magnitude, selected_subspaces],
[selected_subspaces]
)
demo.launch(share=True)