cyberandy commited on
Commit
383d1f8
·
verified ·
1 Parent(s): 1df4bb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -28
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
  import requests
3
- from typing import Dict, List, Tuple
4
 
5
- def get_features(text: str) -> Dict:
6
  url = "https://www.neuronpedia.org/api/search-with-topk"
7
  payload = {
8
  "modelId": "gemma-2-2b",
@@ -30,43 +29,70 @@ def create_dashboard(feature_id: int) -> str:
30
  </div>
31
  """
32
 
33
- def process_features(text: str) -> Tuple[gr.Column, str]:
34
  if not text:
35
- return gr.Column(), ""
36
-
37
  features_data = get_features(text)
38
  if not features_data:
39
- return gr.Column(), ""
40
-
 
41
  first_feature_id = None
42
- with gr.Column() as col:
43
- for result in features_data['results']:
44
- if result['token'] == '<bos>':
45
- continue
46
-
47
- gr.Markdown(f"### {result['token']}")
48
- for i, feature in enumerate(result['top_features'][:3]):
49
- feature_id = feature['feature_index']
50
- if first_feature_id is None:
51
- first_feature_id = feature_id
52
- gr.Button(
53
- f"Feature {feature_id} (Activation: {feature['activation_value']:.2f})",
54
- elem_id=str(feature_id)
55
- ).click(
56
- fn=lambda fid=feature_id: create_dashboard(fid),
57
- outputs=dashboard
58
- )
59
 
60
- return col, create_dashboard(first_feature_id) if first_feature_id else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  css = """
63
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
 
64
  body { font-family: 'Open Sans', sans-serif !important; }
 
65
  .dashboard-container {
66
  border: 1px solid #e0e5ff;
67
  border-radius: 8px;
68
  background-color: #ffffff;
69
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  """
71
 
72
  theme = gr.themes.Soft(
@@ -79,6 +105,9 @@ theme = gr.themes.Soft(
79
  )
80
  )
81
 
 
 
 
82
  with gr.Blocks(theme=theme, css=css) as demo:
83
  gr.Markdown("# Brand Analyzer", elem_classes="text-2xl font-bold mb-2")
84
  gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
@@ -97,13 +126,16 @@ with gr.Blocks(theme=theme, css=css) as demo:
97
  )
98
 
99
  with gr.Column(scale=2):
100
- features_col = gr.Column()
101
  dashboard = gr.HTML()
 
 
 
102
 
103
  analyze_btn.click(
104
- fn=process_features,
105
  inputs=[input_text],
106
- outputs=[features_col, dashboard]
107
  )
108
 
109
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import requests
 
3
 
4
+ def get_features(text: str):
5
  url = "https://www.neuronpedia.org/api/search-with-topk"
6
  payload = {
7
  "modelId": "gemma-2-2b",
 
29
  </div>
30
  """
31
 
32
+ def analyze_text(text: str):
33
  if not text:
34
+ return gr.update(visible=False), ""
35
+
36
  features_data = get_features(text)
37
  if not features_data:
38
+ return gr.update(visible=False), ""
39
+
40
+ html = "<div class='features-list'>"
41
  first_feature_id = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ for result in features_data['results']:
44
+ if result['token'] == '<bos>':
45
+ continue
46
+
47
+ token = result['token']
48
+ html += f"<h3>{token}</h3>"
49
+
50
+ for feature in result['top_features'][:3]:
51
+ feature_id = feature['feature_index']
52
+ if first_feature_id is None:
53
+ first_feature_id = feature_id
54
+
55
+ html += f"""
56
+ <button onclick='document.dispatchEvent(new CustomEvent("select_feature",
57
+ {{detail: {{feature_id: {feature_id}}}}}))' class='feature-button'>
58
+ Feature {feature_id} (Activation: {feature['activation_value']:.2f})
59
+ </button>
60
+ """
61
+
62
+ html += "</div>"
63
+ initial_dashboard = create_dashboard(first_feature_id) if first_feature_id else ""
64
+
65
+ return html, initial_dashboard
66
 
67
  css = """
68
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
69
+
70
  body { font-family: 'Open Sans', sans-serif !important; }
71
+
72
  .dashboard-container {
73
  border: 1px solid #e0e5ff;
74
  border-radius: 8px;
75
  background-color: #ffffff;
76
  }
77
+
78
+ .features-list h3 {
79
+ margin-top: 1rem;
80
+ font-weight: 600;
81
+ }
82
+
83
+ .feature-button {
84
+ display: block;
85
+ margin: 0.5rem 0;
86
+ padding: 0.5rem 1rem;
87
+ background-color: #f3f4f6;
88
+ border: 1px solid #e5e7eb;
89
+ border-radius: 0.375rem;
90
+ cursor: pointer;
91
+ }
92
+
93
+ .feature-button:hover {
94
+ background-color: #e5e7eb;
95
+ }
96
  """
97
 
98
  theme = gr.themes.Soft(
 
105
  )
106
  )
107
 
108
+ def update_dashboard(feature_id: int):
109
+ return create_dashboard(feature_id)
110
+
111
  with gr.Blocks(theme=theme, css=css) as demo:
112
  gr.Markdown("# Brand Analyzer", elem_classes="text-2xl font-bold mb-2")
113
  gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
 
126
  )
127
 
128
  with gr.Column(scale=2):
129
+ features_html = gr.HTML()
130
  dashboard = gr.HTML()
131
+
132
+ # Handle feature selection via JavaScript
133
+ dashboard.change(fn=update_dashboard, inputs=gr.Textbox(visible=False), outputs=dashboard)
134
 
135
  analyze_btn.click(
136
+ fn=analyze_text,
137
  inputs=[input_text],
138
+ outputs=[features_html, dashboard]
139
  )
140
 
141
  if __name__ == "__main__":