Spaces:
Runtime error
Runtime error
Commit
·
3ea613d
1
Parent(s):
28d55fa
Testing sentiment func
Browse files
app.py
CHANGED
@@ -24,11 +24,21 @@ def preprocess(text):
|
|
24 |
new_text.append(t)
|
25 |
return " ".join(new_text)
|
26 |
|
|
|
27 |
MODEL = "cardiffnlp/twitter-roberta-base-sentiment"
|
28 |
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
29 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL)
|
30 |
model.save_pretrained(MODEL)
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
# https://stackoverflow.com/questions/492519/timeout-on-a-function-call
|
33 |
def timeout(max_timeout):
|
34 |
"""Timeout decorator, parameter in seconds."""
|
@@ -44,8 +54,7 @@ def timeout(max_timeout):
|
|
44 |
return func_wrapper
|
45 |
return timeout_decorator
|
46 |
|
47 |
-
#
|
48 |
-
|
49 |
@timeout(120.0)
|
50 |
def get_tweets(username, limit=500, save_name=None):
|
51 |
#nest_asyncio.apply() # Helps avoid RuntimeError: This event loop is already running
|
@@ -80,5 +89,7 @@ with st.form("my_form"):
|
|
80 |
st.write("Fetching user", user, "n_tweets", n_tweets)
|
81 |
tweets = get_tweets(user, limit=n_tweets)
|
82 |
st.write(st.dataframe(tweets.head()))
|
|
|
|
|
83 |
|
84 |
st.write("Outside the form")
|
|
|
24 |
new_text.append(t)
|
25 |
return " ".join(new_text)
|
26 |
|
27 |
+
# Loading pretrained model
|
28 |
MODEL = "cardiffnlp/twitter-roberta-base-sentiment"
|
29 |
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
30 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL)
|
31 |
model.save_pretrained(MODEL)
|
32 |
|
33 |
+
# Func to get a score using the above model
|
34 |
+
def combined_score(text):
|
35 |
+
text = preprocess(text)
|
36 |
+
encoded_input = tokenizer(text, return_tensors='pt')
|
37 |
+
output = model(**encoded_input)
|
38 |
+
scores = output[0][0].detach().numpy()
|
39 |
+
scores = softmax(scores)
|
40 |
+
return -scores[0] + scores[2] # scores = [negative, neutral, positive]
|
41 |
+
|
42 |
# https://stackoverflow.com/questions/492519/timeout-on-a-function-call
|
43 |
def timeout(max_timeout):
|
44 |
"""Timeout decorator, parameter in seconds."""
|
|
|
54 |
return func_wrapper
|
55 |
return timeout_decorator
|
56 |
|
57 |
+
# Getting tweets from a user
|
|
|
58 |
@timeout(120.0)
|
59 |
def get_tweets(username, limit=500, save_name=None):
|
60 |
#nest_asyncio.apply() # Helps avoid RuntimeError: This event loop is already running
|
|
|
89 |
st.write("Fetching user", user, "n_tweets", n_tweets)
|
90 |
tweets = get_tweets(user, limit=n_tweets)
|
91 |
st.write(st.dataframe(tweets.head()))
|
92 |
+
tweets['sentiment'] = tweets['tweet'].map(lambda s: combined_score(s))
|
93 |
+
st.write(st.dataframe(tweets[['tweet', 'sentiment']].head()))
|
94 |
|
95 |
st.write("Outside the form")
|