File size: 10,252 Bytes
f85532f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574ab91
9186441
f85532f
9186441
 
 
f85532f
 
9186441
f85532f
 
9186441
574ab91
9186441
deaf693
 
9186441
deaf693
9186441
574ab91
9186441
f85532f
574ab91
f85532f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9186441
f85532f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574ab91
f85532f
 
 
 
 
574ab91
f85532f
 
 
 
574ab91
f85532f
 
 
 
 
 
574ab91
9186441
 
 
 
 
 
 
 
574ab91
f85532f
 
 
 
9186441
 
f85532f
9186441
f85532f
 
 
574ab91
f85532f
574ab91
f85532f
 
 
 
574ab91
f85532f
 
574ab91
f85532f
 
 
574ab91
f85532f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9186441
f85532f
 
 
9186441
f85532f
9186441
f85532f
 
 
 
 
 
 
574ab91
9186441
 
 
 
 
 
574ab91
9186441
 
 
574ab91
9186441
 
 
 
 
 
 
 
 
 
574ab91
f85532f
 
 
 
deaf693
 
 
 
 
 
 
 
 
 
 
574ab91
f85532f
 
574ab91
f85532f
 
574ab91
f85532f
 
 
9186441
 
 
574ab91
f85532f
 
 
 
 
 
574ab91
f85532f
9186441
 
 
 
574ab91
f85532f
574ab91
8d9dd3f
04f6ca9
8d9dd3f
 
 
 
574ab91
9186441
f85532f
 
 
8d9dd3f
 
f85532f
8d9dd3f
f85532f
 
 
 
 
 
9186441
04f6ca9
8d9dd3f
f85532f
574ab91
9186441
f85532f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import logging

# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class MarketingFeature:
    """Structure to hold marketing-relevant feature information"""
    feature_id: int
    name: str
    category: str
    description: str
    interpretation_guide: str
    layer: int
    threshold: float = 0.1

# Define marketing-relevant features from Gemma Scope
MARKETING_FEATURES = [
    MarketingFeature(
        feature_id=35,
        name="Technical Term Detector",
        category="technical",
        description="Detects technical and specialized terminology",
        interpretation_guide="High activation indicates strong technical focus",
        layer=20
    ),
    MarketingFeature(
        feature_id=6680,
        name="Compound Technical Terms",
        category="technical",
        description="Identifies complex technical concepts",
        interpretation_guide="Consider simplifying language if activation is too high",
        layer=20
    ),
    MarketingFeature(
        feature_id=2,
        name="SEO Keyword Detector",
        category="seo",
        description="Identifies potential SEO keywords",
        interpretation_guide="High activation suggests strong SEO potential",
        layer=20
    ),
]

class MarketingAnalyzer:
    """Main class for analyzing marketing content using Gemma Scope"""

    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Store model size as instance variable
        self.model_size = "2b"
        self._initialize_model()
        self._load_saes()

    def _initialize_model(self):
        """Initialize Gemma model and tokenizer"""
        try:
            model_name = f"google/gemma-{self.model_size}"

            # Initialize model and tokenizer with token from environment
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map='auto'
            )
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)

            self.model.eval()
            logger.info(f"Initialized model: {model_name}")

        except Exception as e:
            logger.error(f"Error initializing model: {str(e)}")
            raise

    def _load_saes(self):
        """Load relevant SAEs from Gemma Scope"""
        self.saes = {}
        for feature in MARKETING_FEATURES:
            try:
                # Load SAE parameters for each feature
                path = hf_hub_download(
                    repo_id=f"google/gemma-scope-{self.model_size}-pt-res",
                    filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz"
                )
                params = np.load(path)
                self.saes[feature.feature_id] = {
                    'params': {k: torch.from_numpy(v).to(self.device) for k, v in params.items()},
                    'feature': feature
                }
                logger.info(f"Loaded SAE for feature {feature.feature_id}")
            except Exception as e:
                logger.error(f"Error loading SAE for feature {feature.feature_id}: {str(e)}")
                continue

    def analyze_content(self, text: str) -> Dict:
        """Analyze marketing content using loaded SAEs"""
        results = {
            'text': text,
            'features': {},
            'categories': {},
            'recommendations': []
        }

        try:
            # Get model activations
            inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
            with torch.no_grad():
                outputs = self.model(**inputs, output_hidden_states=True)

            # Analyze each feature
            for feature_id, sae_data in self.saes.items():
                feature = sae_data['feature']
                layer_output = outputs.hidden_states[feature.layer]

                # Apply SAE
                activations = self._apply_sae(
                    layer_output,
                    sae_data['params'],
                    feature.threshold
                )

                # Skip BOS token and handle empty activations
                activations = activations[:, 1:]  # Skip BOS token
                if activations.numel() > 0:
                    mean_activation = float(activations.mean())
                    max_activation = float(activations.max())
                else:
                    mean_activation = 0.0
                    max_activation = 0.0

                # Record results
                feature_result = {
                    'name': feature.name,
                    'category': feature.category,
                    'activation_score': mean_activation,
                    'max_activation': max_activation,
                    'interpretation': self._interpret_activation(
                        mean_activation,
                        feature
                    )
                }

                results['features'][feature_id] = feature_result

                # Aggregate by category
                if feature.category not in results['categories']:
                    results['categories'][feature.category] = []
                results['categories'][feature.category].append(feature_result)

            # Generate recommendations
            results['recommendations'] = self._generate_recommendations(results)

        except Exception as e:
            logger.error(f"Error analyzing content: {str(e)}")
            raise

        return results

    def _apply_sae(
        self,
        activations: torch.Tensor,
        sae_params: Dict[str, torch.Tensor],
        threshold: float
    ) -> torch.Tensor:
        """Apply SAE to get feature activations"""
        pre_acts = activations @ sae_params['W_enc'] + sae_params['b_enc']
        mask = pre_acts > sae_params['threshold']
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def _interpret_activation(
        self,
        activation: float,
        feature: MarketingFeature
    ) -> str:
        """Interpret activation patterns for a feature"""
        if activation > 0.8:
            return f"Very strong presence of {feature.name.lower()}"
        elif activation > 0.5:
            return f"Moderate presence of {feature.name.lower()}"
        else:
            return f"Limited presence of {feature.name.lower()}"

    def _generate_recommendations(self, results: Dict) -> List[str]:
        """Generate content recommendations based on analysis"""
        recommendations = []

        try:
            # Get technical features
            tech_features = [
                f for f in results['features'].values()
                if f['category'] == 'technical'
            ]

            # Calculate average technical score if we have features
            if tech_features:
                tech_score = np.mean([f['activation_score'] for f in tech_features])

                if tech_score > 0.8:
                    recommendations.append(
                        "Consider simplifying technical language for broader audience"
                    )
                elif tech_score < 0.3:
                    recommendations.append(
                        "Could benefit from more specific technical details"
                    )
        except Exception as e:
            logger.error(f"Error generating recommendations: {str(e)}")

        return recommendations

def create_gradio_interface():
    """Create Gradio interface for marketing analysis"""
    try:
        analyzer = MarketingAnalyzer()
    except Exception as e:
        logger.error(f"Failed to initialize analyzer: {str(e)}")
        return gr.Interface(
            fn=lambda x: "Error: Failed to initialize model. Please check authentication.",
            inputs=gr.Textbox(),
            outputs=gr.Textbox(),
            title="Marketing Content Analyzer (Error)",
            description="Failed to initialize. Please check if HF_TOKEN is properly set."
        )

    def analyze(text):
        results = analyzer.analyze_content(text)

        # Format results for display
        output = "Content Analysis Results\n\n"

        # Overall category scores
        output += "Category Scores:\n"
        for category, features in results['categories'].items():
            if features:  # Check if we have features for this category
                avg_score = np.mean([f['activation_score'] for f in features])
                output += f"{category.title()}: {avg_score:.2f}\n"

        # Feature details
        output += "\nFeature Details:\n"
        for feature_id, feature in results['features'].items():
            output += f"\n{feature['name']}:\n"
            output += f"Score: {feature['activation_score']:.2f}\n"
            output += f"Interpretation: {feature['interpretation']}\n"

        # Recommendations
        if results['recommendations']:
            output += "\nRecommendations:\n"
            for rec in results['recommendations']:
                output += f"- {rec}\n"

        return output

    # Create interface with custom theming
    custom_theme = gr.themes.Soft(
        primary_hue="indigo",
        secondary_hue="blue",
        neutral_hue="gray"
    )

    interface = gr.Interface(
        fn=analyze,
        inputs=gr.Textbox(
            lines=5,
            placeholder="Enter your marketing content here...",
            label="Marketing Content"
        ),
        outputs=gr.Textbox(label="Analysis Results"),
        title="Marketing Content Analyzer",
        description="Analyze your marketing content using Gemma Scope's neural features",
        examples=[
            ["WordLift is an AI-powered SEO tool"],
            ["Our advanced machine learning algorithms optimize your content"],
            ["Simple and effective website optimization"]
        ],
        theme=custom_theme
    )
    )

    return interface

if __name__ == "__main__":
    iface = create_gradio_interface()
    iface.launch()