Spaces:
Sleeping
Sleeping
import os, json, random | |
import torch | |
import gradio as gr | |
import spaces | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from huggingface_hub import login, hf_hub_download | |
import pyreft | |
import pyvene as pv | |
from threading import Thread | |
from typing import Iterator | |
import torch.nn.functional as F | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
login(token=HF_TOKEN) | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 128 # smaller default to save memory | |
MAX_INPUT_TOKEN_LENGTH = 4096 | |
css = """ | |
#alert-message textarea { | |
background-color: #e8f4ff; | |
border: 1px solid #cce5ff; | |
color: #084298; | |
font-size: 1.1em; | |
padding: 12px; | |
border-radius: 4px; | |
font-weight: 500; | |
} | |
.concept-help { | |
font-size: 0.9em; | |
color: #666; | |
margin-top: 4px; | |
font-style: italic; | |
} | |
""" | |
def load_jsonl(jsonl_path): | |
jsonl_data = [] | |
with open(jsonl_path, 'r') as f: | |
for line in f: | |
data = json.loads(line) | |
jsonl_data.append(data) | |
return jsonl_data | |
class Steer(pv.SourcelessIntervention): | |
"""Steer model via activation addition""" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs, keep_last_dim=True) | |
self.proj = torch.nn.Linear( | |
self.embed_dim, kwargs["latent_dim"], bias=False) | |
self.subspace_generator = kwargs["subspace_generator"] | |
def steer(self, base, source=None, subspaces=None): | |
if subspaces["steer"]["subspace_gen_inputs"] is not None: | |
# we call our subspace generator to generate the subspace on-the-fly. | |
raw_steering_vec = self.subspace_generator( | |
subspaces["steer"]["subspace_gen_inputs"]["input_ids"], | |
subspaces["steer"]["subspace_gen_inputs"]["attention_mask"], | |
)[0] | |
steering_vec = torch.tensor(subspaces["steer"]["mag"]) * \ | |
raw_steering_vec.unsqueeze(dim=0) | |
return base + steering_vec | |
else: | |
steering_vec = torch.tensor(subspaces["steer"]["mag"]) * \ | |
self.proj.weight[subspaces["steer"]["idx"]].unsqueeze(dim=0) | |
return base + steering_vec | |
def forward(self, base, source=None, subspaces=None): | |
if subspaces == None: | |
return base | |
if subspaces["detect"] is not None: | |
if subspaces["detect"]["subspace_gen_inputs"] is not None: | |
# we call our subspace generator to generate the subspace on-the-fly. | |
raw_detection_vec = self.subspace_generator( | |
subspaces["detect"]["subspace_gen_inputs"]["input_ids"], | |
subspaces["detect"]["subspace_gen_inputs"]["attention_mask"], | |
)[0].unsqueeze(dim=-1) | |
else: | |
raw_detection_vec = self.proj.weight[subspaces["detect"]["idx"]].unsqueeze(dim=-1) | |
print(base.shape) | |
print(raw_detection_vec.shape) | |
detection_latent = torch.matmul(base, raw_detection_vec.to(base.dtype)).squeeze(dim=-1) # (batch_size, seq, 1) -> (batch_size, seq) | |
max_latent = torch.max(detection_latent, dim=-1).values[0] # (batch_size, seq) -> (batch_size) | |
print("max_latent", max_latent) | |
if max_latent > torch.tensor(subspaces["detect"]["mag"]): | |
print("Detected!") | |
return self.steer(base, source, subspaces) | |
else: | |
return base | |
else: | |
return self.steer(base, source, subspaces) | |
class RegressionWrapper(torch.nn.Module): | |
def __init__(self, base_model, hidden_size, output_dim): | |
super().__init__() | |
self.base_model = base_model | |
self.regression_head = torch.nn.Linear(hidden_size, output_dim) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.base_model.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
output_hidden_states=True, | |
return_dict=True | |
) | |
last_hiddens = outputs.hidden_states[-1] | |
last_token_representations = last_hiddens[:, -1] | |
preds = self.regression_head(last_token_representations) | |
preds = F.normalize(preds, p=2, dim=-1) | |
return preds | |
# Check GPU | |
if not torch.cuda.is_available(): | |
print("Warning: Running on CPU, may be slow.") | |
# Load model & dictionary | |
model_id = "google/gemma-2-2b-it" | |
pv_model = None | |
tokenizer = None | |
concept_list = [] | |
concept_id_map = {} | |
if torch.cuda.is_available(): | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, device_map="cuda", torch_dtype=torch.bfloat16 | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# Download dictionary | |
weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt") | |
meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl") | |
params = torch.load(weight_path).cuda() | |
md = load_jsonl(meta_path) | |
concept_list = [item["concept"] for item in md] | |
concept_id_map = {} | |
# the reason to reindex is because there is one concept that is missing. | |
concept_reindex = 0 | |
for item in md: | |
concept_id_map[item["concept"]] = concept_reindex | |
concept_reindex += 1 | |
# load subspace generator. | |
base_tokenizer = AutoTokenizer.from_pretrained( | |
f"google/gemma-2-2b", model_max_length=512) | |
config = AutoConfig.from_pretrained("google/gemma-2-2b") | |
base_model = AutoModelForCausalLM.from_config(config) | |
subspace_generator_weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res-generator", filename="l20/weight.pt") | |
hidden_size = base_model.config.hidden_size | |
subspace_generator = RegressionWrapper( | |
base_model, hidden_size, hidden_size).bfloat16().to("cuda") | |
subspace_generator.load_state_dict(torch.load(subspace_generator_weight_path)) | |
print(f"Loading model from saved file {subspace_generator_weight_path}") | |
_ = subspace_generator.eval() | |
steer = Steer( | |
embed_dim=params.shape[0], latent_dim=params.shape[1], | |
subspace_generator=subspace_generator) | |
steer.proj.weight.data = params.float() | |
pv_model = pv.IntervenableModel({ | |
"component": f"model.layers[20].output", | |
"intervention": steer}, model=model) | |
terminators = [tokenizer.eos_token_id] if tokenizer else [] | |
def generate( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
detection_list: list[dict], | |
steering_list: list[dict], | |
max_new_tokens: int=DEFAULT_MAX_NEW_TOKENS, | |
) -> Iterator[str]: | |
# limit to last 4 turns | |
start_idx = max(0, len(chat_history) - 4) | |
recent_history = chat_history[start_idx:] | |
# build list of messages | |
messages = [] | |
for rh in recent_history: | |
messages.append({"role": rh["role"], "content": rh["content"]}) | |
messages.append({"role": "user", "content": message}) | |
input_ids = torch.tensor([tokenizer.apply_chat_template( | |
messages, tokenize=True, add_generation_prompt=True)]).cuda() | |
# trim if needed | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
yield "[Truncated prior text]\n" | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
print("detection_list: ", detection_list) | |
print("steering_list: ", steering_list) | |
generate_kwargs = { | |
"base": {"input_ids": input_ids}, | |
"unit_locations": None, | |
"max_new_tokens": max_new_tokens, | |
"intervene_on_prompt": True, | |
"subspaces": [ | |
{ | |
"detect": { | |
"idx": int(detection_list[0]["idx"]), | |
"mag": detection_list[0]["internal_mag"]*50, | |
"subspace_gen_inputs": base_tokenizer(detection_list[0]["subspace_gen_text"], return_tensors="pt").to("cuda") \ | |
if detection_list[0]["subspace_gen_text"] is not None else None | |
} if detection_list else None, | |
"steer": { | |
"idx": int(steering_list[0]["idx"]), | |
"mag": steering_list[0]["internal_mag"]*50, | |
"subspace_gen_inputs": base_tokenizer(steering_list[0]["subspace_gen_text"], return_tensors="pt").to("cuda") \ | |
if steering_list[0]["subspace_gen_text"] is not None else None | |
} | |
} | |
] if steering_list else None, # if steering is not provided, we do not steer. | |
"streamer": streamer, | |
"do_sample": True | |
} | |
t = Thread(target=pv_model.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_text = [] | |
for token_str in streamer: | |
partial_text.append(token_str) | |
yield "".join(partial_text) | |
def filter_concepts(search_text: str): | |
if not search_text.strip(): | |
return concept_list[:500] | |
filtered = [c for c in concept_list if search_text.lower() in c.lower()] | |
return filtered[:500] | |
def add_concept_to_list(selected_concept, user_slider_val, current_list): | |
if not selected_concept: | |
return current_list | |
selected_concept_text = None | |
if selected_concept.startswith("[New] "): | |
selected_concept_text = selected_concept[6:] | |
idx = 0 | |
else: | |
idx = concept_id_map[selected_concept] | |
internal_mag = user_slider_val | |
new_entry = { | |
"text": selected_concept, | |
"idx": idx, | |
"display_mag": user_slider_val, | |
"internal_mag": internal_mag, | |
"subspace_gen_text": selected_concept_text | |
} | |
# Add to the beginning of the list | |
current_list = [new_entry] | |
return current_list | |
def update_dropdown_choices(search_text, is_detection=False): | |
filtered = filter_concepts(search_text) | |
if not filtered or len(filtered) == 0: | |
alert_message = ( | |
"Good news! Based on the topic you provided, we will automatically generate a detector for you!" | |
) if is_detection else ( | |
"Good news! Based on the topic you provided, we will automatically generate a steering vector. Try it out by starting a chat!" | |
) | |
return gr.update( | |
choices=[], | |
value=None, | |
interactive=True | |
), gr.Textbox( | |
label="No matching topics found", | |
value=alert_message, | |
lines=3, | |
interactive=False, | |
visible=True, | |
elem_id="alert-message" | |
) | |
return gr.update( | |
choices=filtered, | |
value=filtered[0], | |
interactive=True, | |
visible=True | |
), gr.Textbox(visible=False) | |
with gr.Blocks(css=css, fill_height=True) as demo: | |
selected_detection = gr.State([]) | |
selected_subspaces = gr.State([]) | |
with gr.Row(min_height=500, equal_height=True): | |
# Left side: chat area | |
with gr.Column(scale=7): | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
title="Conditionally Steer AI Responses Based on Topics", | |
description="""This is an experimental chatbot that you can steer using topics you care about: | |
Step 1: Choose a topic to detect (e.g., "Google") | |
Step 2: Choose a topic you want the model to discuss when the previous topic comes up (e.g., "ethics") | |
Try it out! For example, set it to detect "Google" topics and steer toward discussing "ethics", and ask "List some search engines and their pros and cons". We intervene on Gemma-2-2B-it by adding steering vectors to the residual stream at layer 20.""", | |
additional_inputs=[selected_detection, selected_subspaces], | |
fill_height=True, | |
) | |
# Right side: concept detection and steering | |
with gr.Column(scale=3): | |
gr.Markdown("""#### Step 1: Choose a topic you want to recognize.""") | |
with gr.Group(): | |
detect_search = gr.Textbox( | |
label="Search for topics to detect", | |
placeholder="Try: 'Google'", | |
lines=1, | |
) | |
detect_msg = gr.TextArea(visible=False) | |
detect_dropdown = gr.Dropdown( | |
label="Choose a topic to detect (Click to see more!)", | |
interactive=True, | |
allow_custom_value=False, | |
) | |
detect_threshold = gr.Slider( | |
label="Detection sensitivity", | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0.5, | |
) | |
gr.Markdown("---") | |
gr.Markdown("""#### Step 2: Choose another topic you want to discuss when it detects the chosen topic above.""") | |
with gr.Group(): | |
search_box = gr.Textbox( | |
label="Search topics to steer", | |
placeholder="Try: 'ethics'", | |
lines=1, | |
) | |
msg = gr.TextArea(visible=False) | |
concept_dropdown = gr.Dropdown( | |
label="Choose a topic to steer the model (Click to see more!)", | |
interactive=True, | |
allow_custom_value=False, | |
) | |
concept_magnitude = gr.Slider( | |
label="Steering intensity", | |
minimum=-5, | |
maximum=5, | |
step=0.1, | |
value=3.5, | |
) | |
# Wire up events for detection | |
detect_search.input( | |
lambda x: update_dropdown_choices(x, is_detection=True), | |
[detect_search], | |
[detect_dropdown, detect_msg] | |
).then( | |
add_concept_to_list, | |
[detect_dropdown, detect_threshold, selected_detection], | |
[selected_detection] | |
) | |
detect_dropdown.select( | |
add_concept_to_list, | |
[detect_dropdown, detect_threshold, selected_detection], | |
[selected_detection] | |
) | |
detect_threshold.input( | |
add_concept_to_list, | |
[detect_dropdown, detect_threshold, selected_detection], | |
[selected_detection] | |
) | |
# Wire up events for steering | |
search_box.input( | |
lambda x: update_dropdown_choices(x, is_detection=False), | |
[search_box], | |
[concept_dropdown, msg] | |
).then( | |
add_concept_to_list, | |
[concept_dropdown, concept_magnitude, selected_subspaces], | |
[selected_subspaces] | |
) | |
concept_dropdown.select( | |
add_concept_to_list, | |
[concept_dropdown, concept_magnitude, selected_subspaces], | |
[selected_subspaces] | |
) | |
concept_magnitude.input( | |
add_concept_to_list, | |
[concept_dropdown, concept_magnitude, selected_subspaces], | |
[selected_subspaces] | |
) | |
demo.launch(share=True, height=1000) |