Spaces:
Sleeping
Sleeping
File size: 7,727 Bytes
44c881e 8712c90 321a1b2 5ac398b 321a1b2 5ac398b 321a1b2 5ac398b c3f5f94 321a1b2 5ac398b 321a1b2 5ac398b c3f5f94 321a1b2 c3f5f94 321a1b2 c3f5f94 f643580 321a1b2 f643580 321a1b2 f643580 6465b33 44c881e 321a1b2 f643580 5ac398b 44c881e 5ac398b 44c881e a24593e 5ac398b f643580 5ac398b f643580 7cdea90 44c881e 5ac398b 321a1b2 44c881e 321a1b2 44c881e e53e16b 321a1b2 44c881e 321a1b2 c3f5f94 44c881e 8712c90 44c881e c3f5f94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import gradio as gr
import requests
from typing import Dict, List, Tuple
import json
def get_features(text: str) -> Dict:
"""Get neural features from the API."""
url = "https://www.neuronpedia.org/api/search-with-topk"
payload = {
"modelId": "gemma-2-2b",
"text": text,
"layer": "20-gemmascope-res-16k"
}
try:
response = requests.post(
url,
headers={"Content-Type": "application/json"},
json=payload
)
response.raise_for_status()
return response.json()
except Exception as e:
return None
def format_features(features_data: Dict, expanded_tokens: List[str], selected_feature: Dict) -> str:
"""Format features as HTML with expanded state."""
if not features_data or 'results' not in features_data:
return ""
output = ['<div class="p-6">']
# Process each token's features
for result in features_data['results']:
if result['token'] == '<bos>':
continue
token = result['token']
features = result['top_features']
is_expanded = token in expanded_tokens
feature_count = len(features) if is_expanded else min(3, len(features))
output.append(f'<div class="mb-8"><h2 class="text-xl font-bold mb-4">Token: {token}</h2>')
# Display features
for idx in range(feature_count):
feature = features[idx]
feature_id = feature['feature_index']
activation = feature['activation_value']
is_selected = selected_feature and selected_feature.get('feature_id') == feature_id
selected_class = "border-blue-500 border-2" if is_selected else ""
output.append(f"""
<div class="feature-card p-4 rounded-lg mb-4 hover:border-blue-500 {selected_class}">
<div class="flex justify-between items-center">
<div>
<span class="font-semibold">Feature {feature_id}</span>
<span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span>
</div>
</div>
</div>
""")
# Show more/less button if needed
if len(features) > 3:
action = "less" if is_expanded else f"{len(features) - 3} more"
output.append(f"""
<div class="text-center mb-4">
<button class="text-blue-600 hover:text-blue-800 text-sm"
onclick="gradio('toggle_expansion', '{token}')">
Show {action} features
</button>
</div>
""")
output.append('</div>')
output.append('</div>')
return "\n".join(output)
def format_dashboard(feature: Dict) -> str:
"""Format the feature dashboard."""
if not feature:
return ""
feature_id = feature['feature_id']
activation = feature['activation']
return f"""
<div class="dashboard-container p-4">
<h3 class="text-lg font-semibold mb-4">
Feature {feature_id} Dashboard (Activation: {activation:.2f})
</h3>
<iframe
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
width="100%"
height="600"
frameborder="0"
class="rounded-lg"
></iframe>
</div>
"""
def analyze_features(text: str, state: Dict) -> Tuple[str, str, Dict]:
"""Process text and update state."""
if not text:
return "", "", state
features_data = get_features(text)
if not features_data:
return "Error analyzing text", "", state
# Update state
state['features_data'] = features_data
if not state.get('expanded_tokens'):
state['expanded_tokens'] = []
# Select first feature by default if none selected
if not state.get('selected_feature'):
for result in features_data['results']:
if result['token'] != '<bos>' and result['top_features']:
first_feature = result['top_features'][0]
state['selected_feature'] = {
'feature_id': first_feature['feature_index'],
'activation': first_feature['activation_value']
}
break
features_html = format_features(features_data, state['expanded_tokens'], state['selected_feature'])
dashboard_html = format_dashboard(state['selected_feature'])
return features_html, dashboard_html, state
def toggle_expansion(token: str, state: Dict) -> Tuple[str, str, Dict]:
"""Toggle expansion state for a token."""
if token in state['expanded_tokens']:
state['expanded_tokens'].remove(token)
else:
state['expanded_tokens'].append(token)
features_html = format_features(state['features_data'], state['expanded_tokens'], state['selected_feature'])
dashboard_html = format_dashboard(state['selected_feature'])
return features_html, dashboard_html, state
css = """
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
body {
font-family: 'Open Sans', sans-serif !important;
}
.feature-card {
border: 1px solid #e0e5ff;
background-color: #ffffff;
transition: all 0.2s ease;
}
.feature-card:hover {
box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1);
}
.dashboard-container {
border: 1px solid #e0e5ff;
border-radius: 8px;
background-color: #ffffff;
}
"""
theme = gr.themes.Soft(
primary_hue=gr.themes.colors.Color(
name="blue",
c50="#eef1ff", c100="#e0e5ff", c200="#c3cbff",
c300="#a5b2ff", c400="#8798ff", c500="#6a7eff",
c600="#3452db", c700="#2a41af", c800="#1f3183",
c900="#152156", c950="#0a102b",
)
)
def create_interface():
# Initialize state
state = gr.State({
'features_data': None,
'expanded_tokens': [],
'selected_feature': None
})
with gr.Blocks(theme=theme, css=css) as interface:
gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2")
gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
with gr.Row():
with gr.Column(scale=1):
input_text = gr.Textbox(
lines=5,
placeholder="Enter text to analyze...",
label="Input Text"
)
analyze_btn = gr.Button("Analyze Features", variant="primary")
gr.Examples(
examples=["WordLift", "Think Different", "Just Do It"],
inputs=input_text
)
with gr.Column(scale=2):
features_html = gr.HTML()
dashboard_html = gr.HTML()
# Event handlers
analyze_btn.click(
fn=analyze_features,
inputs=[input_text, state],
outputs=[features_html, dashboard_html, state]
)
# Custom JavaScript function for token expansion
interface.load(None, None, None, _js="""
function toggle_expansion(token) {
// Function will be called from HTML onclick
}
""")
return interface
if __name__ == "__main__":
create_interface().launch(share=True) |