Spaces:
Runtime error
Runtime error
import os | |
from typing import Any, List, Dict | |
from collections import defaultdict | |
import pandas as pd | |
dirname = os.path.dirname(__file__) | |
css_filename = os.path.join(dirname, "tapas-styles.css") | |
with open(css_filename) as f: | |
css = f.read() | |
def HTMLBody(table_html: str, css_styles=css) -> str: | |
""" | |
Generates the full html with css from a list of html spans | |
Args: | |
children (:obj:`List[str]`): | |
A list of strings, assumed to be html elements | |
css_styles (:obj:`str`, `optional`): | |
Optional alternative implementation of the css | |
Returns: | |
:obj:`str`: An HTML string with style markup | |
""" | |
return f""" | |
<html> | |
<head> | |
<style> | |
{css_styles} | |
</style> | |
</head> | |
<body> | |
<div class="tokenized-text" dir=auto> | |
{table_html} | |
</div> | |
</body> | |
</html> | |
""" | |
class TapasVisualizer: | |
def __init__(self, tokenizer) -> None: | |
self.tokenizer = tokenizer | |
def normalize_token_str(self, token_str: str) -> str: | |
return token_str.replace("##", "") | |
def style_span(self, span_text: str, css_classes: List[str]) -> str: | |
css = f'''class="{' '.join(css_classes)}"''' | |
return f"<span {css} >{span_text}</span>" | |
def text_to_html(self, org_text: str, tokens: List[str]) -> str: | |
"""Create html based on the original text and its tokens. | |
Note: The tokens need to be in same order as in the original text | |
Args: | |
org_text (str): Original string before tokenization | |
tokens (List[str]): The tokens of org_text | |
Returns: | |
str: html with styling for the tokens | |
""" | |
if len(tokens) == 0: | |
print(f'Empty tokens for: {org_text}') | |
return '' | |
cur_token_id = 0 | |
cur_token = self.normalize_token_str(tokens[cur_token_id]) | |
# Loop through each character | |
next_start = 0 | |
last_end = 0 | |
spans = [] | |
while next_start < len(org_text): | |
candidate = org_text[next_start: next_start + len(cur_token)] | |
# The tokenizer performs lowercasing; so check against lowercase | |
if candidate.lower() == cur_token: | |
if last_end != next_start: | |
# There was token-less text (probably whitespace) | |
# in the middle | |
spans.append(self.style_span(org_text[last_end: next_start], ['non-token'])) | |
odd_or_even = 'even-token' if cur_token_id % 2 == 0 else 'odd-token' | |
spans.append(self.style_span(candidate, ['token', odd_or_even])) | |
next_start += len(cur_token) | |
last_end = next_start | |
cur_token_id += 1 | |
if cur_token_id >= len(tokens): | |
break | |
cur_token = self.normalize_token_str(tokens[cur_token_id]) | |
else: | |
next_start += 1 | |
if last_end != len(org_text): | |
spans.append(self.style_span(org_text[last_end: next_start], ['non-token'])) | |
return spans | |
def cells_to_html(self, | |
cell_vals: List[List[str]], | |
cell_tokens: Dict, | |
row_id_start: int=0, | |
cell_element: str="td", | |
cumulative_cnt: int=0, | |
table_html: str="") -> str: | |
for row_id, row in enumerate(cell_vals, start=row_id_start): | |
row_html = "" | |
row_token_cnt = 0 | |
for col_id, cell in enumerate(row, start=1): | |
cur_cell_tokens = cell_tokens[(row_id, col_id)] | |
span_htmls = self.text_to_html(cell, cur_cell_tokens) | |
cell_html = "".join(span_htmls) | |
row_html += f"<{cell_element}>{cell_html}</{cell_element}>" | |
row_token_cnt += len(cur_cell_tokens) | |
cumulative_cnt += row_token_cnt | |
# cnt_str = f'{row_token_cnt} | {cumulative_cnt}' | |
row_html += f'<td style="border: none;" align="right">{self.style_span(str(row_token_cnt), ["non-token", "count"])}</td>' | |
row_html += f'<td style="border: none;" align="right">{self.style_span(str(cumulative_cnt), ["non-token", "count"])}</td>' | |
table_html += f'<tr>{row_html}</tr>' | |
return table_html, cumulative_cnt | |
def __call__(self, table: pd.DataFrame) -> Any: | |
tokenized = self.tokenizer(table) | |
cell_tokens = defaultdict(list) | |
for id_ind, input_id in enumerate(tokenized['input_ids']): | |
input_id = int(input_id) | |
# 'prev_label', 'column_rank', 'inv_column_rank', 'numeric_relation' not required | |
segment_id, col_id, row_id, *_ = tokenized['token_type_ids'][id_ind] | |
token_text = self.tokenizer._convert_id_to_token(input_id) | |
if int(segment_id) == 1: | |
cell_tokens[(row_id, col_id)].append(token_text) | |
table_html, cumulative_cnt = self.cells_to_html(cell_vals=[table.columns], | |
cell_tokens=cell_tokens, | |
row_id_start=0, | |
cell_element="th", | |
cumulative_cnt=0, | |
table_html="") | |
table_html, cumulative_cnt = self.cells_to_html(cell_vals=table.values, | |
cell_tokens=cell_tokens, | |
row_id_start=1, | |
cell_element="td", | |
cumulative_cnt=cumulative_cnt, | |
table_html=table_html) | |
table_html = f'<table>{table_html}</table>' | |
return HTMLBody(table_html) | |