frankaging
o1 impl
bddba98
raw
history blame
8.27 kB
import os, json, random
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from huggingface_hub import login, hf_hub_download
import pyreft
import pyvene as pv
from threading import Thread
from typing import Iterator
HF_TOKEN = os.environ.get("HF_TOKEN")
login(token=HF_TOKEN)
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 128 # smaller default to save memory
MAX_INPUT_TOKEN_LENGTH = 4096
def load_jsonl(jsonl_path):
jsonl_data = []
with open(jsonl_path, 'r') as f:
for line in f:
data = json.loads(line)
jsonl_data.append(data)
return jsonl_data
class Steer(pv.SourcelessIntervention):
"""Steer model via activation addition"""
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
self.proj = torch.nn.Linear(
self.embed_dim, kwargs["latent_dim"], bias=False
)
def forward(self, base, source=None, subspaces=None):
steering_vec = torch.tensor(subspaces["mag"]) * \
self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
return base + steering_vec
# Check GPU
if not torch.cuda.is_available():
print("Warning: Running on CPU, may be slow.")
# Load model & dictionary
model_id = "google/gemma-2-2b-it"
pv_model = None
tokenizer = None
concept_list = []
concept_id_map = {}
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="cuda", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Download dictionary
weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt")
meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl")
params = torch.load(weight_path).cuda()
md = load_jsonl(meta_path)
concept_list = [item["concept"] for item in md]
concept_id_map = {item["concept"]: item["concept_id"] for item in md}
steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1])
steer.proj.weight.data = params.float()
pv_model = pv.IntervenableModel({
"component": f"model.layers[20].output",
"intervention": steer}, model=model)
terminators = [tokenizer.eos_token_id] if tokenizer else []
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
subspaces_list: list[dict],
max_new_tokens: int=DEFAULT_MAX_NEW_TOKENS,
) -> Iterator[str]:
# limit to last 3 turns
start_idx = max(0, len(chat_history) - 3)
recent_history = chat_history[start_idx:]
# build list of messages
messages = []
# for user_msg, model_msg in recent_history:
# messages.append({"role": "user", "content": user_msg})
# messages.append({"role": "model", "content": model_msg})
messages.append({"role": "user", "content": message})
input_ids = torch.tensor([tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True)]).cuda()
# trim if needed
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
yield "[Truncated prior text]\n"
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"base": {"input_ids": input_ids},
"unit_locations": None,
"max_new_tokens": max_new_tokens,
"intervene_on_prompt": True,
"subspaces": [
{
"idx": int(subspaces_list[0]["idx"]),
"mag": int(subspaces_list[0]["internal_mag"])
}
] if subspaces_list else [],
"streamer": streamer,
"do_sample": True
}
t = Thread(target=pv_model.generate, kwargs=generate_kwargs)
t.start()
partial_text = []
for token_str in streamer:
partial_text.append(token_str)
yield "".join(partial_text)
def filter_concepts(search_text: str):
if not search_text.strip():
return concept_list[:500]
filtered = [c for c in concept_list if search_text.lower() in c.lower()]
return filtered[:500]
def add_concept_to_list(selected_concept, user_slider_val, current_list):
"""
Return exactly 2 values:
1) The updated list of concepts (list of dicts).
2) A Gradio update for the removal dropdown’s choices.
"""
if not selected_concept:
return current_list, gr.update(choices=_build_remove_choices(current_list))
idx = concept_id_map[selected_concept]
internal_mag = user_slider_val * 50
new_entry = {
"text": selected_concept,
"idx": idx,
"display_mag": user_slider_val,
"internal_mag": internal_mag,
}
updated_list = current_list + [new_entry]
return updated_list, gr.update(choices=_build_remove_choices(updated_list))
def remove_concept_from_list(selected_text, current_list):
"""
Return exactly 2 values:
1) The updated list of concepts (list of dicts).
2) A Gradio update for the removal dropdown’s choices.
"""
if not selected_text:
return current_list, gr.update(choices=_build_remove_choices(current_list))
updated_list = [x for x in current_list if x["text"] != selected_text]
return updated_list, gr.update(choices=_build_remove_choices(updated_list))
def _build_remove_choices(subspaces):
return [x["text"] for x in subspaces]
def update_dropdown_choices(search_text):
filtered = filter_concepts(search_text)
return gr.update(choices=filtered)
with gr.Blocks(css="style.css") as demo:
# Pre-populate with a random concept if available
default_subspaces = []
if pv_model and concept_list:
default_concept = "words related to time travel and its consequences"
default_subspaces = [{
"text": default_concept,
"idx": concept_id_map[default_concept],
"display_mag": 3,
"internal_mag": 150.0,
}]
selected_subspaces = gr.State(default_subspaces)
with gr.Row():
# Left side: bigger chat area
with gr.Column(scale=7):
chat_interface = gr.ChatInterface(
fn=generate,
title="LM Steering with ReFT-r1 (16K concepts)",
type="messages",
additional_inputs=[selected_subspaces],
)
# Right side: concept management
with gr.Column(scale=3):
gr.Markdown("# Steering Concepts")
search_box = gr.Textbox(
label="Search concepts",
placeholder="e.g. 'time travel'"
)
concept_dropdown = gr.Dropdown(
label="Filtered Concepts",
choices=[]
)
concept_magnitude = gr.Slider(
label="Steering Factor",
minimum=-5,
maximum=5,
step=1,
value=3
)
add_button = gr.Button("Add Concept")
# Row with the remove dropdown + button
with gr.Row():
remove_dropdown = gr.Dropdown(
label="Remove concept",
choices=_build_remove_choices(default_subspaces),
multiselect=False
)
remove_button = gr.Button("Remove", variant="secondary")
# Wire up events
# When the search box changes, update the concept dropdown choices:
search_box.change(
update_dropdown_choices,
[search_box],
[concept_dropdown]
)
# When "Add Concept" is clicked, add the concept + magnitude to the list,
# and update the "Remove" dropdown choices.
add_button.click(
add_concept_to_list,
[concept_dropdown, concept_magnitude, selected_subspaces],
[selected_subspaces, remove_dropdown]
)
# When "Remove" is clicked, remove the selected concept from the list,
# and update the "Remove" dropdown choices.
remove_button.click(
remove_concept_from_list,
[remove_dropdown, selected_subspaces],
[selected_subspaces, remove_dropdown]
)
demo.launch()