radpid / app.py
yassonee's picture
Update app.py
69c9de9 verified
raw
history blame
7.99 kB
import streamlit as st
from transformers import pipeline
from PIL import Image, ImageDraw
import numpy as np
# Configuration de la page
if 'page_config' not in st.session_state:
st.set_page_config(
page_title="Fraktur Detektion",
layout="wide",
initial_sidebar_state="collapsed",
menu_items=None
)
st.session_state.page_config = True
@st.cache_resource
def load_models():
return {
"KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"RöntgenMeister": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
def translate_label(label):
translations = {
"fracture": "Knochenbruch",
"no fracture": "Kein Knochenbruch",
"normal": "Normal",
"abnormal": "Auffällig",
"F1": "Knochenbruch",
"NF": "Kein Knochenbruch"
}
return translations.get(label.lower(), label)
def create_heatmap_overlay(image, box, score):
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
x1, y1 = box['xmin'], box['ymin']
x2, y2 = box['xmax'], box['ymax']
if score > 0.8:
fill_color = (255, 0, 0, 100)
border_color = (255, 0, 0, 255)
elif score > 0.6:
fill_color = (255, 165, 0, 100)
border_color = (255, 165, 0, 255)
else:
fill_color = (255, 255, 0, 100)
border_color = (255, 255, 0, 255)
draw.rectangle([x1, y1, x2, y2], fill=fill_color)
draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
return overlay
def draw_boxes(image, predictions):
result_image = image.copy().convert('RGBA')
for pred in predictions:
box = pred['box']
score = pred['score']
overlay = create_heatmap_overlay(image, box, score)
result_image = Image.alpha_composite(result_image, overlay)
draw = ImageDraw.Draw(result_image)
temp = 36.5 + (score * 2.5)
label = f"{translate_label(pred['label'])} ({score:.1%}{temp:.1f}°C)"
text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
draw.text(
(box['xmin'], box['ymin']-20),
label,
fill=(255, 255, 255, 255)
)
return result_image
def main():
st.markdown("""
<style>
.stApp {background: #f0f2f5}
div[data-testid="stToolbar"] {display: none}
#MainMenu {visibility: hidden}
footer {visibility: hidden}
header {visibility: hidden}
.result-box {
background: #f8f9fa;
padding: 0.75rem;
border-radius: 8px;
margin: 0.5rem 0;
border: 1px solid #e9ecef;
}
</style>
""", unsafe_allow_html=True)
try:
models = load_models()
st.write("### 📤 Röntgenbild hochladen")
uploaded_file = st.file_uploader("Bild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed")
col1, col2 = st.columns([2, 1])
with col1:
conf_threshold = st.slider(
"Konfidenzschwelle",
min_value=0.0, max_value=1.0,
value=0.60, step=0.05
)
with col2:
analyze_button = st.button("Analysieren")
if uploaded_file and analyze_button:
with st.spinner("Bild wird analysiert..."):
image = Image.open(uploaded_file)
results_container = st.container()
predictions_watcher = models["KnochenWächter"](image)
predictions_master = models["RöntgenMeister"](image)
predictions_locator = models["KnochenAuge"](image)
has_fracture = False
max_fracture_score = 0
filtered_locations = [p for p in predictions_locator
if p['score'] >= conf_threshold]
for pred in predictions_watcher:
if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower():
has_fracture = True
max_fracture_score = max(max_fracture_score, pred['score'])
with results_container:
st.write("### 🔍 Analyse Ergebnisse")
col1, col2 = st.columns(2)
with col1:
st.write("#### 🤖 KI-Diagnose")
st.markdown("#### 🛡️ KnochenWächter")
for pred in predictions_watcher:
confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
label_lower = pred['label'].lower()
if pred['score'] >= conf_threshold and 'fracture' in label_lower:
has_fracture = True
max_fracture_score = max(max_fracture_score, pred['score'])
st.markdown(f"""
<div class="result-box">
<span style="color: {confidence_color}; font-weight: 500;">
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
st.markdown("#### 🎓 RöntgenMeister")
for pred in predictions_master:
confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
st.markdown(f"""
<div class="result-box">
<span style="color: {confidence_color}; font-weight: 500;">
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
if max_fracture_score > 0:
st.write("#### 📊 Wahrscheinlichkeit")
no_fracture_prob = 1 - max_fracture_score
st.markdown(f"""
<div class="result-box">
Knochenbruch: <strong style="color: #0066cc">{max_fracture_score:.1%}</strong><br>
Kein Knochenbruch: <strong style="color: #ffa500">{no_fracture_prob:.1%}</strong>
</div>
""", unsafe_allow_html=True)
with col2:
predictions = models["KnochenAuge"](image)
filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
if filtered_preds:
st.write("#### 🎯 Fraktur Lokalisation")
result_image = draw_boxes(image, filtered_preds)
st.image(result_image, use_container_width=True)
else:
st.write("#### 🖼️ Röntgenbild")
st.image(image, use_container_width=True)
except Exception as e:
st.error(f"Ein Fehler ist aufgetreten: {str(e)}")
if __name__ == "__main__":
main()