Spaces:
Running
Running
import collections | |
import itertools | |
from dataclasses import dataclass | |
from typing import List, Optional, Set, Tuple | |
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer | |
from relik.reader.data.relik_reader_sample import RelikReaderSample | |
class Window: | |
doc_id: int | |
window_id: int | |
text: str | |
tokens: List[str] | |
doc_topic: Optional[str] | |
offset: int | |
token2char_start: dict | |
token2char_end: dict | |
window_candidates: Optional[List[str]] = None | |
class WindowManager: | |
def __init__(self, tokenizer: BaseTokenizer) -> None: | |
self.tokenizer = tokenizer | |
def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]: | |
tokenized_document = self.tokenizer(document) | |
tokens = [] | |
tokens_char_mapping = [] | |
for token in tokenized_document: | |
tokens.append(token.text) | |
tokens_char_mapping.append((token.start_char, token.end_char)) | |
return tokens, tokens_char_mapping | |
def create_windows( | |
self, | |
document: str, | |
window_size: int, | |
stride: int, | |
doc_id: int = 0, | |
doc_topic: str = None, | |
) -> List[RelikReaderSample]: | |
document_tokens, tokens_char_mapping = self.tokenize(document) | |
if doc_topic is None: | |
doc_topic = document_tokens[0] if len(document_tokens) > 0 else "" | |
document_windows = [] | |
if len(document_tokens) <= window_size: | |
text = document | |
# relik_reader_sample = RelikReaderSample() | |
document_windows.append( | |
# Window( | |
RelikReaderSample( | |
doc_id=doc_id, | |
window_id=0, | |
text=text, | |
tokens=document_tokens, | |
doc_topic=doc_topic, | |
offset=0, | |
token2char_start={ | |
str(i): tokens_char_mapping[i][0] | |
for i in range(len(document_tokens)) | |
}, | |
token2char_end={ | |
str(i): tokens_char_mapping[i][1] | |
for i in range(len(document_tokens)) | |
}, | |
) | |
) | |
else: | |
for window_id, i in enumerate(range(0, len(document_tokens), stride)): | |
# if the last stride is smaller than the window size, then we can | |
# include more tokens form the previous window. | |
if i != 0 and i + window_size > len(document_tokens): | |
overflowing_tokens = i + window_size - len(document_tokens) | |
if overflowing_tokens >= stride: | |
break | |
i -= overflowing_tokens | |
involved_token_indices = list( | |
range(i, min(i + window_size, len(document_tokens) - 1)) | |
) | |
window_tokens = [document_tokens[j] for j in involved_token_indices] | |
window_text_start = tokens_char_mapping[involved_token_indices[0]][0] | |
window_text_end = tokens_char_mapping[involved_token_indices[-1]][1] | |
text = document[window_text_start:window_text_end] | |
document_windows.append( | |
# Window( | |
RelikReaderSample( | |
# dict( | |
doc_id=doc_id, | |
window_id=window_id, | |
text=text, | |
tokens=window_tokens, | |
doc_topic=doc_topic, | |
offset=window_text_start, | |
token2char_start={ | |
str(i): tokens_char_mapping[ti][0] | |
for i, ti in enumerate(involved_token_indices) | |
}, | |
token2char_end={ | |
str(i): tokens_char_mapping[ti][1] | |
for i, ti in enumerate(involved_token_indices) | |
}, | |
# ) | |
) | |
) | |
return document_windows | |
def merge_windows( | |
self, windows: List[RelikReaderSample] | |
) -> List[RelikReaderSample]: | |
windows_by_doc_id = collections.defaultdict(list) | |
for window in windows: | |
windows_by_doc_id[window.doc_id].append(window) | |
merged_window_by_doc = { | |
doc_id: self.merge_doc_windows(doc_windows) | |
for doc_id, doc_windows in windows_by_doc_id.items() | |
} | |
return list(merged_window_by_doc.values()) | |
def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample: | |
if len(windows) == 1: | |
return windows[0] | |
if len(windows) > 0 and getattr(windows[0], "offset", None) is not None: | |
windows = sorted(windows, key=(lambda x: x.offset)) | |
window_accumulator = windows[0] | |
for next_window in windows[1:]: | |
window_accumulator = self._merge_window_pair( | |
window_accumulator, next_window | |
) | |
return window_accumulator | |
def _merge_tokens( | |
self, window1: RelikReaderSample, window2: RelikReaderSample | |
) -> Tuple[list, dict, dict]: | |
w1_tokens = window1.tokens[1:-1] | |
w2_tokens = window2.tokens[1:-1] | |
# find intersection | |
tokens_intersection = None | |
for k in reversed(range(1, len(w1_tokens))): | |
if w1_tokens[-k:] == w2_tokens[:k]: | |
tokens_intersection = k | |
break | |
assert tokens_intersection is not None, ( | |
f"{window1.doc_id} - {window1.sent_id} - {window1.offset}" | |
+ f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n" | |
+ f"w1 tokens: {w1_tokens}\n" | |
+ f"w2 tokens: {w2_tokens}\n" | |
) | |
final_tokens = ( | |
[window1.tokens[0]] # CLS | |
+ w1_tokens | |
+ w2_tokens[tokens_intersection:] | |
+ [window1.tokens[-1]] # SEP | |
) | |
w2_starting_offset = len(w1_tokens) - tokens_intersection | |
def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict: | |
final_t2c = dict() | |
final_t2c.update(t2c1) | |
for t, c in t2c2.items(): | |
t = int(t) | |
if t < tokens_intersection: | |
continue | |
final_t2c[str(t + w2_starting_offset)] = c | |
return final_t2c | |
return ( | |
final_tokens, | |
merge_char_mapping(window1.token2char_start, window2.token2char_start), | |
merge_char_mapping(window1.token2char_end, window2.token2char_end), | |
) | |
def _merge_span_annotation( | |
self, span_annotation1: List[list], span_annotation2: List[list] | |
) -> List[list]: | |
uniq_store = set() | |
final_span_annotation_store = [] | |
for span_annotation in itertools.chain(span_annotation1, span_annotation2): | |
span_annotation_id = tuple(span_annotation) | |
if span_annotation_id not in uniq_store: | |
uniq_store.add(span_annotation_id) | |
final_span_annotation_store.append(span_annotation) | |
return sorted(final_span_annotation_store, key=lambda x: x[0]) | |
def _merge_predictions( | |
self, | |
window1: RelikReaderSample, | |
window2: RelikReaderSample, | |
) -> Tuple[Set[Tuple[int, int, str]], dict]: | |
merged_predictions = window1.predicted_window_labels_chars.union( | |
window2.predicted_window_labels_chars | |
) | |
span_title_probabilities = dict() | |
# probabilities | |
for span_prediction, predicted_probs in itertools.chain( | |
window1.probs_window_labels_chars.items(), | |
window2.probs_window_labels_chars.items(), | |
): | |
if span_prediction not in span_title_probabilities: | |
span_title_probabilities[span_prediction] = predicted_probs | |
return merged_predictions, span_title_probabilities | |
def _merge_window_pair( | |
self, | |
window1: RelikReaderSample, | |
window2: RelikReaderSample, | |
) -> RelikReaderSample: | |
merging_output = dict() | |
if getattr(window1, "doc_id", None) is not None: | |
assert window1.doc_id == window2.doc_id | |
if getattr(window1, "offset", None) is not None: | |
assert ( | |
window1.offset < window2.offset | |
), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})" | |
merging_output["doc_id"] = window1.doc_id | |
merging_output["offset"] = window2.offset | |
m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens( | |
window1, window2 | |
) | |
window_labels = None | |
if getattr(window1, "window_labels", None) is not None: | |
window_labels = self._merge_span_annotation( | |
window1.window_labels, window2.window_labels | |
) | |
( | |
predicted_window_labels_chars, | |
probs_window_labels_chars, | |
) = self._merge_predictions( | |
window1, | |
window2, | |
) | |
merging_output.update( | |
dict( | |
tokens=m_tokens, | |
token2char_start=m_token2char_start, | |
token2char_end=m_token2char_end, | |
window_labels=window_labels, | |
predicted_window_labels_chars=predicted_window_labels_chars, | |
probs_window_labels_chars=probs_window_labels_chars, | |
) | |
) | |
return RelikReaderSample(**merging_output) | |