any0019 commited on
Commit
53014ec
·
1 Parent(s): b87506a

Delete api.py

Browse files
Files changed (1) hide show
  1. api.py +0 -118
api.py DELETED
@@ -1,118 +0,0 @@
1
- import re
2
- import streamlit as st
3
- from termcolor import colored
4
- import torch
5
- from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
6
-
7
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
-
9
- @st.cache
10
- def load_models():
11
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
12
- bert_mlm_positive = BertForMaskedLM.from_pretrained('text_style_mlm_positive', return_dict=True).to(device).train(True)
13
- bert_mlm_negative = BertForMaskedLM.from_pretrained('text_style_mlm_negative', return_dict=True).to(device).train(True)
14
- bert_classifier = BertForSequenceClassification.from_pretrained('text_style_classifier', num_labels=2).to(device).train(True)
15
- return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
16
-
17
- tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models()
18
-
19
-
20
- def highlight_diff(sent, sent_main):
21
- tokens = tokenizer.tokenize(sent)
22
- tokens_main = tokenizer.tokenize(sent_main)
23
-
24
- new_toks = []
25
- for i, (tok, tok_main) in enumerate(zip(tokens, tokens_main)):
26
- if tok != tok_main:
27
- new_toks.append(colored(tok, 'red', attrs=['bold', 'underline']))
28
- else:
29
- new_toks.append(tok)
30
-
31
- return ' '.join(new_toks)
32
-
33
-
34
- def get_classifier_prob(sent):
35
- bert_classifier.eval()
36
- with torch.no_grad():
37
- return bert_classifier(**{k: v.to(device) for k, v in tokenizer(sent, return_tensors='pt').items()}).logits.softmax(dim=-1)[0].cpu().numpy()
38
-
39
-
40
- def beam_get_replacements(current_beam, beam_size, epsilon=1e-3, used_positions=[]):
41
- """
42
- - for each sentence in :current_beam: - split the sentence into tokens using the INGSOC-approved BERT tokenizer
43
- - check :beam_size: hypotheses on each step for each sentence
44
- - save best :beam_size: hypotheses
45
- :return: generator<list of hypotheses on step>
46
- """
47
- # <YOUR CODE HERE>
48
- bert_mlm_positive.eval()
49
- bert_mlm_negative.eval()
50
- new_beam = []
51
- with torch.no_grad():
52
- for sentence in current_beam:
53
- input_ = {k: v.to(device) for k, v in tokenizer(sentence, return_tensors='pt').items()}
54
- probs_negative = bert_mlm_negative(**input_).logits.softmax(dim=-1)[0]
55
- probs_positive = bert_mlm_positive(**input_).logits.softmax(dim=-1)[0]
56
-
57
- ids = input_['input_ids'][0].cpu().numpy()
58
- seq_len = probs_positive.shape[0]
59
- p_pos = probs_positive[torch.arange(seq_len), ids]
60
- p_neg = probs_negative[torch.arange(seq_len), ids]
61
-
62
- order_of_replacement = ((p_pos + epsilon) / (p_neg + epsilon)).argsort()
63
- for pos in order_of_replacement:
64
- if pos in used_positions or pos==0 or pos==len(ids)-1:
65
- continue
66
- used_position = pos
67
- replacement_ids = (-probs_positive[pos,:]).argsort()[:beam_size]
68
- for replacement_id in replacement_ids:
69
- if replacement_id == ids[pos]:
70
- continue
71
- new_ids = ids.copy()
72
- new_ids[pos] = replacement_id
73
- new_beam.append(new_ids)
74
- break
75
- if len(new_beam) > 0:
76
- new_beam = [tokenizer.decode(ids[1:-1]) for ids in new_beam]
77
- new_beam = {sent: get_classifier_prob(sent)[1] for sent in new_beam}
78
- for sent, prob in current_beam.items():
79
- new_beam[sent] = prob
80
-
81
- if len(new_beam) > beam_size:
82
- new_beam = {k: v for k, v in sorted(new_beam.items(), key = lambda el: el[1], reverse=True)[:beam_size]}
83
- return new_beam, used_position
84
- else:
85
- st.write("No more new hypotheses")
86
- return current_beam, None
87
-
88
-
89
- def get_best_hypotheses(sentence, beam_size, max_steps, epsilon=1e-3, pretty_output=False):
90
- current_beam = {sentence: get_classifier_prob(sentence)[1]}
91
- used_poss = []
92
-
93
- st.write(f"step #0:")
94
- st.write(f"-- 1: (positive probability ~ {round(current_beam[sentence], 5)})\n {sentence}")
95
-
96
- for step in range(max_steps):
97
- current_beam, used_pos = beam_get_replacements(current_beam, beam_size, epsilon, used_poss)
98
-
99
- st.write(f"\nstep #{step+1}:")
100
- for i, (sent, prob) in enumerate(current_beam.items()):
101
- st.write(f"-- {i+1}: (positive probability ~ {round(prob, 5)})\n {highlight_diff(sent, sentence) if pretty_output else sent}")
102
-
103
- if used_pos is None:
104
- return current_beam, used_poss
105
- else:
106
- used_poss.append(used_pos)
107
-
108
- return current_beam, used_poss
109
-
110
-
111
- st.title("Correcting opinions")
112
- default_value = "write your review here (in lower case - vocab reasons)"
113
- sentence = st.text_area("Text", default_value, height = 275)
114
- beam_size = st.sidebar.slider("Beam size", value = 3, min_value = 1, max_value=20, step=1)
115
- max_steps = st.sidebar.slider("Max steps", value = 3, min_value = 1, max_value=10, step=1)
116
- prettyfy = st.sidebar.slider("Higlight changes", value = 0, min_value = 0, max_value=1, step=1)
117
-
118
- beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=bool(prettyfy))