Spaces:
Sleeping
Sleeping
GitHub Actions
commited on
Commit
·
42d7bae
1
Parent(s):
6882062
Sync App from main repo
Browse files
app.py
CHANGED
@@ -2,29 +2,74 @@ import streamlit as st
|
|
2 |
import requests
|
3 |
import pandas as pd
|
4 |
import matplotlib.pyplot as plt
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
models = {
|
8 |
-
"LSTM
|
9 |
-
"CNN
|
10 |
-
"PCA XGBoost
|
11 |
"LSTM Binary": "lstm_binary_model.h5",
|
12 |
"CNN Binary": "cnn_binary_model.h5",
|
13 |
"PCA XGBoost Binary": "pca_xgboost_binary_model.pkl",
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# Model selection
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
df = pd.read_csv(uploaded_file)
|
23 |
-
# st.write("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
|
|
|
|
25 |
st.write("Visualized Data:")
|
26 |
fig, ax = plt.subplots(figsize=(10, 6))
|
27 |
-
df.plot(ax=ax)
|
28 |
st.pyplot(fig)
|
29 |
|
30 |
if st.button("Predict"):
|
@@ -36,11 +81,11 @@ if uploaded_file is not None:
|
|
36 |
# Call the API with the file directly
|
37 |
response = requests.post(
|
38 |
f"https://fabriciojm-hadt-api.hf.space/predict?model_name={model}",
|
39 |
-
files={"filepath_csv": (uploaded_file.name, uploaded_file, "text/csv")}
|
40 |
)
|
41 |
|
42 |
if response.status_code == 200:
|
43 |
prediction = response.json()["prediction"]
|
44 |
-
st.write(f"Prediction using {model_name}:
|
45 |
else:
|
46 |
st.error(f"Error: {response.json().get('detail', 'Unknown error')}")
|
|
|
2 |
import requests
|
3 |
import pandas as pd
|
4 |
import matplotlib.pyplot as plt
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
st.title("Heart Arrhythmia Detection Tools (hadt)")
|
9 |
+
|
10 |
+
st.markdown("""
|
11 |
+
This is a demo of the Heart Arrhythmia Detection Tools (hadt) project.
|
12 |
+
The project is available on [GitHub](https://github.com/fabriciojm/hadt).
|
13 |
+
""")
|
14 |
|
15 |
models = {
|
16 |
+
"LSTM Multiclass": "lstm_multi_model.h5",
|
17 |
+
"CNN Multiclass": "cnn_multi_model.h5",
|
18 |
+
"PCA XGBoost Multiclass": "pca_xgboost_multi_model.pkl",
|
19 |
"LSTM Binary": "lstm_binary_model.h5",
|
20 |
"CNN Binary": "cnn_binary_model.h5",
|
21 |
"PCA XGBoost Binary": "pca_xgboost_binary_model.pkl",
|
22 |
+
}
|
23 |
+
|
24 |
+
beat_labels = {
|
25 |
+
"N": "Normal",
|
26 |
+
"Q": "Unknown Beat",
|
27 |
+
"S": "Supraventricular Ectopic",
|
28 |
+
"V": "Ventricular Ectopic",
|
29 |
+
"A": "Abnormal",
|
30 |
+
}
|
31 |
|
32 |
# Model selection
|
33 |
+
classification = ["Multiclass", "Binary"]
|
34 |
+
model_list = ["LSTM", "CNN", "PCA XGBoost"]
|
35 |
+
|
36 |
+
model_selected = st.selectbox("Select a Model", model_list)
|
37 |
+
classification_selected = st.selectbox("Classification type", classification)
|
38 |
+
|
39 |
+
model_name = f"{model_selected} {classification_selected}"
|
40 |
|
41 |
+
st.markdown("""Upload a CSV file with single heartbeat (csv with 180 points) or load from available examples
|
42 |
+
""")
|
43 |
+
|
44 |
+
# Option to upload or load a file
|
45 |
+
option = st.radio("Choose input method", ("Load example file", "Upload CSV file"))
|
46 |
+
|
47 |
+
if option == "Load example file":
|
48 |
+
# Load example files from Hugging Face dataset
|
49 |
+
example_files = ["single_N.csv", "single_Q.csv", "single_S.csv", "single_V.csv"]
|
50 |
+
example_selected = st.selectbox("Select an example file", example_files)
|
51 |
+
|
52 |
+
# Load the selected example file
|
53 |
+
file_path = hf_hub_download(repo_id='fabriciojm/ecg-examples', repo_type='dataset', filename=example_selected)
|
54 |
+
with open(file_path, 'rb') as f:
|
55 |
+
file_content = f.read()
|
56 |
+
uploaded_file = BytesIO(file_content)
|
57 |
+
uploaded_file.name = example_selected # Set a name attribute to mimic the uploaded file
|
58 |
df = pd.read_csv(uploaded_file)
|
59 |
+
# st.write("Loaded Data:", df)
|
60 |
+
|
61 |
+
else:
|
62 |
+
# File uploader
|
63 |
+
uploaded_file = st.file_uploader("Upload a CSV file", type="csv")
|
64 |
+
if uploaded_file is not None:
|
65 |
+
df = pd.read_csv(uploaded_file)
|
66 |
+
# st.write("Uploaded Data:", df)
|
67 |
|
68 |
+
# Visualize data
|
69 |
+
if 'df' in locals():
|
70 |
st.write("Visualized Data:")
|
71 |
fig, ax = plt.subplots(figsize=(10, 6))
|
72 |
+
df.iloc[0].plot(ax=ax)
|
73 |
st.pyplot(fig)
|
74 |
|
75 |
if st.button("Predict"):
|
|
|
81 |
# Call the API with the file directly
|
82 |
response = requests.post(
|
83 |
f"https://fabriciojm-hadt-api.hf.space/predict?model_name={model}",
|
84 |
+
files={"filepath_csv": (uploaded_file.name, uploaded_file, "text/csv")}
|
85 |
)
|
86 |
|
87 |
if response.status_code == 200:
|
88 |
prediction = response.json()["prediction"]
|
89 |
+
st.write(f"Prediction using {model_name}: {beat_labels[prediction]} (class {prediction}) heartbeat")
|
90 |
else:
|
91 |
st.error(f"Error: {response.json().get('detail', 'Unknown error')}")
|