tapas-tokenizer-viz / tapas_visualizer.py
bglearning's picture
Refactor table_html creation
f9dd31c
raw
history blame
6.04 kB
import os
from typing import Any, List, Dict
from collections import defaultdict
import pandas as pd
dirname = os.path.dirname(__file__)
css_filename = os.path.join(dirname, "tapas-styles.css")
with open(css_filename) as f:
css = f.read()
def HTMLBody(table_html: str, css_styles=css) -> str:
"""
Generates the full html with css from a list of html spans
Args:
children (:obj:`List[str]`):
A list of strings, assumed to be html elements
css_styles (:obj:`str`, `optional`):
Optional alternative implementation of the css
Returns:
:obj:`str`: An HTML string with style markup
"""
return f"""
<html>
<head>
<style>
{css_styles}
</style>
</head>
<body>
<div class="tokenized-text" dir=auto>
{table_html}
</div>
</body>
</html>
"""
class TapasVisualizer:
def __init__(self, tokenizer) -> None:
self.tokenizer = tokenizer
def normalize_token_str(self, token_str: str) -> str:
return token_str.replace("##", "")
def style_span(self, span_text: str, css_classes: List[str]) -> str:
css = f'''class="{' '.join(css_classes)}"'''
return f"<span {css} >{span_text}</span>"
def text_to_html(self, org_text: str, tokens: List[str]) -> str:
"""Create html based on the original text and its tokens.
Note: The tokens need to be in same order as in the original text
Args:
org_text (str): Original string before tokenization
tokens (List[str]): The tokens of org_text
Returns:
str: html with styling for the tokens
"""
if len(tokens) == 0:
print(f'Empty tokens for: {org_text}')
return ''
cur_token_id = 0
cur_token = self.normalize_token_str(tokens[cur_token_id])
# Loop through each character
next_start = 0
last_end = 0
spans = []
while next_start < len(org_text):
candidate = org_text[next_start: next_start + len(cur_token)]
# The tokenizer performs lowercasing; so check against lowercase
if candidate.lower() == cur_token:
if last_end != next_start:
# There was token-less text (probably whitespace)
# in the middle
spans.append(self.style_span(org_text[last_end: next_start], ['non-token']))
odd_or_even = 'even-token' if cur_token_id % 2 == 0 else 'odd-token'
spans.append(self.style_span(candidate, ['token', odd_or_even]))
next_start += len(cur_token)
last_end = next_start
cur_token_id += 1
if cur_token_id >= len(tokens):
break
cur_token = self.normalize_token_str(tokens[cur_token_id])
else:
next_start += 1
if last_end != len(org_text):
spans.append(self.style_span(org_text[last_end: next_start], ['non-token']))
return spans
def cells_to_html(self,
cell_vals: List[List[str]],
cell_tokens: Dict,
row_id_start: int=0,
cell_element: str="td",
cumulative_cnt: int=0,
table_html: str="") -> str:
for row_id, row in enumerate(cell_vals, start=row_id_start):
row_html = ""
row_token_cnt = 0
for col_id, cell in enumerate(row, start=1):
cur_cell_tokens = cell_tokens[(row_id, col_id)]
span_htmls = self.text_to_html(cell, cur_cell_tokens)
cell_html = "".join(span_htmls)
row_html += f"<{cell_element}>{cell_html}</{cell_element}>"
row_token_cnt += len(cur_cell_tokens)
cumulative_cnt += row_token_cnt
# cnt_str = f'{row_token_cnt} | {cumulative_cnt}'
row_html += f'<td style="border: none;" align="right">{self.style_span(str(row_token_cnt), ["non-token", "count"])}</td>'
row_html += f'<td style="border: none;" align="right">{self.style_span(str(cumulative_cnt), ["non-token", "count"])}</td>'
table_html += f'<tr>{row_html}</tr>'
return table_html, cumulative_cnt
def __call__(self, table: pd.DataFrame) -> Any:
tokenized = self.tokenizer(table)
cell_tokens = defaultdict(list)
for id_ind, input_id in enumerate(tokenized['input_ids']):
input_id = int(input_id)
# 'prev_label', 'column_rank', 'inv_column_rank', 'numeric_relation' not required
segment_id, col_id, row_id, *_ = tokenized['token_type_ids'][id_ind]
token_text = self.tokenizer._convert_id_to_token(input_id)
if int(segment_id) == 1:
cell_tokens[(row_id, col_id)].append(token_text)
table_html, cumulative_cnt = self.cells_to_html(cell_vals=[table.columns],
cell_tokens=cell_tokens,
row_id_start=0,
cell_element="th",
cumulative_cnt=0,
table_html="")
table_html, cumulative_cnt = self.cells_to_html(cell_vals=table.values,
cell_tokens=cell_tokens,
row_id_start=1,
cell_element="td",
cumulative_cnt=cumulative_cnt,
table_html=table_html)
table_html = f'<table>{table_html}</table>'
return HTMLBody(table_html)