Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
from typing import Dict, Tuple, List | |
# Define custom CSS with Open Sans font and color theme | |
css = """ | |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap'); | |
body { | |
font-family: 'Open Sans', sans-serif !important; | |
} | |
.primary-btn { | |
background-color: #3452db !important; | |
} | |
.primary-btn:hover { | |
background-color: #2a41af !important; | |
} | |
.feature-card { | |
border: 1px solid #e0e5ff; | |
background-color: #ffffff; | |
transition: all 0.2s ease; | |
} | |
.feature-card:hover { | |
border-color: #3452db; | |
box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1); | |
} | |
.feature-card.selected { | |
border: 2px solid #3452db; | |
background-color: #eef1ff; | |
} | |
.show-more-btn { | |
color: #3452db; | |
font-weight: 600; | |
} | |
.show-more-btn:hover { | |
color: #2a41af; | |
} | |
.token-header { | |
color: #152156; | |
font-weight: 700; | |
} | |
.dashboard-container { | |
border: 1px solid #e0e5ff; | |
border-radius: 8px; | |
background-color: #ffffff; | |
} | |
""" | |
# Create custom theme | |
theme = gr.themes.Soft( | |
primary_hue=gr.themes.colors.Color( | |
name="blue", | |
c50="#eef1ff", | |
c100="#e0e5ff", | |
c200="#c3cbff", | |
c300="#a5b2ff", | |
c400="#8798ff", | |
c500="#6a7eff", | |
c600="#3452db", | |
c700="#2a41af", | |
c800="#1f3183", | |
c900="#152156", | |
c950="#0a102b", | |
) | |
) | |
def get_features(text: str) -> Dict: | |
"""Get neural features from the API using the exact website parameters.""" | |
url = "https://www.neuronpedia.org/api/search-with-topk" | |
payload = { | |
"modelId": "gemma-2-2b", | |
"text": text, | |
"layer": "20-gemmascope-res-16k" | |
} | |
try: | |
response = requests.post( | |
url, | |
headers={"Content-Type": "application/json"}, | |
json=payload | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
return None | |
def create_feature_html(feature_id: int, activation: float, selected: bool = False) -> str: | |
"""Create HTML for an individual feature card.""" | |
selected_class = "selected" if selected else "" | |
return f""" | |
<div class="feature-card {selected_class} p-4 rounded-lg mb-4" | |
data-feature-id="{feature_id}" | |
onclick="selectFeature(this, {feature_id}, {activation})"> | |
<div class="flex justify-between items-center"> | |
<div> | |
<span class="font-semibold">Feature {feature_id}</span> | |
<span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span> | |
</div> | |
</div> | |
</div> | |
""" | |
def create_token_section(token: str, features: List[Dict], initial_count: int = 3) -> str: | |
"""Create HTML for a token section with its features.""" | |
features_html = "".join([ | |
create_feature_html(f['feature_index'], f['activation_value']) | |
for f in features[:initial_count] | |
]) | |
show_more = "" | |
if len(features) > initial_count: | |
remaining = len(features) - initial_count | |
hidden_features = "".join([ | |
create_feature_html(f['feature_index'], f['activation_value']) | |
for f in features[initial_count:] | |
]) | |
show_more = f""" | |
<div class="hidden" id="more-features-{token}">{hidden_features}</div> | |
<button id="toggle-btn-{token}" | |
class="show-more-btn text-sm mt-2" | |
onclick="toggleFeatures('{token}')"> | |
Show {remaining} More Features | |
</button> | |
""" | |
return f""" | |
<div class="mb-6"> | |
<h2 class="token-header text-xl mb-4">Token: {token}</h2> | |
<div id="features-{token}"> | |
{features_html} | |
</div> | |
{show_more} | |
</div> | |
""" | |
def create_dashboard_html(feature_id: int, activation: float) -> str: | |
"""Create HTML for the feature dashboard.""" | |
return f""" | |
<div class="dashboard-container p-4"> | |
<h3 class="text-lg font-semibold mb-4 text-gray-900"> | |
Feature {feature_id} Dashboard (Activation: {activation:.2f}) | |
</h3> | |
<iframe | |
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300" | |
width="100%" | |
height="600" | |
frameborder="0" | |
class="rounded-lg" | |
></iframe> | |
</div> | |
""" | |
def create_interface_html(data: Dict) -> str: | |
"""Create the complete interface HTML with JavaScript functionality.""" | |
js_code = """ | |
<script> | |
function updateDashboard(featureId, activation) { | |
const dashboardContainer = document.getElementById('dashboard-container'); | |
dashboardContainer.innerHTML = ` | |
<div class="dashboard-container p-4"> | |
<h3 class="text-lg font-semibold mb-4 text-gray-900"> | |
Feature ${featureId} Dashboard (Activation: ${activation.toFixed(2)}) | |
</h3> | |
<iframe | |
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/${featureId}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300" | |
width="100%" | |
height="600" | |
frameborder="0" | |
class="rounded-lg" | |
></iframe> | |
</div> | |
`; | |
} | |
function selectFeature(element, featureId, activation) { | |
// Update selected state visually | |
document.querySelectorAll('.feature-card').forEach(card => { | |
card.classList.remove('selected'); | |
}); | |
element.classList.add('selected'); | |
// Update dashboard | |
updateDashboard(featureId, activation); | |
} | |
function toggleFeatures(token) { | |
const moreFeatures = document.getElementById(`more-features-${token}`); | |
const featuresContainer = document.getElementById(`features-${token}`); | |
const toggleButton = document.getElementById(`toggle-btn-${token}`); | |
if (moreFeatures.classList.contains('hidden')) { | |
// Show additional features | |
moreFeatures.classList.remove('hidden'); | |
const additionalFeatures = moreFeatures.innerHTML; | |
featuresContainer.insertAdjacentHTML('beforeend', additionalFeatures); | |
toggleButton.textContent = 'Show Less'; | |
} else { | |
// Hide additional features | |
const allFeatures = featuresContainer.querySelectorAll('.feature-card'); | |
Array.from(allFeatures).slice(3).forEach(card => card.remove()); | |
moreFeatures.classList.add('hidden'); | |
toggleButton.textContent = `Show ${moreFeatures.children.length} More Features`; | |
} | |
} | |
</script> | |
""" | |
tokens_html = "" | |
dashboard_html = "" | |
first_feature = None | |
for result in data['results']: | |
if result['token'] == '<bos>': | |
continue | |
tokens_html += create_token_section(result['token'], result['top_features']) | |
if not first_feature and result['top_features']: | |
first_feature = result['top_features'][0] | |
dashboard_html = create_dashboard_html( | |
first_feature['feature_index'], | |
first_feature['activation_value'] | |
) | |
return f""" | |
<div class="p-6"> | |
{js_code} | |
<div class="grid grid-cols-1 lg:grid-cols-2 gap-8"> | |
<div class="space-y-6"> | |
{tokens_html} | |
</div> | |
<div class="lg:sticky lg:top-6"> | |
<div id="dashboard-container"> | |
{dashboard_html} | |
</div> | |
</div> | |
</div> | |
</div> | |
""" | |
def analyze_features(text: str) -> Tuple[str, str, str]: | |
data = get_features(text) | |
if not data: | |
return "Error analyzing text", "", "" | |
interface_html = create_interface_html(data) | |
return interface_html, "", "" | |
def create_interface(): | |
with gr.Blocks(theme=theme, css=css) as interface: | |
gr.Markdown( | |
"# Neural Feature Analyzer", | |
elem_classes="text-2xl font-bold text-gray-900 mb-2" | |
) | |
gr.Markdown( | |
"*Analyze text using Gemma's interpretable neural features*", | |
elem_classes="text-gray-600 mb-6" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
lines=5, | |
placeholder="Enter text to analyze...", | |
label="Input Text", | |
elem_classes="mb-4" | |
) | |
analyze_btn = gr.Button( | |
"Analyze Features", | |
variant="primary", | |
elem_classes="primary-btn" | |
) | |
gr.Examples( | |
["WordLift", "Think Different", "Just Do It"], | |
inputs=input_text, | |
elem_classes="mt-4" | |
) | |
with gr.Column(): | |
output_html = gr.HTML() | |
feature_label = gr.Text(show_label=False, visible=False) | |
dashboard = gr.HTML(visible=False) | |
analyze_btn.click( | |
fn=analyze_features, | |
inputs=input_text, | |
outputs=[output_html, feature_label, dashboard] | |
) | |
return interface | |
if __name__ == "__main__": | |
create_interface().launch() |