File size: 5,716 Bytes
1a3b3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import pandas as pd
import plotly.graph_objects as go
from typing import List, Dict, Any, Tuple, Union


_PUBMED_LINK= "https://pubmed.ncbi.nlm.nih.gov/{article_id}/"
_PMC_LINK = "https://www.ncbi.nlm.nih.gov/pmc/articles/{article_id}/"
_MARKDOWN_TEMPLATE = """# [{article_title}]({article_link})
# Filtered sections :

{sections_md}"""

# entities highlighted text
def get_highlighted_text(entities:List[Dict[str,Any]], original_text:str) -> List[Tuple[str,Union[str,None]]] :
    """Convert the output of the model to a list of tuples (entity, label)
    for `gradio.HighlightedText`output"""
    conversion = {"PrimaryOutcome":"primary","SecondaryOutcome":"secondary"}
    highlighted_text = []
    for entity in entities:
        entity_original_text = original_text[entity["start"]:entity["end"]]
        if entity["entity_group"] == "O":
            entity_output = (entity_original_text, None)
        else:
            entity_output = (entity_original_text, conversion[entity["entity_group"]])
        highlighted_text.append(entity_output)
    return highlighted_text

# article filtered sections markdown output
def get_article_markdown(
        article_id:str,
        article_sections:dict[str,list[str]],
        filtered_sections:dict[str,list[str]]) -> str:
    """Get the markdown of a list of sections"""
    # link to online article
    article_link = _PMC_LINK if article_id.startswith("PMC") else _PUBMED_LINK
    article_link = article_link.format(article_id=article_id)   
    # get title, abstract, and filtered sections
    article_title = article_sections["Title"][0]
    sections_md = ""
    for title, content in filtered_sections.items():
        sections_md += f"## {title}\n"
        sections_md += " ".join(content) + "\n"
    return _MARKDOWN_TEMPLATE.format(
        article_link=article_link,
        article_title=article_title,
        sections_md=sections_md
    )

# registry dataframe display
def _highlight_df_rows(row):
    if row['type'] =='primary':
        return ['background-color: lightcoral'] * len(row)
    elif row['type'] == 'secondary':
        return ['background-color: lightgreen'] * len(row)
    else :
        return ['background-color: lightgrey'] * len(row)

def get_registry_dataframe(registry_outcomes: list[dict[str,str]]) -> str:
    return pd.DataFrame(registry_outcomes).style.apply(_highlight_df_rows, axis=1)

# fcts for sankey diagram
def _sent_line_formatting(sentence:str, max_words:int=10) -> str:
    """format a sentence to be displayed in a sankey diagram so that
    each line has a maximum of `max_words` words"""
    words = sentence.split()
    batchs = [words[i:i+max_words] for i in range(0, len(words), max_words)]
    return "<br>".join([" ".join(batch) for batch in batchs])

def _find_entity_score(entity_text, raw_entities):
    for tc_output in raw_entities:
        if entity_text == tc_output["word"]:
            return tc_output["score"]

def get_sankey_diagram(
        registry_outcomes: list[tuple[str,str]], 
        article_outcomes: list[tuple[str,str]],
        connections: set[tuple[int,int,float]], 
        raw_entities: list[Dict[str,Any]],
        cosine_threshold: float=0.44,
    ) -> go.Figure:

    color_map = {
        "primary": "red",
        "secondary": "green",
        "other": "grey",
    }
    # Create lists of formatted sentences and colors for the nodes
    list1 = [(_sent_line_formatting(sent), color_map[typ]) for typ, sent in registry_outcomes]
    list2 = [(_sent_line_formatting(sent), color_map[typ]) for typ, sent in article_outcomes]
    display_connections = [
        (list1[i][0],list2[j][0],"mediumaquamarine") if cosine > cosine_threshold
        else (list1[i][0],list2[j][0],"lightgray") for i,j,cosine in connections
    ]
    # Create a list of labels and colors for the nodes
    labels = [x[0] for x in list1 + list2]
    colors = [x[1] for x in list1 + list2]
    # Create lists of sources and targets for the connections
    sources = [labels.index(x[0]) for x in display_connections]
    targets = [labels.index(x[1]) for x in display_connections]
    # Create a list of values and colors for the connections
    values = [1] * len(display_connections)
    connection_colors = [x[2] for x in display_connections]

    # data appearing on hover of each node (outcome)
    node_customdata = [f"from: registry<br>type:{t}" for t,_ in registry_outcomes]
    node_customdata += [f"from: article<br>type: {t}<br>confidence: " + str(_find_entity_score(s, raw_entities)) for t,s in article_outcomes]
    node_hovertemplate = "outcome: %{label}<br>%{customdata} <extra></extra>"
    # data appearing on hover of each link (node connections)
    link_customdata = [cosine for _,_,cosine in connections]
    link_hovertemplate = "similarity: %{customdata} <extra></extra>"
    # sankey diagram data filling
    sankey =  go.Sankey(
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=labels,
            color=colors,
            customdata=node_customdata,
            hovertemplate=node_hovertemplate
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values,
            customdata=link_customdata,
            color=connection_colors,
            hovertemplate=link_hovertemplate
        )
    )
    # conversion to figure
    fig = go.Figure(data=[sankey])
    fig.update_layout(
        title_text="Registry outcomes (left) connections with article outcomes (right), similarity threshold = " + str(cosine_threshold), 
        font_size=10,
        width=1200,
        xaxis=dict(rangeslider=dict(visible=True),type="linear")
    )
    return fig