import streamlit as st from transformers import pipeline from PIL import Image, ImageDraw import numpy as np import colorsys st.set_page_config( page_title="Fraktur Detektion", layout="wide", initial_sidebar_state="collapsed" ) st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_models(): return { "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"), "KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"), "RöntgenMeister": pipeline("image-classification", model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388") } def translate_label(label): translations = { "fracture": "Knochenbruch", "no fracture": "Kein Knochenbruch", "normal": "Normal", "abnormal": "Auffällig", "F1": "Knochenbruch", "NF": "Kein Knochenbruch" } return translations.get(label.lower(), label) def create_heatmap_overlay(image, box, score): overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) def get_temp_color(value): if value > 0.8: return (255, 0, 0) # Rouge vif elif value > 0.6: return (255, 69, 0) # Rouge-orange elif value > 0.4: return (255, 165, 0) # Orange else: return (255, 255, 0) # Jaune x1, y1 = box['xmin'], box['ymin'] x2, y2 = box['xmax'], box['ymax'] width = x2 - x1 height = y2 - y1 steps = 30 for i in range(steps): alpha = int(255 * (1 - (i / steps)) * 0.7) base_color = get_temp_color(score) color = base_color + (alpha,) shrink_x = (i * width) / (steps * 2) shrink_y = (i * height) / (steps * 2) draw.rectangle( [x1 + shrink_x, y1 + shrink_y, x2 - shrink_x, y2 - shrink_y], fill=color, outline=None ) border_color = get_temp_color(score) + (200,) draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2) return overlay def draw_boxes(image, predictions): result_image = image.copy().convert('RGBA') sorted_predictions = sorted(predictions, key=lambda x: x['score']) for pred in sorted_predictions: box = pred['box'] score = pred['score'] heatmap = create_heatmap_overlay(image, box, score) result_image = Image.alpha_composite(result_image, heatmap) draw = ImageDraw.Draw(result_image) temp = 36.5 + (score * 2.5) label = f"{translate_label(pred['label'])} ({score:.1%}) • {temp:.1f}°C" text_bbox = draw.textbbox((box['xmin'], box['ymin']-25), label) padding = 3 text_bbox = ( text_bbox[0]-padding, text_bbox[1]-padding, text_bbox[2]+padding, text_bbox[3]+padding ) draw.rectangle(text_bbox, fill="#000000CC") draw.text( (box['xmin'], box['ymin']-25), label, fill="#FFFFFF", stroke_width=1, stroke_fill="#000000" ) return result_image def main(): models = load_models() with st.container(): st.write("### 📤 Röntgenbild hochladen") uploaded_file = st.file_uploader("Bild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed") col1, col2 = st.columns([2, 1]) with col1: conf_threshold = st.slider( "Konfidenzschwelle", min_value=0.0, max_value=1.0, value=0.60, step=0.05, label_visibility="visible" ) with col2: analyze_button = st.button("Analysieren") if uploaded_file and analyze_button: with st.spinner("Bild wird analysiert..."): image = Image.open(uploaded_file) results_container = st.container() predictions_watcher = models["KnochenWächter"](image) predictions_master = models["RöntgenMeister"](image) predictions_locator = models["KnochenAuge"](image) has_fracture = False max_fracture_score = 0 filtered_locations = [p for p in predictions_locator if p['score'] >= conf_threshold and 'fracture' in p['label'].lower()] for pred in predictions_watcher: if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower(): has_fracture = True max_fracture_score = max(max_fracture_score, pred['score']) with results_container: st.write("### 🔍 Analyse Ergebnisse") col1, col2 = st.columns(2) with col1: st.write("#### 🤖 KI-Diagnose") st.write("##### 🛡️ KnochenWächter") for pred in predictions_watcher: if pred['score'] >= conf_threshold: confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500' label_lower = pred['label'].lower() if 'fracture' in label_lower: has_fracture = True max_fracture_score = max(max_fracture_score, pred['score']) st.markdown(f"""