cyberandy commited on
Commit
285af3f
·
verified ·
1 Parent(s): 321a1b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -117
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import requests
3
  from typing import Dict, List, Tuple
4
- import json
5
 
6
  def get_features(text: str) -> Dict:
7
  """Get neural features from the API."""
@@ -23,77 +22,77 @@ def get_features(text: str) -> Dict:
23
  except Exception as e:
24
  return None
25
 
26
- def format_features(features_data: Dict, expanded_tokens: List[str], selected_feature: Dict) -> str:
27
- """Format features as HTML with expanded state."""
28
- if not features_data or 'results' not in features_data:
29
- return ""
 
 
 
 
30
 
31
- output = ['<div class="p-6">']
 
 
 
32
 
33
- # Process each token's features
 
34
  for result in features_data['results']:
35
  if result['token'] == '<bos>':
36
  continue
37
-
38
  token = result['token']
39
  features = result['top_features']
40
- is_expanded = token in expanded_tokens
41
- feature_count = len(features) if is_expanded else min(3, len(features))
42
-
43
- output.append(f'<div class="mb-8"><h2 class="text-xl font-bold mb-4">Token: {token}</h2>')
44
 
45
- # Display features
46
- for idx in range(feature_count):
47
- feature = features[idx]
48
- feature_id = feature['feature_index']
49
- activation = feature['activation_value']
50
- is_selected = selected_feature and selected_feature.get('feature_id') == feature_id
51
 
52
- selected_class = "border-blue-500 border-2" if is_selected else ""
53
-
54
- output.append(f"""
55
- <div class="feature-card p-4 rounded-lg mb-4 hover:border-blue-500 {selected_class}">
56
- <div class="flex justify-between items-center">
57
- <div>
58
- <span class="font-semibold">Feature {feature_id}</span>
59
- <span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span>
60
- </div>
61
- </div>
62
- </div>
63
- """)
64
-
65
- # Show more/less button if needed
66
  if len(features) > 3:
67
- action = "less" if is_expanded else f"{len(features) - 3} more"
68
- output.append(f"""
69
- <div class="text-center mb-4">
70
- <button class="text-blue-600 hover:text-blue-800 text-sm"
71
- onclick="gradio('toggle_expansion', '{token}')">
72
- Show {action} features
73
- </button>
74
- </div>
75
- """)
76
-
77
- output.append('</div>')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- output.append('</div>')
80
- return "\n".join(output)
81
 
82
- def format_dashboard(feature: Dict) -> str:
83
- """Format the feature dashboard."""
84
  if not feature:
85
  return ""
86
 
87
- feature_id = feature['feature_id']
88
- activation = feature['activation']
89
-
90
  return f"""
91
  <div class="dashboard-container p-4">
92
- <h3 class="text-lg font-semibold mb-4">
93
- Feature {feature_id} Dashboard (Activation: {activation:.2f})
94
- </h3>
95
  <iframe
96
- src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
97
  width="100%"
98
  height="600"
99
  frameborder="0"
@@ -102,47 +101,31 @@ def format_dashboard(feature: Dict) -> str:
102
  </div>
103
  """
104
 
105
- def analyze_features(text: str, state: Dict) -> Tuple[str, str, Dict]:
106
- """Process text and update state."""
107
- if not text:
108
- return "", "", state
109
-
110
- features_data = get_features(text)
111
- if not features_data:
112
- return "Error analyzing text", "", state
113
-
114
- # Update state
115
- state['features_data'] = features_data
116
- if not state.get('expanded_tokens'):
117
- state['expanded_tokens'] = []
118
-
119
- # Select first feature by default if none selected
120
- if not state.get('selected_feature'):
121
- for result in features_data['results']:
122
- if result['token'] != '<bos>' and result['top_features']:
123
- first_feature = result['top_features'][0]
124
  state['selected_feature'] = {
125
- 'feature_id': first_feature['feature_index'],
126
- 'activation': first_feature['activation_value']
127
  }
128
- break
129
-
130
- features_html = format_features(features_data, state['expanded_tokens'], state['selected_feature'])
131
- dashboard_html = format_dashboard(state['selected_feature'])
132
-
133
- return features_html, dashboard_html, state
134
 
135
- def toggle_expansion(token: str, state: Dict) -> Tuple[str, str, Dict]:
136
- """Toggle expansion state for a token."""
137
- if token in state['expanded_tokens']:
138
  state['expanded_tokens'].remove(token)
139
  else:
140
  state['expanded_tokens'].append(token)
141
-
142
- features_html = format_features(state['features_data'], state['expanded_tokens'], state['selected_feature'])
143
- dashboard_html = format_dashboard(state['selected_feature'])
144
-
145
- return features_html, dashboard_html, state
146
 
147
  css = """
148
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
@@ -151,16 +134,6 @@ body {
151
  font-family: 'Open Sans', sans-serif !important;
152
  }
153
 
154
- .feature-card {
155
- border: 1px solid #e0e5ff;
156
- background-color: #ffffff;
157
- transition: all 0.2s ease;
158
- }
159
-
160
- .feature-card:hover {
161
- box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1);
162
- }
163
-
164
  .dashboard-container {
165
  border: 1px solid #e0e5ff;
166
  border-radius: 8px;
@@ -179,12 +152,7 @@ theme = gr.themes.Soft(
179
  )
180
 
181
  def create_interface():
182
- # Initialize state
183
- state = gr.State({
184
- 'features_data': None,
185
- 'expanded_tokens': [],
186
- 'selected_feature': None
187
- })
188
 
189
  with gr.Blocks(theme=theme, css=css) as interface:
190
  gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
@@ -204,22 +172,14 @@ def create_interface():
204
  )
205
 
206
  with gr.Column(scale=2):
207
- features_html = gr.HTML()
208
- dashboard_html = gr.HTML()
209
 
210
- # Event handlers
211
  analyze_btn.click(
212
- fn=analyze_features,
213
  inputs=[input_text, state],
214
- outputs=[features_html, dashboard_html, state]
215
  )
216
-
217
- # Custom JavaScript function for token expansion
218
- interface.load(None, None, None, _js="""
219
- function toggle_expansion(token) {
220
- // Function will be called from HTML onclick
221
- }
222
- """)
223
 
224
  return interface
225
 
 
1
  import gradio as gr
2
  import requests
3
  from typing import Dict, List, Tuple
 
4
 
5
  def get_features(text: str) -> Dict:
6
  """Get neural features from the API."""
 
22
  except Exception as e:
23
  return None
24
 
25
+ def process_features(text: str, state: Dict) -> Tuple[gr.Tabs, gr.HTML, Dict]:
26
+ """Process features and return UI components."""
27
+ if not text:
28
+ return None, None, state
29
+
30
+ features_data = get_features(text)
31
+ if not features_data:
32
+ return None, None, state
33
 
34
+ # Update state with new features data
35
+ state['features_data'] = features_data
36
+ if 'expanded_tokens' not in state:
37
+ state['expanded_tokens'] = []
38
 
39
+ # Create tabs for each token
40
+ tokens_data = []
41
  for result in features_data['results']:
42
  if result['token'] == '<bos>':
43
  continue
44
+
45
  token = result['token']
46
  features = result['top_features']
47
+ is_expanded = token in state.get('expanded_tokens', [])
48
+ num_features = len(features) if is_expanded else min(3, len(features))
 
 
49
 
50
+ feature_list = []
51
+ for feature in features[:num_features]:
52
+ feature_list.append(gr.Button(
53
+ f"Feature {feature['feature_index']} (Activation: {feature['activation_value']:.2f})",
54
+ variant="secondary"
55
+ ))
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  if len(features) > 3:
58
+ expand_btn = gr.Button(
59
+ f"Show {'less' if is_expanded else 'more'} features",
60
+ variant="secondary",
61
+ size="sm"
62
+ )
63
+ feature_list.append(expand_btn)
64
+
65
+ tokens_data.append((token, feature_list))
66
+
67
+ # Select first feature for dashboard
68
+ if tokens_data and 'selected_feature' not in state:
69
+ first_token = tokens_data[0][0]
70
+ first_feature = features_data['results'][0]['top_features'][0]
71
+ state['selected_feature'] = {
72
+ 'feature_id': first_feature['feature_index'],
73
+ 'activation': first_feature['activation_value']
74
+ }
75
+ dashboard_html = create_dashboard(state['selected_feature'])
76
+ else:
77
+ dashboard_html = None
78
+
79
+ # Create tabs component
80
+ token_tabs = gr.Tabs(
81
+ [gr.Tab(token, feature_list) for token, feature_list in tokens_data]
82
+ )
83
 
84
+ return token_tabs, dashboard_html, state
 
85
 
86
+ def create_dashboard(feature: Dict) -> str:
87
+ """Create dashboard HTML for a feature."""
88
  if not feature:
89
  return ""
90
 
 
 
 
91
  return f"""
92
  <div class="dashboard-container p-4">
93
+ <h3 class="text-lg font-semibold mb-4">Feature {feature['feature_id']} Dashboard</h3>
 
 
94
  <iframe
95
+ src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature['feature_id']}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
96
  width="100%"
97
  height="600"
98
  frameborder="0"
 
101
  </div>
102
  """
103
 
104
+ def select_feature(feature_id: str, state: Dict) -> Tuple[str, Dict]:
105
+ """Handle feature selection."""
106
+ if not state.get('features_data'):
107
+ return None, state
108
+
109
+ for result in state['features_data']['results']:
110
+ for feature in result['top_features']:
111
+ if feature['feature_index'] == feature_id:
 
 
 
 
 
 
 
 
 
 
 
112
  state['selected_feature'] = {
113
+ 'feature_id': feature_id,
114
+ 'activation': feature['activation_value']
115
  }
116
+ return create_dashboard(state['selected_feature']), state
117
+
118
+ return None, state
 
 
 
119
 
120
+ def toggle_token_expansion(token: str, state: Dict) -> Tuple[gr.Tabs, Dict]:
121
+ """Toggle token expansion state."""
122
+ if token in state.get('expanded_tokens', []):
123
  state['expanded_tokens'].remove(token)
124
  else:
125
  state['expanded_tokens'].append(token)
126
+
127
+ tabs, _, state = process_features(state['current_text'], state)
128
+ return tabs, state
 
 
129
 
130
  css = """
131
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
 
134
  font-family: 'Open Sans', sans-serif !important;
135
  }
136
 
 
 
 
 
 
 
 
 
 
 
137
  .dashboard-container {
138
  border: 1px solid #e0e5ff;
139
  border-radius: 8px;
 
152
  )
153
 
154
  def create_interface():
155
+ state = gr.State({})
 
 
 
 
 
156
 
157
  with gr.Blocks(theme=theme, css=css) as interface:
158
  gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
 
172
  )
173
 
174
  with gr.Column(scale=2):
175
+ token_tabs = gr.Tabs()
176
+ dashboard = gr.HTML()
177
 
 
178
  analyze_btn.click(
179
+ fn=process_features,
180
  inputs=[input_text, state],
181
+ outputs=[token_tabs, dashboard, state]
182
  )
 
 
 
 
 
 
 
183
 
184
  return interface
185