PromptMeister commited on
Commit
e68ca61
·
verified ·
1 Parent(s): b1935a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -158
app.py CHANGED
@@ -2,20 +2,60 @@ import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
6
- from transformers import BertTokenizerFast
7
  import matplotlib.pyplot as plt
8
  import json
 
 
 
9
 
10
- # Initialize models
11
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
12
- ner_pipeline = pipeline("ner", model="dslim/bert-base-NER")
13
- pos_model = AutoModelForTokenClassification.from_pretrained("vblagoje/bert-english-uncased-finetuned-pos")
14
- pos_tokenizer = BertTokenizerFast.from_pretrained("vblagoje/bert-english-uncased-finetuned-pos")
15
- pos_pipeline = pipeline("token-classification", model=pos_model, tokenizer=pos_tokenizer)
16
 
17
- # Intent classification - using zero-shot classification
18
- intent_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def get_token_colors(token_type):
21
  colors = {
@@ -41,7 +81,10 @@ def simulate_historical_data(token):
41
  values = [45, 50, 60, 70, 75, 80]
42
  else:
43
  # Standard pattern for common words
44
- base = 50 + (hash(token) % 30)
 
 
 
45
  noise = np.random.normal(0, 5, 6)
46
  values = [max(5, min(95, base + i*5 + n)) for i, n in enumerate(noise)]
47
 
@@ -60,7 +103,7 @@ def generate_origin_data(token):
60
  ]
61
 
62
  # Deterministic selection based on the token
63
- index = hash(token) % len(origins)
64
  origin = origins[index]
65
 
66
  note = f"First appeared in {origin['era']} texts derived from {origin['language']}."
@@ -105,155 +148,219 @@ def analyze_token_types(tokens):
105
  return processed_tokens
106
 
107
  def plot_historical_data(historical_data):
108
- """Create a plot of historical usage data"""
109
- eras = [item[0] for item in historical_data]
110
- values = [item[1] for item in historical_data]
111
-
112
- plt.figure(figsize=(8, 3))
113
- plt.bar(eras, values, color='skyblue')
114
- plt.title('Historical Usage')
115
- plt.xlabel('Era')
116
- plt.ylabel('Usage Level')
117
- plt.ylim(0, 100)
118
- plt.xticks(rotation=45)
119
- plt.tight_layout()
120
-
121
- return plt
 
 
 
 
 
 
 
 
 
122
 
123
- def analyze_keyword(keyword):
124
- if not keyword.strip():
125
- return None, None, None, None, None
126
-
127
- # Basic tokenization
128
- words = keyword.strip().lower().split()
129
-
130
- # Get token types
131
- token_analysis = analyze_token_types(words)
132
-
133
- # Get NER tags
134
- ner_results = ner_pipeline(keyword)
135
-
136
- # Get POS tags
137
- pos_results = pos_pipeline(keyword)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- # Process and organize results
140
- full_token_analysis = []
141
- for token in token_analysis:
142
- # Find POS tag for this token
143
- pos_tag = "NOUN" # Default
144
- for pos_result in pos_results:
145
- if pos_result["word"].lower() == token["text"]:
146
- pos_tag = pos_result["entity"]
147
- break
148
-
149
- # Find entity type if any
150
- entity_type = None
151
- for ner_result in ner_results:
152
- if ner_result["word"].lower() == token["text"]:
153
- entity_type = ner_result["entity"]
154
- break
155
-
156
- # Generate historical data
157
- historical_data = simulate_historical_data(token["text"])
158
-
159
- # Generate origin data
160
- origin = generate_origin_data(token["text"])
161
-
162
- # Calculate importance (simplified algorithm)
163
- importance = 60 + (len(token["text"]) * 2)
164
- importance = min(95, importance)
165
-
166
- # Generate related terms (simplified)
167
- related_terms = [f"{token['text']}-related-1", f"{token['text']}-related-2"]
168
-
169
- full_token_analysis.append({
170
- "token": token["text"],
171
- "type": token["type"],
172
- "posTag": pos_tag,
173
- "entityType": entity_type,
174
- "importance": importance,
175
- "historicalData": historical_data,
176
- "origin": origin,
177
- "relatedTerms": related_terms
178
- })
179
 
180
- # Intent analysis
181
- intent_result = intent_classifier(
182
- keyword,
183
- candidate_labels=["informational", "navigational", "transactional"]
184
- )
 
 
 
 
185
 
186
- intent_analysis = {
187
- "type": intent_result["labels"][0].capitalize(),
188
- "strength": round(intent_result["scores"][0] * 100),
189
- "mutations": [
190
- f"{intent_result['labels'][0]}-variation-1",
191
- f"{intent_result['labels'][0]}-variation-2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  ]
193
- }
194
-
195
- # Evolution potential (simplified calculation)
196
- evolution_potential = min(95, 65 + (len(keyword) % 30))
197
-
198
- # Predicted trends (simplified)
199
- trends = [
200
- "Voice search adaptation",
201
- "Visual search integration"
202
- ]
203
-
204
- # Evolution chart data (simulated)
205
- evolution_data = [
206
- {"month": "Jan", "searchVolume": 1000, "competitionScore": 45, "intentClarity": 80},
207
- {"month": "Feb", "searchVolume": 1200, "competitionScore": 48, "intentClarity": 82},
208
- {"month": "Mar", "searchVolume": 1100, "competitionScore": 52, "intentClarity": 85},
209
- {"month": "Apr", "searchVolume": 1400, "competitionScore": 55, "intentClarity": 88},
210
- {"month": "May", "searchVolume": 1800, "competitionScore": 58, "intentClarity": 90},
211
- {"month": "Jun", "searchVolume": 2200, "competitionScore": 60, "intentClarity": 92}
212
- ]
213
-
214
- # Create plots
215
- evolution_chart = create_evolution_chart(evolution_data)
216
-
217
- # Generate HTML for token visualization
218
- token_viz_html = generate_token_visualization_html(token_analysis, full_token_analysis)
219
-
220
- # Generate HTML for full analysis
221
- analysis_html = generate_full_analysis_html(
222
- keyword,
223
- full_token_analysis,
224
- intent_analysis,
225
- evolution_potential,
226
- trends
227
- )
228
-
229
- # Generate JSON results
230
- json_results = {
231
- "keyword": keyword,
232
- "tokenAnalysis": full_token_analysis,
233
- "intentAnalysis": intent_analysis,
234
- "evolutionPotential": evolution_potential,
235
- "predictedTrends": trends
236
- }
237
-
238
- return token_viz_html, analysis_html, json_results, evolution_chart, full_token_analysis
239
-
240
- def create_evolution_chart(data):
241
- """Create an evolution chart from data"""
242
- df = pd.DataFrame(data)
243
-
244
- plt.figure(figsize=(10, 5))
245
- plt.plot(df['month'], df['searchVolume'], marker='o', label='Search Volume')
246
- plt.plot(df['month'], df['competitionScore']*20, marker='s', label='Competition Score')
247
- plt.plot(df['month'], df['intentClarity']*20, marker='^', label='Intent Clarity')
248
-
249
- plt.title('Predicted Evolution')
250
- plt.xlabel('Month')
251
- plt.ylabel('Value')
252
- plt.legend()
253
- plt.grid(True, linestyle='--', alpha=0.7)
254
- plt.tight_layout()
255
 
256
- return plt
 
 
 
257
 
258
  def generate_token_visualization_html(token_analysis, full_analysis):
259
  """Generate HTML for token visualization"""
@@ -469,6 +576,10 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
469
  with gr.Row():
470
  with gr.Column():
471
  input_text = gr.Textbox(label="Enter keyword to analyze", placeholder="e.g. artificial intelligence")
 
 
 
 
472
  analyze_btn = gr.Button("Analyze DNA", variant="primary")
473
 
474
  with gr.Row():
@@ -492,9 +603,15 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
492
 
493
  # Set up event handlers
494
  analyze_btn.click(
 
 
 
495
  analyze_keyword,
496
  inputs=[input_text],
497
- outputs=[token_viz_html, analysis_html, json_output, evolution_chart, None]
 
 
 
498
  )
499
 
500
  # Example buttons
@@ -503,11 +620,18 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
503
  lambda btn_text: btn_text,
504
  inputs=[btn],
505
  outputs=[input_text]
 
 
 
506
  ).then(
507
  analyze_keyword,
508
  inputs=[input_text],
509
- outputs=[token_viz_html, analysis_html, json_output, evolution_chart, None]
 
 
 
510
  )
511
 
512
  # Launch the app
513
- demo.launch()
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import torch
 
 
5
  import matplotlib.pyplot as plt
6
  import json
7
+ import time
8
+ import os
9
+ from functools import partial
10
 
11
+ # Global variables to store models
12
+ tokenizer = None
13
+ ner_pipeline = None
14
+ pos_pipeline = None
15
+ intent_classifier = None
16
+ models_loaded = False
17
 
18
+ def load_models(progress=gr.Progress()):
19
+ """Lazy-load models only when needed"""
20
+ global tokenizer, ner_pipeline, pos_pipeline, intent_classifier, models_loaded
21
+
22
+ if models_loaded:
23
+ return True
24
+
25
+ try:
26
+ progress(0.1, desc="Loading models...")
27
+
28
+ # Use smaller models and load them sequentially to reduce memory pressure
29
+ from transformers import AutoTokenizer, pipeline
30
+
31
+ progress(0.2, desc="Loading tokenizer...")
32
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
33
+
34
+ progress(0.4, desc="Loading NER model...")
35
+ ner_pipeline = pipeline("ner", model="dslim/bert-base-NER")
36
+
37
+ progress(0.6, desc="Loading POS model...")
38
+ # Use smaller POS model
39
+ from transformers import AutoModelForTokenClassification, BertTokenizerFast
40
+ pos_model = AutoModelForTokenClassification.from_pretrained("vblagoje/bert-english-uncased-finetuned-pos")
41
+ pos_tokenizer = BertTokenizerFast.from_pretrained("vblagoje/bert-english-uncased-finetuned-pos")
42
+ pos_pipeline = pipeline("token-classification", model=pos_model, tokenizer=pos_tokenizer)
43
+
44
+ progress(0.8, desc="Loading intent classifier...")
45
+ # Use a smaller model for zero-shot classification
46
+ intent_classifier = pipeline(
47
+ "zero-shot-classification",
48
+ model="typeform/distilbert-base-uncased-mnli", # Smaller than BART
49
+ device=0 if torch.cuda.is_available() else -1 # Use GPU if available
50
+ )
51
+
52
+ progress(1.0, desc="Models loaded successfully!")
53
+ models_loaded = True
54
+ return True
55
+
56
+ except Exception as e:
57
+ print(f"Error loading models: {str(e)}")
58
+ return f"Error: {str(e)}"
59
 
60
  def get_token_colors(token_type):
61
  colors = {
 
81
  values = [45, 50, 60, 70, 75, 80]
82
  else:
83
  # Standard pattern for common words
84
+ # Use token hash value modulo instead of hash() directly to avoid different results across runs
85
+ base = 50 + (sum(ord(c) for c in token) % 30)
86
+ # Use a fixed seed for reproducibility
87
+ np.random.seed(sum(ord(c) for c in token))
88
  noise = np.random.normal(0, 5, 6)
89
  values = [max(5, min(95, base + i*5 + n)) for i, n in enumerate(noise)]
90
 
 
103
  ]
104
 
105
  # Deterministic selection based on the token
106
+ index = sum(ord(c) for c in token) % len(origins)
107
  origin = origins[index]
108
 
109
  note = f"First appeared in {origin['era']} texts derived from {origin['language']}."
 
148
  return processed_tokens
149
 
150
  def plot_historical_data(historical_data):
151
+ """Create a plot of historical usage data, with error handling"""
152
+ try:
153
+ eras = [item[0] for item in historical_data]
154
+ values = [item[1] for item in historical_data]
155
+
156
+ plt.figure(figsize=(8, 3))
157
+ plt.bar(eras, values, color='skyblue')
158
+ plt.title('Historical Usage')
159
+ plt.xlabel('Era')
160
+ plt.ylabel('Usage Level')
161
+ plt.ylim(0, 100)
162
+ plt.xticks(rotation=45)
163
+ plt.tight_layout()
164
+
165
+ return plt
166
+ except Exception as e:
167
+ print(f"Error in plot_historical_data: {str(e)}")
168
+ # Return a simple error plot
169
+ plt.figure(figsize=(8, 3))
170
+ plt.text(0.5, 0.5, f"Error creating plot: {str(e)}",
171
+ horizontalalignment='center', verticalalignment='center')
172
+ plt.axis('off')
173
+ return plt
174
 
175
+ def create_evolution_chart(data):
176
+ """Create an evolution chart from data, with error handling"""
177
+ try:
178
+ df = pd.DataFrame(data)
179
+
180
+ plt.figure(figsize=(10, 5))
181
+ plt.plot(df['month'], df['searchVolume'], marker='o', label='Search Volume')
182
+ plt.plot(df['month'], df['competitionScore']*20, marker='s', label='Competition Score')
183
+ plt.plot(df['month'], df['intentClarity']*20, marker='^', label='Intent Clarity')
184
+
185
+ plt.title('Predicted Evolution')
186
+ plt.xlabel('Month')
187
+ plt.ylabel('Value')
188
+ plt.legend()
189
+ plt.grid(True, linestyle='--', alpha=0.7)
190
+ plt.tight_layout()
191
+
192
+ return plt
193
+ except Exception as e:
194
+ print(f"Error in create_evolution_chart: {str(e)}")
195
+ # Return a simple error plot
196
+ plt.figure(figsize=(10, 5))
197
+ plt.text(0.5, 0.5, f"Error creating chart: {str(e)}",
198
+ horizontalalignment='center', verticalalignment='center')
199
+ plt.axis('off')
200
+ return plt
201
+
202
+ def analyze_keyword(keyword, progress=gr.Progress()):
203
+ """Main function to analyze a keyword"""
204
+ if not keyword or not keyword.strip():
205
+ return (
206
+ "<div>Please enter a keyword to analyze</div>",
207
+ "<div>Please enter a keyword to analyze</div>",
208
+ None,
209
+ None
210
+ )
211
 
212
+ progress(0.1, desc="Starting analysis...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
+ # Load models if not already loaded
215
+ model_status = load_models(progress)
216
+ if isinstance(model_status, str) and model_status.startswith("Error"):
217
+ return (
218
+ f"<div style='color:red;'>{model_status}</div>",
219
+ f"<div style='color:red;'>{model_status}</div>",
220
+ None,
221
+ None
222
+ )
223
 
224
+ try:
225
+ # Basic tokenization - just split on spaces for simplicity
226
+ words = keyword.strip().lower().split()
227
+ progress(0.2, desc="Analyzing tokens...")
228
+
229
+ # Get token types
230
+ token_analysis = analyze_token_types(words)
231
+
232
+ progress(0.3, desc="Running NER...")
233
+ # Get NER tags - handle potential errors
234
+ try:
235
+ ner_results = ner_pipeline(keyword)
236
+ except Exception as e:
237
+ print(f"NER error: {str(e)}")
238
+ ner_results = []
239
+
240
+ progress(0.4, desc="Running POS tagging...")
241
+ # Get POS tags - handle potential errors
242
+ try:
243
+ pos_results = pos_pipeline(keyword)
244
+ except Exception as e:
245
+ print(f"POS error: {str(e)}")
246
+ pos_results = []
247
+
248
+ # Process and organize results
249
+ full_token_analysis = []
250
+ for token in token_analysis:
251
+ # Find POS tag for this token
252
+ pos_tag = "NOUN" # Default
253
+ for pos_result in pos_results:
254
+ if pos_result["word"].lower() == token["text"]:
255
+ pos_tag = pos_result["entity"]
256
+ break
257
+
258
+ # Find entity type if any
259
+ entity_type = None
260
+ for ner_result in ner_results:
261
+ if ner_result["word"].lower() == token["text"]:
262
+ entity_type = ner_result["entity"]
263
+ break
264
+
265
+ # Generate historical data
266
+ historical_data = simulate_historical_data(token["text"])
267
+
268
+ # Generate origin data
269
+ origin = generate_origin_data(token["text"])
270
+
271
+ # Calculate importance (simplified algorithm)
272
+ importance = 60 + (len(token["text"]) * 2)
273
+ importance = min(95, importance)
274
+
275
+ # Generate related terms (simplified)
276
+ related_terms = [f"{token['text']}-related-1", f"{token['text']}-related-2"]
277
+
278
+ full_token_analysis.append({
279
+ "token": token["text"],
280
+ "type": token["type"],
281
+ "posTag": pos_tag,
282
+ "entityType": entity_type,
283
+ "importance": importance,
284
+ "historicalData": historical_data,
285
+ "origin": origin,
286
+ "relatedTerms": related_terms
287
+ })
288
+
289
+ progress(0.6, desc="Analyzing intent...")
290
+ # Intent analysis - handle potential errors
291
+ try:
292
+ intent_result = intent_classifier(
293
+ keyword,
294
+ candidate_labels=["informational", "navigational", "transactional"]
295
+ )
296
+
297
+ intent_analysis = {
298
+ "type": intent_result["labels"][0].capitalize(),
299
+ "strength": round(intent_result["scores"][0] * 100),
300
+ "mutations": [
301
+ f"{intent_result['labels'][0]}-variation-1",
302
+ f"{intent_result['labels'][0]}-variation-2"
303
+ ]
304
+ }
305
+ except Exception as e:
306
+ print(f"Intent classification error: {str(e)}")
307
+ intent_analysis = {
308
+ "type": "Informational", # Default fallback
309
+ "strength": 70,
310
+ "mutations": ["fallback-variation-1", "fallback-variation-2"]
311
+ }
312
+
313
+ # Evolution potential (simplified calculation)
314
+ evolution_potential = min(95, 65 + (len(keyword) % 30))
315
+
316
+ # Predicted trends (simplified)
317
+ trends = [
318
+ "Voice search adaptation",
319
+ "Visual search integration"
320
  ]
321
+
322
+ # Evolution chart data (simulated)
323
+ evolution_data = [
324
+ {"month": "Jan", "searchVolume": 1000, "competitionScore": 45, "intentClarity": 80},
325
+ {"month": "Feb", "searchVolume": 1200, "competitionScore": 48, "intentClarity": 82},
326
+ {"month": "Mar", "searchVolume": 1100, "competitionScore": 52, "intentClarity": 85},
327
+ {"month": "Apr", "searchVolume": 1400, "competitionScore": 55, "intentClarity": 88},
328
+ {"month": "May", "searchVolume": 1800, "competitionScore": 58, "intentClarity": 90},
329
+ {"month": "Jun", "searchVolume": 2200, "competitionScore": 60, "intentClarity": 92}
330
+ ]
331
+
332
+ progress(0.8, desc="Creating visualizations...")
333
+ # Create plots
334
+ evolution_chart = create_evolution_chart(evolution_data)
335
+
336
+ # Generate HTML for token visualization
337
+ token_viz_html = generate_token_visualization_html(token_analysis, full_token_analysis)
338
+
339
+ # Generate HTML for full analysis
340
+ analysis_html = generate_full_analysis_html(
341
+ keyword,
342
+ full_token_analysis,
343
+ intent_analysis,
344
+ evolution_potential,
345
+ trends
346
+ )
347
+
348
+ # Generate JSON results
349
+ json_results = {
350
+ "keyword": keyword,
351
+ "tokenAnalysis": full_token_analysis,
352
+ "intentAnalysis": intent_analysis,
353
+ "evolutionPotential": evolution_potential,
354
+ "predictedTrends": trends
355
+ }
356
+
357
+ progress(1.0, desc="Analysis complete!")
358
+ return token_viz_html, analysis_html, json_results, evolution_chart
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
+ except Exception as e:
361
+ error_message = f"<div style='color:red;padding:20px;'>Error analyzing keyword: {str(e)}</div>"
362
+ print(f"Error in analyze_keyword: {str(e)}")
363
+ return error_message, error_message, None, None
364
 
365
  def generate_token_visualization_html(token_analysis, full_analysis):
366
  """Generate HTML for token visualization"""
 
576
  with gr.Row():
577
  with gr.Column():
578
  input_text = gr.Textbox(label="Enter keyword to analyze", placeholder="e.g. artificial intelligence")
579
+
580
+ # Add loading indicator
581
+ status_html = gr.HTML('<div style="color:gray;text-align:center;">Enter a keyword and click "Analyze DNA"</div>')
582
+
583
  analyze_btn = gr.Button("Analyze DNA", variant="primary")
584
 
585
  with gr.Row():
 
603
 
604
  # Set up event handlers
605
  analyze_btn.click(
606
+ lambda: '<div style="color:blue;text-align:center;">Loading models and analyzing... This may take a moment.</div>',
607
+ outputs=status_html
608
+ ).then(
609
  analyze_keyword,
610
  inputs=[input_text],
611
+ outputs=[token_viz_html, analysis_html, json_output, evolution_chart]
612
+ ).then(
613
+ lambda: '<div style="color:green;text-align:center;">Analysis complete!</div>',
614
+ outputs=status_html
615
  )
616
 
617
  # Example buttons
 
620
  lambda btn_text: btn_text,
621
  inputs=[btn],
622
  outputs=[input_text]
623
+ ).then(
624
+ lambda: '<div style="color:blue;text-align:center;">Loading models and analyzing... This may take a moment.</div>',
625
+ outputs=status_html
626
  ).then(
627
  analyze_keyword,
628
  inputs=[input_text],
629
+ outputs=[token_viz_html, analysis_html, json_output, evolution_chart]
630
+ ).then(
631
+ lambda: '<div style="color:green;text-align:center;">Analysis complete!</div>',
632
+ outputs=status_html
633
  )
634
 
635
  # Launch the app
636
+ if __name__ == "__main__":
637
+ demo.launch()