Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,529 Bytes
7497e24 f860e61 0e90065 f860e61 884bfb5 330e95b f860e61 0e90065 f860e61 0e90065 3d0e95b e39562b 0e90065 330e95b f860e61 330e95b e3ab52c 330e95b e3ab52c bddba98 f860e61 1644e6b bb5c56b 60cfe00 bb5c56b 60cfe00 e3ab52c 330e95b e39562b 0e90065 e39562b 0e90065 330e95b e39562b f860e61 330e95b 7f3db95 330e95b e3ab52c 0e90065 e39562b 0e90065 f860e61 884bfb5 0e90065 3cb38a4 de8f900 f860e61 e39562b f860e61 7f543e6 f860e61 e39562b f860e61 e39562b 0e90065 e39562b f860e61 0e90065 1644e6b 0e90065 e39562b 330e95b 0e90065 bddba98 bb5c56b bddba98 1644e6b 0e90065 330e95b 0e90065 f860e61 3d0e95b f860e61 7497e24 f860e61 bddba98 e39562b 7497e24 e39562b 7497e24 3d0e95b bddba98 7497e24 e39562b bddba98 3d0e95b bddba98 f860e61 0e90065 3d0e95b f860e61 e39562b 7497e24 3d0e95b e39562b 4a08f06 7497e24 3d0e95b e39562b 3d0e95b 884bfb5 3d0e95b e39562b 3d0e95b e39562b 3d0e95b e39562b f860e61 bddba98 f860e61 e39562b 884bfb5 f860e61 bddba98 f860e61 e39562b 884bfb5 e39562b 3d0e95b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
import os, json, random
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from huggingface_hub import login, hf_hub_download
import pyreft
import pyvene as pv
from threading import Thread
from typing import Iterator
HF_TOKEN = os.environ.get("HF_TOKEN")
login(token=HF_TOKEN)
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 256 # smaller default to save memory
MAX_INPUT_TOKEN_LENGTH = 4096
def load_jsonl(jsonl_path):
jsonl_data = []
with open(jsonl_path, 'r') as f:
for line in f:
data = json.loads(line)
jsonl_data.append(data)
return jsonl_data
class Steer(pv.SourcelessIntervention):
"""Steer model via activation addition"""
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
self.proj = torch.nn.Linear(
self.embed_dim, kwargs["latent_dim"], bias=False
)
def forward(self, base, source=None, subspaces=None):
if subspaces is None:
return base
steering_vec = []
avg_mag = sum(subspaces["mag"]) / len(subspaces["mag"])
for idx, mag in zip(subspaces["idx"], subspaces["mag"]):
steering_vec.append(self.proj.weight[idx].unsqueeze(dim=0))
steering_vec = torch.cat(steering_vec, dim=0).mean(dim=0)
steering_vec = avg_mag * steering_vec
return base + steering_vec
# Check GPU
if not torch.cuda.is_available():
print("Warning: Running on CPU, may be slow.")
# Load model & dictionary
model_id = "google/gemma-2-2b-it"
pv_model = None
tokenizer = None
concept_list = []
concept_id_map = {}
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="cuda", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Download dictionary
weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt")
meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl")
params = torch.load(weight_path).cuda()
md = load_jsonl(meta_path)
concept_list = [item["concept"] for item in md]
concept_id_map = {}
# the reason to reindex is because there is one concept that is missing.
concept_reindex = 0
for item in md:
concept_id_map[item["concept"]] = concept_reindex
concept_reindex += 1
steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1])
steer.proj.weight.data = params.float()
pv_model = pv.IntervenableModel({
"component": f"model.layers[20].output",
"intervention": steer}, model=model)
terminators = [tokenizer.eos_token_id] if tokenizer else []
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
subspaces_list: list[dict],
max_new_tokens: int=DEFAULT_MAX_NEW_TOKENS,
) -> Iterator[str]:
# limit to last 4 turns
start_idx = max(0, len(chat_history) - 4)
recent_history = chat_history[start_idx:]
# build list of messages
messages = []
for rh in recent_history:
messages.append({"role": rh["role"], "content": rh["content"]})
messages.append({"role": "user", "content": message})
input_ids = torch.tensor([tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True)]).cuda()
# trim if needed
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
yield "[Truncated prior text]\n"
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
print(subspaces_list)
generate_kwargs = {
"base": {"input_ids": input_ids},
"unit_locations": None,
"max_new_tokens": max_new_tokens,
"intervene_on_prompt": True,
"subspaces": [
{
"idx": [int(sl["idx"]) for sl in subspaces_list],
"mag": [int(sl["internal_mag"]) for sl in subspaces_list]
}
] if subspaces_list else None,
"streamer": streamer,
"do_sample": True
}
t = Thread(target=pv_model.generate, kwargs=generate_kwargs)
t.start()
partial_text = []
for token_str in streamer:
partial_text.append(token_str)
yield "".join(partial_text)
def _build_remove_choices(subspaces):
return [f"(+{x['display_mag']:.1f}*) {x['text']}" for x in subspaces]
def filter_concepts(search_text: str):
if not search_text.strip():
return concept_list[:500]
filtered = [c for c in concept_list if search_text.lower() in c.lower()]
return filtered[:500]
def add_concept_to_list(selected_concept, user_slider_val, current_list):
if not selected_concept:
return current_list, gr.update(choices=_build_remove_choices(current_list))
idx = concept_id_map[selected_concept]
internal_mag = user_slider_val * 50
new_entry = {
"text": selected_concept,
"idx": idx,
"display_mag": user_slider_val,
"internal_mag": internal_mag,
}
# Add to the beginning of the list
updated_list = [new_entry] + current_list
return updated_list, gr.update(choices=_build_remove_choices(updated_list))
def remove_concept_from_list(selected_text, current_list):
if not selected_text:
return current_list, gr.update(choices=_build_remove_choices(current_list))
# Remove based on the full formatted text
updated_list = [x for x in current_list if f"(+{x['display_mag']:.1f}*) {x['text']}" != selected_text]
return updated_list, gr.update(choices=_build_remove_choices(updated_list))
def update_dropdown_choices(search_text):
filtered = filter_concepts(search_text)
return gr.update(choices=filtered)
with gr.Blocks(css="style.css") as demo:
# Remove default subspaces
selected_subspaces = gr.State([])
with gr.Row():
# Left side: bigger chat area
with gr.Column(scale=7):
chat_interface = gr.ChatInterface(
fn=generate,
title="Language Model Concept Steering",
description="Steer responses by selecting concepts on the right →",
type="messages",
additional_inputs=[selected_subspaces],
)
# Right side: concept management
with gr.Column(scale=3):
gr.Markdown("## Steer Model Responses")
# Concept Search and Selection
with gr.Group():
search_box = gr.Textbox(
label="Search Concepts",
placeholder="Find concepts to steer the model (e.g. 'time travel')",
)
concept_dropdown = gr.Dropdown(
label="Select a Concept",
interactive=True,
)
concept_magnitude = gr.Slider(
label="Steering Intensity",
minimum=-5,
maximum=5,
step=0.1, # Allow 1 decimal point
value=3,
)
add_button = gr.Button("Add Concept to Steering")
# Current Steering Concepts
gr.Markdown("## Current Steering Concepts")
with gr.Group():
remove_dropdown = gr.Dropdown(
label="Select a Current Steering Concept to Stop",
choices=[],
multiselect=False,
)
remove_button = gr.Button("Remove Current Steering Concept", variant="secondary")
# Wire up events
# When the search box changes, update the concept dropdown choices:
search_box.change(
update_dropdown_choices,
[search_box],
[concept_dropdown]
)
# When "Add Concept" is clicked, add the concept + magnitude to the list,
# and update the "Remove" dropdown choices.
add_button.click(
add_concept_to_list,
[concept_dropdown, concept_magnitude, selected_subspaces],
[selected_subspaces, remove_dropdown]
)
# When "Remove" is clicked, remove the selected concept from the list,
# and update the "Remove" dropdown choices.
remove_button.click(
remove_concept_from_list,
[remove_dropdown, selected_subspaces],
[selected_subspaces, remove_dropdown]
)
demo.launch(share=True)
|