GitHub Actions commited on
Commit
42d7bae
·
1 Parent(s): 6882062

Sync App from main repo

Browse files
Files changed (1) hide show
  1. app.py +58 -13
app.py CHANGED
@@ -2,29 +2,74 @@ import streamlit as st
2
  import requests
3
  import pandas as pd
4
  import matplotlib.pyplot as plt
5
- st.title("Arrhythmia Detection")
 
 
 
 
 
 
 
 
6
 
7
  models = {
8
- "LSTM Multi": "lstm_multi_model.h5",
9
- "CNN Multi": "cnn_multi_model.h5",
10
- "PCA XGBoost Multi": "pca_xgboost_multi_model.pkl",
11
  "LSTM Binary": "lstm_binary_model.h5",
12
  "CNN Binary": "cnn_binary_model.h5",
13
  "PCA XGBoost Binary": "pca_xgboost_binary_model.pkl",
14
- }
 
 
 
 
 
 
 
 
15
 
16
  # Model selection
17
- model_name = st.selectbox("Select a Model", list(models.keys()))
 
 
 
 
 
 
18
 
19
- # File uploader
20
- uploaded_file = st.file_uploader("Upload a CSV file", type="csv")
21
- if uploaded_file is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  df = pd.read_csv(uploaded_file)
23
- # st.write("Uploaded Data:", df)
 
 
 
 
 
 
 
24
 
 
 
25
  st.write("Visualized Data:")
26
  fig, ax = plt.subplots(figsize=(10, 6))
27
- df.plot(ax=ax)
28
  st.pyplot(fig)
29
 
30
  if st.button("Predict"):
@@ -36,11 +81,11 @@ if uploaded_file is not None:
36
  # Call the API with the file directly
37
  response = requests.post(
38
  f"https://fabriciojm-hadt-api.hf.space/predict?model_name={model}",
39
- files={"filepath_csv": (uploaded_file.name, uploaded_file, "text/csv")},
40
  )
41
 
42
  if response.status_code == 200:
43
  prediction = response.json()["prediction"]
44
- st.write(f"Prediction using {model_name}:", prediction)
45
  else:
46
  st.error(f"Error: {response.json().get('detail', 'Unknown error')}")
 
2
  import requests
3
  import pandas as pd
4
  import matplotlib.pyplot as plt
5
+ from huggingface_hub import hf_hub_download
6
+ from io import BytesIO
7
+
8
+ st.title("Heart Arrhythmia Detection Tools (hadt)")
9
+
10
+ st.markdown("""
11
+ This is a demo of the Heart Arrhythmia Detection Tools (hadt) project.
12
+ The project is available on [GitHub](https://github.com/fabriciojm/hadt).
13
+ """)
14
 
15
  models = {
16
+ "LSTM Multiclass": "lstm_multi_model.h5",
17
+ "CNN Multiclass": "cnn_multi_model.h5",
18
+ "PCA XGBoost Multiclass": "pca_xgboost_multi_model.pkl",
19
  "LSTM Binary": "lstm_binary_model.h5",
20
  "CNN Binary": "cnn_binary_model.h5",
21
  "PCA XGBoost Binary": "pca_xgboost_binary_model.pkl",
22
+ }
23
+
24
+ beat_labels = {
25
+ "N": "Normal",
26
+ "Q": "Unknown Beat",
27
+ "S": "Supraventricular Ectopic",
28
+ "V": "Ventricular Ectopic",
29
+ "A": "Abnormal",
30
+ }
31
 
32
  # Model selection
33
+ classification = ["Multiclass", "Binary"]
34
+ model_list = ["LSTM", "CNN", "PCA XGBoost"]
35
+
36
+ model_selected = st.selectbox("Select a Model", model_list)
37
+ classification_selected = st.selectbox("Classification type", classification)
38
+
39
+ model_name = f"{model_selected} {classification_selected}"
40
 
41
+ st.markdown("""Upload a CSV file with single heartbeat (csv with 180 points) or load from available examples
42
+ """)
43
+
44
+ # Option to upload or load a file
45
+ option = st.radio("Choose input method", ("Load example file", "Upload CSV file"))
46
+
47
+ if option == "Load example file":
48
+ # Load example files from Hugging Face dataset
49
+ example_files = ["single_N.csv", "single_Q.csv", "single_S.csv", "single_V.csv"]
50
+ example_selected = st.selectbox("Select an example file", example_files)
51
+
52
+ # Load the selected example file
53
+ file_path = hf_hub_download(repo_id='fabriciojm/ecg-examples', repo_type='dataset', filename=example_selected)
54
+ with open(file_path, 'rb') as f:
55
+ file_content = f.read()
56
+ uploaded_file = BytesIO(file_content)
57
+ uploaded_file.name = example_selected # Set a name attribute to mimic the uploaded file
58
  df = pd.read_csv(uploaded_file)
59
+ # st.write("Loaded Data:", df)
60
+
61
+ else:
62
+ # File uploader
63
+ uploaded_file = st.file_uploader("Upload a CSV file", type="csv")
64
+ if uploaded_file is not None:
65
+ df = pd.read_csv(uploaded_file)
66
+ # st.write("Uploaded Data:", df)
67
 
68
+ # Visualize data
69
+ if 'df' in locals():
70
  st.write("Visualized Data:")
71
  fig, ax = plt.subplots(figsize=(10, 6))
72
+ df.iloc[0].plot(ax=ax)
73
  st.pyplot(fig)
74
 
75
  if st.button("Predict"):
 
81
  # Call the API with the file directly
82
  response = requests.post(
83
  f"https://fabriciojm-hadt-api.hf.space/predict?model_name={model}",
84
+ files={"filepath_csv": (uploaded_file.name, uploaded_file, "text/csv")}
85
  )
86
 
87
  if response.status_code == 200:
88
  prediction = response.json()["prediction"]
89
+ st.write(f"Prediction using {model_name}: {beat_labels[prediction]} (class {prediction}) heartbeat")
90
  else:
91
  st.error(f"Error: {response.json().get('detail', 'Unknown error')}")