brand-llms / app.py
cyberandy's picture
update
01d3df7
raw
history blame
10.3 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import logging
# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class MarketingFeature:
"""Structure to hold marketing-relevant feature information"""
feature_id: int
name: str
category: str
description: str
interpretation_guide: str
layer: int
threshold: float = 0.1
# Define marketing-relevant features from Gemma Scope
MARKETING_FEATURES = [
MarketingFeature(
feature_id=35,
name="Technical Term Detector",
category="technical",
description="Detects technical and specialized terminology",
interpretation_guide="High activation indicates strong technical focus",
layer=6, # Adjusted for Gemma-2B structure
),
MarketingFeature(
feature_id=6680,
name="Compound Technical Terms",
category="technical",
description="Identifies complex technical concepts",
interpretation_guide="Consider simplifying language if activation is too high",
layer=6, # Adjusted for Gemma-2B structure
),
MarketingFeature(
feature_id=2,
name="SEO Keyword Detector",
category="seo",
description="Identifies potential SEO keywords",
interpretation_guide="High activation suggests strong SEO potential",
layer=6, # Adjusted for Gemma-2B structure
),
]
class MarketingAnalyzer:
"""Main class for analyzing marketing content using Gemma Scope"""
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Store model size as instance variable
self.model_size = "2b"
self._initialize_model()
self._load_saes()
def _initialize_model(self):
"""Initialize Gemma model and tokenizer"""
try:
model_name = f"google/gemma-{self.model_size}"
# Initialize model and tokenizer with token from environment
self.model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model.eval()
logger.info(f"Initialized model: {model_name}")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
def _load_saes(self):
"""Load relevant SAEs from Gemma Scope"""
self.saes = {}
for feature in MARKETING_FEATURES:
try:
# Load SAE parameters for each feature
path = hf_hub_download(
repo_id=f"google/gemma-scope-{self.model_size}-pt-res",
filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz",
)
params = np.load(path)
self.saes[feature.feature_id] = {
"params": {
k: torch.from_numpy(v).to(self.device)
for k, v in params.items()
},
"feature": feature,
}
logger.info(f"Loaded SAE for feature {feature.feature_id}")
except Exception as e:
logger.error(
f"Error loading SAE for feature {feature.feature_id}: {str(e)}"
)
continue
def analyze_content(self, text: str) -> Dict:
"""Analyze marketing content using loaded SAEs"""
results = {
"text": text,
"features": {},
"categories": {},
"recommendations": [],
}
try:
# Get model activations
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
# Analyze each feature
for feature_id, sae_data in self.saes.items():
feature = sae_data["feature"]
layer_output = outputs.hidden_states[feature.layer]
# Apply SAE
activations = self._apply_sae(
layer_output, sae_data["params"], feature.threshold
)
# Skip BOS token and handle empty activations
activations = activations[:, 1:] # Skip BOS token
if activations.numel() > 0:
mean_activation = float(activations.mean())
max_activation = float(activations.max())
else:
mean_activation = 0.0
max_activation = 0.0
# Record results
feature_result = {
"name": feature.name,
"category": feature.category,
"activation_score": mean_activation,
"max_activation": max_activation,
"interpretation": self._interpret_activation(
mean_activation, feature
),
}
results["features"][feature_id] = feature_result
# Aggregate by category
if feature.category not in results["categories"]:
results["categories"][feature.category] = []
results["categories"][feature.category].append(feature_result)
# Generate recommendations
results["recommendations"] = self._generate_recommendations(results)
except Exception as e:
logger.error(f"Error analyzing content: {str(e)}")
raise
return results
def _apply_sae(
self,
activations: torch.Tensor,
sae_params: Dict[str, torch.Tensor],
threshold: float,
) -> torch.Tensor:
"""Apply SAE to get feature activations"""
pre_acts = activations @ sae_params["W_enc"] + sae_params["b_enc"]
mask = pre_acts > sae_params["threshold"]
acts = mask * torch.nn.functional.relu(pre_acts)
return acts
def _interpret_activation(
self, activation: float, feature: MarketingFeature
) -> str:
"""Interpret activation patterns for a feature"""
if activation > 0.8:
return f"Very strong presence of {feature.name.lower()}"
elif activation > 0.5:
return f"Moderate presence of {feature.name.lower()}"
else:
return f"Limited presence of {feature.name.lower()}"
def _generate_recommendations(self, results: Dict) -> List[str]:
"""Generate content recommendations based on analysis"""
recommendations = []
try:
# Get technical features
tech_features = [
f for f in results["features"].values() if f["category"] == "technical"
]
# Calculate average technical score if we have features
if tech_features:
tech_score = np.mean([f["activation_score"] for f in tech_features])
if tech_score > 0.8:
recommendations.append(
"Consider simplifying technical language for broader audience"
)
elif tech_score < 0.3:
recommendations.append(
"Could benefit from more specific technical details"
)
except Exception as e:
logger.error(f"Error generating recommendations: {str(e)}")
return recommendations
def create_gradio_interface():
"""Create Gradio interface for marketing analysis"""
try:
analyzer = MarketingAnalyzer()
except Exception as e:
logger.error(f"Failed to initialize analyzer: {str(e)}")
return gr.Interface(
fn=lambda x: "Error: Failed to initialize model. Please check authentication.",
inputs=gr.Textbox(),
outputs=gr.Textbox(),
title="Marketing Content Analyzer (Error)",
description="Failed to initialize. Please check if HF_TOKEN is properly set.",
)
def analyze(text):
results = analyzer.analyze_content(text)
# Format results for display
output = "Content Analysis Results\n\n"
# Overall category scores
output += "Category Scores:\n"
for category, features in results["categories"].items():
if features: # Check if we have features for this category
avg_score = np.mean([f["activation_score"] for f in features])
output += f"{category.title()}: {avg_score:.2f}\n"
# Feature details
output += "\nFeature Details:\n"
for feature_id, feature in results["features"].items():
output += f"\n{feature['name']}:\n"
output += f"Score: {feature['activation_score']:.2f}\n"
output += f"Interpretation: {feature['interpretation']}\n"
# Recommendations
if results["recommendations"]:
output += "\nRecommendations:\n"
for rec in results["recommendations"]:
output += f"- {rec}\n"
return output
# Create interface with custom theming
custom_theme = gr.themes.Soft(
primary_hue="indigo", secondary_hue="blue", neutral_hue="gray"
)
interface = gr.Interface(
fn=analyze,
inputs=gr.Textbox(
lines=5,
placeholder="Enter your marketing content here...",
label="Marketing Content",
),
outputs=gr.Textbox(label="Analysis Results"),
title="Marketing Content Analyzer",
description="Analyze your marketing content using Gemma Scope's neural features",
examples=[
["WordLift is an AI-powered SEO tool"],
["Our advanced machine learning algorithms optimize your content"],
["Simple and effective website optimization"],
],
theme=custom_theme,
)
return interface
if __name__ == "__main__":
iface = create_gradio_interface()
iface.launch()