brand-llms / app.py
cyberandy's picture
Update app.py
f643580 verified
raw
history blame
9.9 kB
import gradio as gr
import requests
from typing import Dict, Tuple, List
# Define custom CSS with Open Sans font and color theme
css = """
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
body {
font-family: 'Open Sans', sans-serif !important;
}
.primary-btn {
background-color: #3452db !important;
}
.primary-btn:hover {
background-color: #2a41af !important;
}
.feature-card {
border: 1px solid #e0e5ff;
background-color: #ffffff;
transition: all 0.2s ease;
}
.feature-card:hover {
border-color: #3452db;
box-shadow: 0 2px 4px rgba(52, 82, 219, 0.1);
}
.feature-card.selected {
border: 2px solid #3452db;
background-color: #eef1ff;
}
.show-more-btn {
color: #3452db;
font-weight: 600;
}
.show-more-btn:hover {
color: #2a41af;
}
.token-header {
color: #152156;
font-weight: 700;
}
.dashboard-container {
border: 1px solid #e0e5ff;
border-radius: 8px;
background-color: #ffffff;
}
"""
# Create custom theme
theme = gr.themes.Soft(
primary_hue=gr.themes.colors.Color(
name="blue",
c50="#eef1ff",
c100="#e0e5ff",
c200="#c3cbff",
c300="#a5b2ff",
c400="#8798ff",
c500="#6a7eff",
c600="#3452db",
c700="#2a41af",
c800="#1f3183",
c900="#152156",
c950="#0a102b",
)
)
def get_features(text: str) -> Dict:
"""Get neural features from the API using the exact website parameters."""
url = "https://www.neuronpedia.org/api/search-with-topk"
payload = {
"modelId": "gemma-2-2b",
"text": text,
"layer": "20-gemmascope-res-16k"
}
try:
response = requests.post(
url,
headers={"Content-Type": "application/json"},
json=payload
)
response.raise_for_status()
return response.json()
except Exception as e:
return None
def create_feature_html(feature_id: int, activation: float, selected: bool = False) -> str:
"""Create HTML for an individual feature card."""
selected_class = "selected" if selected else ""
return f"""
<div class="feature-card {selected_class} p-4 rounded-lg mb-4"
data-feature-id="{feature_id}"
onclick="selectFeature(this, {feature_id}, {activation})">
<div class="flex justify-between items-center">
<div>
<span class="font-semibold">Feature {feature_id}</span>
<span class="ml-2 text-gray-600">(Activation: {activation:.2f})</span>
</div>
</div>
</div>
"""
def create_token_section(token: str, features: List[Dict], initial_count: int = 3) -> str:
"""Create HTML for a token section with its features."""
features_html = "".join([
create_feature_html(f['feature_index'], f['activation_value'])
for f in features[:initial_count]
])
show_more = ""
if len(features) > initial_count:
remaining = len(features) - initial_count
hidden_features = "".join([
create_feature_html(f['feature_index'], f['activation_value'])
for f in features[initial_count:]
])
show_more = f"""
<div class="hidden" id="more-features-{token}">{hidden_features}</div>
<button id="toggle-btn-{token}"
class="show-more-btn text-sm mt-2"
onclick="toggleFeatures('{token}')">
Show {remaining} More Features
</button>
"""
return f"""
<div class="mb-6">
<h2 class="token-header text-xl mb-4">Token: {token}</h2>
<div id="features-{token}">
{features_html}
</div>
{show_more}
</div>
"""
def create_dashboard_html(feature_id: int, activation: float) -> str:
"""Create HTML for the feature dashboard."""
return f"""
<div class="dashboard-container p-4">
<h3 class="text-lg font-semibold mb-4 text-gray-900">
Feature {feature_id} Dashboard (Activation: {activation:.2f})
</h3>
<iframe
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
width="100%"
height="600"
frameborder="0"
class="rounded-lg"
></iframe>
</div>
"""
def create_interface_html(data: Dict) -> str:
"""Create the complete interface HTML with JavaScript functionality."""
js_code = """
<script>
function updateDashboard(featureId, activation) {
const dashboardContainer = document.getElementById('dashboard-container');
dashboardContainer.innerHTML = `
<div class="dashboard-container p-4">
<h3 class="text-lg font-semibold mb-4 text-gray-900">
Feature ${featureId} Dashboard (Activation: ${activation.toFixed(2)})
</h3>
<iframe
src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/${featureId}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
width="100%"
height="600"
frameborder="0"
class="rounded-lg"
></iframe>
</div>
`;
}
function selectFeature(element, featureId, activation) {
// Update selected state visually
document.querySelectorAll('.feature-card').forEach(card => {
card.classList.remove('selected');
});
element.classList.add('selected');
// Update dashboard
updateDashboard(featureId, activation);
}
function toggleFeatures(token) {
const moreFeatures = document.getElementById(`more-features-${token}`);
const featuresContainer = document.getElementById(`features-${token}`);
const toggleButton = document.getElementById(`toggle-btn-${token}`);
if (moreFeatures.classList.contains('hidden')) {
// Show additional features
moreFeatures.classList.remove('hidden');
const additionalFeatures = moreFeatures.innerHTML;
featuresContainer.insertAdjacentHTML('beforeend', additionalFeatures);
toggleButton.textContent = 'Show Less';
} else {
// Hide additional features
const allFeatures = featuresContainer.querySelectorAll('.feature-card');
Array.from(allFeatures).slice(3).forEach(card => card.remove());
moreFeatures.classList.add('hidden');
toggleButton.textContent = `Show ${moreFeatures.children.length} More Features`;
}
}
</script>
"""
tokens_html = ""
dashboard_html = ""
first_feature = None
for result in data['results']:
if result['token'] == '<bos>':
continue
tokens_html += create_token_section(result['token'], result['top_features'])
if not first_feature and result['top_features']:
first_feature = result['top_features'][0]
dashboard_html = create_dashboard_html(
first_feature['feature_index'],
first_feature['activation_value']
)
return f"""
<div class="p-6">
{js_code}
<div class="grid grid-cols-1 lg:grid-cols-2 gap-8">
<div class="space-y-6">
{tokens_html}
</div>
<div class="lg:sticky lg:top-6">
<div id="dashboard-container">
{dashboard_html}
</div>
</div>
</div>
</div>
"""
def analyze_features(text: str) -> Tuple[str, str, str]:
data = get_features(text)
if not data:
return "Error analyzing text", "", ""
interface_html = create_interface_html(data)
return interface_html, "", ""
def create_interface():
with gr.Blocks(theme=theme, css=css) as interface:
gr.Markdown(
"# Neural Feature Analyzer",
elem_classes="text-2xl font-bold text-gray-900 mb-2"
)
gr.Markdown(
"*Analyze text using Gemma's interpretable neural features*",
elem_classes="text-gray-600 mb-6"
)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
lines=5,
placeholder="Enter text to analyze...",
label="Input Text",
elem_classes="mb-4"
)
analyze_btn = gr.Button(
"Analyze Features",
variant="primary",
elem_classes="primary-btn"
)
gr.Examples(
["WordLift", "Think Different", "Just Do It"],
inputs=input_text,
elem_classes="mt-4"
)
with gr.Column():
output_html = gr.HTML()
feature_label = gr.Text(show_label=False, visible=False)
dashboard = gr.HTML(visible=False)
analyze_btn.click(
fn=analyze_features,
inputs=input_text,
outputs=[output_html, feature_label, dashboard]
)
return interface
if __name__ == "__main__":
create_interface().launch()