|
import json |
|
import logging |
|
from collections import defaultdict |
|
from typing import Dict, List, Optional, Union |
|
|
|
from annotation_utils import labeled_span_to_id |
|
from pytorch_ie.annotations import BinaryRelation, LabeledSpan |
|
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
from rendering_utils_displacy import EntityRenderer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
TPL_ENT_WITH_ID = """ |
|
<mark class="entity" id="{id}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;"> |
|
{text} |
|
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span> |
|
</mark> |
|
""" |
|
|
|
|
|
def render_pretty_table( |
|
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs |
|
): |
|
from prettytable import PrettyTable |
|
|
|
t = PrettyTable() |
|
t.field_names = ["head", "tail", "relation"] |
|
t.align = "l" |
|
for relation in list(document.binary_relations) + list(document.binary_relations.predictions): |
|
t.add_row([str(relation.head), str(relation.tail), relation.label]) |
|
|
|
html = t.get_html_string(format=True) |
|
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>" |
|
|
|
return html |
|
|
|
|
|
def render_displacy( |
|
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
inject_relations=True, |
|
colors_hover=None, |
|
entity_options={}, |
|
**render_kwargs, |
|
): |
|
|
|
labeled_spans = list(document.labeled_spans) + list(document.labeled_spans.predictions) |
|
spacy_doc = { |
|
"text": document.text, |
|
"ents": [ |
|
{ |
|
"start": labeled_span.start, |
|
"end": labeled_span.end, |
|
"label": labeled_span.label, |
|
|
|
|
|
"params": {"id": labeled_span_to_id(labeled_span)}, |
|
} |
|
for labeled_span in labeled_spans |
|
], |
|
"title": None, |
|
} |
|
|
|
|
|
entity_options = entity_options.copy() |
|
|
|
entity_options["template"] = TPL_ENT_WITH_ID |
|
renderer = EntityRenderer(options=entity_options) |
|
html = renderer.render([spacy_doc], page=True, minify=True).strip() |
|
|
|
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>" |
|
if inject_relations: |
|
binary_relations = list(document.binary_relations) + list( |
|
document.binary_relations.predictions |
|
) |
|
html = inject_relation_data( |
|
html, |
|
labeled_spans=labeled_spans, |
|
binary_relations=binary_relations, |
|
additional_colors=colors_hover, |
|
) |
|
return html |
|
|
|
|
|
def inject_relation_data( |
|
html: str, |
|
labeled_spans: List[LabeledSpan], |
|
binary_relations: List[BinaryRelation], |
|
additional_colors: Optional[Dict[str, Union[str, dict]]] = None, |
|
) -> str: |
|
from bs4 import BeautifulSoup |
|
|
|
|
|
soup = BeautifulSoup(html, "html.parser") |
|
|
|
entity2tails = defaultdict(list) |
|
entity2heads = defaultdict(list) |
|
for relation in binary_relations: |
|
entity2heads[relation.tail].append((relation.head, relation.label)) |
|
entity2tails[relation.head].append((relation.tail, relation.label)) |
|
|
|
ann_id2annotation = {labeled_span_to_id(entity): entity for entity in labeled_spans} |
|
|
|
entities = soup.find_all(class_="entity") |
|
for entity in entities: |
|
original_color = entity["style"].split("background:")[1].split(";")[0].strip() |
|
entity["data-color-original"] = original_color |
|
if additional_colors is not None: |
|
for key, color in additional_colors.items(): |
|
entity[f"data-color-{key}"] = ( |
|
json.dumps(color) if isinstance(color, dict) else color |
|
) |
|
entity_annotation = ann_id2annotation[entity["id"]] |
|
|
|
annotation_text_without_newline = str(entity_annotation).replace("\n", "") |
|
|
|
if not entity.text.startswith(annotation_text_without_newline): |
|
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}") |
|
entity["data-label"] = entity_annotation.label |
|
entity["data-relation-tails"] = json.dumps( |
|
[ |
|
{"entity-id": labeled_span_to_id(tail), "label": label} |
|
for tail, label in entity2tails.get(entity_annotation, []) |
|
] |
|
) |
|
entity["data-relation-heads"] = json.dumps( |
|
[ |
|
{"entity-id": labeled_span_to_id(head), "label": label} |
|
for head, label in entity2heads.get(entity_annotation, []) |
|
] |
|
) |
|
|
|
|
|
return str(soup) |
|
|