any0019 commited on
Commit
f5382f0
·
1 Parent(s): 53014ec

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import streamlit as st
3
+ from termcolor import colored
4
+ import torch
5
+ from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
6
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
+ @st.cache
8
+ def load_models():
9
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10
+ bert_mlm_positive = BertForMaskedLM.from_pretrained('text_style_mlm_positive', return_dict=True).to(device).train(True)
11
+ bert_mlm_negative = BertForMaskedLM.from_pretrained('text_style_mlm_negative', return_dict=True).to(device).train(True)
12
+ bert_classifier = BertForSequenceClassification.from_pretrained('text_style_classifier', num_labels=2).to(device).train(True)
13
+ return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
14
+ tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models()
15
+ def highlight_diff(sent, sent_main):
16
+ tokens = tokenizer.tokenize(sent)
17
+ tokens_main = tokenizer.tokenize(sent_main)
18
+
19
+ new_toks = []
20
+ for i, (tok, tok_main) in enumerate(zip(tokens, tokens_main)):
21
+ if tok != tok_main:
22
+ new_toks.append(colored(tok, 'red', attrs=['bold', 'underline']))
23
+ else:
24
+ new_toks.append(tok)
25
+
26
+ return ' '.join(new_toks)
27
+
28
+ def get_classifier_prob(sent):
29
+ bert_classifier.eval()
30
+ with torch.no_grad():
31
+ return bert_classifier(**{k: v.to(device) for k, v in tokenizer(sent, return_tensors='pt').items()}).logits.softmax(dim=-1)[0].cpu().numpy()
32
+ def beam_get_replacements(current_beam, beam_size, epsilon=1e-3, used_positions=[]):
33
+ """
34
+ - for each sentence in :current_beam: - split the sentence into tokens using the INGSOC-approved BERT tokenizer
35
+ - check :beam_size: hypotheses on each step for each sentence
36
+ - save best :beam_size: hypotheses
37
+ :return: generator<list of hypotheses on step>
38
+ """
39
+ # <YOUR CODE HERE>
40
+ bert_mlm_positive.eval()
41
+ bert_mlm_negative.eval()
42
+ new_beam = []
43
+ with torch.no_grad():
44
+ for sentence in current_beam:
45
+ input_ = {k: v.to(device) for k, v in tokenizer(sentence, return_tensors='pt').items()}
46
+ probs_negative = bert_mlm_negative(**input_).logits.softmax(dim=-1)[0]
47
+ probs_positive = bert_mlm_positive(**input_).logits.softmax(dim=-1)[0]
48
+ ids = input_['input_ids'][0].cpu().numpy()
49
+ seq_len = probs_positive.shape[0]
50
+ p_pos = probs_positive[torch.arange(seq_len), ids]
51
+ p_neg = probs_negative[torch.arange(seq_len), ids]
52
+ order_of_replacement = ((p_pos + epsilon) / (p_neg + epsilon)).argsort()
53
+ for pos in order_of_replacement:
54
+ if pos in used_positions or pos==0 or pos==len(ids)-1:
55
+ continue
56
+ used_position = pos
57
+ replacement_ids = (-probs_positive[pos,:]).argsort()[:beam_size]
58
+ for replacement_id in replacement_ids:
59
+ if replacement_id == ids[pos]:
60
+ continue
61
+ new_ids = ids.copy()
62
+ new_ids[pos] = replacement_id
63
+ new_beam.append(new_ids)
64
+ break
65
+ if len(new_beam) > 0:
66
+ new_beam = [tokenizer.decode(ids[1:-1]) for ids in new_beam]
67
+ new_beam = {sent: get_classifier_prob(sent)[1] for sent in new_beam}
68
+ for sent, prob in current_beam.items():
69
+ new_beam[sent] = prob
70
+
71
+ if len(new_beam) > beam_size:
72
+ new_beam = {k: v for k, v in sorted(new_beam.items(), key = lambda el: el[1], reverse=True)[:beam_size]}
73
+ return new_beam, used_position
74
+ else:
75
+ st.write("No more new hypotheses")
76
+ return current_beam, None
77
+ def get_best_hypotheses(sentence, beam_size, max_steps, epsilon=1e-3, pretty_output=False):
78
+ current_beam = {sentence: get_classifier_prob(sentence)[1]}
79
+ used_poss = []
80
+
81
+ st.write(f"step #0:")
82
+ st.write(f"-- 1: (positive probability ~ {round(current_beam[sentence], 5)})\n {sentence}")
83
+
84
+ for step in range(max_steps):
85
+ current_beam, used_pos = beam_get_replacements(current_beam, beam_size, epsilon, used_poss)
86
+
87
+ st.write(f"\nstep #{step+1}:")
88
+ for i, (sent, prob) in enumerate(current_beam.items()):
89
+ st.write(f"-- {i+1}: (positive probability ~ {round(prob, 5)})\n {highlight_diff(sent, sentence) if pretty_output else sent}")
90
+
91
+ if used_pos is None:
92
+ return current_beam, used_poss
93
+ else:
94
+ used_poss.append(used_pos)
95
+
96
+ return current_beam, used_poss
97
+ st.title("Correcting opinions")
98
+ default_value = "write your review here (in lower case - vocab reasons)"
99
+ sentence = st.text_area("Text", default_value, height = 275)
100
+ beam_size = st.sidebar.slider("Beam size", value = 3, min_value = 1, max_value=20, step=1)
101
+ max_steps = st.sidebar.slider("Max steps", value = 3, min_value = 1, max_value=10, step=1)
102
+ prettyfy = st.sidebar.slider("Higlight changes", value = 0, min_value = 0, max_value=1, step=1)
103
+ beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=bool(prettyfy))