Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
from typing import Dict, Tuple, List | |
import json | |
from dataclasses import dataclass | |
from typing import Optional | |
class Feature: | |
feature_id: int | |
activation: float | |
token: str | |
position: int | |
class FeatureState: | |
def __init__(self): | |
self.features_by_token = {} | |
self.expanded_tokens = set() | |
self.selected_feature = None | |
def get_features(text: str) -> Dict: | |
"""Get neural features from the API using the exact website parameters.""" | |
url = "https://www.neuronpedia.org/api/search-with-topk" | |
payload = { | |
"modelId": "gemma-2-2b", | |
"text": text, | |
"layer": "20-gemmascope-res-16k" | |
} | |
try: | |
response = requests.post( | |
url, | |
headers={"Content-Type": "application/json"}, | |
json=payload | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
return None | |
def format_feature_list(features: List[Feature], token: str, expanded: bool = False) -> str: | |
"""Format features as HTML list.""" | |
display_features = features if expanded else features[:3] | |
features_html = "" | |
for feature in display_features: | |
features_html += f""" | |
<div class="feature-card p-4 rounded-lg mb-4 cursor-pointer hover:border-blue-500" | |
data-feature-id="{feature.feature_id}"> | |
<div class="flex justify-between items-center"> | |
<div> | |
<span class="font-semibold">Feature {feature.feature_id}</span> | |
<span class="ml-2 text-gray-600">(Activation: {feature.activation:.2f})</span> | |
</div> | |
</div> | |
</div> | |
""" | |
if not expanded and len(features) > 3: | |
remaining = len(features) - 3 | |
features_html += f""" | |
<div class="text-center"> | |
<span class="text-blue-500 text-sm">{remaining} more features available</span> | |
</div> | |
""" | |
return features_html | |
def format_dashboard(feature: Feature) -> str: | |
"""Format the dashboard HTML for a selected feature.""" | |
if not feature: | |
return "" | |
return f""" | |
<div class="dashboard-container p-4"> | |
<h3 class="text-lg font-semibold mb-4 text-gray-900"> | |
Feature {feature.feature_id} Dashboard (Activation: {feature.activation:.2f}) | |
</h3> | |
<iframe | |
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature.feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300" | |
width="100%" | |
height="600" | |
frameborder="0" | |
class="rounded-lg" | |
></iframe> | |
</div> | |
""" | |
def process_features(data: Dict) -> Dict[str, List[Feature]]: | |
"""Process API response into features grouped by token.""" | |
features_by_token = {} | |
for result in data.get('results', []): | |
if result['token'] == '<bos>': | |
continue | |
token = result['token'] | |
features = [] | |
for idx, feature in enumerate(result.get('top_features', [])): | |
features.append(Feature( | |
feature_id=feature['feature_index'], | |
activation=feature['activation_value'], | |
token=token, | |
position=idx | |
)) | |
features_by_token[token] = features | |
return features_by_token | |
css = """ | |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap'); | |
body { | |
font-family: 'Open Sans', sans-serif !important; | |
} | |
.feature-card { | |
border: 1px solid #e0e5ff; | |
background-color: #ffffff; | |
transition: all 0.2s ease; | |
} | |
.feature-card:hover { | |
border-color: #3452db; | |
box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1); | |
} | |
.dashboard-container { | |
border: 1px solid #e0e5ff; | |
border-radius: 8px; | |
background-color: #ffffff; | |
} | |
""" | |
theme = gr.themes.Soft( | |
primary_hue=gr.themes.colors.Color( | |
name="blue", | |
c50="#eef1ff", | |
c100="#e0e5ff", | |
c200="#c3cbff", | |
c300="#a5b2ff", | |
c400="#8798ff", | |
c500="#6a7eff", | |
c600="#3452db", | |
c700="#2a41af", | |
c800="#1f3183", | |
c900="#152156", | |
c950="#0a102b", | |
) | |
) | |
def analyze_features(text: str, state: Optional[Dict] = None) -> Tuple[str, Dict]: | |
"""Main analysis function that processes text and returns formatted output.""" | |
if not text: | |
return "", None | |
data = get_features(text) | |
if not data: | |
return "Error analyzing text", None | |
# Process features and build state | |
features_by_token = process_features(data) | |
# Initialize state if needed | |
if not state: | |
state = { | |
'features_by_token': features_by_token, | |
'expanded_tokens': set(), | |
'selected_feature': None | |
} | |
# Select first feature as default | |
first_token = next(iter(features_by_token)) | |
if features_by_token[first_token]: | |
state['selected_feature'] = features_by_token[first_token][0] | |
# Build output HTML | |
output = [] | |
for token, features in features_by_token.items(): | |
expanded = token in state['expanded_tokens'] | |
token_html = f"<h2 class='text-xl font-bold mb-4'>Token: {token}</h2>" | |
features_html = format_feature_list(features, token, expanded) | |
output.append(f"<div class='mb-6'>{token_html}{features_html}</div>") | |
# Add dashboard if a feature is selected | |
if state['selected_feature']: | |
output.append(format_dashboard(state['selected_feature'])) | |
return "\n".join(output), state | |
def toggle_expansion(token: str, state: Dict) -> Tuple[str, Dict]: | |
"""Toggle expansion state for a token's features.""" | |
if token in state['expanded_tokens']: | |
state['expanded_tokens'].remove(token) | |
else: | |
state['expanded_tokens'].add(token) | |
output_html, state = analyze_features(None, state) | |
return output_html, state | |
def select_feature(feature_id: int, state: Dict) -> Tuple[str, Dict]: | |
"""Select a feature and update the dashboard.""" | |
for features in state['features_by_token'].values(): | |
for feature in features: | |
if feature.feature_id == feature_id: | |
state['selected_feature'] = feature | |
break | |
output_html, state = analyze_features(None, state) | |
return output_html, state | |
def create_interface(): | |
state = gr.State({}) | |
with gr.Blocks(theme=theme, css=css) as interface: | |
gr.Markdown("# Neural Feature Analyzer", elem_classes="text-2xl font-bold mb-2") | |
gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_text = gr.Textbox( | |
lines=5, | |
placeholder="Enter text to analyze...", | |
label="Input Text" | |
) | |
analyze_btn = gr.Button("Analyze Features", variant="primary") | |
gr.Examples( | |
examples=["WordLift", "Think Different", "Just Do It"], | |
inputs=input_text | |
) | |
with gr.Column(scale=2): | |
output = gr.HTML() | |
# Event handlers | |
analyze_btn.click( | |
fn=analyze_features, | |
inputs=[input_text, state], | |
outputs=[output, state] | |
) | |
return interface | |
if __name__ == "__main__": | |
create_interface().launch() |