Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import json | |
from typing import Dict, List, Tuple | |
BRAND_EXAMPLES = [ | |
"Nike - Just Do It. The power of determination.", | |
"Apple - Think Different. Innovation redefined.", | |
"McDonald's - I'm Lovin' It. Creating joy.", | |
"BMW - The Ultimate Driving Machine.", | |
"L'Oréal - Because You're Worth It." | |
] | |
def get_top_features(text: str, k: int = 5) -> Dict: | |
url = "https://www.neuronpedia.org/api/search-with-topk" | |
payload = { | |
"modelId": "gemma-2-2b", | |
"layer": "0-gemmascope-mlp-16k", | |
"sourceSet": "gemma-scope", | |
"text": text, | |
"k": k, | |
"maxDensity": 0.01, | |
"ignoreBos": True | |
} | |
response = requests.post( | |
url, | |
headers={"Content-Type": "application/json"}, | |
json=payload | |
) | |
return response.json() if response.status_code == 200 else None | |
def format_output(data: Dict) -> Tuple[str, str, str]: | |
if not data: | |
return "Error analyzing text", "", "" | |
output = "# Neural Feature Analysis\n\n" | |
# Format token-feature analysis | |
for result in data['results']: | |
token = result['token'] | |
if token == '<bos>': # Skip BOS token | |
continue | |
features = result['top_features'] | |
if features: | |
output += f"\n## Token: '{token}'\n" | |
for feat in features: | |
feat_index = feat['feature_index'] | |
activation = feat['activation_value'] | |
output += f"- **Feature {feat_index}**: activation = {activation:.2f}\n" | |
# Get highest activation feature for dashboard | |
max_activation = 0 | |
max_feature = None | |
for result in data['results']: | |
for feature in result['top_features']: | |
if feature['activation_value'] > max_activation: | |
max_activation = feature['activation_value'] | |
max_feature = feature['feature_index'] | |
if max_feature: | |
dashboard_url = f"https://www.neuronpedia.org/gemma-2-2b/0-gemmascope-mlp-16k/{max_feature}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300" | |
iframe = f'<iframe src="{dashboard_url}" width="100%" height="600px" frameborder="0" style="border:1px solid #eee;border-radius:8px;"></iframe>' | |
feature_label = f"Feature {max_feature} Dashboard (Highest Activation: {max_activation:.2f})" | |
else: | |
iframe = "" | |
feature_label = "No significant features found" | |
return output, iframe, feature_label | |
def create_interface(): | |
with gr.Blocks() as interface: | |
gr.Markdown("# Neural Feature Analyzer") | |
gr.Markdown("Analyze text using Gemma's interpretable neural features\n\nShows top 5 most activated features for each token with density < 1%") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
lines=5, | |
placeholder="Enter text to analyze...", | |
label="Input Text" | |
) | |
analyze_btn = gr.Button("Analyze Neural Features", variant="primary") | |
gr.Examples(BRAND_EXAMPLES, inputs=input_text) | |
with gr.Column(): | |
output_text = gr.Markdown() | |
feature_label = gr.Text(show_label=False) | |
dashboard = gr.HTML() | |
analyze_btn.click( | |
fn=lambda text: format_output(get_top_features(text)), | |
inputs=input_text, | |
outputs=[output_text, dashboard, feature_label] | |
) | |
return interface | |
if __name__ == "__main__": | |
create_interface().launch() |