yassonee commited on
Commit
005d8cf
·
verified ·
1 Parent(s): 9be43fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -41
app.py CHANGED
@@ -1,43 +1,59 @@
1
  import streamlit as st
2
- from transformers import AutoProcessor, AutoModelForImageClassification
3
  from PIL import Image
4
- import requests
5
-
6
- # Titre de l'application
7
- st.title("RADPID: Assistant de diagnostic radiologique")
8
- st.markdown("**Chargez une radiographie et sélectionnez la tâche souhaitée :**")
9
-
10
- # Sélection des tâches
11
- task = st.radio("Sélectionnez une tâche", ["Fracture Detection", "Pneumothorax Detection", "Pneumonia Detection"])
12
-
13
- # Modèles
14
- models = {
15
- "Fracture Detection": "facebook/detr-resnet-50",
16
- "Pneumothorax Detection": "RGDancer/Pneumothorax_detection",
17
- "Pneumonia Detection": "wanghaoy/Chest_XRay_Pneumonia",
18
- }
19
-
20
- # Charger le modèle et le processeur correspondant
21
- model_name = models[task]
22
- processor = AutoProcessor.from_pretrained(model_name)
23
- model = AutoModelForImageClassification.from_pretrained(model_name)
24
-
25
- # Upload de l'image
26
- uploaded_file = st.file_uploader("Upload your Chest X-Ray image", type=["png", "jpg", "jpeg"])
27
-
28
- if uploaded_file is not None:
29
- # Charger l'image
30
- image = Image.open(uploaded_file).convert("RGB")
31
- st.image(image, caption="Image Uploadée", use_column_width=True)
32
-
33
- # Effectuer la prédiction
34
- st.markdown("### Résultat de la prédiction :")
35
- with st.spinner("Analyse en cours..."):
36
- inputs = processor(images=image, return_tensors="pt")
37
- outputs = model(**inputs)
38
- predictions = outputs.logits.softmax(dim=-1).tolist()
39
-
40
- # Afficher les scores
41
- st.write(f"Scores pour la tâche '{task}':")
42
- st.json(predictions)
43
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
  from PIL import Image
4
+ import torch
5
+
6
+ st.set_page_config(page_title="Aide au diagnostic radiologique", layout="wide")
7
+
8
+ def load_models():
9
+ models = {
10
+ 'Fracture': "kathleen/vit-base-fracture-detection",
11
+ 'Pneumothorax': "nickmuchi/pneumothorax-detection-vit",
12
+ 'Pneumonie': "nickmuchi/chest-xray-pneumonia-detection"
13
+ }
14
+
15
+ loaded_models = {}
16
+ for name, model_id in models.items():
17
+ loaded_models[name] = pipeline("image-classification", model=model_id)
18
+ return loaded_models
19
+
20
+ @st.cache_resource
21
+ def get_models():
22
+ return load_models()
23
+
24
+ def main():
25
+ st.title("Assistant de diagnostic radiologique")
26
+
27
+ models = get_models()
28
+
29
+ uploaded_file = st.file_uploader("Télécharger une image radiologique", type=["jpg", "jpeg", "png"])
30
+
31
+ if uploaded_file:
32
+ image = Image.open(uploaded_file)
33
+ st.image(image, caption="Image téléchargée", use_column_width=True)
34
+
35
+ col1, col2, col3 = st.columns(3)
36
+
37
+ with col1:
38
+ if st.button("Détecter Fracture"):
39
+ with st.spinner("Analyse en cours..."):
40
+ result = models['Fracture'](image)
41
+ st.write(f"Résultat: {result[0]['label']}")
42
+ st.write(f"Confiance: {result[0]['score']:.2%}")
43
+
44
+ with col2:
45
+ if st.button("Détecter Pneumothorax"):
46
+ with st.spinner("Analyse en cours..."):
47
+ result = models['Pneumothorax'](image)
48
+ st.write(f"Résultat: {result[0]['label']}")
49
+ st.write(f"Confiance: {result[0]['score']:.2%}")
50
+
51
+ with col3:
52
+ if st.button("Détecter Pneumonie"):
53
+ with st.spinner("Analyse en cours..."):
54
+ result = models['Pneumonie'](image)
55
+ st.write(f"Résultat: {result[0]['label']}")
56
+ st.write(f"Confiance: {result[0]['score']:.2%}")
57
+
58
+ if __name__ == "__main__":
59
+ main()