radpid / app.py
yassonee's picture
Update app.py
8ab1fd2 verified
raw
history blame
10.7 kB
import streamlit as st
from transformers import pipeline
from PIL import Image, ImageDraw
import numpy as np
import colorsys
st.set_page_config(
page_title="Fraktur Detektion",
layout="wide",
initial_sidebar_state="collapsed"
)
st.markdown("""
<style>
.stApp {
background: #f0f2f5 !important;
}
.block-container {
padding-top: 0 !important;
padding-bottom: 0 !important;
max-width: 1400px !important;
}
.upload-container {
background: white;
padding: 1.5rem;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
margin-bottom: 1rem;
text-align: center;
}
.results-container {
background: white;
padding: 1.5rem;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.result-box {
background: #f8f9fa;
padding: 0.75rem;
border-radius: 8px;
margin: 0.5rem 0;
border: 1px solid #e9ecef;
}
h1, h2, h3, h4, p {
color: #1a1a1a !important;
margin: 0.5rem 0 !important;
}
.stImage {
background: white;
padding: 0.5rem;
border-radius: 8px;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
.stImage > img {
max-height: 300px !important;
width: auto !important;
margin: 0 auto !important;
display: block !important;
}
[data-testid="stFileUploader"] {
width: 100% !important;
}
.stButton > button {
width: 200px;
background-color: #0066cc !important;
color: white !important;
border: none !important;
padding: 0.5rem 1rem !important;
border-radius: 5px !important;
transition: all 0.3s ease !important;
}
.stButton > button:hover {
background-color: #0052a3 !important;
transform: translateY(-1px);
}
#MainMenu, footer, header, [data-testid="stToolbar"] {
display: none !important;
}
/* Hide deprecation warning */
[data-testid="stExpander"], .element-container:has(>.stAlert) {
display: none !important;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_models():
return {
"KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"RöntgenMeister": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
def translate_label(label):
translations = {
"fracture": "Knochenbruch",
"no fracture": "Kein Knochenbruch",
"normal": "Normal",
"abnormal": "Auffällig",
"F1": "Knochenbruch",
"NF": "Kein Knochenbruch"
}
return translations.get(label.lower(), label)
def create_heatmap_overlay(image, box, score):
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
def get_temp_color(value):
if value > 0.8:
return (255, 0, 0) # Rouge vif
elif value > 0.6:
return (255, 69, 0) # Rouge-orange
elif value > 0.4:
return (255, 165, 0) # Orange
else:
return (255, 255, 0) # Jaune
x1, y1 = box['xmin'], box['ymin']
x2, y2 = box['xmax'], box['ymax']
width = x2 - x1
height = y2 - y1
steps = 30
for i in range(steps):
alpha = int(255 * (1 - (i / steps)) * 0.7)
base_color = get_temp_color(score)
color = base_color + (alpha,)
shrink_x = (i * width) / (steps * 2)
shrink_y = (i * height) / (steps * 2)
draw.rectangle(
[x1 + shrink_x, y1 + shrink_y, x2 - shrink_x, y2 - shrink_y],
fill=color,
outline=None
)
border_color = get_temp_color(score) + (200,)
draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
return overlay
def draw_boxes(image, predictions):
result_image = image.copy().convert('RGBA')
sorted_predictions = sorted(predictions, key=lambda x: x['score'])
for pred in sorted_predictions:
box = pred['box']
score = pred['score']
heatmap = create_heatmap_overlay(image, box, score)
result_image = Image.alpha_composite(result_image, heatmap)
draw = ImageDraw.Draw(result_image)
temp = 36.5 + (score * 2.5)
label = f"{translate_label(pred['label'])} ({score:.1%}) • {temp:.1f}°C"
text_bbox = draw.textbbox((box['xmin'], box['ymin']-25), label)
padding = 3
text_bbox = (
text_bbox[0]-padding, text_bbox[1]-padding,
text_bbox[2]+padding, text_bbox[3]+padding
)
draw.rectangle(text_bbox, fill="#000000CC")
draw.text(
(box['xmin'], box['ymin']-25),
label,
fill="#FFFFFF",
stroke_width=1,
stroke_fill="#000000"
)
return result_image
def main():
models = load_models()
with st.container():
st.write("### 📤 Röntgenbild hochladen")
uploaded_file = st.file_uploader("Bild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed")
col1, col2 = st.columns([2, 1])
with col1:
conf_threshold = st.slider(
"Konfidenzschwelle",
min_value=0.0, max_value=1.0,
value=0.60, step=0.05,
label_visibility="visible"
)
with col2:
analyze_button = st.button("Analysieren")
if uploaded_file and analyze_button:
with st.spinner("Bild wird analysiert..."):
image = Image.open(uploaded_file)
results_container = st.container()
predictions_watcher = models["KnochenWächter"](image)
predictions_master = models["RöntgenMeister"](image)
predictions_locator = models["KnochenAuge"](image)
has_fracture = False
max_fracture_score = 0
filtered_locations = [p for p in predictions_locator
if p['score'] >= conf_threshold
and 'fracture' in p['label'].lower()]
for pred in predictions_watcher:
if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower():
has_fracture = True
max_fracture_score = max(max_fracture_score, pred['score'])
with results_container:
st.write("### 🔍 Analyse Ergebnisse")
col1, col2 = st.columns(2)
with col1:
st.write("#### 🤖 KI-Diagnose")
st.write("##### 🛡️ KnochenWächter")
for pred in predictions_watcher:
if pred['score'] >= conf_threshold:
confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
label_lower = pred['label'].lower()
if 'fracture' in label_lower:
has_fracture = True
max_fracture_score = max(max_fracture_score, pred['score'])
st.markdown(f"""
<div class="result-box" style="color: #1a1a1a;">
<span style="color: {confidence_color}; font-weight: 500;">
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
st.write("#### 🎓 RöntgenMeister")
for pred in predictions_master:
if pred['score'] >= conf_threshold:
confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
st.markdown(f"""
<div class="result-box" style="color: #1a1a1a;">
<span style="color: {confidence_color}; font-weight: 500;">
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
if max_fracture_score > 0:
st.write("#### 📊 Wahrscheinlichkeit")
no_fracture_prob = 1 - max_fracture_score
st.markdown(f"""
<div class="result-box" style="color: #1a1a1a;">
Knochenbruch: <strong style="color: #0066cc">{max_fracture_score:.1%}</strong><br>
Kein Knochenbruch: <strong style="color: #ffa500">{no_fracture_prob:.1%}</strong>
</div>
""", unsafe_allow_html=True)
with col2:
predictions = models["KnochenAuge"](image)
# Debug: Afficher toutes les prédictions avant filtrage
st.write("Debug - Toutes les prédictions:")
for p in predictions:
st.write(f"Label: {p['label']}, Score: {p['score']}")
filtered_preds = [p for p in predictions if p['score'] >= conf_threshold
and 'fracture' in p['label'].lower()]
# Debug: Afficher les prédictions filtrées
st.write("Debug - Prédictions filtrées:")
for p in filtered_preds:
st.write(f"Label: {p['label']}, Score: {p['score']}, Box: {p['box']}")
if filtered_preds:
st.write("#### 🎯 Fraktur Lokalisation")
result_image = draw_boxes(image, filtered_preds)
st.image(result_image, use_container_width=True)
else:
st.write("#### 🖼️ Röntgenbild")
st.image(image, use_container_width=True)
if __name__ == "__main__":
main()