import streamlit as st from transformers import pipeline import torch from PIL import Image, ImageDraw import io st.set_page_config(page_title="Détection de Fractures Osseuses", layout="wide") @st.cache_resource def load_model(): return pipeline("object-detection", model="D3STRON/bone-fracture-detr") def draw_boxes(image, predictions): draw = ImageDraw.Draw(image) for pred in predictions: box = pred['box'] label = f"{pred['label']} ({pred['score']:.2%})" # Draw bounding box draw.rectangle( [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])], outline="red", width=3 ) # Draw label background text_bbox = draw.textbbox((box['xmin'], box['ymin']), label) draw.rectangle(text_bbox, fill="red") # Draw label text draw.text( (box['xmin'], box['ymin']), label, fill="white" ) return image def main(): st.title("🦴 Détecteur de Fractures Osseuses") st.write("Téléchargez une radiographie pour détecter les fractures osseuses.") pipe = load_model() uploaded_file = st.file_uploader( "Choisissez une image de radiographie", type=['png', 'jpg', 'jpeg'] ) conf_threshold = st.slider( "Seuil de confiance", min_value=0.0, max_value=1.0, value=0.5, step=0.05 ) if uploaded_file: col1, col2 = st.columns(2) # Original image image = Image.open(uploaded_file) col1.header("Image originale") col1.image(image) # Process image with st.spinner("Analyse en cours..."): predictions = pipe(image) # Filter predictions based on confidence threshold filtered_preds = [ pred for pred in predictions if pred['score'] >= conf_threshold ] # Draw boxes on a copy of the image result_image = image.copy() result_image = draw_boxes(result_image, filtered_preds) # Display results col2.header("Résultats de la détection") col2.image(result_image) # Display detailed predictions if filtered_preds: st.subheader("Détails des détections") for pred in filtered_preds: st.write( f"• Type: {pred['label']} - " f"Confiance: {pred['score']:.2%}" ) else: st.warning( "Aucune fracture détectée avec le seuil de confiance actuel. " "Essayez de baisser le seuil pour plus de résultats." ) if __name__ == "__main__": main()