ahm14 commited on
Commit
9943898
·
verified ·
1 Parent(s): c536f05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -7
app.py CHANGED
@@ -10,8 +10,15 @@ import io
10
  nltk.download('punkt')
11
 
12
  # Load AI models once to optimize performance
13
- tone_model = pipeline("zero-shot-classification", model="facebook/roberta-large-mnli")
14
- frame_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
 
 
 
 
 
 
 
15
 
16
  # Updated tone categories
17
  tone_categories = [
@@ -33,15 +40,23 @@ def detect_language(text):
33
  except Exception:
34
  return "unknown"
35
 
36
- # Analyze tone using RoBERTa model
37
  def analyze_tone(text):
38
- model_result = tone_model(text, candidate_labels=tone_categories)
39
- return model_result["labels"][:2] # Top 2 tone labels
 
 
 
 
40
 
41
  # Extract frames using BART model
42
  def extract_frames(text):
43
- model_result = frame_model(text, candidate_labels=frame_categories)
44
- return model_result["labels"][:2] # Top 2 frame labels
 
 
 
 
45
 
46
  # Extract hashtags
47
  def extract_hashtags(text):
 
10
  nltk.download('punkt')
11
 
12
  # Load AI models once to optimize performance
13
+ try:
14
+ tone_model = pipeline("zero-shot-classification", model="cross-encoder/nli-deberta-v3-large")
15
+ except OSError:
16
+ st.error("Failed to load tone analysis model. Please check internet connection or model availability.")
17
+
18
+ try:
19
+ frame_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
20
+ except OSError:
21
+ st.error("Failed to load frame classification model. Please check internet connection or model availability.")
22
 
23
  # Updated tone categories
24
  tone_categories = [
 
40
  except Exception:
41
  return "unknown"
42
 
43
+ # Analyze tone using DeBERTa model
44
  def analyze_tone(text):
45
+ try:
46
+ model_result = tone_model(text, candidate_labels=tone_categories)
47
+ return model_result["labels"][:2] # Top 2 tone labels
48
+ except Exception as e:
49
+ st.error(f"Error analyzing tone: {e}")
50
+ return ["Unknown"]
51
 
52
  # Extract frames using BART model
53
  def extract_frames(text):
54
+ try:
55
+ model_result = frame_model(text, candidate_labels=frame_categories)
56
+ return model_result["labels"][:2] # Top 2 frame labels
57
+ except Exception as e:
58
+ st.error(f"Error extracting frames: {e}")
59
+ return ["Unknown"]
60
 
61
  # Extract hashtags
62
  def extract_hashtags(text):