brand-llms / app.py
cyberandy's picture
Update app.py
e53e16b verified
raw
history blame
7.68 kB
import gradio as gr
import requests
from typing import Dict, Tuple, List
import json
def get_features(text: str) -> Dict:
"""Get neural features from the API using the exact website parameters."""
url = "https://www.neuronpedia.org/api/search-with-topk"
payload = {
"modelId": "gemma-2-2b",
"text": text,
"layer": "20-gemmascope-res-16k"
}
try:
response = requests.post(
url,
headers={"Content-Type": "application/json"},
json=payload
)
response.raise_for_status()
return response.json()
except Exception as e:
return None
def create_feature_html(feature_id: int, activation: float, selected: bool = False) -> str:
"""Create HTML for an individual feature card."""
border_class = "border-blue-500 border-2" if selected else "border border-gray-200"
return f"""
<div class="feature-card mb-4 {border_class} rounded-lg shadow hover:shadow-md transition-all cursor-pointer p-4"
data-feature-id="{feature_id}" onclick="selectFeature(this, {feature_id})">
<div class="flex justify-between items-center">
<div>
<span class="font-semibold">Feature {feature_id}</span>
<span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span>
</div>
</div>
</div>
"""
def create_token_section(token: str, features: List[Dict], initial_count: int = 3) -> str:
"""Create HTML for a token section with its features."""
features_html = "".join([
create_feature_html(f['feature_index'], f['activation_value'])
for f in features[:initial_count]
])
show_more = ""
if len(features) > initial_count:
remaining = len(features) - initial_count
hidden_features = "".join([
create_feature_html(f['feature_index'], f['activation_value'])
for f in features[initial_count:]
])
show_more = f"""
<div class="hidden" id="more-features-{token}">{hidden_features}</div>
<button class="text-blue-500 hover:text-blue-700 text-sm mt-2"
onclick="toggleFeatures('{token}', this)">
Show {remaining} More Features
</button>
"""
return f"""
<div class="mb-6">
<h2 class="text-xl font-bold mb-4">Token: {token}</h2>
<div id="features-{token}">
{features_html}
</div>
{show_more}
</div>
"""
def create_dashboard_html(feature_id: int, activation: float) -> str:
"""Create HTML for the feature dashboard."""
return f"""
<div class="border border-gray-200 rounded-lg p-4">
<h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard (Activation: {activation:.2f})</h3>
<iframe
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
width="100%"
height="600"
frameborder="0"
class="rounded-lg"
></iframe>
</div>
"""
def create_interface_html(data: Dict) -> str:
"""Create the complete interface HTML."""
js_code = """
<script>
function selectFeature(element, featureId) {
// Remove selection from all features
document.querySelectorAll('.feature-card').forEach(card => {
card.classList.remove('border-blue-500', 'border-2');
card.classList.add('border', 'border-gray-200');
});
// Add selection to clicked feature
element.classList.remove('border', 'border-gray-200');
element.classList.add('border-blue-500', 'border-2');
// Update dashboard
document.getElementById('dashboard-container').innerHTML =
`<iframe src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/${featureId}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
width="100%" height="600" frameborder="0" class="rounded-lg"></iframe>`;
}
function toggleFeatures(token, button) {
const moreFeatures = document.getElementById(`more-features-${token}`);
const featuresContainer = document.getElementById(`features-${token}`);
if (moreFeatures.classList.contains('hidden')) {
moreFeatures.classList.remove('hidden');
featuresContainer.innerHTML += moreFeatures.innerHTML;
button.textContent = 'Show Less';
} else {
const allFeatures = featuresContainer.querySelectorAll('.feature-card');
for (let i = 3; i < allFeatures.length; i++) {
allFeatures[i].remove();
}
moreFeatures.classList.add('hidden');
button.textContent = `Show ${moreFeatures.querySelectorAll('.feature-card').length} More Features`;
}
}
</script>
"""
tokens_html = ""
dashboard_html = ""
first_feature = None
for result in data['results']:
if result['token'] == '<bos>':
continue
tokens_html += create_token_section(result['token'], result['top_features'])
if not first_feature and result['top_features']:
first_feature = result['top_features'][0]
dashboard_html = create_dashboard_html(
first_feature['feature_index'],
first_feature['activation_value']
)
return f"""
<div class="p-6">
{js_code}
<div class="grid grid-cols-1 lg:grid-cols-2 gap-8">
<div class="space-y-6">
{tokens_html}
</div>
<div class="lg:sticky lg:top-6">
<div id="dashboard-container">
{dashboard_html}
</div>
</div>
</div>
</div>
"""
def analyze_features(text: str) -> Tuple[str, str, str]:
data = get_features(text)
if not data:
return "Error analyzing text", "", ""
interface_html = create_interface_html(data)
return interface_html, "", ""
def create_interface():
with gr.Blocks(css="") as interface:
gr.Markdown("# Neural Feature Analyzer")
gr.Markdown("*Analyze text using Gemma's interpretable neural features*")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
lines=5,
placeholder="Enter text to analyze...",
label="Input Text"
)
analyze_btn = gr.Button("Analyze Features", variant="primary")
gr.Examples([
"WordLift",
"Think Different",
"Just Do It"
], inputs=input_text)
with gr.Column():
output_html = gr.HTML()
feature_label = gr.Text(show_label=False, visible=False)
dashboard = gr.HTML(visible=False)
analyze_btn.click(
fn=analyze_features,
inputs=input_text,
outputs=[output_html, feature_label, dashboard]
)
return interface
if __name__ == "__main__":
create_interface().launch()