yassonee commited on
Commit
f0f1078
·
verified ·
1 Parent(s): 5219a6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -54
app.py CHANGED
@@ -1,73 +1,62 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
3
  from PIL import Image
4
- import torch
5
  import numpy as np
6
  import cv2
7
 
8
- st.set_page_config(page_title="Détection de fractures osseuses par rayons X")
9
 
10
  st.title("Détection de fractures osseuses par rayons X")
11
 
12
  @st.cache_resource
13
- def load_models():
14
- processor = AutoImageProcessor.from_pretrained("Heem2/bone-fracture-detection-using-xray")
15
- model = AutoModelForImageClassification.from_pretrained("Heem2/bone-fracture-detection-using-xray")
16
- return processor, model
17
 
18
- processor, model = load_models()
19
-
20
- def generate_heatmap(image, model, processor):
21
- # Préparer l'image
22
- inputs = processor(images=image, return_tensors="pt")
23
-
24
- # Obtenir les activations
25
- with torch.no_grad():
26
- outputs = model(**inputs)
27
- # Utiliser les dernières activations
28
- features = model.classifier.weight.data
29
-
30
- # Créer la carte de chaleur
31
- cam = torch.matmul(outputs.logits, features)
32
- cam = cam.reshape(7, 7) # Ajuster selon la taille de votre modèle
33
- cam = cam.detach().numpy()
34
-
35
- # Normaliser
36
- cam = (cam - cam.min()) / (cam.max() - cam.min())
37
-
38
- # Redimensionner à la taille de l'image
39
- cam = cv2.resize(cam, (image.size[0], image.size[1]))
40
- heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
41
- heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
42
-
43
- return heatmap
44
 
45
  uploaded_file = st.file_uploader("Téléchargez une image radiographique", type=["jpg", "jpeg", "png"])
46
 
47
  if uploaded_file:
48
- # Afficher l'image originale
49
  image = Image.open(uploaded_file)
50
- st.image(image, caption="Image originale", use_column_width=True)
 
 
 
 
51
 
52
- # Prédiction
53
- pipe = pipeline("image-classification", model=model, feature_extractor=processor)
54
- results = pipe(image)
55
 
56
- # Afficher les résultats
57
- st.subheader("Résultats de l'analyse")
58
- for result in results:
59
- confidence = result['score'] * 100
60
- label = "Fracture détectée" if result['label'] == "FRACTURE" else "Pas de fracture"
61
- st.write(f"{label} (Confiance: {confidence:.2f}%)")
62
-
63
- # Barre de confiance colorée
64
- color = "red" if label == "Fracture détectée" else "green"
65
- st.progress(result['score'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- # Générer et afficher la carte de chaleur si fracture détectée
68
- if label == "Fracture détectée":
69
- st.subheader("Localisation probable de la fracture")
70
- heatmap = generate_heatmap(image, model, processor)
71
- st.image(heatmap, caption="Carte de chaleur de la fracture", use_column_width=True)
72
  else:
73
- st.write("Veuillez télécharger une image radiographique pour l'analyse.")
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
  from PIL import Image
 
4
  import numpy as np
5
  import cv2
6
 
7
+ st.set_page_config(page_title="Détection de fractures osseuses")
8
 
9
  st.title("Détection de fractures osseuses par rayons X")
10
 
11
  @st.cache_resource
12
+ def load_model():
13
+ return pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray")
 
 
14
 
15
+ model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  uploaded_file = st.file_uploader("Téléchargez une image radiographique", type=["jpg", "jpeg", "png"])
18
 
19
  if uploaded_file:
20
+ # Load and resize image
21
  image = Image.open(uploaded_file)
22
+ # Resize to max 800px width while maintaining aspect ratio
23
+ if image.size[0] > 800:
24
+ ratio = 800.0 / image.size[0]
25
+ size = (800, int(image.size[1] * ratio))
26
+ image = image.resize(size, Image.Resampling.LANCZOS)
27
 
28
+ # Convert to array for overlay
29
+ image_array = np.array(image)
 
30
 
31
+ # Make prediction
32
+ result = model(image)[0] # Get only top prediction
33
+
34
+ # Create columns for side by side display
35
+ col1, col2 = st.columns(2)
36
+
37
+ with col1:
38
+ st.image(image, caption="Image originale", use_container_width=True)
39
+
40
+ with col2:
41
+ # Create colored overlay based on prediction
42
+ overlay = np.zeros_like(image_array)
43
+ if result['label'] == "FRACTURE":
44
+ overlay[..., 0] = 255 # Red tint for fracture
45
+ alpha = 0.3
46
+ else:
47
+ overlay[..., 1] = 255 # Green tint for normal
48
+ alpha = 0.2
49
+
50
+ # Blend images
51
+ output = cv2.addWeighted(image_array, 1, overlay, alpha, 0)
52
+ st.image(output, caption="Image analysée", use_container_width=True)
53
+
54
+ # Display result
55
+ st.subheader("Résultat")
56
+ if result['label'] == "FRACTURE":
57
+ st.error(f"⚠️ Fracture détectée (Confiance: {result['score']*100:.1f}%)")
58
+ else:
59
+ st.success(f"✅ Pas de fracture détectée (Confiance: {result['score']*100:.1f}%)")
60
 
 
 
 
 
 
61
  else:
62
+ st.info("Veuillez télécharger une image radiographique pour l'analyse.")