johnowhitaker commited on
Commit
3ea613d
·
1 Parent(s): 28d55fa

Testing sentiment func

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -24,11 +24,21 @@ def preprocess(text):
24
  new_text.append(t)
25
  return " ".join(new_text)
26
 
 
27
  MODEL = "cardiffnlp/twitter-roberta-base-sentiment"
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
29
  model = AutoModelForSequenceClassification.from_pretrained(MODEL)
30
  model.save_pretrained(MODEL)
31
 
 
 
 
 
 
 
 
 
 
32
  # https://stackoverflow.com/questions/492519/timeout-on-a-function-call
33
  def timeout(max_timeout):
34
  """Timeout decorator, parameter in seconds."""
@@ -44,8 +54,7 @@ def timeout(max_timeout):
44
  return func_wrapper
45
  return timeout_decorator
46
 
47
- # nest_asyncio.apply()
48
-
49
  @timeout(120.0)
50
  def get_tweets(username, limit=500, save_name=None):
51
  #nest_asyncio.apply() # Helps avoid RuntimeError: This event loop is already running
@@ -80,5 +89,7 @@ with st.form("my_form"):
80
  st.write("Fetching user", user, "n_tweets", n_tweets)
81
  tweets = get_tweets(user, limit=n_tweets)
82
  st.write(st.dataframe(tweets.head()))
 
 
83
 
84
  st.write("Outside the form")
 
24
  new_text.append(t)
25
  return " ".join(new_text)
26
 
27
+ # Loading pretrained model
28
  MODEL = "cardiffnlp/twitter-roberta-base-sentiment"
29
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
30
  model = AutoModelForSequenceClassification.from_pretrained(MODEL)
31
  model.save_pretrained(MODEL)
32
 
33
+ # Func to get a score using the above model
34
+ def combined_score(text):
35
+ text = preprocess(text)
36
+ encoded_input = tokenizer(text, return_tensors='pt')
37
+ output = model(**encoded_input)
38
+ scores = output[0][0].detach().numpy()
39
+ scores = softmax(scores)
40
+ return -scores[0] + scores[2] # scores = [negative, neutral, positive]
41
+
42
  # https://stackoverflow.com/questions/492519/timeout-on-a-function-call
43
  def timeout(max_timeout):
44
  """Timeout decorator, parameter in seconds."""
 
54
  return func_wrapper
55
  return timeout_decorator
56
 
57
+ # Getting tweets from a user
 
58
  @timeout(120.0)
59
  def get_tweets(username, limit=500, save_name=None):
60
  #nest_asyncio.apply() # Helps avoid RuntimeError: This event loop is already running
 
89
  st.write("Fetching user", user, "n_tweets", n_tweets)
90
  tweets = get_tweets(user, limit=n_tweets)
91
  st.write(st.dataframe(tweets.head()))
92
+ tweets['sentiment'] = tweets['tweet'].map(lambda s: combined_score(s))
93
+ st.write(st.dataframe(tweets[['tweet', 'sentiment']].head()))
94
 
95
  st.write("Outside the form")