radpid / app.py
yassonee's picture
Update app.py
0000f4a verified
raw
history blame
2.75 kB
import streamlit as st
from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import numpy as np
import cv2
st.set_page_config(page_title="Détection de fractures osseuses par rayons X")
st.title("Détection de fractures osseuses par rayons X")
@st.cache_resource
def load_models():
processor = AutoImageProcessor.from_pretrained("Heem2/bone-fracture-detection-using-xray")
model = AutoModelForImageClassification.from_pretrained("Heem2/bone-fracture-detection-using-xray")
return processor, model
processor, model = load_models()
def generate_heatmap(image, model, processor):
# Préparer l'image
inputs = processor(images=image, return_tensors="pt")
# Obtenir les activations
with torch.no_grad():
outputs = model(**inputs)
# Utiliser les dernières activations
features = model.classifier.weight.data
# Créer la carte de chaleur
cam = torch.matmul(outputs.logits, features)
cam = cam.reshape(7, 7) # Ajuster selon la taille de votre modèle
cam = cam.detach().numpy()
# Normaliser
cam = (cam - cam.min()) / (cam.max() - cam.min())
# Redimensionner à la taille de l'image
cam = cv2.resize(cam, (image.size[0], image.size[1]))
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
return heatmap
uploaded_file = st.file_uploader("Téléchargez une image radiographique", type=["jpg", "jpeg", "png"])
if uploaded_file:
# Afficher l'image originale
image = Image.open(uploaded_file)
st.image(image, caption="Image originale", use_column_width=True)
# Prédiction
pipe = pipeline("image-classification", model=model, feature_extractor=processor)
results = pipe(image)
# Afficher les résultats
st.subheader("Résultats de l'analyse")
for result in results:
confidence = result['score'] * 100
label = "Fracture détectée" if result['label'] == "FRACTURE" else "Pas de fracture"
st.write(f"{label} (Confiance: {confidence:.2f}%)")
# Barre de confiance colorée
color = "red" if label == "Fracture détectée" else "green"
st.progress(result['score'])
# Générer et afficher la carte de chaleur si fracture détectée
if label == "Fracture détectée":
st.subheader("Localisation probable de la fracture")
heatmap = generate_heatmap(image, model, processor)
st.image(heatmap, caption="Carte de chaleur de la fracture", use_column_width=True)
else:
st.write("Veuillez télécharger une image radiographique pour l'analyse.")