cyberandy commited on
Commit
c3f5f94
·
verified ·
1 Parent(s): 5ac398b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -127
app.py CHANGED
@@ -1,22 +1,6 @@
1
  import gradio as gr
2
  import requests
3
  from typing import Dict, Tuple, List
4
- import json
5
- from dataclasses import dataclass
6
- from typing import Optional
7
-
8
- @dataclass
9
- class Feature:
10
- feature_id: int
11
- activation: float
12
- token: str
13
- position: int
14
-
15
- class FeatureState:
16
- def __init__(self):
17
- self.features_by_token = {}
18
- self.expanded_tokens = set()
19
- self.selected_feature = None
20
 
21
  def get_features(text: str) -> Dict:
22
  """Get neural features from the API using the exact website parameters."""
@@ -38,72 +22,88 @@ def get_features(text: str) -> Dict:
38
  except Exception as e:
39
  return None
40
 
41
- def format_feature_list(features: List[Feature], token: str, expanded: bool = False) -> str:
42
  """Format features as HTML list."""
43
- display_features = features if expanded else features[:3]
44
  features_html = ""
45
 
46
- for feature in display_features:
 
 
 
 
47
  features_html += f"""
48
- <div class="feature-card p-4 rounded-lg mb-4 cursor-pointer hover:border-blue-500"
49
- data-feature-id="{feature.feature_id}">
50
  <div class="flex justify-between items-center">
51
  <div>
52
- <span class="font-semibold">Feature {feature.feature_id}</span>
53
- <span class="ml-2 text-gray-600">(Activation: {feature.activation:.2f})</span>
54
  </div>
 
 
 
 
 
55
  </div>
56
  </div>
57
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- if not expanded and len(features) > 3:
60
- remaining = len(features) - 3
61
  features_html += f"""
62
- <div class="text-center">
63
- <span class="text-blue-500 text-sm">{remaining} more features available</span>
 
 
 
64
  </div>
65
  """
66
 
67
  return features_html
68
 
69
- def format_dashboard(feature: Feature) -> str:
70
- """Format the dashboard HTML for a selected feature."""
71
- if not feature:
72
  return ""
73
 
74
- return f"""
75
- <div class="dashboard-container p-4">
76
- <h3 class="text-lg font-semibold mb-4 text-gray-900">
77
- Feature {feature.feature_id} Dashboard (Activation: {feature.activation:.2f})
78
- </h3>
79
- <iframe
80
- src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature.feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
81
- width="100%"
82
- height="600"
83
- frameborder="0"
84
- class="rounded-lg"
85
- ></iframe>
86
- </div>
87
- """
88
-
89
- def process_features(data: Dict) -> Dict[str, List[Feature]]:
90
- """Process API response into features grouped by token."""
91
- features_by_token = {}
92
- for result in data.get('results', []):
93
  if result['token'] == '<bos>':
94
  continue
95
 
96
  token = result['token']
97
- features = []
98
- for idx, feature in enumerate(result.get('top_features', [])):
99
- features.append(Feature(
100
- feature_id=feature['feature_index'],
101
- activation=feature['activation_value'],
102
- token=token,
103
- position=idx
104
- ))
105
- features_by_token[token] = features
106
- return features_by_token
 
107
 
108
  css = """
109
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
@@ -119,7 +119,6 @@ body {
119
  }
120
 
121
  .feature-card:hover {
122
- border-color: #3452db;
123
  box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1);
124
  }
125
 
@@ -128,6 +127,10 @@ body {
128
  border-radius: 8px;
129
  background-color: #ffffff;
130
  }
 
 
 
 
131
  """
132
 
133
  theme = gr.themes.Soft(
@@ -147,69 +150,7 @@ theme = gr.themes.Soft(
147
  )
148
  )
149
 
150
- def analyze_features(text: str, state: Optional[Dict] = None) -> Tuple[str, Dict]:
151
- """Main analysis function that processes text and returns formatted output."""
152
- if not text:
153
- return "", None
154
-
155
- data = get_features(text)
156
- if not data:
157
- return "Error analyzing text", None
158
-
159
- # Process features and build state
160
- features_by_token = process_features(data)
161
-
162
- # Initialize state if needed
163
- if not state:
164
- state = {
165
- 'features_by_token': features_by_token,
166
- 'expanded_tokens': set(),
167
- 'selected_feature': None
168
- }
169
- # Select first feature as default
170
- first_token = next(iter(features_by_token))
171
- if features_by_token[first_token]:
172
- state['selected_feature'] = features_by_token[first_token][0]
173
-
174
- # Build output HTML
175
- output = []
176
- for token, features in features_by_token.items():
177
- expanded = token in state['expanded_tokens']
178
- token_html = f"<h2 class='text-xl font-bold mb-4'>Token: {token}</h2>"
179
- features_html = format_feature_list(features, token, expanded)
180
-
181
- output.append(f"<div class='mb-6'>{token_html}{features_html}</div>")
182
-
183
- # Add dashboard if a feature is selected
184
- if state['selected_feature']:
185
- output.append(format_dashboard(state['selected_feature']))
186
-
187
- return "\n".join(output), state
188
-
189
- def toggle_expansion(token: str, state: Dict) -> Tuple[str, Dict]:
190
- """Toggle expansion state for a token's features."""
191
- if token in state['expanded_tokens']:
192
- state['expanded_tokens'].remove(token)
193
- else:
194
- state['expanded_tokens'].add(token)
195
-
196
- output_html, state = analyze_features(None, state)
197
- return output_html, state
198
-
199
- def select_feature(feature_id: int, state: Dict) -> Tuple[str, Dict]:
200
- """Select a feature and update the dashboard."""
201
- for features in state['features_by_token'].values():
202
- for feature in features:
203
- if feature.feature_id == feature_id:
204
- state['selected_feature'] = feature
205
- break
206
-
207
- output_html, state = analyze_features(None, state)
208
- return output_html, state
209
-
210
  def create_interface():
211
- state = gr.State({})
212
-
213
  with gr.Blocks(theme=theme, css=css) as interface:
214
  gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
215
  gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
@@ -230,14 +171,13 @@ def create_interface():
230
  with gr.Column(scale=2):
231
  output = gr.HTML()
232
 
233
- # Event handlers
234
  analyze_btn.click(
235
  fn=analyze_features,
236
- inputs=[input_text, state],
237
- outputs=[output, state]
238
  )
239
-
240
  return interface
241
 
242
  if __name__ == "__main__":
243
- create_interface().launch()
 
1
  import gradio as gr
2
  import requests
3
  from typing import Dict, Tuple, List
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def get_features(text: str) -> Dict:
6
  """Get neural features from the API using the exact website parameters."""
 
22
  except Exception as e:
23
  return None
24
 
25
+ def format_feature_list(token: str, features: List[Dict], show_all: bool = False) -> str:
26
  """Format features as HTML list."""
27
+ feature_count = len(features) if show_all else min(3, len(features))
28
  features_html = ""
29
 
30
+ for idx in range(feature_count):
31
+ feature = features[idx]
32
+ feature_id = feature['feature_index']
33
+ activation = feature['activation_value']
34
+
35
  features_html += f"""
36
+ <div class="feature-card p-4 rounded-lg mb-4 hover:border-blue-500">
 
37
  <div class="flex justify-between items-center">
38
  <div>
39
+ <span class="font-semibold">Feature {feature_id}</span>
40
+ <span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span>
41
  </div>
42
+ <a href="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}"
43
+ target="_blank"
44
+ class="text-blue-600 hover:text-blue-800">
45
+ View on Neuronpedia →
46
+ </a>
47
  </div>
48
  </div>
49
  """
50
+
51
+ # Add dashboard for first feature only
52
+ if idx == 0:
53
+ features_html += f"""
54
+ <div class="dashboard-container mb-6 p-4">
55
+ <h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard</h3>
56
+ <iframe
57
+ src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
58
+ width="100%"
59
+ height="600"
60
+ frameborder="0"
61
+ class="rounded-lg"
62
+ ></iframe>
63
+ </div>
64
+ """
65
 
66
+ remaining = len(features) - 3
67
+ if not show_all and remaining > 0:
68
  features_html += f"""
69
+ <div class="text-sm text-gray-600 mb-4">
70
+ {remaining} more features available.
71
+ <a href="https://www.neuronpedia.org/gemma-2-2b" target="_blank" class="text-blue-600 hover:text-blue-800">
72
+ View all on Neuronpedia →
73
+ </a>
74
  </div>
75
  """
76
 
77
  return features_html
78
 
79
+ def analyze_features(text: str) -> str:
80
+ """Main analysis function that processes text and returns formatted output."""
81
+ if not text:
82
  return ""
83
 
84
+ data = get_features(text)
85
+ if not data:
86
+ return "Error analyzing text"
87
+
88
+ output = ['<div class="p-6">']
89
+
90
+ # Process each token's features
91
+ for result in data['results']:
 
 
 
 
 
 
 
 
 
 
 
92
  if result['token'] == '<bos>':
93
  continue
94
 
95
  token = result['token']
96
+ features = result['top_features']
97
+
98
+ output.append(f"""
99
+ <div class="mb-8">
100
+ <h2 class="text-xl font-bold mb-4">Token: {token}</h2>
101
+ {format_feature_list(token, features)}
102
+ </div>
103
+ """)
104
+
105
+ output.append('</div>')
106
+ return "\n".join(output)
107
 
108
  css = """
109
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
 
119
  }
120
 
121
  .feature-card:hover {
 
122
  box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1);
123
  }
124
 
 
127
  border-radius: 8px;
128
  background-color: #ffffff;
129
  }
130
+
131
+ .hljs {
132
+ background: #f5f7ff !important;
133
+ }
134
  """
135
 
136
  theme = gr.themes.Soft(
 
150
  )
151
  )
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def create_interface():
 
 
154
  with gr.Blocks(theme=theme, css=css) as interface:
155
  gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
156
  gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
 
171
  with gr.Column(scale=2):
172
  output = gr.HTML()
173
 
 
174
  analyze_btn.click(
175
  fn=analyze_features,
176
+ inputs=input_text,
177
+ outputs=output
178
  )
179
+
180
  return interface
181
 
182
  if __name__ == "__main__":
183
+ create_interface().launch(share=True)