Spaces:
Sleeping
Sleeping
File size: 5,716 Bytes
1a3b3aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import pandas as pd
import plotly.graph_objects as go
from typing import List, Dict, Any, Tuple, Union
_PUBMED_LINK= "https://pubmed.ncbi.nlm.nih.gov/{article_id}/"
_PMC_LINK = "https://www.ncbi.nlm.nih.gov/pmc/articles/{article_id}/"
_MARKDOWN_TEMPLATE = """# [{article_title}]({article_link})
# Filtered sections :
{sections_md}"""
# entities highlighted text
def get_highlighted_text(entities:List[Dict[str,Any]], original_text:str) -> List[Tuple[str,Union[str,None]]] :
"""Convert the output of the model to a list of tuples (entity, label)
for `gradio.HighlightedText`output"""
conversion = {"PrimaryOutcome":"primary","SecondaryOutcome":"secondary"}
highlighted_text = []
for entity in entities:
entity_original_text = original_text[entity["start"]:entity["end"]]
if entity["entity_group"] == "O":
entity_output = (entity_original_text, None)
else:
entity_output = (entity_original_text, conversion[entity["entity_group"]])
highlighted_text.append(entity_output)
return highlighted_text
# article filtered sections markdown output
def get_article_markdown(
article_id:str,
article_sections:dict[str,list[str]],
filtered_sections:dict[str,list[str]]) -> str:
"""Get the markdown of a list of sections"""
# link to online article
article_link = _PMC_LINK if article_id.startswith("PMC") else _PUBMED_LINK
article_link = article_link.format(article_id=article_id)
# get title, abstract, and filtered sections
article_title = article_sections["Title"][0]
sections_md = ""
for title, content in filtered_sections.items():
sections_md += f"## {title}\n"
sections_md += " ".join(content) + "\n"
return _MARKDOWN_TEMPLATE.format(
article_link=article_link,
article_title=article_title,
sections_md=sections_md
)
# registry dataframe display
def _highlight_df_rows(row):
if row['type'] =='primary':
return ['background-color: lightcoral'] * len(row)
elif row['type'] == 'secondary':
return ['background-color: lightgreen'] * len(row)
else :
return ['background-color: lightgrey'] * len(row)
def get_registry_dataframe(registry_outcomes: list[dict[str,str]]) -> str:
return pd.DataFrame(registry_outcomes).style.apply(_highlight_df_rows, axis=1)
# fcts for sankey diagram
def _sent_line_formatting(sentence:str, max_words:int=10) -> str:
"""format a sentence to be displayed in a sankey diagram so that
each line has a maximum of `max_words` words"""
words = sentence.split()
batchs = [words[i:i+max_words] for i in range(0, len(words), max_words)]
return "<br>".join([" ".join(batch) for batch in batchs])
def _find_entity_score(entity_text, raw_entities):
for tc_output in raw_entities:
if entity_text == tc_output["word"]:
return tc_output["score"]
def get_sankey_diagram(
registry_outcomes: list[tuple[str,str]],
article_outcomes: list[tuple[str,str]],
connections: set[tuple[int,int,float]],
raw_entities: list[Dict[str,Any]],
cosine_threshold: float=0.44,
) -> go.Figure:
color_map = {
"primary": "red",
"secondary": "green",
"other": "grey",
}
# Create lists of formatted sentences and colors for the nodes
list1 = [(_sent_line_formatting(sent), color_map[typ]) for typ, sent in registry_outcomes]
list2 = [(_sent_line_formatting(sent), color_map[typ]) for typ, sent in article_outcomes]
display_connections = [
(list1[i][0],list2[j][0],"mediumaquamarine") if cosine > cosine_threshold
else (list1[i][0],list2[j][0],"lightgray") for i,j,cosine in connections
]
# Create a list of labels and colors for the nodes
labels = [x[0] for x in list1 + list2]
colors = [x[1] for x in list1 + list2]
# Create lists of sources and targets for the connections
sources = [labels.index(x[0]) for x in display_connections]
targets = [labels.index(x[1]) for x in display_connections]
# Create a list of values and colors for the connections
values = [1] * len(display_connections)
connection_colors = [x[2] for x in display_connections]
# data appearing on hover of each node (outcome)
node_customdata = [f"from: registry<br>type:{t}" for t,_ in registry_outcomes]
node_customdata += [f"from: article<br>type: {t}<br>confidence: " + str(_find_entity_score(s, raw_entities)) for t,s in article_outcomes]
node_hovertemplate = "outcome: %{label}<br>%{customdata} <extra></extra>"
# data appearing on hover of each link (node connections)
link_customdata = [cosine for _,_,cosine in connections]
link_hovertemplate = "similarity: %{customdata} <extra></extra>"
# sankey diagram data filling
sankey = go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.5),
label=labels,
color=colors,
customdata=node_customdata,
hovertemplate=node_hovertemplate
),
link=dict(
source=sources,
target=targets,
value=values,
customdata=link_customdata,
color=connection_colors,
hovertemplate=link_hovertemplate
)
)
# conversion to figure
fig = go.Figure(data=[sankey])
fig.update_layout(
title_text="Registry outcomes (left) connections with article outcomes (right), similarity threshold = " + str(cosine_threshold),
font_size=10,
width=1200,
xaxis=dict(rangeslider=dict(visible=True),type="linear")
)
return fig |