Spaces:
Sleeping
Sleeping
import os, json, random | |
import torch | |
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from huggingface_hub import login, hf_hub_download | |
import pyreft | |
import pyvene as pv | |
from threading import Thread | |
from typing import Iterator | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
login(token=HF_TOKEN) | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 128 # smaller default to save memory | |
MAX_INPUT_TOKEN_LENGTH = 4096 | |
def load_jsonl(jsonl_path): | |
jsonl_data = [] | |
with open(jsonl_path, 'r') as f: | |
for line in f: | |
data = json.loads(line) | |
jsonl_data.append(data) | |
return jsonl_data | |
class Steer(pv.SourcelessIntervention): | |
"""Steer model via activation addition""" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs, keep_last_dim=True) | |
self.proj = torch.nn.Linear( | |
self.embed_dim, kwargs["latent_dim"], bias=False | |
) | |
def forward(self, base, source=None, subspaces=None): | |
steering_vec = torch.tensor(subspaces["mag"]) * \ | |
self.proj.weight[subspaces["idx"]].unsqueeze(dim=0) | |
return base + steering_vec | |
# Check GPU | |
if not torch.cuda.is_available(): | |
print("Warning: Running on CPU, may be slow.") | |
# Load model & dictionary | |
model_id = "google/gemma-2-2b-it" | |
pv_model = None | |
tokenizer = None | |
concept_list = [] | |
concept_id_map = {} | |
if torch.cuda.is_available(): | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, device_map="cuda", torch_dtype=torch.bfloat16 | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# Download dictionary | |
weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt") | |
meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl") | |
params = torch.load(weight_path).cuda() | |
md = load_jsonl(meta_path) | |
concept_list = [item["concept"] for item in md] | |
concept_id_map = {item["concept"]: item["concept_id"] for item in md} | |
steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1]) | |
steer.proj.weight.data = params.float() | |
pv_model = pv.IntervenableModel({ | |
"component": f"model.layers[20].output", | |
"intervention": steer}, model=model) | |
terminators = [tokenizer.eos_token_id] if tokenizer else [] | |
def generate( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
subspaces_list: list[dict], | |
max_new_tokens: int=DEFAULT_MAX_NEW_TOKENS, | |
) -> Iterator[str]: | |
# limit to last 3 turns | |
start_idx = max(0, len(chat_history) - 3) | |
recent_history = chat_history[start_idx:] | |
# build list of messages | |
messages = [] | |
# for user_msg, model_msg in recent_history: | |
# messages.append({"role": "user", "content": user_msg}) | |
# messages.append({"role": "model", "content": model_msg}) | |
messages.append({"role": "user", "content": message}) | |
input_ids = torch.tensor([tokenizer.apply_chat_template( | |
messages, tokenize=True, add_generation_prompt=True)]).cuda() | |
# trim if needed | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
yield "[Truncated prior text]\n" | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = { | |
"base": {"input_ids": input_ids}, | |
"unit_locations": None, | |
"max_new_tokens": max_new_tokens, | |
"intervene_on_prompt": True, | |
"subspaces": [ | |
{ | |
"idx": int(subspaces_list[0]["idx"]), | |
"mag": int(subspaces_list[0]["internal_mag"]) | |
} | |
] if subspaces_list else [], | |
"streamer": streamer, | |
"do_sample": True | |
} | |
t = Thread(target=pv_model.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_text = [] | |
for token_str in streamer: | |
partial_text.append(token_str) | |
yield "".join(partial_text) | |
def filter_concepts(search_text: str): | |
if not search_text.strip(): | |
return concept_list[:500] | |
filtered = [c for c in concept_list if search_text.lower() in c.lower()] | |
return filtered[:500] | |
def add_concept_to_list(selected_concept, user_slider_val, current_list): | |
""" | |
Return exactly 2 values: | |
1) The updated list of concepts (list of dicts). | |
2) A Gradio update for the removal dropdown’s choices. | |
""" | |
if not selected_concept: | |
return current_list, gr.update(choices=_build_remove_choices(current_list)) | |
idx = concept_id_map[selected_concept] | |
internal_mag = user_slider_val * 50 | |
new_entry = { | |
"text": selected_concept, | |
"idx": idx, | |
"display_mag": user_slider_val, | |
"internal_mag": internal_mag, | |
} | |
updated_list = current_list + [new_entry] | |
return updated_list, gr.update(choices=_build_remove_choices(updated_list)) | |
def remove_concept_from_list(selected_text, current_list): | |
""" | |
Return exactly 2 values: | |
1) The updated list of concepts (list of dicts). | |
2) A Gradio update for the removal dropdown’s choices. | |
""" | |
if not selected_text: | |
return current_list, gr.update(choices=_build_remove_choices(current_list)) | |
updated_list = [x for x in current_list if x["text"] != selected_text] | |
return updated_list, gr.update(choices=_build_remove_choices(updated_list)) | |
def _build_remove_choices(subspaces): | |
return [x["text"] for x in subspaces] | |
def update_dropdown_choices(search_text): | |
filtered = filter_concepts(search_text) | |
return gr.update(choices=filtered) | |
with gr.Blocks(css="style.css") as demo: | |
# Pre-populate with a random concept if available | |
default_subspaces = [] | |
if pv_model and concept_list: | |
default_concept = "words related to time travel and its consequences" | |
default_subspaces = [{ | |
"text": default_concept, | |
"idx": concept_id_map[default_concept], | |
"display_mag": 3, | |
"internal_mag": 150.0, | |
}] | |
selected_subspaces = gr.State(default_subspaces) | |
with gr.Row(): | |
# Left side: bigger chat area | |
with gr.Column(scale=7): | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
title="LM Steering with ReFT-r1 (16K concepts)", | |
type="messages", | |
additional_inputs=[selected_subspaces], | |
) | |
# Right side: concept management | |
with gr.Column(scale=3): | |
gr.Markdown("# Steering Concepts") | |
search_box = gr.Textbox( | |
label="Search concepts", | |
placeholder="e.g. 'time travel'" | |
) | |
concept_dropdown = gr.Dropdown( | |
label="Filtered Concepts", | |
choices=[] | |
) | |
concept_magnitude = gr.Slider( | |
label="Steering Factor", | |
minimum=-5, | |
maximum=5, | |
step=1, | |
value=3 | |
) | |
add_button = gr.Button("Add Concept") | |
# Row with the remove dropdown + button | |
with gr.Row(): | |
remove_dropdown = gr.Dropdown( | |
label="Remove concept", | |
choices=_build_remove_choices(default_subspaces), | |
multiselect=False | |
) | |
remove_button = gr.Button("Remove", variant="secondary") | |
# Wire up events | |
# When the search box changes, update the concept dropdown choices: | |
search_box.change( | |
update_dropdown_choices, | |
[search_box], | |
[concept_dropdown] | |
) | |
# When "Add Concept" is clicked, add the concept + magnitude to the list, | |
# and update the "Remove" dropdown choices. | |
add_button.click( | |
add_concept_to_list, | |
[concept_dropdown, concept_magnitude, selected_subspaces], | |
[selected_subspaces, remove_dropdown] | |
) | |
# When "Remove" is clicked, remove the selected concept from the list, | |
# and update the "Remove" dropdown choices. | |
remove_button.click( | |
remove_concept_from_list, | |
[remove_dropdown, selected_subspaces], | |
[selected_subspaces, remove_dropdown] | |
) | |
demo.launch() | |