coyotte508 HF staff commited on
Commit
7ba8d64
·
1 Parent(s): f40082f

Create new file

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from transformers import AutoModelForSeq2SeqLM
3
+ from transformers import AutoTokenizer
4
+ from textblob import TextBlob
5
+ from hatesonar import Sonar
6
+ import gradio as gr
7
+ import torch
8
+
9
+ # Load trained model
10
+ model = AutoModelForSeq2SeqLM.from_pretrained("output/reframer")
11
+ tokenizer = AutoTokenizer.from_pretrained("output/reframer")
12
+ reframer = pipeline('summarization', model=model, tokenizer=tokenizer)
13
+
14
+
15
+ CHAR_LENGTH_LOWER_BOUND = 15 # The minimum character length threshold for the input text
16
+ CHAR_LENGTH_HIGHER_BOUND = 150 # The maximum character length threshold for the input text
17
+ SENTIMENT_THRESHOLD = 0.2 # The maximum Textblob sentiment score for the input text
18
+ OFFENSIVENESS_CONFIDENCE_THRESHOLD = 0.8 # The threshold for the confidence score of a text being offensive
19
+
20
+ LENGTH_ERROR = "The input text is too long or too short. Please try again by inputing text with moderate length."
21
+ SENTIMENT_ERROR = "The input text is too positive. Please try again by inputing text with negative sentiment."
22
+ OFFENSIVE_ERROR = "The input text is offensive. Please try again by inputing non-offensive text."
23
+
24
+ CACHE = [] # A list storing the most recent 5 reframing history
25
+ MAX_STORE = 5 # The maximum number of history user would like to store
26
+
27
+ BEST_N = 3 # The number of best decodes user would like to seee
28
+
29
+
30
+ def input_error_message(error_type):
31
+ # type: (str) -> str
32
+ """Generate an input error message from error type."""
33
+ return "[Error]: Invalid Input. " + error_type
34
+
35
+ def update_cache(cache, new_record):
36
+ # type: List[List[str, str, str]] -> List[List[str, str, str]]
37
+ """Update the cache to store the most recent five reframing histories."""
38
+ cache.append(new_record)
39
+ if len(cache) > MAX_STORE:
40
+ cache = cache[1:]
41
+ return cache
42
+
43
+ def reframe(input_text, strategy):
44
+ # type: (str, str) -> str
45
+ """Reframe the input text with a specified strategy.
46
+
47
+ The strategy will be concetenated to the input text and passed to a finetuned BART model.
48
+
49
+ The reframed positive text will be returned.
50
+ """
51
+ text_with_strategy = input_text + "Strategy: ['" + strategy + "']"
52
+
53
+ # Input Control
54
+ # The input text cannot be too short to ensure it has substantial content to be reframed. It also cannot be too long to ensure the text has a focused idea.
55
+ if len(input_text) < CHAR_LENGTH_LOWER_BOUND or len(input_text) > CHAR_LENGTH_HIGHER_BOUND:
56
+ return input_text + input_error_message(LENGTH_ERROR)
57
+ # The input text cannot be too positive to ensure the text can be positively reframed.
58
+ if TextBlob(input_text).sentiment.polarity > 0.2:
59
+ return input_text + input_error_message(SENTIMENT_ERROR)
60
+ # The input text cannot be offensive.
61
+ sonar = Sonar()
62
+ # sonar.ping(input_text) outputs a dictionary and the second score under the key classes is the confidence for the input text being offensive language
63
+ if sonar.ping(input_text)['classes'][1]['confidence'] > OFFENSIVENESS_CONFIDENCE_THRESHOLD:
64
+ return input_text + input_error_message(OFFENSIVE_ERROR)
65
+
66
+ # Reframing
67
+ # reframer pipeline outputs a list containing one dictionary where the value for 'summary_text' is the reframed text output
68
+ reframed_text = reframer(text_with_strategy)[0]['summary_text']
69
+
70
+ # Update cache
71
+ global CACHE
72
+ CACHE = update_cache(CACHE, [input_text, strategy, reframed_text])
73
+
74
+ return reframed_text
75
+
76
+
77
+ def show_reframe_change(input_text, strategy):
78
+ # type: (str, str) -> List[Tuple[str, str]]
79
+ """Compare the addition and deletion of characters in input_text to form reframed_text.
80
+
81
+ The returned output is a list of tuples with two elements, the first element being the character in reframed text and the second element being the action performed with respect to the input text.
82
+ """
83
+ reframed_text = reframe(input_text, strategy)
84
+ from difflib import Differ
85
+ d = Differ()
86
+ return [
87
+ (token[2:], token[0] if token[0] != " " else None)
88
+ for token in d.compare(input_text, reframed_text)
89
+ ]
90
+
91
+ def show_n_best_decodes(input_text, strategy):
92
+ # type: (str, str) -> str
93
+ prompt = [input_text + "Strategy: ['" + strategy + "']"]
94
+ n_best_decodes = model.generate(torch.tensor(tokenizer(prompt, padding=True)['input_ids']),
95
+ do_sample=True,
96
+ num_return_sequences=BEST_N
97
+ )
98
+ best_n_result = ""
99
+ for i in range(len(n_best_decodes)):
100
+ best_n_result += str(i+1) + " " + tokenizer.decode(n_best_decodes[i], skip_special_tokens=True)
101
+ if i < BEST_N - 1:
102
+ best_n_result += "\n"
103
+ return best_n_result
104
+
105
+ def show_history(cache):
106
+ # type: List[List[str, str, str]] -> str
107
+ history = ""
108
+ for i in cache:
109
+ input_text, strategy, reframed_text = i
110
+ history += "Input text: " + input_text + " Strategy: " + strategy + " -> Reframed text: " + reframed_text + "\n"
111
+ return gr.Textbox.update(value=history, visible=True)
112
+
113
+
114
+ # Build Gradio interface
115
+ with gr.Blocks() as demo:
116
+ # Instruction
117
+ gr.Markdown(
118
+ '''
119
+ # Positive Reframing
120
+ **Start inputing negative texts to see how you can see the same event from a positive angle.**
121
+ ''')
122
+
123
+ # Input text to be reframed
124
+ text = gr.Textbox(label="Original Text")
125
+
126
+ # Input strategy for the reframing
127
+ gr.Markdown(
128
+ '''
129
+ **Choose one of the six strategies to carry out reframing:** \n
130
+ **Growth Mindset:** Viewing a challenging event as an opportunity for the author specifically to grow or improve themselves. \n
131
+ **Impermanence:** Saying bad things don’t last forever, will get better soon, and/or that others have experienced similar struggles. \n
132
+ **Neutralizing:** Replacing a negative word with a neutral word. For example, “This was a terrible day” becomes “This was a long day.” \n
133
+ **Optimism:** Focusing on things about the situation itself, in that moment, that are good (not just forecasting a better future). \n
134
+ **Self-affirmation:** Talking about what strengths the author already has, or the values they admire, like love, courage, perseverance, etc. \n
135
+ **Thankfulness:** Expressing thankfulness or gratitude with key words like appreciate, glad that, thankful for, good thing, etc.
136
+ ''')
137
+ strategy = gr.Radio(
138
+ ["thankfulness", "neutralizing", "optimism", "growth", "impermanence", "self_affirmation"], label="Strategy to use?"
139
+ )
140
+
141
+ # Trigger button for reframing
142
+ greet_btn = gr.Button("Reframe")
143
+ best_output = gr.HighlightedText(
144
+ label="Diff",
145
+ combine_adjacent=True,
146
+ ).style(color_map={"+": "green", "-": "red"})
147
+ greet_btn.click(fn=show_reframe_change, inputs=[text, strategy], outputs=best_output)
148
+
149
+ # Trigger button for showing n best reframings
150
+ greet_btn = gr.Button("Show Best {n} Results".format(n=BEST_N))
151
+ n_best_output = gr.Textbox(interactive=False)
152
+ greet_btn.click(fn=show_n_best_decodes, inputs=[text, strategy], outputs=n_best_output)
153
+
154
+ # Default examples of text and strategy pairs for user to have a quick start
155
+ gr.Markdown("## Examples")
156
+ gr.Examples(
157
+ [["I have a lot of homework to do today.", "self_affirmation"], ["So stressed about the midterm next week.", "optimism"], ["I failed my math quiz I am such a loser.", "growth"]],
158
+ [text, strategy], best_output, show_reframe_change, cache_examples=False, run_on_click=False
159
+ )
160
+
161
+ # Link to paper and Github repo
162
+ gr.Markdown(
163
+ '''
164
+ For more details: You can read our [paper](https://arxiv.org/abs/2204.02952) or access our [code](https://github.com/SALT-NLP/positive-frames).
165
+ ''')
166
+
167
+ demo.launch()