yassonee commited on
Commit
0000f4a
·
verified ·
1 Parent(s): aa66005

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -23
app.py CHANGED
@@ -1,33 +1,73 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
  from PIL import Image
 
 
 
 
 
 
 
4
 
5
  @st.cache_resource
6
- def load_model():
7
- return pipeline("image-classification", model="mrm8488/vit-base-patch16-224_finetuned-pneumothorax", top_k=2)
 
 
 
 
8
 
9
- def main():
10
- st.title("Détection de Pneumothorax")
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- model = load_model()
 
13
 
14
- uploaded_file = st.file_uploader("Télécharger une radiographie", type=["jpg", "jpeg", "png"])
 
 
 
15
 
16
- if uploaded_file:
17
- image = Image.open(uploaded_file).convert('RGB')
18
- resized_image = image.resize((224, 224))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- col1, col2 = st.columns(2)
21
- with col1:
22
- st.image(resized_image, width=300, caption="Image originale")
23
 
24
- if st.button("Analyser"):
25
- with st.spinner("Analyse en cours..."):
26
- results = model(resized_image)
27
- with col2:
28
- st.write("Résultats:")
29
- for result in results:
30
- st.write(f"{result['label']}: {result['score']:.2%}")
31
-
32
- if __name__ == "__main__":
33
- main()
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
3
  from PIL import Image
4
+ import torch
5
+ import numpy as np
6
+ import cv2
7
+
8
+ st.set_page_config(page_title="Détection de fractures osseuses par rayons X")
9
+
10
+ st.title("Détection de fractures osseuses par rayons X")
11
 
12
  @st.cache_resource
13
+ def load_models():
14
+ processor = AutoImageProcessor.from_pretrained("Heem2/bone-fracture-detection-using-xray")
15
+ model = AutoModelForImageClassification.from_pretrained("Heem2/bone-fracture-detection-using-xray")
16
+ return processor, model
17
+
18
+ processor, model = load_models()
19
 
20
+ def generate_heatmap(image, model, processor):
21
+ # Préparer l'image
22
+ inputs = processor(images=image, return_tensors="pt")
23
+
24
+ # Obtenir les activations
25
+ with torch.no_grad():
26
+ outputs = model(**inputs)
27
+ # Utiliser les dernières activations
28
+ features = model.classifier.weight.data
29
+
30
+ # Créer la carte de chaleur
31
+ cam = torch.matmul(outputs.logits, features)
32
+ cam = cam.reshape(7, 7) # Ajuster selon la taille de votre modèle
33
+ cam = cam.detach().numpy()
34
 
35
+ # Normaliser
36
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
37
 
38
+ # Redimensionner à la taille de l'image
39
+ cam = cv2.resize(cam, (image.size[0], image.size[1]))
40
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
41
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
42
 
43
+ return heatmap
44
+
45
+ uploaded_file = st.file_uploader("Téléchargez une image radiographique", type=["jpg", "jpeg", "png"])
46
+
47
+ if uploaded_file:
48
+ # Afficher l'image originale
49
+ image = Image.open(uploaded_file)
50
+ st.image(image, caption="Image originale", use_column_width=True)
51
+
52
+ # Prédiction
53
+ pipe = pipeline("image-classification", model=model, feature_extractor=processor)
54
+ results = pipe(image)
55
+
56
+ # Afficher les résultats
57
+ st.subheader("Résultats de l'analyse")
58
+ for result in results:
59
+ confidence = result['score'] * 100
60
+ label = "Fracture détectée" if result['label'] == "FRACTURE" else "Pas de fracture"
61
+ st.write(f"{label} (Confiance: {confidence:.2f}%)")
62
 
63
+ # Barre de confiance colorée
64
+ color = "red" if label == "Fracture détectée" else "green"
65
+ st.progress(result['score'])
66
 
67
+ # Générer et afficher la carte de chaleur si fracture détectée
68
+ if label == "Fracture détectée":
69
+ st.subheader("Localisation probable de la fracture")
70
+ heatmap = generate_heatmap(image, model, processor)
71
+ st.image(heatmap, caption="Carte de chaleur de la fracture", use_column_width=True)
72
+ else:
73
+ st.write("Veuillez télécharger une image radiographique pour l'analyse.")