Spaces:
Runtime error
Runtime error
File size: 5,410 Bytes
f5382f0 b3a0a26 f5382f0 b3a0a26 5fa5cad f5382f0 93f4a89 f5382f0 b3a0a26 f5382f0 b3a0a26 f5382f0 b78c6b8 f5382f0 b3a0a26 f5382f0 b3a0a26 f5382f0 b3a0a26 f5382f0 6d67df4 b78c6b8 f5382f0 6d67df4 b78c6b8 f5382f0 b3a0a26 fb2797d b3a0a26 f5382f0 b086847 b3a0a26 b086847 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import streamlit as st
from termcolor import colored
import torch
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@st.cache(allow_output_mutation=True)
def load_models():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_mlm_positive = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_positive', return_dict=True).to(device).train(True)
bert_mlm_negative = BertForMaskedLM.from_pretrained('any0019/text_style_mlm_negative', return_dict=True).to(device).train(True)
bert_classifier = BertForSequenceClassification.from_pretrained('any0019/text_style_classifier', num_labels=2).to(device).train(True)
return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models()
def highlight_diff(sent, sent_main):
tokens = tokenizer.tokenize(sent)
tokens_main = tokenizer.tokenize(sent_main)
new_toks = []
for i, (tok, tok_main) in enumerate(zip(tokens, tokens_main)):
if tok != tok_main:
new_toks.append(f'\textcolor\{red\}\{**{tok}**\}')
else:
new_toks.append(tok)
return ' '.join(new_toks)
def get_classifier_prob(sent):
bert_classifier.eval()
with torch.no_grad():
return bert_classifier(**{k: v.to(device) for k, v in tokenizer(sent, return_tensors='pt').items()}).logits.softmax(dim=-1)[0].cpu().numpy()
def beam_get_replacements(current_beam, beam_size, epsilon=1e-3, used_positions=[]):
"""
- for each sentence in :current_beam: - split the sentence into tokens using the INGSOC-approved BERT tokenizer
- check :beam_size: hypotheses on each step for each sentence
- save best :beam_size: hypotheses
:return: generator<list of hypotheses on step>
"""
# <YOUR CODE HERE>
bert_mlm_positive.eval()
bert_mlm_negative.eval()
new_beam = []
with torch.no_grad():
for sentence in current_beam:
input_ = {k: v.to(device) for k, v in tokenizer(sentence, return_tensors='pt').items()}
probs_negative = bert_mlm_negative(**input_).logits.softmax(dim=-1)[0]
probs_positive = bert_mlm_positive(**input_).logits.softmax(dim=-1)[0]
ids = input_['input_ids'][0].cpu().numpy()
seq_len = probs_positive.shape[0]
p_pos = probs_positive[torch.arange(seq_len), ids]
p_neg = probs_negative[torch.arange(seq_len), ids]
order_of_replacement = ((p_pos + epsilon) / (p_neg + epsilon)).argsort()
for pos in order_of_replacement:
if pos in used_positions or pos==0 or pos==len(ids)-1:
continue
used_position = pos
replacement_ids = (-probs_positive[pos,:]).argsort()[:beam_size]
for replacement_id in replacement_ids:
if replacement_id == ids[pos]:
continue
new_ids = ids.copy()
new_ids[pos] = replacement_id
new_beam.append(new_ids)
break
if len(new_beam) > 0:
new_beam = [tokenizer.decode(ids[1:-1]) for ids in new_beam]
new_beam = {sent: get_classifier_prob(sent)[1] for sent in new_beam}
for sent, prob in current_beam.items():
new_beam[sent] = prob
if len(new_beam) > beam_size:
new_beam = {k: v for k, v in sorted(new_beam.items(), key = lambda el: el[1], reverse=True)[:beam_size]}
return new_beam, used_position
else:
st.write("No more new hypotheses")
return current_beam, None
def get_best_hypotheses(sentence, beam_size, max_steps, epsilon=1e-3, pretty_output=False):
current_beam = {sentence: get_classifier_prob(sentence)[1]}
used_poss = []
st.write(f"step #0:")
st.write(f"-- 1: (positive probability ~ {round(current_beam[sentence], 5)})")
st.write(f"$\qquad${sentence}")
for step in range(max_steps):
current_beam, used_pos = beam_get_replacements(current_beam, beam_size, epsilon, used_poss)
st.write(f"\nstep #{step+1}:")
for i, (sent, prob) in enumerate(current_beam.items()):
st.write(f"-- {i+1}: (positive probability ~ {round(prob, 5)})")
st.write(f"$\qquad${highlight_diff(sent, sentence) if pretty_output else sent}")
if used_pos is None:
return current_beam, used_poss
else:
used_poss.append(used_pos)
return current_beam, used_poss
st.title("Correcting opinions of fellow comrades")
default_value = "write your review here (in lower case - vocab reasons)"
sentence = st.text_area("Text", default_value, height = 275)
beam_size = st.sidebar.slider("Beam size", value = 3, min_value = 1, max_value=20, step=1)
max_steps = st.sidebar.slider("Max steps", value = 3, min_value = 1, max_value=10, step=1)
# prettyfy = st.sidebar.slider("Higlight changes", value = 0, min_value = 0, max_value=1, step=1)
# beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=bool(prettyfy))
beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=False) |