import gradio as gr
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class MarketingFeature:
    """Structure to hold marketing-relevant feature information"""

    feature_id: int
    name: str
    category: str
    description: str
    interpretation_guide: str
    layer: int
    threshold: float = 0.1


# Define relevant features
MARKETING_FEATURES = [
    MarketingFeature(
        feature_id=35,
        name="Technical Term Detector",
        category="technical",
        description="Detects technical and specialized terminology",
        interpretation_guide="High activation indicates strong technical focus",
        layer=20,
    ),
    MarketingFeature(
        feature_id=6680,
        name="Compound Technical Terms",
        category="technical",
        description="Identifies complex technical concepts",
        interpretation_guide="Consider simplifying language if activation is too high",
        layer=20,
    ),
    MarketingFeature(
        feature_id=2,
        name="SEO Keyword Detector",
        category="seo",
        description="Identifies potential SEO keywords",
        interpretation_guide="High activation suggests strong SEO potential",
        layer=20,
    ),
]


class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = pre_acts > self.threshold
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon


class MarketingAnalyzer:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.set_grad_enabled(False)  # Avoid memory issues
        self._initialize_model()

    def _initialize_model(self):
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                "google/gemma-2-2b", device_map="auto"
            )
            self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
            self.model.eval()
            logger.info("Model initialized successfully")
        except Exception as e:
            logger.error(f"Error initializing model: {str(e)}")
            raise

    def _load_sae(self, feature_id: int, layer: int = 20):
        """Dynamically load a single SAE"""
        try:
            path = hf_hub_download(
                repo_id="google/gemma-scope-2b-pt-res",
                filename=f"layer_{layer}/width_16k/average_l0_71/params.npz",
                force_download=False,
            )
            params = np.load(path)

            # Create SAE
            d_model = params["W_enc"].shape[0]
            d_sae = params["W_enc"].shape[1]
            sae = JumpReLUSAE(d_model, d_sae).to(self.device)

            # Load parameters
            sae_params = {
                k: torch.from_numpy(v).to(self.device) for k, v in params.items()
            }
            sae.load_state_dict(sae_params)

            return sae
        except Exception as e:
            logger.error(f"Error loading SAE for feature {feature_id}: {str(e)}")
            return None

    def _gather_activations(self, text: str, layer: int):
        inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
        target_act = None

        def hook(mod, inputs, outputs):
            nonlocal target_act
            target_act = outputs[0]
            return outputs

        handle = self.model.model.layers[layer].register_forward_hook(hook)
        with torch.no_grad():
            _ = self.model(**inputs)
        handle.remove()

        return target_act, inputs

    def _get_feature_activations(self, text: str, sae, layer: int = 20):
        """Get activations for a single feature"""
        activations, _ = self._gather_activations(text, layer)
        sae_acts = sae.encode(activations.to(torch.float32))
        sae_acts = sae_acts[:, 1:]  # Skip BOS token

        if sae_acts.numel() > 0:
            mean_activation = float(sae_acts.mean())
            max_activation = float(sae_acts.max())
        else:
            mean_activation = 0.0
            max_activation = 0.0

        return mean_activation, max_activation

    def analyze_content(self, text: str) -> Dict:
        """Analyze content and find most relevant features"""
        results = {
            "text": text,
            "features": {},
            "categories": {},
            "recommendations": [],
        }

        try:
            # Start with a set of potential features to explore
            feature_pool = list(range(1, 16385))  # Full range of features
            sample_size = 50  # Number of features to sample
            sampled_features = np.random.choice(
                feature_pool, sample_size, replace=False
            )

            # Test each feature
            feature_activations = []
            for feature_id in sampled_features:
                sae = self._load_sae(feature_id)
                if sae is None:
                    continue

                mean_activation, max_activation = self._get_feature_activations(
                    text, sae
                )
                feature_activations.append(
                    {
                        "feature_id": feature_id,
                        "mean_activation": mean_activation,
                        "max_activation": max_activation,
                    }
                )

            # Sort by activation and take top features
            top_features = sorted(
                feature_activations, key=lambda x: x["max_activation"], reverse=True
            )[
                :3
            ]  # Keep top 3 features

            # Analyze top features in detail
            for feature_data in top_features:
                feature_id = feature_data["feature_id"]

                # Get neuronpedia data if available (this would be a placeholder)
                feature_name = f"Feature {feature_id}"
                feature_category = "neural"  # Default category

                feature_result = {
                    "name": feature_name,
                    "category": feature_category,
                    "activation_score": feature_data["mean_activation"],
                    "max_activation": feature_data["max_activation"],
                    "interpretation": self._interpret_activation(
                        feature_data["mean_activation"], feature_id
                    ),
                }

                results["features"][feature_id] = feature_result

                if feature_category not in results["categories"]:
                    results["categories"][feature_category] = []
                results["categories"][feature_category].append(feature_result)

            # Generate recommendations based on activations
            if top_features:
                max_activation = max(f["max_activation"] for f in top_features)
                if max_activation > 0.8:
                    results["recommendations"].append(
                        f"Strong activation detected in feature {top_features[0]['feature_id']}. "
                        "Consider exploring this aspect further."
                    )
                elif max_activation < 0.3:
                    results["recommendations"].append(
                        "Low feature activations overall. Content might benefit from more distinctive elements."
                    )

        except Exception as e:
            logger.error(f"Error analyzing content: {str(e)}")
            raise

        return results

    def _interpret_activation(self, activation: float, feature_id: int) -> str:
        """Interpret activation levels for a feature"""
        if activation > 0.8:
            return f"Very strong activation of feature {feature_id}"
        elif activation > 0.5:
            return f"Moderate activation of feature {feature_id}"
        else:
            return f"Limited activation of feature {feature_id}"


def create_gradio_interface():
    try:
        analyzer = MarketingAnalyzer()
    except Exception as e:
        logger.error(f"Failed to initialize analyzer: {str(e)}")
        return gr.Interface(
            fn=lambda x: "Error: Failed to initialize model. Please check authentication.",
            inputs=gr.Textbox(),
            outputs=gr.Textbox(),
            title="Marketing Content Analyzer (Error)",
            description="Failed to initialize.",
        )

    def analyze(text):
        results = analyzer.analyze_content(text)

        output = "# Content Analysis Results\n\n"

        output += "## Category Scores\n"
        for category, features in results["categories"].items():
            if features:
                avg_score = np.mean([f["activation_score"] for f in features])
                output += f"**{category.title()}**: {avg_score:.2f}\n"

        output += "\n## Feature Details\n"
        for feature_id, feature in results["features"].items():
            output += f"\n### {feature['name']} (Feature {feature_id})\n"
            output += f"**Score**: {feature['activation_score']:.2f}\n\n"
            output += f"**Interpretation**: {feature['interpretation']}\n\n"
            # Add feature explanation from Neuronpedia reference
            output += f"[View feature details on Neuronpedia](https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id})\n\n"

        if results["recommendations"]:
            output += "\n## Recommendations\n"
            for rec in results["recommendations"]:
                output += f"- {rec}\n"

        feature_id = max(
            results["features"].items(), key=lambda x: x[1]["activation_score"]
        )[0]

        # Build dashboard URL for the highest activating feature
        dashboard_url = f"https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

        return output, dashboard_url, feature_id

    with gr.Blocks(
        theme=gr.themes.Default(
            font=[gr.themes.GoogleFont("Open Sans"), "Arial", "sans-serif"],
            primary_hue="indigo",
            secondary_hue="blue",
            neutral_hue="gray",
        )
    ) as interface:
        gr.Markdown("# Marketing Content Analyzer")
        gr.Markdown(
            "Analyze your marketing content using Gemma Scope's neural features"
        )

        with gr.Row():
            with gr.Column(scale=1):
                input_text = gr.Textbox(
                    lines=5,
                    placeholder="Enter your marketing content here...",
                    label="Marketing Content",
                )
                analyze_btn = gr.Button("Analyze", variant="primary")
                gr.Examples(
                    examples=[
                        "WordLift is an AI-powered SEO tool",
                        "Our advanced machine learning algorithms optimize your content",
                        "Simple and effective website optimization",
                    ],
                    inputs=input_text,
                )

            with gr.Column(scale=2):
                output_text = gr.Markdown(label="Analysis Results")
                with gr.Group():
                    gr.Markdown("## Feature Dashboard")
                    feature_id_text = gr.Text(
                        label="Currently viewing feature", show_label=False
                    )
                    dashboard_frame = gr.HTML(
                        value="Analysis results will appear here",
                        label="Feature Dashboard",
                    )

        def update_dashboard(text):
            output, dashboard_url, feature_id = analyze(text)
            return (
                output,
                f"<iframe src='{dashboard_url}' width='100%' height='600px' frameborder='0' style='border: 1px solid #eee; border-radius: 8px;'></iframe>",
                f"Currently viewing Feature {feature_id} - Most active feature in your content",
            )

        analyze_btn.click(
            fn=update_dashboard,
            inputs=input_text,
            outputs=[output_text, dashboard_frame, feature_id_text],
        )

    return interface


if __name__ == "__main__":
    iface = create_gradio_interface()
    iface.launch()