Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 5,834 Bytes
			
			| 5dc2016 ddff90b 7ce5b82 ddff90b 7ce5b82 ddff90b 3279179 4a37eb1 ddff90b 0416a61 bda3587 445b26a 8bf2961 f2852e3 44803cb e7caceb 6d2e57c e7caceb 6d2e57c e7caceb 4a37eb1 e7caceb 4a37eb1 e7caceb 6d2e57c e7caceb 6d2e57c f2852e3 38efeba 847adc5 a224dfa 0416a61 f2852e3 ddff90b cde5ff7 b102419 2360c00 388fbdd 81c44d2 971a385 81c44d2 8f768aa 4e45f70 31ca6c1 7780086 091df08 7780086 d266b61 0d9531e b28ab8e 0d9531e b28ab8e 062e24e f435314 e1cbd0e b28ab8e f435314 b28ab8e 062e24e b28ab8e b424a32 0d9531e 8f768aa b102419 80b9099 b102419 a8b6710 33ca54e a8b6710 f28ad72 a8b6710 7ead1f4 6d2e57c 8e409e1 091df08 c6171a2 8e409e1 091df08 8e409e1 062e24e 4a37eb1 83629fc f2852e3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import nltk
nltk.download('stopwords')
nltk.download('punkt')
import pandas as pd
import classify_abs
import extract_abs
#pd.set_option('display.max_colwidth', None)
import streamlit as st
import spacy
import tensorflow as tf
import pickle
import plotly.graph_objects as go
########## Title for the Web App ##########
st.markdown('''<img src="https://huggingface.co/spaces/ncats/EpiPipeline4GARD/resolve/main/NCATS_logo.png" alt="National Center for Advancing Translational Sciences Logo" width=550>''',unsafe_allow_html=True)
#st.markdown("
#st.markdown('''<img src="https://huggingface.co/spaces/ncats/EpiPipeline4GARD/raw/main/NCATS_logo.svg" alt="National Center for Advancing Translational Sciences Logo" width="800" height="300">''',unsafe_allow_html=True)
st.title("Epidemiology Extraction Pipeline for Rare Diseases")
#st.subheader("National Center for Advancing Translational Sciences (NIH/NCATS)") 
#### CHANGE SIDEBAR WIDTH ###
st.markdown(
    """
    <style>
    [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
        width: 250px;
    }
    [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
        width: 250px;
        margin-left: -350px;
    }
    </style>
    """,
    unsafe_allow_html=True,
)
#max_results is Maximum number of PubMed ID's to retrieve BEFORE filtering
max_results = st.sidebar.number_input("Maximum number of articles to find in PubMed", min_value=1, max_value=None, value=50)
filtering = st.sidebar.radio("What type of filtering would you like?",('Strict', 'Lenient', 'None')).lower()
extract_diseases = st.sidebar.checkbox("Extract Rare Diseases", value=False)
@st.experimental_singleton(show_spinner=False)
def load_models_experimental():
    classify_model_vars = classify_abs.init_classify_model()
    NER_pipeline, entity_classes = extract_abs.init_NER_pipeline()
    GARD_dict, max_length = extract_abs.load_GARD_diseases()
    return classify_model_vars, NER_pipeline, entity_classes, GARD_dict, max_length
@st.cache(allow_output_mutation=True)
def load_models():
    # load the tokenizer
    with open('tokenizer.pickle', 'rb') as handle:
        classify_tokenizer = pickle.load(handle)
    
    # load the model
    classify_model = tf.keras.models.load_model("LSTM_RNN_Model") 
    
    #classify_model_vars = classify_abs.init_classify_model()
    NER_pipeline, entity_classes = extract_abs.init_NER_pipeline()
    GARD_dict, max_length = extract_abs.load_GARD_diseases()
    return classify_tokenizer, classify_model, NER_pipeline, entity_classes, GARD_dict, max_length
@st.cache
def convert_df(df):
    # IMPORTANT: Cache the conversion to prevent computation on every rerun
    return df.to_csv().encode('utf-8')
    
#@st.experimental_memo
@st.cache(allow_output_mutation=True)
def epi_sankey(sankey_data,disease_or_gard_id):
    gathered, relevant, epidemiologic = sankey_data
    
    fig = go.Figure(data=[go.Sankey(
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(color = "white", width = 0.5),
          label = ["PubMed IDs Gathered", "Irrelevant Abstracts","Relevant Abstracts Gathered","Epidemiologic Abstracts","Not Epidemiologic"],
          color = "purple"
        ),
        #label = ["A1", "A2", "B1", "B2", "C1", "C2"]
        link = dict(
          source = [0, 0, 2, 2],
          target = [2, 1, 3, 4],
          value = [relevant, gathered-relevant, epidemiologic, relevant-epidemiologic]
      ))])
    fig.update_layout(
    hovermode = 'x',
    title="Search for the Epidemiology of "+disease_or_gard_id,
    font=dict(size = 10, color = 'black'),
)
    
    return fig
with st.spinner('Loading Epidemiology Models and Dependencies...'):
    classify_model_vars, NER_pipeline, entity_classes, GARD_dict, max_length = load_models_experimental()
    #classify_tokenizer, classify_model, NER_pipeline, entity_classes, GARD_dict, max_length = load_models()
    #Load spaCy models which cannot be cached due to hash function error
    #nlp = spacy.load('en_core_web_lg')
    #nlpSci = spacy.load("en_ner_bc5cdr_md")
    #nlpSci2 = spacy.load('en_ner_bionlp13cg_md')
    #classify_model_vars = (nlp, nlpSci, nlpSci2, classify_model, classify_tokenizer)
loaded = st.success('All Models and Dependencies Loaded!')
disease_or_gard_id = st.text_input("Input a rare disease term or GARD ID.")
loaded.empty()
st.markdown("Examples of rare diseases include [**Fellman syndrome**](https://rarediseases.info.nih.gov/diseases/1/gracile-syndrome), [**Classic Homocystinuria**](https://rarediseases.info.nih.gov/diseases/6667/classic-homocystinuria) and [**Phenylketonuria**](https://rarediseases.info.nih.gov/diseases/7383/phenylketonuria).")
st.markdown("A full list of rare diseases tracked by GARD can be found [here](https://rarediseases.info.nih.gov/diseases/browse-by-first-letter).")
if disease_or_gard_id:
    df, sankey_data = extract_abs.streamlit_extraction(disease_or_gard_id, max_results, filtering,
                                NER_pipeline, entity_classes, 
                                extract_diseases,GARD_dict, max_length, 
                                classify_model_vars)
    st.dataframe(df, height=100)
    csv = convert_df(df)
    st.download_button(
        label="Download epidemiology results for "+disease_or_gard_id+" as CSV",
        data = csv,
        file_name=disease_or_gard_id+'.csv',
        mime='text/csv',
        )
    #st.dataframe(data=None, width=None, height=None)
    fig = epi_sankey(sankey_data,disease_or_gard_id)
        
    #if st.button('Display Sankey Diagram of Automated Search'):
    st.plotly_chart(fig, use_container_width=True)
# st.code(body, language="python") | 
