frankaging commited on
Commit
1baa5c3
·
1 Parent(s): 7497e24
Files changed (1) hide show
  1. app.py +79 -82
app.py CHANGED
@@ -45,20 +45,20 @@ class Steer(pv.SourcelessIntervention):
45
 
46
  def forward(self, base, source=None, subspaces=None):
47
  # subspaces is a list of dicts:
48
- # each has {"idx": int, "internal_mag": float, ...}
49
  steer_vec = base
50
  if subspaces is not None:
51
  for sp in subspaces:
52
  idx = sp["idx"]
53
- # Use the internal magnitude for actual steering
54
- mag = sp["internal_mag"]
55
  steering_vec = mag * self.proj.weight[idx].unsqueeze(dim=0)
56
  steer_vec = steer_vec + steering_vec
57
  return steer_vec
58
 
59
- # ---------------------------------------------------
60
- # Load Model & Dictionary if GPU is available
61
- # ---------------------------------------------------
 
62
  if not torch.cuda.is_available():
63
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo won't perform well on CPU.</p>"
64
 
@@ -91,10 +91,9 @@ if torch.cuda.is_available():
91
  terminators = [tokenizer.eos_token_id]
92
 
93
 
94
- # ---------------------------------------------------------------------
95
- # The main generation function, limiting to last 3 conversation turns
96
- # and then using apply_chat_template
97
- # ---------------------------------------------------------------------
98
  @spaces.GPU
99
  def generate(
100
  message: str,
@@ -107,31 +106,36 @@ def generate(
107
  start_idx = max(0, len(chat_history) - 3)
108
  recent_history = chat_history[start_idx:]
109
 
110
- # Build a list of messages
111
- # each tuple is (user_message, model_message)
112
  messages = []
113
  for user_msg, model_msg in recent_history:
114
  messages.append({"role": "user", "content": user_msg})
115
  messages.append({"role": "model", "content": model_msg})
116
 
117
- # Now append the new user message
118
  messages.append({"role": "user", "content": message})
119
 
120
- input_ids = torch.tensor([tokenizer.apply_chat_template(
121
- messages, tokenize=True, add_generation_prompt=True)]).cuda()
 
 
 
 
 
122
 
123
- # Possibly trim if over max length
124
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
125
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
126
  yield "\n[Warning: Truncated conversation exceeds max allowed input tokens]\n"
127
 
128
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
129
  generate_kwargs = {
130
- "base": {"input_ids": input_ids},
131
  "unit_locations": None,
132
  "max_new_tokens": max_new_tokens,
133
  "intervene_on_prompt": True,
134
- "subspaces": subspaces_list, # pass entire structure, using "internal_mag"
135
  "streamer": streamer,
136
  "eos_token_id": terminators,
137
  "early_stopping": True,
@@ -147,11 +151,11 @@ def generate(
147
  yield "".join(partial_text)
148
 
149
 
150
- # --------------
151
  # UI Callbacks
152
- # --------------
153
  def filter_concepts(search_text: str):
154
- """Return the first ~500 concepts that match (case-insensitive)."""
155
  if not search_text.strip():
156
  return concept_list[:500]
157
  filtered = [c for c in concept_list if search_text.lower() in c.lower()]
@@ -159,87 +163,75 @@ def filter_concepts(search_text: str):
159
 
160
  def add_concept_to_list(selected_concept, user_slider_val, current_list):
161
  """
162
- When 'Add Concept' is clicked, add the chosen concept with the
163
- scaled magnitude to the subspaces list.
164
-
165
- user_slider_val is from [-5..5], we multiply by 50 internally.
166
  """
167
  if not selected_concept:
168
- return current_list, _build_table_data(current_list)
169
 
170
  concept_idx = concept_id_map[selected_concept]
171
-
172
- # Multiply slider by 50 internally
173
- internal_mag = user_slider_val * 50
174
-
175
- # We'll store both displayed magnitude (for the table) and the internal
176
- # magnitude for the model. Also store 'text' for easy display.
177
  new_entry = {
178
  "text": selected_concept,
179
  "idx": concept_idx,
180
  "display_mag": user_slider_val,
181
  "internal_mag": internal_mag,
182
  }
183
-
184
- # Avoid duplicates if you prefer:
185
- # e.g. check if concept_idx already in current_list. We'll skip that for now.
186
  updated_list = current_list + [new_entry]
187
- return updated_list, _build_table_data(updated_list)
 
 
 
 
188
 
189
- def remove_selected_row(selected_rows, current_list):
190
  """
191
- Removes the row selected from the table.
192
- selected_rows is a list of selected row indices,
193
- e.g. [1] meaning row with index 1 is selected.
194
  """
195
- if not selected_rows:
196
- return current_list, _build_table_data(current_list)
197
- row_idx = selected_rows[0] # single selection
198
- # Safely remove if in range
199
- if 0 <= row_idx < len(current_list):
200
- updated_list = current_list[:row_idx] + current_list[row_idx+1:]
201
- return updated_list, _build_table_data(updated_list)
202
- else:
203
- return current_list, _build_table_data(current_list)
204
 
205
  def _build_table_data(subspaces):
206
- """
207
- Build a list of [concept_text, display_mag] to show in the table.
208
- """
209
  return [[x["text"], x["display_mag"]] for x in subspaces]
210
 
 
 
 
 
211
  def update_dropdown_choices(search_text):
212
  filtered = filter_concepts(search_text)
213
  return gr.update(choices=filtered)
214
 
215
-
216
- # -------------------------
217
- # Build the Gradio Blocks
218
- # -------------------------
219
  with gr.Blocks(css="style.css") as demo:
220
  gr.Markdown(DESCRIPTION)
221
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
222
 
223
- # If GPU is available, define a random default concept:
224
  default_subspaces = []
225
  if torch.cuda.is_available() and len(concept_list) > 0:
226
- default_index = random.randint(0, len(concept_list) - 1)
227
- default_concept = concept_list[default_index]
228
- default_concept_idx = concept_id_map[default_concept]
229
- # default slider is 3 => 3*50=150 internally
230
  default_subspaces = [{
231
  "text": default_concept,
232
- "idx": default_concept_idx,
233
- "display_mag": 3, # what user sees
234
- "internal_mag": 150.0, # actual scaling
235
  }]
236
 
237
- # Keep state of subspaces
238
  selected_subspaces = gr.State(default_subspaces)
239
 
240
  with gr.Row():
241
- # Left column: Chat
242
  with gr.Column(scale=5):
 
243
  chat_interface = gr.ChatInterface(
244
  fn=generate,
245
  additional_inputs=[
@@ -250,14 +242,14 @@ with gr.Blocks(css="style.css") as demo:
250
  step=1,
251
  value=DEFAULT_MAX_NEW_TOKENS,
252
  ),
253
- selected_subspaces # pass the entire subspaces list
254
  ],
255
  title="Model Steering with ReFT-r1 (16K concepts)",
 
256
  )
257
-
258
- # Right column: concept searching, adding, table display, removal
259
  with gr.Column(scale=4):
260
  gr.Markdown("## Steering Concepts")
 
261
  search_box = gr.Textbox(
262
  label="Search concepts",
263
  placeholder="Type text to filter concepts (e.g. 'sports')"
@@ -268,30 +260,35 @@ with gr.Blocks(css="style.css") as demo:
268
  multiselect=False
269
  )
270
  concept_magnitude = gr.Slider(
271
- label="Scaled Magnitude (multiplies by 50 internally)",
272
  minimum=-5,
273
  maximum=5,
274
- step=1.0,
275
  value=3
276
  )
277
  add_button = gr.Button("Add Concept")
278
 
279
- # Current subspaces table
280
  active_subspaces_table = gr.Dataframe(
281
  headers=["Concept", "Magnitude (scaled)"],
282
  datatype=["str", "number"],
 
283
  interactive=False,
284
- row_selectable="single",
285
- label="Active Concept Subspaces",
286
- value=_build_table_data(default_subspaces)
287
  )
288
 
289
- remove_button = gr.Button("Remove Selected Row")
 
 
 
 
 
 
290
 
291
  gr.Markdown(LICENSE)
292
 
293
  # Wire up events
294
- # Whenever user types in search_box, update concept_dropdown
295
  search_box.change(
296
  fn=update_dropdown_choices,
297
  inputs=[search_box],
@@ -302,14 +299,14 @@ with gr.Blocks(css="style.css") as demo:
302
  add_button.click(
303
  fn=add_concept_to_list,
304
  inputs=[concept_dropdown, concept_magnitude, selected_subspaces],
305
- outputs=[selected_subspaces, active_subspaces_table],
306
  )
307
 
308
- # Remove selected row from table
309
  remove_button.click(
310
- fn=remove_selected_row,
311
- inputs=[active_subspaces_table, selected_subspaces],
312
- outputs=[selected_subspaces, active_subspaces_table],
313
  )
314
 
315
  demo.queue(max_size=20).launch()
 
45
 
46
  def forward(self, base, source=None, subspaces=None):
47
  # subspaces is a list of dicts:
48
+ # each has {"idx": int, "internal_mag": float, "text": str, ...}
49
  steer_vec = base
50
  if subspaces is not None:
51
  for sp in subspaces:
52
  idx = sp["idx"]
53
+ mag = sp["internal_mag"] # the true scaling factor
 
54
  steering_vec = mag * self.proj.weight[idx].unsqueeze(dim=0)
55
  steer_vec = steer_vec + steering_vec
56
  return steer_vec
57
 
58
+
59
+ # ------------------------------------------
60
+ # Load the Model & Dictionary if GPU exists
61
+ # ------------------------------------------
62
  if not torch.cuda.is_available():
63
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo won't perform well on CPU.</p>"
64
 
 
91
  terminators = [tokenizer.eos_token_id]
92
 
93
 
94
+ # --------------------------------------------
95
+ # Main generation function: keep last 3 turns
96
+ # --------------------------------------------
 
97
  @spaces.GPU
98
  def generate(
99
  message: str,
 
106
  start_idx = max(0, len(chat_history) - 3)
107
  recent_history = chat_history[start_idx:]
108
 
109
+ # Convert (user_msg, model_msg) => list of messages
 
110
  messages = []
111
  for user_msg, model_msg in recent_history:
112
  messages.append({"role": "user", "content": user_msg})
113
  messages.append({"role": "model", "content": model_msg})
114
 
115
+ # Add the new user message
116
  messages.append({"role": "user", "content": message})
117
 
118
+ # Apply the chat template (some HF models expect "assistant" instead of "model")
119
+ # but let's keep "model" to match your code, if that is required.
120
+ prompt_dict = tokenizer.apply_chat_template(
121
+ messages, tokenize=True, add_generation_prompt=True
122
+ )
123
+ input_ids = torch.tensor([prompt_dict["input_ids"]]).cuda()
124
+ attention_mask = torch.tensor([prompt_dict["attention_mask"]]).cuda()
125
 
126
+ # Possibly trim if too long
127
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
128
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
129
+ attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
130
  yield "\n[Warning: Truncated conversation exceeds max allowed input tokens]\n"
131
 
132
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
133
  generate_kwargs = {
134
+ "base": {"input_ids": input_ids, "attention_mask": attention_mask},
135
  "unit_locations": None,
136
  "max_new_tokens": max_new_tokens,
137
  "intervene_on_prompt": True,
138
+ "subspaces": subspaces_list,
139
  "streamer": streamer,
140
  "eos_token_id": terminators,
141
  "early_stopping": True,
 
151
  yield "".join(partial_text)
152
 
153
 
154
+ # ----------------
155
  # UI Callbacks
156
+ # ----------------
157
  def filter_concepts(search_text: str):
158
+ """Return the first 500 concepts that match (case-insensitive)."""
159
  if not search_text.strip():
160
  return concept_list[:500]
161
  filtered = [c for c in concept_list if search_text.lower() in c.lower()]
 
163
 
164
  def add_concept_to_list(selected_concept, user_slider_val, current_list):
165
  """
166
+ user_slider_val is from [-5..5]. We multiply by 50 internally to get the real magnitude.
 
 
 
167
  """
168
  if not selected_concept:
169
+ return current_list, _build_table_data(current_list), gr.update(choices=_build_remove_choices(current_list))
170
 
171
  concept_idx = concept_id_map[selected_concept]
172
+ internal_mag = user_slider_val * 50 # scale by 50
 
 
 
 
 
173
  new_entry = {
174
  "text": selected_concept,
175
  "idx": concept_idx,
176
  "display_mag": user_slider_val,
177
  "internal_mag": internal_mag,
178
  }
 
 
 
179
  updated_list = current_list + [new_entry]
180
+ return (
181
+ updated_list,
182
+ _build_table_data(updated_list),
183
+ gr.update(choices=_build_remove_choices(updated_list))
184
+ )
185
 
186
+ def remove_concept_from_list(concept_to_remove, current_list):
187
  """
188
+ Remove the chosen concept name from the list.
 
 
189
  """
190
+ if not concept_to_remove:
191
+ return current_list, _build_table_data(current_list), gr.update(choices=_build_remove_choices(current_list))
192
+
193
+ updated_list = [x for x in current_list if x["text"] != concept_to_remove]
194
+ return (
195
+ updated_list,
196
+ _build_table_data(updated_list),
197
+ gr.update(choices=_build_remove_choices(updated_list))
198
+ )
199
 
200
  def _build_table_data(subspaces):
201
+ """Return [[concept_name, scaled_mag], ...] for display."""
 
 
202
  return [[x["text"], x["display_mag"]] for x in subspaces]
203
 
204
+ def _build_remove_choices(subspaces):
205
+ """Return concept names for the remove dropdown."""
206
+ return [x["text"] for x in subspaces]
207
+
208
  def update_dropdown_choices(search_text):
209
  filtered = filter_concepts(search_text)
210
  return gr.update(choices=filtered)
211
 
212
+ # --------------------------------------------------------------------
213
+ # Build the Interface
214
+ # --------------------------------------------------------------------
 
215
  with gr.Blocks(css="style.css") as demo:
216
  gr.Markdown(DESCRIPTION)
217
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
218
 
219
+ # If GPU is available, pick a random concept as default
220
  default_subspaces = []
221
  if torch.cuda.is_available() and len(concept_list) > 0:
222
+ default_concept = random.choice(concept_list)
 
 
 
223
  default_subspaces = [{
224
  "text": default_concept,
225
+ "idx": concept_id_map[default_concept],
226
+ "display_mag": 3, # user sees 3
227
+ "internal_mag": 150.0, # actual factor
228
  }]
229
 
 
230
  selected_subspaces = gr.State(default_subspaces)
231
 
232
  with gr.Row():
 
233
  with gr.Column(scale=5):
234
+ # Use type="messages" to avoid tuple-format deprecation warnings
235
  chat_interface = gr.ChatInterface(
236
  fn=generate,
237
  additional_inputs=[
 
242
  step=1,
243
  value=DEFAULT_MAX_NEW_TOKENS,
244
  ),
245
+ selected_subspaces
246
  ],
247
  title="Model Steering with ReFT-r1 (16K concepts)",
248
+ type="messages", # <--- uses openai-style 'role' and 'content'
249
  )
 
 
250
  with gr.Column(scale=4):
251
  gr.Markdown("## Steering Concepts")
252
+
253
  search_box = gr.Textbox(
254
  label="Search concepts",
255
  placeholder="Type text to filter concepts (e.g. 'sports')"
 
260
  multiselect=False
261
  )
262
  concept_magnitude = gr.Slider(
263
+ label="Scaled Magnitude (×50 internally)",
264
  minimum=-5,
265
  maximum=5,
266
+ step=1,
267
  value=3
268
  )
269
  add_button = gr.Button("Add Concept")
270
 
271
+ # Show the table of active subspaces
272
  active_subspaces_table = gr.Dataframe(
273
  headers=["Concept", "Magnitude (scaled)"],
274
  datatype=["str", "number"],
275
+ value=_build_table_data(default_subspaces),
276
  interactive=False,
277
+ label="Active Concept Subspaces"
 
 
278
  )
279
 
280
+ # Remove concept by name
281
+ remove_dropdown = gr.Dropdown(
282
+ label="Remove a concept",
283
+ choices=_build_remove_choices(default_subspaces),
284
+ multiselect=False
285
+ )
286
+ remove_button = gr.Button("Remove Selected Concept")
287
 
288
  gr.Markdown(LICENSE)
289
 
290
  # Wire up events
291
+ # Update concept dropdown when user types in search
292
  search_box.change(
293
  fn=update_dropdown_choices,
294
  inputs=[search_box],
 
299
  add_button.click(
300
  fn=add_concept_to_list,
301
  inputs=[concept_dropdown, concept_magnitude, selected_subspaces],
302
+ outputs=[selected_subspaces, active_subspaces_table, remove_dropdown],
303
  )
304
 
305
+ # Remove a concept
306
  remove_button.click(
307
+ fn=remove_concept_from_list,
308
+ inputs=[remove_dropdown, selected_subspaces],
309
+ outputs=[selected_subspaces, active_subspaces_table, remove_dropdown],
310
  )
311
 
312
  demo.queue(max_size=20).launch()