any0019 commited on
Commit
b87506a
·
1 Parent(s): 2f3b323

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +118 -0
api.py CHANGED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import streamlit as st
3
+ from termcolor import colored
4
+ import torch
5
+ from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
6
+
7
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+
9
+ @st.cache
10
+ def load_models():
11
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
12
+ bert_mlm_positive = BertForMaskedLM.from_pretrained('text_style_mlm_positive', return_dict=True).to(device).train(True)
13
+ bert_mlm_negative = BertForMaskedLM.from_pretrained('text_style_mlm_negative', return_dict=True).to(device).train(True)
14
+ bert_classifier = BertForSequenceClassification.from_pretrained('text_style_classifier', num_labels=2).to(device).train(True)
15
+ return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
16
+
17
+ tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models()
18
+
19
+
20
+ def highlight_diff(sent, sent_main):
21
+ tokens = tokenizer.tokenize(sent)
22
+ tokens_main = tokenizer.tokenize(sent_main)
23
+
24
+ new_toks = []
25
+ for i, (tok, tok_main) in enumerate(zip(tokens, tokens_main)):
26
+ if tok != tok_main:
27
+ new_toks.append(colored(tok, 'red', attrs=['bold', 'underline']))
28
+ else:
29
+ new_toks.append(tok)
30
+
31
+ return ' '.join(new_toks)
32
+
33
+
34
+ def get_classifier_prob(sent):
35
+ bert_classifier.eval()
36
+ with torch.no_grad():
37
+ return bert_classifier(**{k: v.to(device) for k, v in tokenizer(sent, return_tensors='pt').items()}).logits.softmax(dim=-1)[0].cpu().numpy()
38
+
39
+
40
+ def beam_get_replacements(current_beam, beam_size, epsilon=1e-3, used_positions=[]):
41
+ """
42
+ - for each sentence in :current_beam: - split the sentence into tokens using the INGSOC-approved BERT tokenizer
43
+ - check :beam_size: hypotheses on each step for each sentence
44
+ - save best :beam_size: hypotheses
45
+ :return: generator<list of hypotheses on step>
46
+ """
47
+ # <YOUR CODE HERE>
48
+ bert_mlm_positive.eval()
49
+ bert_mlm_negative.eval()
50
+ new_beam = []
51
+ with torch.no_grad():
52
+ for sentence in current_beam:
53
+ input_ = {k: v.to(device) for k, v in tokenizer(sentence, return_tensors='pt').items()}
54
+ probs_negative = bert_mlm_negative(**input_).logits.softmax(dim=-1)[0]
55
+ probs_positive = bert_mlm_positive(**input_).logits.softmax(dim=-1)[0]
56
+
57
+ ids = input_['input_ids'][0].cpu().numpy()
58
+ seq_len = probs_positive.shape[0]
59
+ p_pos = probs_positive[torch.arange(seq_len), ids]
60
+ p_neg = probs_negative[torch.arange(seq_len), ids]
61
+
62
+ order_of_replacement = ((p_pos + epsilon) / (p_neg + epsilon)).argsort()
63
+ for pos in order_of_replacement:
64
+ if pos in used_positions or pos==0 or pos==len(ids)-1:
65
+ continue
66
+ used_position = pos
67
+ replacement_ids = (-probs_positive[pos,:]).argsort()[:beam_size]
68
+ for replacement_id in replacement_ids:
69
+ if replacement_id == ids[pos]:
70
+ continue
71
+ new_ids = ids.copy()
72
+ new_ids[pos] = replacement_id
73
+ new_beam.append(new_ids)
74
+ break
75
+ if len(new_beam) > 0:
76
+ new_beam = [tokenizer.decode(ids[1:-1]) for ids in new_beam]
77
+ new_beam = {sent: get_classifier_prob(sent)[1] for sent in new_beam}
78
+ for sent, prob in current_beam.items():
79
+ new_beam[sent] = prob
80
+
81
+ if len(new_beam) > beam_size:
82
+ new_beam = {k: v for k, v in sorted(new_beam.items(), key = lambda el: el[1], reverse=True)[:beam_size]}
83
+ return new_beam, used_position
84
+ else:
85
+ st.write("No more new hypotheses")
86
+ return current_beam, None
87
+
88
+
89
+ def get_best_hypotheses(sentence, beam_size, max_steps, epsilon=1e-3, pretty_output=False):
90
+ current_beam = {sentence: get_classifier_prob(sentence)[1]}
91
+ used_poss = []
92
+
93
+ st.write(f"step #0:")
94
+ st.write(f"-- 1: (positive probability ~ {round(current_beam[sentence], 5)})\n {sentence}")
95
+
96
+ for step in range(max_steps):
97
+ current_beam, used_pos = beam_get_replacements(current_beam, beam_size, epsilon, used_poss)
98
+
99
+ st.write(f"\nstep #{step+1}:")
100
+ for i, (sent, prob) in enumerate(current_beam.items()):
101
+ st.write(f"-- {i+1}: (positive probability ~ {round(prob, 5)})\n {highlight_diff(sent, sentence) if pretty_output else sent}")
102
+
103
+ if used_pos is None:
104
+ return current_beam, used_poss
105
+ else:
106
+ used_poss.append(used_pos)
107
+
108
+ return current_beam, used_poss
109
+
110
+
111
+ st.title("Correcting opinions")
112
+ default_value = "write your review here (in lower case - vocab reasons)"
113
+ sentence = st.text_area("Text", default_value, height = 275)
114
+ beam_size = st.sidebar.slider("Beam size", value = 3, min_value = 1, max_value=20, step=1)
115
+ max_steps = st.sidebar.slider("Max steps", value = 3, min_value = 1, max_value=10, step=1)
116
+ prettyfy = st.sidebar.slider("Higlight changes", value = 0, min_value = 0, max_value=1, step=1)
117
+
118
+ beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=bool(prettyfy))