Spaces:
Sleeping
Sleeping
File size: 4,844 Bytes
7b0aaea b5b39b7 42d7bae 7b0aaea b5b39b7 42d7bae b5b39b7 42d7bae 7b0aaea 42d7bae 7b0aaea e7181fd 42d7bae e7181fd 42d7bae 7b0aaea 42d7bae e7181fd 42d7bae aaf5034 42d7bae aaf5034 e7181fd aaf5034 e7181fd aaf5034 42d7bae b5b39b7 7b0aaea e7181fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import streamlit as st
import requests
import pandas as pd
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
from io import BytesIO
st.title("Heart Arrhythmia Detection Tools (hadt)")
st.markdown("""
This is a demo of the Heart Arrhythmia Detection Tools (hadt) project.
The project is available on [GitHub](https://github.com/fabriciojm/hadt).
""")
models = {
"LSTM Multiclass": "lstm_multi_model.h5",
"CNN Multiclass": "cnn_multi_model.h5",
"PCA XGBoost Multiclass": "pca_xgboost_multi_model.pkl",
"LSTM Binary": "lstm_binary_model.h5",
"CNN Binary": "cnn_binary_model.h5",
"PCA XGBoost Binary": "pca_xgboost_binary_model.pkl",
}
beat_labels = {
"N": "Normal",
"Q": "Unknown Beat",
"S": "Supraventricular Ectopic",
"V": "Ventricular Ectopic",
"A": "Abnormal",
}
# Model selection
classification = ["Multiclass", "Binary"]
model_list = ["LSTM", "CNN", "PCA XGBoost"]
model_selected = st.selectbox("Select a Model", model_list)
classification_selected = st.selectbox("Classification type", classification)
model_name = f"{model_selected} {classification_selected}"
def visualize_single(df, st):
st.write("Visualized Data:")
fig, ax = plt.subplots(figsize=(10, 6))
df.iloc[0].plot(ax=ax)
st.pyplot(fig)
# This function will be used when the API is capable of returning extracted beats
# def visualize_multiple(beats, st):
# st.write("Visualized Data:")
# if len(beats) % 4 != 0:
# nrows = len(beats) // 4 + 1
# else:
# nrows = len(beats) // 4
# fig, axs = plt.subplots(nrows, 4, figsize=(10, nrows*2.5))
# for i, beat in enumerate(beats):
# axs.flatten()[i].plot(beat)
# # delete last plots if not used
# for j in range(len(beats)%4):
# fig.delaxes(axs.flatten()[-j-1])
# st.pyplot(fig)
st.markdown("""Upload a CSV file with single heartbeat (csv with 180 points) or load from available examples
""")
# Option to upload or load a file
option = st.radio("Choose input method", ("Load example file", "Upload CSV file", "Upload Apple Watch ECG CSV file (EXPERIMENTAL)"))
if option == "Load example file":
# Load example files from Hugging Face dataset
example_files = ["single_N.csv", "single_Q.csv", "single_S.csv", "single_V.csv"]
example_selected = st.selectbox("Select an example file", example_files)
# Load the selected example file
file_path = hf_hub_download(repo_id='fabriciojm/ecg-examples', repo_type='dataset', filename=example_selected)
with open(file_path, 'rb') as f:
file_content = f.read()
uploaded_file = BytesIO(file_content)
uploaded_file.name = example_selected # Set a name attribute to mimic the uploaded file
df = pd.read_csv(uploaded_file)
# st.write("Loaded Data:", df)
if 'df' in locals():
visualize_single(df, st)
elif option == "Upload CSV file":
# File uploader
uploaded_file = st.file_uploader("Upload a CSV file", type="csv")
st.write("The CSV file should have 180 points per row, following the format in [the examples](https://huggingface.co/datasets/fabriciojm/ecg-examples)")
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
# st.write("Uploaded Data:", df)
if 'df' in locals():
visualize_single(df, st)
elif option == "Upload Apple Watch ECG CSV file (EXPERIMENTAL)":
# File uploader
st.write("DISCLAIMER: this is an experimental feature, and the results may not be accurate. This should not be used as professional medical advice.")
uploaded_file = st.file_uploader("Upload a CSV file", type="csv")
st.write("The Apple Watch CSV file should have the same format as [the examples](https://huggingface.co/datasets/fabriciojm/apple-ecg-examples)")
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
# st.write("Uploaded Data:", df)
if st.button("Predict"):
model = models[model_name]
# Reset the file pointer to the beginning
uploaded_file.seek(0)
# Call the API with the file directly
base_url = "https://fabriciojm-hadt-api.hf.space/predict"
if option == "Upload Apple Watch ECG CSV file (EXPERIMENTAL)":
base_url += "_multibeats"
print(f"Request url: {base_url}?model_name={model}")
response = requests.post(
f"{base_url}?model_name={model}",
files={"filepath_csv": (uploaded_file.name, uploaded_file, "text/csv")}
)
if response.status_code == 200:
prediction = response.json()["prediction"]
st.write(f"Prediction using {model_name}:")
for i, p in enumerate(prediction):
st.write(f"Beat {i+1}: {beat_labels[p]} (class {p})") # {beat_labels[prediction]} (class {prediction}) heartbeat
else:
st.error(f"Error: {response.json().get('detail', 'Unknown error')}")
|