|
|
import streamlit as st |
|
|
from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification |
|
|
from PIL import Image |
|
|
import torch |
|
|
import numpy as np |
|
|
import cv2 |
|
|
|
|
|
st.set_page_config(page_title="Détection de fractures osseuses par rayons X") |
|
|
|
|
|
st.title("Détection de fractures osseuses par rayons X") |
|
|
|
|
|
@st.cache_resource |
|
|
def load_models(): |
|
|
processor = AutoImageProcessor.from_pretrained("Heem2/bone-fracture-detection-using-xray") |
|
|
model = AutoModelForImageClassification.from_pretrained("Heem2/bone-fracture-detection-using-xray") |
|
|
return processor, model |
|
|
|
|
|
processor, model = load_models() |
|
|
|
|
|
def generate_heatmap(image, model, processor): |
|
|
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
features = model.classifier.weight.data |
|
|
|
|
|
|
|
|
cam = torch.matmul(outputs.logits, features) |
|
|
cam = cam.reshape(7, 7) |
|
|
cam = cam.detach().numpy() |
|
|
|
|
|
|
|
|
cam = (cam - cam.min()) / (cam.max() - cam.min()) |
|
|
|
|
|
|
|
|
cam = cv2.resize(cam, (image.size[0], image.size[1])) |
|
|
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET) |
|
|
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
return heatmap |
|
|
|
|
|
uploaded_file = st.file_uploader("Téléchargez une image radiographique", type=["jpg", "jpeg", "png"]) |
|
|
|
|
|
if uploaded_file: |
|
|
|
|
|
image = Image.open(uploaded_file) |
|
|
st.image(image, caption="Image originale", use_column_width=True) |
|
|
|
|
|
|
|
|
pipe = pipeline("image-classification", model=model, feature_extractor=processor) |
|
|
results = pipe(image) |
|
|
|
|
|
|
|
|
st.subheader("Résultats de l'analyse") |
|
|
for result in results: |
|
|
confidence = result['score'] * 100 |
|
|
label = "Fracture détectée" if result['label'] == "FRACTURE" else "Pas de fracture" |
|
|
st.write(f"{label} (Confiance: {confidence:.2f}%)") |
|
|
|
|
|
|
|
|
color = "red" if label == "Fracture détectée" else "green" |
|
|
st.progress(result['score']) |
|
|
|
|
|
|
|
|
if label == "Fracture détectée": |
|
|
st.subheader("Localisation probable de la fracture") |
|
|
heatmap = generate_heatmap(image, model, processor) |
|
|
st.image(heatmap, caption="Carte de chaleur de la fracture", use_column_width=True) |
|
|
else: |
|
|
st.write("Veuillez télécharger une image radiographique pour l'analyse.") |