cyberandy commited on
Commit
6465b33
·
verified ·
1 Parent(s): a24593e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -30
app.py CHANGED
@@ -2,12 +2,19 @@ import gradio as gr
2
  import requests
3
  from typing import Dict, Tuple, List
4
  from operator import itemgetter
 
 
5
 
6
- def get_top_features(text: str, k: int = 5) -> Dict:
 
 
 
 
 
7
  url = "https://www.neuronpedia.org/api/search-with-topk"
8
  payload = {
9
  "modelId": "gemma-2-2b",
10
- "layer": "0-gemmascope-mlp-16k",
11
  "sourceSet": "gemma-scope",
12
  "text": text,
13
  "k": k,
@@ -15,43 +22,80 @@ def get_top_features(text: str, k: int = 5) -> Dict:
15
  "ignoreBos": True
16
  }
17
 
18
- response = requests.post(
19
- url,
20
- headers={"Content-Type": "application/json"},
21
- json=payload
22
- )
23
- return response.json() if response.status_code == 200 else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- def format_output(data: Dict) -> Tuple[str, str, str]:
 
 
26
  if not data:
27
  return "Error analyzing text", "", ""
28
 
29
- # Collect all features from all tokens
30
- all_features = []
31
- for result in data['results']:
32
- token = result['token']
33
- if token == '<bos>':
34
- continue
35
-
36
- for feature in result['top_features']:
37
- all_features.append({
38
- 'token': token,
39
- 'feature_id': feature['feature_index'],
40
- 'activation': feature['activation_value'],
41
- 'feature_data': feature.get('feature', {})
42
- })
43
 
44
- # Sort all features by activation value and get top 5
45
- top_features = sorted(all_features, key=itemgetter('activation'), reverse=True)[:5]
46
 
47
  # Format output
48
  output = "# Neural Feature Analysis\n\n"
49
  output += "## Top 5 Most Active Features\n\n"
50
 
51
  for idx, feat in enumerate(top_features, 1):
52
- feature_url = f"https://www.neuronpedia.org/gemma-2-2b/0-gemmascope-mlp-16k/{feat['feature_id']}"
53
 
54
- # Try to get feature name/description if available
55
  feature_info = ""
56
  if 'name' in feat['feature_data']:
57
  feature_info = f" - {feat['feature_data']['name']}"
@@ -61,12 +105,13 @@ def format_output(data: Dict) -> Tuple[str, str, str]:
61
  output += f"### {idx}. Feature {feat['feature_id']}{feature_info}\n"
62
  output += f"- **Token:** '{feat['token']}'\n"
63
  output += f"- **Activation:** {feat['activation']:.2f}\n"
 
64
  output += f"- [View on Neuronpedia]({feature_url})\n\n"
65
 
66
  # Use highest activation feature for dashboard
67
  if top_features:
68
  top_feature = top_features[0]
69
- dashboard_url = f"https://www.neuronpedia.org/gemma-2-2b/0-gemmascope-mlp-16k/{top_feature['feature_id']}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
70
  iframe = f'''
71
  <div style="border:1px solid #eee;border-radius:8px;padding:1px;background:#fff;">
72
  <iframe
@@ -99,9 +144,9 @@ def create_interface():
99
  )
100
  analyze_btn = gr.Button("Analyze Features", variant="primary")
101
  gr.Examples([
 
102
  "Nike - Just Do It. The power of determination.",
103
  "Apple - Think Different. Innovation redefined.",
104
- "McDonald's - I'm Lovin' It. Creating joy.",
105
  ], inputs=input_text)
106
 
107
  with gr.Column():
@@ -110,7 +155,7 @@ def create_interface():
110
  dashboard = gr.HTML()
111
 
112
  analyze_btn.click(
113
- fn=lambda text: format_output(get_top_features(text)),
114
  inputs=input_text,
115
  outputs=[output_text, dashboard, feature_label]
116
  )
 
2
  import requests
3
  from typing import Dict, Tuple, List
4
  from operator import itemgetter
5
+ from collections import Counter
6
+ import logging
7
 
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def get_features(text: str, k: int = 5) -> Dict:
13
+ """Get neural features from the API with detailed logging."""
14
  url = "https://www.neuronpedia.org/api/search-with-topk"
15
  payload = {
16
  "modelId": "gemma-2-2b",
17
+ "layer": "20-gemmascope-res-16k", # Updated to match website
18
  "sourceSet": "gemma-scope",
19
  "text": text,
20
  "k": k,
 
22
  "ignoreBos": True
23
  }
24
 
25
+ try:
26
+ response = requests.post(
27
+ url,
28
+ headers={"Content-Type": "application/json"},
29
+ json=payload
30
+ )
31
+ response.raise_for_status()
32
+ data = response.json()
33
+
34
+ # Log the raw response for analysis
35
+ logger.info(f"API Response: {data}")
36
+
37
+ # Analyze feature distribution
38
+ all_features = []
39
+ feature_counter = Counter()
40
+
41
+ for result in data['results']:
42
+ token = result['token']
43
+ logger.info(f"\nToken: {token}")
44
+
45
+ for feature in result['top_features']:
46
+ feature_id = feature['feature_index']
47
+ activation = feature['activation_value']
48
+ logger.info(f"Feature {feature_id}: {activation}")
49
+
50
+ all_features.append({
51
+ 'token': token,
52
+ 'feature_id': feature_id,
53
+ 'activation': activation,
54
+ 'feature_data': feature.get('feature', {})
55
+ })
56
+ feature_counter[feature_id] += 1
57
+
58
+ # Log feature frequency analysis
59
+ logger.info("\nFeature Frequencies:")
60
+ for feature_id, count in feature_counter.most_common():
61
+ logger.info(f"Feature {feature_id}: {count} occurrences")
62
+
63
+ return data, all_features, feature_counter
64
+
65
+ except Exception as e:
66
+ logger.error(f"Error in API call: {str(e)}")
67
+ return None, [], Counter()
68
 
69
+ def format_output(text: str) -> Tuple[str, str, str]:
70
+ data, all_features, feature_counter = get_features(text)
71
+
72
  if not data:
73
  return "Error analyzing text", "", ""
74
 
75
+ # Sort features by frequency first, then by maximum activation within each feature
76
+ feature_activations = {}
77
+ for feature in all_features:
78
+ feature_id = feature['feature_id']
79
+ activation = feature['activation']
80
+ if feature_id not in feature_activations or activation > feature_activations[feature_id]['activation']:
81
+ feature_activations[feature_id] = feature
82
+
83
+ # Get top features by frequency, then sort by activation
84
+ most_common_features = [
85
+ feature_activations[feature_id]
86
+ for feature_id, _ in feature_counter.most_common()
87
+ ]
 
88
 
89
+ # Sort by activation within the most common features
90
+ top_features = sorted(most_common_features, key=itemgetter('activation'), reverse=True)[:5]
91
 
92
  # Format output
93
  output = "# Neural Feature Analysis\n\n"
94
  output += "## Top 5 Most Active Features\n\n"
95
 
96
  for idx, feat in enumerate(top_features, 1):
97
+ feature_url = f"https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feat['feature_id']}"
98
 
 
99
  feature_info = ""
100
  if 'name' in feat['feature_data']:
101
  feature_info = f" - {feat['feature_data']['name']}"
 
105
  output += f"### {idx}. Feature {feat['feature_id']}{feature_info}\n"
106
  output += f"- **Token:** '{feat['token']}'\n"
107
  output += f"- **Activation:** {feat['activation']:.2f}\n"
108
+ output += f"- **Frequency:** {feature_counter[feat['feature_id']]} occurrences\n"
109
  output += f"- [View on Neuronpedia]({feature_url})\n\n"
110
 
111
  # Use highest activation feature for dashboard
112
  if top_features:
113
  top_feature = top_features[0]
114
+ dashboard_url = f"https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{top_feature['feature_id']}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
115
  iframe = f'''
116
  <div style="border:1px solid #eee;border-radius:8px;padding:1px;background:#fff;">
117
  <iframe
 
144
  )
145
  analyze_btn = gr.Button("Analyze Features", variant="primary")
146
  gr.Examples([
147
+ "WordLift",
148
  "Nike - Just Do It. The power of determination.",
149
  "Apple - Think Different. Innovation redefined.",
 
150
  ], inputs=input_text)
151
 
152
  with gr.Column():
 
155
  dashboard = gr.HTML()
156
 
157
  analyze_btn.click(
158
+ fn=lambda text: format_output(text),
159
  inputs=input_text,
160
  outputs=[output_text, dashboard, feature_label]
161
  )