File size: 5,193 Bytes
f5382f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import re
import streamlit as st
from termcolor import colored
import torch
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@st.cache
def load_models():
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_mlm_positive = BertForMaskedLM.from_pretrained('text_style_mlm_positive', return_dict=True).to(device).train(True)
    bert_mlm_negative = BertForMaskedLM.from_pretrained('text_style_mlm_negative', return_dict=True).to(device).train(True)
    bert_classifier = BertForSequenceClassification.from_pretrained('text_style_classifier', num_labels=2).to(device).train(True)
    return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models()
def highlight_diff(sent, sent_main):
    tokens = tokenizer.tokenize(sent)
    tokens_main = tokenizer.tokenize(sent_main)
    
    new_toks = []
    for i, (tok, tok_main) in enumerate(zip(tokens, tokens_main)):
        if tok != tok_main:
            new_toks.append(colored(tok, 'red', attrs=['bold', 'underline']))
        else:
            new_toks.append(tok)
    
    return ' '.join(new_toks)
    
def get_classifier_prob(sent):
    bert_classifier.eval()
    with torch.no_grad():
        return bert_classifier(**{k: v.to(device) for k, v in tokenizer(sent, return_tensors='pt').items()}).logits.softmax(dim=-1)[0].cpu().numpy()
def beam_get_replacements(current_beam, beam_size, epsilon=1e-3, used_positions=[]):
    """
    - for each sentence in :current_beam: - split the sentence into tokens using the INGSOC-approved BERT tokenizer
    - check :beam_size: hypotheses on each step for each sentence
    - save best :beam_size: hypotheses
    :return: generator<list of hypotheses on step>
    """
    # <YOUR CODE HERE>
    bert_mlm_positive.eval()
    bert_mlm_negative.eval()
    new_beam = []
    with torch.no_grad():
        for sentence in current_beam:
            input_ = {k: v.to(device) for k, v in tokenizer(sentence, return_tensors='pt').items()}
            probs_negative = bert_mlm_negative(**input_).logits.softmax(dim=-1)[0]
            probs_positive = bert_mlm_positive(**input_).logits.softmax(dim=-1)[0]
            ids = input_['input_ids'][0].cpu().numpy()
            seq_len = probs_positive.shape[0]
            p_pos = probs_positive[torch.arange(seq_len), ids]
            p_neg = probs_negative[torch.arange(seq_len), ids]
            order_of_replacement = ((p_pos + epsilon) / (p_neg + epsilon)).argsort()
            for pos in order_of_replacement:
                if pos in used_positions or pos==0 or pos==len(ids)-1:
                    continue
                used_position = pos
                replacement_ids = (-probs_positive[pos,:]).argsort()[:beam_size]
                for replacement_id in replacement_ids:
                    if replacement_id == ids[pos]:
                        continue
                    new_ids = ids.copy()
                    new_ids[pos] = replacement_id
                    new_beam.append(new_ids)
                break
        if len(new_beam) > 0:
            new_beam = [tokenizer.decode(ids[1:-1]) for ids in new_beam]
            new_beam = {sent: get_classifier_prob(sent)[1] for sent in new_beam}
            for sent, prob in current_beam.items():
                new_beam[sent] = prob
            
            if len(new_beam) > beam_size:
                new_beam = {k: v for k, v in sorted(new_beam.items(), key = lambda el: el[1], reverse=True)[:beam_size]}
            return new_beam, used_position
        else:
            st.write("No more new hypotheses")
            return current_beam, None
def get_best_hypotheses(sentence, beam_size, max_steps, epsilon=1e-3, pretty_output=False):
    current_beam = {sentence: get_classifier_prob(sentence)[1]}
    used_poss = []
    
    st.write(f"step #0:")
    st.write(f"-- 1: (positive probability ~ {round(current_beam[sentence], 5)})\n      {sentence}")
    
    for step in range(max_steps):
        current_beam, used_pos = beam_get_replacements(current_beam, beam_size, epsilon, used_poss)
        
        st.write(f"\nstep #{step+1}:")
        for i, (sent, prob) in enumerate(current_beam.items()):
            st.write(f"-- {i+1}: (positive probability ~ {round(prob, 5)})\n      {highlight_diff(sent, sentence) if pretty_output else sent}")
        
        if used_pos is None:
            return current_beam, used_poss
        else:
            used_poss.append(used_pos)
    
    return current_beam, used_poss
st.title("Correcting opinions")
default_value = "write your review here (in lower case - vocab reasons)"
sentence = st.text_area("Text", default_value, height = 275)
beam_size = st.sidebar.slider("Beam size", value = 3, min_value = 1, max_value=20, step=1)
max_steps = st.sidebar.slider("Max steps", value = 3, min_value = 1, max_value=10, step=1)
prettyfy = st.sidebar.slider("Higlight changes", value = 0, min_value = 0, max_value=1, step=1)
beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=bool(prettyfy))