Spaces:
Runtime error
Runtime error
import re | |
import streamlit as st | |
from termcolor import colored | |
import torch | |
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def load_models(): | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
bert_mlm_positive = BertForMaskedLM.from_pretrained('text_style_mlm_positive', return_dict=True).to(device).train(True) | |
bert_mlm_negative = BertForMaskedLM.from_pretrained('text_style_mlm_negative', return_dict=True).to(device).train(True) | |
bert_classifier = BertForSequenceClassification.from_pretrained('text_style_classifier', num_labels=2).to(device).train(True) | |
return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier | |
tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models() | |
def highlight_diff(sent, sent_main): | |
tokens = tokenizer.tokenize(sent) | |
tokens_main = tokenizer.tokenize(sent_main) | |
new_toks = [] | |
for i, (tok, tok_main) in enumerate(zip(tokens, tokens_main)): | |
if tok != tok_main: | |
new_toks.append(colored(tok, 'red', attrs=['bold', 'underline'])) | |
else: | |
new_toks.append(tok) | |
return ' '.join(new_toks) | |
def get_classifier_prob(sent): | |
bert_classifier.eval() | |
with torch.no_grad(): | |
return bert_classifier(**{k: v.to(device) for k, v in tokenizer(sent, return_tensors='pt').items()}).logits.softmax(dim=-1)[0].cpu().numpy() | |
def beam_get_replacements(current_beam, beam_size, epsilon=1e-3, used_positions=[]): | |
""" | |
- for each sentence in :current_beam: - split the sentence into tokens using the INGSOC-approved BERT tokenizer | |
- check :beam_size: hypotheses on each step for each sentence | |
- save best :beam_size: hypotheses | |
:return: generator<list of hypotheses on step> | |
""" | |
# <YOUR CODE HERE> | |
bert_mlm_positive.eval() | |
bert_mlm_negative.eval() | |
new_beam = [] | |
with torch.no_grad(): | |
for sentence in current_beam: | |
input_ = {k: v.to(device) for k, v in tokenizer(sentence, return_tensors='pt').items()} | |
probs_negative = bert_mlm_negative(**input_).logits.softmax(dim=-1)[0] | |
probs_positive = bert_mlm_positive(**input_).logits.softmax(dim=-1)[0] | |
ids = input_['input_ids'][0].cpu().numpy() | |
seq_len = probs_positive.shape[0] | |
p_pos = probs_positive[torch.arange(seq_len), ids] | |
p_neg = probs_negative[torch.arange(seq_len), ids] | |
order_of_replacement = ((p_pos + epsilon) / (p_neg + epsilon)).argsort() | |
for pos in order_of_replacement: | |
if pos in used_positions or pos==0 or pos==len(ids)-1: | |
continue | |
used_position = pos | |
replacement_ids = (-probs_positive[pos,:]).argsort()[:beam_size] | |
for replacement_id in replacement_ids: | |
if replacement_id == ids[pos]: | |
continue | |
new_ids = ids.copy() | |
new_ids[pos] = replacement_id | |
new_beam.append(new_ids) | |
break | |
if len(new_beam) > 0: | |
new_beam = [tokenizer.decode(ids[1:-1]) for ids in new_beam] | |
new_beam = {sent: get_classifier_prob(sent)[1] for sent in new_beam} | |
for sent, prob in current_beam.items(): | |
new_beam[sent] = prob | |
if len(new_beam) > beam_size: | |
new_beam = {k: v for k, v in sorted(new_beam.items(), key = lambda el: el[1], reverse=True)[:beam_size]} | |
return new_beam, used_position | |
else: | |
st.write("No more new hypotheses") | |
return current_beam, None | |
def get_best_hypotheses(sentence, beam_size, max_steps, epsilon=1e-3, pretty_output=False): | |
current_beam = {sentence: get_classifier_prob(sentence)[1]} | |
used_poss = [] | |
st.write(f"step #0:") | |
st.write(f"-- 1: (positive probability ~ {round(current_beam[sentence], 5)})\n {sentence}") | |
for step in range(max_steps): | |
current_beam, used_pos = beam_get_replacements(current_beam, beam_size, epsilon, used_poss) | |
st.write(f"\nstep #{step+1}:") | |
for i, (sent, prob) in enumerate(current_beam.items()): | |
st.write(f"-- {i+1}: (positive probability ~ {round(prob, 5)})\n {highlight_diff(sent, sentence) if pretty_output else sent}") | |
if used_pos is None: | |
return current_beam, used_poss | |
else: | |
used_poss.append(used_pos) | |
return current_beam, used_poss | |
st.title("Correcting opinions") | |
default_value = "write your review here (in lower case - vocab reasons)" | |
sentence = st.text_area("Text", default_value, height = 275) | |
beam_size = st.sidebar.slider("Beam size", value = 3, min_value = 1, max_value=20, step=1) | |
max_steps = st.sidebar.slider("Max steps", value = 3, min_value = 1, max_value=10, step=1) | |
prettyfy = st.sidebar.slider("Higlight changes", value = 0, min_value = 0, max_value=1, step=1) | |
beam, used_poss = get_best_hypotheses(sentence, beam_size=beam_size, max_steps=max_steps, pretty_output=bool(prettyfy)) |