import os, json, random import torch import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from huggingface_hub import login, hf_hub_download import pyvene as pv from threading import Thread from typing import Iterator HF_TOKEN = os.environ.get("HF_TOKEN") login(token=HF_TOKEN) MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 512 # smaller default to save memory MAX_INPUT_TOKEN_LENGTH = 4096 def load_jsonl(jsonl_path): jsonl_data = [] with open(jsonl_path, 'r') as f: for line in f: data = json.loads(line) jsonl_data.append(data) return jsonl_data class Steer(pv.SourcelessIntervention): def __init__(self, **kwargs): super().__init__(**kwargs, keep_last_dim=True) self.proj = torch.nn.Linear(self.embed_dim, kwargs["latent_dim"], bias=False) def forward(self, base, source=None, subspaces=None): steer_vec = base if subspaces is not None: for sp in subspaces: idx = sp["idx"] mag = sp["internal_mag"] # scaled by 50 steering_vec = mag * self.proj.weight[idx].unsqueeze(dim=0) steer_vec = steer_vec + steering_vec return steer_vec # Check GPU if not torch.cuda.is_available(): print("Warning: Running on CPU, may be slow.") # Load model & dictionary model_id = "google/gemma-2-2b-it" pv_model = None tokenizer = None concept_list = [] concept_id_map = {} if torch.cuda.is_available(): model = AutoModelForCausalLM.from_pretrained( model_id, device_map="cuda", torch_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_id) # Download dictionary weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt") meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl") params = torch.load(weight_path).cuda() md = load_jsonl(meta_path) concept_list = [item["concept"] for item in md] concept_id_map = {item["concept"]: item["concept_id"] for item in md} steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1]) steer.proj.weight.data = params.float() pv_model = pv.IntervenableModel( { "component": f"model.layers[20].output", "intervention": steer, }, model=model, ) terminators = [tokenizer.eos_token_id] if tokenizer else [] @spaces.GPU def generate( message: str, chat_history: list[tuple[str, str]], max_new_tokens: int, subspaces_list: list[dict], ) -> Iterator[str]: # limit to last 3 turns start_idx = max(0, len(chat_history) - 3) recent_history = chat_history[start_idx:] # build list of messages messages = [] for user_msg, model_msg in recent_history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "model", "content": model_msg}) messages.append({"role": "user", "content": message}) input_ids = torch.tensor([tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True)]).cuda() # trim if needed if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] yield "[Truncated prior text]\n" streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "base": {"input_ids": input_ids}, "unit_locations": None, "max_new_tokens": max_new_tokens, "intervene_on_prompt": True, "subspaces": subspaces_list, "streamer": streamer, "eos_token_id": terminators, "early_stopping": True, "do_sample": True } t = Thread(target=pv_model.generate, kwargs=generate_kwargs) t.start() partial_text = [] for token_str in streamer: partial_text.append(token_str) yield "".join(partial_text) def filter_concepts(search_text: str): if not search_text.strip(): return concept_list[:500] filtered = [c for c in concept_list if search_text.lower() in c.lower()] return filtered[:500] def add_concept_to_list(selected_concept, user_slider_val, current_list): if not selected_concept: return current_list, _build_table_data(current_list), gr.update(choices=_build_remove_choices(current_list)) idx = concept_id_map[selected_concept] internal_mag = user_slider_val * 50 new_entry = { "text": selected_concept, "idx": idx, "display_mag": user_slider_val, "internal_mag": internal_mag, } updated_list = current_list + [new_entry] return ( updated_list, _build_table_data(updated_list), gr.update(choices=_build_remove_choices(updated_list)) ) def remove_concept_from_list(selected_text, current_list): if not selected_text: return current_list, _build_table_data(current_list), gr.update(choices=_build_remove_choices(current_list)) updated_list = [x for x in current_list if x["text"] != selected_text] return ( updated_list, _build_table_data(updated_list), gr.update(choices=_build_remove_choices(updated_list)) ) def _build_table_data(subspaces): return [[x["text"], x["display_mag"]] for x in subspaces] def _build_remove_choices(subspaces): return [x["text"] for x in subspaces] def update_dropdown_choices(search_text): filtered = filter_concepts(search_text) return gr.update(choices=filtered) with gr.Blocks(css="style.css") as demo: # A short title only gr.Markdown("## Model Steering with ReFT-r1 (16K concepts)") # Pre-populate with a random concept if available default_subspaces = [] if pv_model and concept_list: default_concept = random.choice(concept_list) default_subspaces = [{ "text": default_concept, "idx": concept_id_map[default_concept], "display_mag": 3, "internal_mag": 150.0, }] selected_subspaces = gr.State(default_subspaces) with gr.Row(): # Left side: bigger chat area with gr.Column(scale=7): chat_interface = gr.ChatInterface( fn=generate, additional_inputs=[], # we'll put the max tokens slider below title="", type="messages", ) # Right side: concept management with gr.Column(scale=3): gr.Markdown("### Steering Concepts") search_box = gr.Textbox( label="Search concepts", placeholder="e.g. 'time travel'" ) concept_dropdown = gr.Dropdown( label="Filtered Concepts", choices=[] ) concept_magnitude = gr.Slider( label="Steering Factor", minimum=-5, maximum=5, step=1, value=3 ) add_button = gr.Button("Add Concept") active_subspaces_table = gr.Dataframe( headers=["Concept", "Mag (scaled)"], datatype=["str", "number"], value=_build_table_data(default_subspaces), interactive=False, label="Active Concept Subspaces", ) # Row with the remove dropdown + button with gr.Row(): remove_dropdown = gr.Dropdown( label="Remove concept", choices=_build_remove_choices(default_subspaces), multiselect=False ) remove_button = gr.Button("Remove", variant="secondary") # Place the max tokens slider at bottom, smaller with gr.Row(): gr.Markdown("**Max New Tokens**", elem_classes=["small-label"]) max_token_slider = gr.Slider( minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, label="", # hide the big label container=False, style={"width": "30%"} # narrower ) # Wire up events search_box.change(update_dropdown_choices, [search_box], [concept_dropdown]) add_button.click( add_concept_to_list, [concept_dropdown, concept_magnitude, selected_subspaces], [selected_subspaces, active_subspaces_table, remove_dropdown] ) remove_button.click( remove_concept_from_list, [remove_dropdown, selected_subspaces], [selected_subspaces, active_subspaces_table, remove_dropdown] ) # Link the slider back to chat generation chat_interface.configure( extra_inputs=[max_token_slider, selected_subspaces] ) demo.launch()