Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,6 @@ import requests
|
|
3 |
from typing import Dict, List, Tuple
|
4 |
|
5 |
def get_features(text: str) -> Dict:
|
6 |
-
"""Get neural features from the API."""
|
7 |
url = "https://www.neuronpedia.org/api/search-with-topk"
|
8 |
payload = {
|
9 |
"modelId": "gemma-2-2b",
|
@@ -22,69 +21,44 @@ def get_features(text: str) -> Dict:
|
|
22 |
except Exception as e:
|
23 |
return None
|
24 |
|
25 |
-
def process_features(text: str
|
26 |
-
"""Process features and return UI components."""
|
27 |
if not text:
|
28 |
-
return
|
29 |
|
30 |
features_data = get_features(text)
|
31 |
if not features_data:
|
32 |
-
return
|
33 |
|
34 |
-
|
35 |
-
state['features_data'] = features_data
|
36 |
-
if 'expanded_tokens' not in state:
|
37 |
-
state['expanded_tokens'] = []
|
38 |
-
|
39 |
-
# Create tabs for each token
|
40 |
-
tokens_data = []
|
41 |
for result in features_data['results']:
|
42 |
if result['token'] == '<bos>':
|
43 |
continue
|
44 |
|
45 |
token = result['token']
|
46 |
features = result['top_features']
|
47 |
-
is_expanded = token in state.get('expanded_tokens', [])
|
48 |
-
num_features = len(features) if is_expanded else min(3, len(features))
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
size="sm"
|
62 |
-
)
|
63 |
-
feature_list.append(expand_btn)
|
64 |
-
|
65 |
-
tokens_data.append((token, feature_list))
|
66 |
-
|
67 |
-
# Select first feature for dashboard
|
68 |
-
if tokens_data and 'selected_feature' not in state:
|
69 |
-
first_token = tokens_data[0][0]
|
70 |
first_feature = features_data['results'][0]['top_features'][0]
|
71 |
-
|
72 |
'feature_id': first_feature['feature_index'],
|
73 |
'activation': first_feature['activation_value']
|
74 |
-
}
|
75 |
-
dashboard_html = create_dashboard(state['selected_feature'])
|
76 |
else:
|
77 |
-
dashboard_html =
|
78 |
-
|
79 |
-
# Create tabs component
|
80 |
-
token_tabs = gr.Tabs(
|
81 |
-
[gr.Tab(token, feature_list) for token, feature_list in tokens_data]
|
82 |
-
)
|
83 |
|
84 |
-
return
|
85 |
|
86 |
def create_dashboard(feature: Dict) -> str:
|
87 |
-
"""Create dashboard HTML for a feature."""
|
88 |
if not feature:
|
89 |
return ""
|
90 |
|
@@ -101,32 +75,6 @@ def create_dashboard(feature: Dict) -> str:
|
|
101 |
</div>
|
102 |
"""
|
103 |
|
104 |
-
def select_feature(feature_id: str, state: Dict) -> Tuple[str, Dict]:
|
105 |
-
"""Handle feature selection."""
|
106 |
-
if not state.get('features_data'):
|
107 |
-
return None, state
|
108 |
-
|
109 |
-
for result in state['features_data']['results']:
|
110 |
-
for feature in result['top_features']:
|
111 |
-
if feature['feature_index'] == feature_id:
|
112 |
-
state['selected_feature'] = {
|
113 |
-
'feature_id': feature_id,
|
114 |
-
'activation': feature['activation_value']
|
115 |
-
}
|
116 |
-
return create_dashboard(state['selected_feature']), state
|
117 |
-
|
118 |
-
return None, state
|
119 |
-
|
120 |
-
def toggle_token_expansion(token: str, state: Dict) -> Tuple[gr.Tabs, Dict]:
|
121 |
-
"""Toggle token expansion state."""
|
122 |
-
if token in state.get('expanded_tokens', []):
|
123 |
-
state['expanded_tokens'].remove(token)
|
124 |
-
else:
|
125 |
-
state['expanded_tokens'].append(token)
|
126 |
-
|
127 |
-
tabs, _, state = process_features(state['current_text'], state)
|
128 |
-
return tabs, state
|
129 |
-
|
130 |
css = """
|
131 |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
|
132 |
|
@@ -152,8 +100,6 @@ theme = gr.themes.Soft(
|
|
152 |
)
|
153 |
|
154 |
def create_interface():
|
155 |
-
state = gr.State({})
|
156 |
-
|
157 |
with gr.Blocks(theme=theme, css=css) as interface:
|
158 |
gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
|
159 |
gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
|
@@ -177,11 +123,11 @@ def create_interface():
|
|
177 |
|
178 |
analyze_btn.click(
|
179 |
fn=process_features,
|
180 |
-
inputs=[input_text
|
181 |
-
outputs=[token_tabs, dashboard
|
182 |
)
|
183 |
|
184 |
return interface
|
185 |
|
186 |
if __name__ == "__main__":
|
187 |
-
create_interface().launch(
|
|
|
3 |
from typing import Dict, List, Tuple
|
4 |
|
5 |
def get_features(text: str) -> Dict:
|
|
|
6 |
url = "https://www.neuronpedia.org/api/search-with-topk"
|
7 |
payload = {
|
8 |
"modelId": "gemma-2-2b",
|
|
|
21 |
except Exception as e:
|
22 |
return None
|
23 |
|
24 |
+
def process_features(text: str) -> Tuple[List[gr.Tab], str]:
|
|
|
25 |
if not text:
|
26 |
+
return [], ""
|
27 |
|
28 |
features_data = get_features(text)
|
29 |
if not features_data:
|
30 |
+
return [], ""
|
31 |
|
32 |
+
tabs = []
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
for result in features_data['results']:
|
34 |
if result['token'] == '<bos>':
|
35 |
continue
|
36 |
|
37 |
token = result['token']
|
38 |
features = result['top_features']
|
|
|
|
|
39 |
|
40 |
+
with gr.Tab(token):
|
41 |
+
feature_list = []
|
42 |
+
for feature in features[:3]:
|
43 |
+
gr.Button(
|
44 |
+
f"Feature {feature['feature_index']} (Activation: {feature['activation_value']:.2f})",
|
45 |
+
variant="secondary"
|
46 |
+
)
|
47 |
+
tabs.append(gr.Tab)
|
48 |
+
|
49 |
+
# Create initial dashboard for first feature
|
50 |
+
if features_data['results']:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
first_feature = features_data['results'][0]['top_features'][0]
|
52 |
+
dashboard_html = create_dashboard({
|
53 |
'feature_id': first_feature['feature_index'],
|
54 |
'activation': first_feature['activation_value']
|
55 |
+
})
|
|
|
56 |
else:
|
57 |
+
dashboard_html = ""
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
return tabs, dashboard_html
|
60 |
|
61 |
def create_dashboard(feature: Dict) -> str:
|
|
|
62 |
if not feature:
|
63 |
return ""
|
64 |
|
|
|
75 |
</div>
|
76 |
"""
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
css = """
|
79 |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
|
80 |
|
|
|
100 |
)
|
101 |
|
102 |
def create_interface():
|
|
|
|
|
103 |
with gr.Blocks(theme=theme, css=css) as interface:
|
104 |
gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
|
105 |
gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
|
|
|
123 |
|
124 |
analyze_btn.click(
|
125 |
fn=process_features,
|
126 |
+
inputs=[input_text],
|
127 |
+
outputs=[token_tabs, dashboard]
|
128 |
)
|
129 |
|
130 |
return interface
|
131 |
|
132 |
if __name__ == "__main__":
|
133 |
+
create_interface().launch()
|