File size: 5,410 Bytes
f5382f0
 
 
 
b3a0a26
f5382f0
b3a0a26
 
5fa5cad
f5382f0
 
93f4a89
 
 
f5382f0
b3a0a26
 
f5382f0
b3a0a26
 
f5382f0
 
 
 
 
 
 
b78c6b8
f5382f0
 
 
 
b3a0a26
 
f5382f0
 
 
 
b3a0a26
 
f5382f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3a0a26
 
f5382f0
 
 
 
 
6d67df4
b78c6b8
f5382f0
 
 
 
 
 
6d67df4
b78c6b8
f5382f0
 
 
 
 
 
 
b3a0a26
 
fb2797d
b3a0a26
f5382f0
 
 
 
b086847
b3a0a26
b086847
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
from termcolor import colored
import torch
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification

device = 'cuda' if torch.cuda.is_available() else 'cpu'


@st.cache(allow_output_mutation=True)
def load_models():
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_mlm_positive = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_positive', return_dict=True).to(device).train(True)
    bert_mlm_negative = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_negative', return_dict=True).to(device).train(True)
    bert_classifier = BertForSequenceClassification.from_pretrained('any0019/text_style_classifier', num_labels=2).to(device).train(True)
    return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier


tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models()


def highlight_diff(sent, sent_main):
    tokens = tokenizer.tokenize(sent)
    tokens_main = tokenizer.tokenize(sent_main)
    
    new_toks = []
    for i, (tok, tok_main) in enumerate(zip(tokens, tokens_main)):
        if tok != tok_main:
            new_toks.append(f'\textcolor\{red\}\{**{tok}**\}')
        else:
            new_toks.append(tok)
    
    return ' '.join(new_toks)


def get_classifier_prob(sent):
    bert_classifier.eval()
    with torch.no_grad():
        return bert_classifier(**{k: v.to(device) for k, v in tokenizer(sent, return_tensors='pt').items()}).logits.softmax(dim=-1)[0].cpu().numpy()


def beam_get_replacements(current_beam, beam_size, epsilon=1e-3, used_positions=[]):
    """
    - for each sentence in :current_beam: - split the sentence into tokens using the INGSOC-approved BERT tokenizer
    - check :beam_size: hypotheses on each step for each sentence
    - save best :beam_size: hypotheses
    :return: generator<list of hypotheses on step>
    """
    # <YOUR CODE HERE>
    bert_mlm_positive.eval()
    bert_mlm_negative.eval()
    new_beam = []
    with torch.no_grad():
        for sentence in current_beam:
            input_ = {k: v.to(device) for k, v in tokenizer(sentence, return_tensors='pt').items()}
            probs_negative = bert_mlm_negative(**input_).logits.softmax(dim=-1)[0]
            probs_positive = bert_mlm_positive(**input_).logits.softmax(dim=-1)[0]
            ids = input_['input_ids'][0].cpu().numpy()
            seq_len = probs_positive.shape[0]
            p_pos = probs_positive[torch.arange(seq_len), ids]
            p_neg = probs_negative[torch.arange(seq_len), ids]
            order_of_replacement = ((p_pos + epsilon) / (p_neg + epsilon)).argsort()
            for pos in order_of_replacement:
                if pos in used_positions or pos==0 or pos==len(ids)-1:
                    continue
                used_position = pos
                replacement_ids = (-probs_positive[pos,:]).argsort()[:beam_size]
                for replacement_id in replacement_ids:
                    if replacement_id == ids[pos]:
                        continue
                    new_ids = ids.copy()
                    new_ids[pos] = replacement_id
                    new_beam.append(new_ids)
                break
        if len(new_beam) > 0:
            new_beam = [tokenizer.decode(ids[1:-1]) for ids in new_beam]
            new_beam = {sent: get_classifier_prob(sent)[1] for sent in new_beam}
            for sent, prob in current_beam.items():
                new_beam[sent] = prob
            
            if len(new_beam) > beam_size:
                new_beam = {k: v for k, v in sorted(new_beam.items(), key = lambda el: el[1], reverse=True)[:beam_size]}
            return new_beam, used_position
        else:
            st.write("No more new hypotheses")
            return current_beam, None


def get_best_hypotheses(sentence, beam_size, max_steps, epsilon=1e-3, pretty_output=False):
    current_beam = {sentence: get_classifier_prob(sentence)[1]}
    used_poss = []
    
    st.write(f"step #0:")
    st.write(f"-- 1: (positive probability ~ {round(current_beam[sentence], 5)})")
    st.write(f"$\qquad${sentence}")
    
    for step in range(max_steps):
        current_beam, used_pos = beam_get_replacements(current_beam, beam_size, epsilon, used_poss)
        
        st.write(f"\nstep #{step+1}:")
        for i, (sent, prob) in enumerate(current_beam.items()):
            st.write(f"-- {i+1}: (positive probability ~ {round(prob, 5)})")
            st.write(f"$\qquad${highlight_diff(sent, sentence) if pretty_output else sent}")
        
        if used_pos is None:
            return current_beam, used_poss
        else:
            used_poss.append(used_pos)
    
    return current_beam, used_poss


st.title("Correcting opinions of fellow comrades")

default_value = "write your review here (in lower case - vocab reasons)"
sentence = st.text_area("Text", default_value, height = 275)
beam_size = st.sidebar.slider("Beam size", value = 3, min_value = 1, max_value=20, step=1)
max_steps = st.sidebar.slider("Max steps", value = 3, min_value = 1, max_value=10, step=1)
# prettyfy = st.sidebar.slider("Higlight changes", value = 0, min_value = 0, max_value=1, step=1)

# beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=bool(prettyfy))
beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=False)