|
import torch |
|
import os |
|
from tqdm import tqdm |
|
from .modeling import GECToR |
|
from transformers import PreTrainedTokenizer |
|
from typing import List, Dict |
|
from .predict import ( |
|
edit_src_by_tags, |
|
_predict |
|
) |
|
|
|
def predict_verbose( |
|
model: GECToR, |
|
tokenizer: PreTrainedTokenizer, |
|
srcs: List[str], |
|
encode: dict, |
|
decode: dict, |
|
keep_confidence: float=0, |
|
min_error_prob: float=0, |
|
batch_size: int=128, |
|
n_iteration: int=5 |
|
) -> List[str]: |
|
srcs = [['$START'] + src.split(' ') for src in srcs] |
|
final_edited_sents = ['-1'] * len(srcs) |
|
to_be_processed = srcs |
|
original_sent_idx = list(range(0, len(srcs))) |
|
iteration_log: List[List[Dict]] = [] |
|
iteration_log = [] |
|
|
|
for i, src in enumerate(srcs): |
|
iteration_log.append([{ |
|
'src': src, |
|
'tag': None |
|
}]) |
|
for itr in range(n_iteration): |
|
print(f'Iteratoin {itr}. the number of to_be_processed: {len(to_be_processed)}') |
|
pred_labels, no_corrections = _predict( |
|
model, |
|
tokenizer, |
|
to_be_processed, |
|
keep_confidence, |
|
min_error_prob, |
|
batch_size |
|
) |
|
current_srcs = [] |
|
current_pred_labels = [] |
|
current_orig_idx = [] |
|
for i, yes in enumerate(no_corrections): |
|
if yes: |
|
final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '') |
|
else: |
|
current_srcs.append(to_be_processed[i]) |
|
current_pred_labels.append(pred_labels[i]) |
|
current_orig_idx.append(original_sent_idx[i]) |
|
if current_srcs == []: |
|
|
|
break |
|
edited_srcs = edit_src_by_tags( |
|
current_srcs, |
|
current_pred_labels, |
|
encode, |
|
decode |
|
) |
|
|
|
|
|
for i, orig_id in enumerate(current_orig_idx): |
|
iteration_log[orig_id][itr]['tag'] = current_pred_labels[i] |
|
iteration_log[orig_id].append({ |
|
'src': edited_srcs[i], |
|
'tag': None |
|
}) |
|
|
|
to_be_processed = edited_srcs |
|
original_sent_idx = current_orig_idx |
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(to_be_processed)): |
|
final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '') |
|
assert('-1' not in final_edited_sents) |
|
return final_edited_sents, iteration_log |