File size: 3,132 Bytes
7b0aaea
 
 
b5b39b7
42d7bae
 
 
 
 
 
 
 
 
7b0aaea
b5b39b7
42d7bae
 
 
b5b39b7
 
 
42d7bae
 
 
 
 
 
 
 
 
7b0aaea
 
42d7bae
 
 
 
 
 
 
7b0aaea
42d7bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b0aaea
42d7bae
 
 
 
 
 
 
 
b5b39b7
42d7bae
 
b5b39b7
 
42d7bae
b5b39b7
7b0aaea
 
 
 
b5b39b7
 
 
 
7b0aaea
b5b39b7
42d7bae
7b0aaea
 
 
 
42d7bae
7b0aaea
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import streamlit as st
import requests
import pandas as pd
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
from io import BytesIO

st.title("Heart Arrhythmia Detection Tools (hadt)")

st.markdown("""
This is a demo of the Heart Arrhythmia Detection Tools (hadt) project.
The project is available on [GitHub](https://github.com/fabriciojm/hadt).
""")

models = {
    "LSTM Multiclass": "lstm_multi_model.h5",
    "CNN Multiclass": "cnn_multi_model.h5",
    "PCA XGBoost Multiclass": "pca_xgboost_multi_model.pkl",
    "LSTM Binary": "lstm_binary_model.h5",
    "CNN Binary": "cnn_binary_model.h5",
    "PCA XGBoost Binary": "pca_xgboost_binary_model.pkl",
}

beat_labels = {
    "N": "Normal",
    "Q": "Unknown Beat",
    "S": "Supraventricular Ectopic", 
    "V": "Ventricular Ectopic",
    "A": "Abnormal",
}

# Model selection
classification = ["Multiclass", "Binary"]
model_list = ["LSTM", "CNN", "PCA XGBoost"]

model_selected = st.selectbox("Select a Model", model_list)
classification_selected = st.selectbox("Classification type", classification)

model_name = f"{model_selected} {classification_selected}"

st.markdown("""Upload a CSV file with single heartbeat (csv with 180 points) or load from available examples 
""")

# Option to upload or load a file
option = st.radio("Choose input method", ("Load example file", "Upload CSV file"))

if option == "Load example file":
    # Load example files from Hugging Face dataset
    example_files = ["single_N.csv", "single_Q.csv", "single_S.csv", "single_V.csv"]
    example_selected = st.selectbox("Select an example file", example_files)

    # Load the selected example file
    file_path = hf_hub_download(repo_id='fabriciojm/ecg-examples', repo_type='dataset', filename=example_selected)
    with open(file_path, 'rb') as f:
        file_content = f.read()
    uploaded_file = BytesIO(file_content)
    uploaded_file.name = example_selected  # Set a name attribute to mimic the uploaded file
    df = pd.read_csv(uploaded_file)
    # st.write("Loaded Data:", df)

else:
    # File uploader
    uploaded_file = st.file_uploader("Upload a CSV file", type="csv")
    if uploaded_file is not None:
        df = pd.read_csv(uploaded_file)
        # st.write("Uploaded Data:", df)

# Visualize data
if 'df' in locals():
    st.write("Visualized Data:")
    fig, ax = plt.subplots(figsize=(10, 6))
    df.iloc[0].plot(ax=ax)
    st.pyplot(fig)

    if st.button("Predict"):
        model = models[model_name]

        # Reset the file pointer to the beginning
        uploaded_file.seek(0)

        # Call the API with the file directly
        response = requests.post(
            f"https://fabriciojm-hadt-api.hf.space/predict?model_name={model}",
            files={"filepath_csv": (uploaded_file.name, uploaded_file, "text/csv")}
        )

        if response.status_code == 200:
            prediction = response.json()["prediction"]
            st.write(f"Prediction using {model_name}: {beat_labels[prediction]} (class {prediction}) heartbeat")
        else:
            st.error(f"Error: {response.json().get('detail', 'Unknown error')}")