Spaces:
Running
Running
import pandas as pd | |
import plotly.graph_objects as go | |
from typing import List, Dict, Any, Tuple, Union | |
_PUBMED_LINK= "https://pubmed.ncbi.nlm.nih.gov/{article_id}/" | |
_PMC_LINK = "https://www.ncbi.nlm.nih.gov/pmc/articles/{article_id}/" | |
_MARKDOWN_TEMPLATE = """# [{article_title}]({article_link}) | |
# Filtered sections : | |
{sections_md}""" | |
# entities highlighted text | |
def get_highlighted_text(entities:List[Dict[str,Any]], original_text:str) -> List[Tuple[str,Union[str,None]]] : | |
"""Convert the output of the model to a list of tuples (entity, label) | |
for `gradio.HighlightedText`output""" | |
conversion = {"PrimaryOutcome":"primary","SecondaryOutcome":"secondary"} | |
highlighted_text = [] | |
for entity in entities: | |
entity_original_text = original_text[entity["start"]:entity["end"]] | |
if entity["entity_group"] == "O": | |
entity_output = (entity_original_text, None) | |
else: | |
entity_output = (entity_original_text, conversion[entity["entity_group"]]) | |
highlighted_text.append(entity_output) | |
return highlighted_text | |
# article filtered sections markdown output | |
def get_article_markdown( | |
article_id:str, | |
article_sections:dict[str,list[str]], | |
filtered_sections:dict[str,list[str]]) -> str: | |
"""Get the markdown of a list of sections""" | |
# link to online article | |
article_link = _PMC_LINK if article_id.startswith("PMC") else _PUBMED_LINK | |
article_link = article_link.format(article_id=article_id) | |
# get title, abstract, and filtered sections | |
article_title = article_sections["Title"][0] | |
sections_md = "" | |
for title, content in filtered_sections.items(): | |
sections_md += f"## {title}\n" | |
sections_md += " ".join(content) + "\n" | |
return _MARKDOWN_TEMPLATE.format( | |
article_link=article_link, | |
article_title=article_title, | |
sections_md=sections_md | |
) | |
# registry dataframe display | |
def _highlight_df_rows(row): | |
if row['type'] =='primary': | |
return ['background-color: lightcoral'] * len(row) | |
elif row['type'] == 'secondary': | |
return ['background-color: lightgreen'] * len(row) | |
else : | |
return ['background-color: lightgrey'] * len(row) | |
def get_registry_dataframe(registry_outcomes: list[dict[str,str]]) -> str: | |
return pd.DataFrame(registry_outcomes).style.apply(_highlight_df_rows, axis=1) | |
# fcts for sankey diagram | |
def _sent_line_formatting(sentence:str, max_words:int=10) -> str: | |
"""format a sentence to be displayed in a sankey diagram so that | |
each line has a maximum of `max_words` words""" | |
words = sentence.split() | |
batchs = [words[i:i+max_words] for i in range(0, len(words), max_words)] | |
return "<br>".join([" ".join(batch) for batch in batchs]) | |
def _find_entity_score(entity_text, raw_entities): | |
for tc_output in raw_entities: | |
if entity_text == tc_output["word"]: | |
return tc_output["score"] | |
def get_sankey_diagram( | |
registry_outcomes: list[tuple[str,str]], | |
article_outcomes: list[tuple[str,str]], | |
connections: set[tuple[int,int,float]], | |
raw_entities: list[Dict[str,Any]], | |
cosine_threshold: float=0.44, | |
) -> go.Figure: | |
color_map = { | |
"primary": "red", | |
"secondary": "green", | |
"other": "grey", | |
} | |
# Create lists of formatted sentences and colors for the nodes | |
list1 = [(_sent_line_formatting(sent), color_map[typ]) for typ, sent in registry_outcomes] | |
list2 = [(_sent_line_formatting(sent), color_map[typ]) for typ, sent in article_outcomes] | |
display_connections = [ | |
(list1[i][0],list2[j][0],"mediumaquamarine") if cosine > cosine_threshold | |
else (list1[i][0],list2[j][0],"lightgray") for i,j,cosine in connections | |
] | |
# Create a list of labels and colors for the nodes | |
labels = [x[0] for x in list1 + list2] | |
colors = [x[1] for x in list1 + list2] | |
# Create lists of sources and targets for the connections | |
sources = [labels.index(x[0]) for x in display_connections] | |
targets = [labels.index(x[1]) for x in display_connections] | |
# Create a list of values and colors for the connections | |
values = [1] * len(display_connections) | |
connection_colors = [x[2] for x in display_connections] | |
# data appearing on hover of each node (outcome) | |
node_customdata = [f"from: registry<br>type:{t}" for t,_ in registry_outcomes] | |
node_customdata += [f"from: article<br>type: {t}<br>confidence: " + str(_find_entity_score(s, raw_entities)) for t,s in article_outcomes] | |
node_hovertemplate = "outcome: %{label}<br>%{customdata} <extra></extra>" | |
# data appearing on hover of each link (node connections) | |
link_customdata = [cosine for _,_,cosine in connections] | |
link_hovertemplate = "similarity: %{customdata} <extra></extra>" | |
# sankey diagram data filling | |
sankey = go.Sankey( | |
node=dict( | |
pad=15, | |
thickness=20, | |
line=dict(color="black", width=0.5), | |
label=labels, | |
color=colors, | |
customdata=node_customdata, | |
hovertemplate=node_hovertemplate | |
), | |
link=dict( | |
source=sources, | |
target=targets, | |
value=values, | |
customdata=link_customdata, | |
color=connection_colors, | |
hovertemplate=link_hovertemplate | |
) | |
) | |
# conversion to figure | |
fig = go.Figure(data=[sankey]) | |
fig.update_layout( | |
title_text="Registry outcomes (left) connections with article outcomes (right), similarity threshold = " + str(cosine_threshold), | |
font_size=10, | |
width=1200, | |
xaxis=dict(rangeslider=dict(visible=True),type="linear") | |
) | |
return fig |