radpid / app.py
yassonee's picture
Update app.py
9f1f60f verified
raw
history blame
10.3 kB
import streamlit as st
from transformers import pipeline
from PIL import Image, ImageDraw
import numpy as np
import colorsys
from streamlit.web.server.server import Server
import streamlit.components.v1 as components
# Add WebSocket headers configuration
if Server.get_current():
Server.get_current()._websocket_headers = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type"
}
st.set_page_config(
page_title="Fraktur Detektion",
layout="wide",
initial_sidebar_state="collapsed"
)
# Add Edge WebSocket compatibility
components.html("""
<script>
if (window.WebSocket && navigator.userAgent.indexOf("Edge") > -1) {
window.WebSocket = window.WebSocket || window.MozWebSocket;
}
</script>
""", height=0)
st.markdown("""
<style>
.stApp {
background: #f0f2f5 !important;
}
.block-container {
padding-top: 0 !important;
padding-bottom: 0 !important;
max-width: 1400px !important;
}
.upload-container {
background: white;
padding: 1.5rem;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
margin-bottom: 1rem;
text-align: center;
}
.results-container {
background: white;
padding: 1.5rem;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.result-box {
background: #f8f9fa;
padding: 0.75rem;
border-radius: 8px;
margin: 0.5rem 0;
border: 1px solid #e9ecef;
}
h1, h2, h3, h4, p {
color: #1a1a1a !important;
margin: 0.5rem 0 !important;
}
.stImage {
background: white;
padding: 0.5rem;
border-radius: 8px;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
.stImage > img {
max-height: 300px !important;
width: auto !important;
margin: 0 auto !important;
display: block !important;
}
[data-testid="stFileUploader"] {
width: 100% !important;
}
.stFileUploaderFileName {
color: #1a1a1a !important;
}
.stButton > button {
width: 200px;
background-color: #f8f9fa !important;
color: #1a1a1a !important;
border: 1px solid #e9ecef !important;
padding: 0.5rem 1rem !important;
border-radius: 5px !important;
transition: all 0.3s ease !important;
}
.stButton > button:hover {
background-color: #e9ecef !important;
transform: translateY(-1px);
}
#MainMenu, footer, header, [data-testid="stToolbar"] {
display: none !important;
}
/* Hide deprecation warning */
[data-testid="stExpander"], .element-container:has(>.stAlert) {
display: none !important;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_models():
return {
"KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"RöntgenMeister": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
def translate_label(label):
translations = {
"fracture": "Knochenbruch",
"no fracture": "Kein Knochenbruch",
"normal": "Normal",
"abnormal": "Auffällig",
"F1": "Knochenbruch",
"NF": "Kein Knochenbruch"
}
return translations.get(label.lower(), label)
def create_heatmap_overlay(image, box, score):
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
x1, y1 = box['xmin'], box['ymin']
x2, y2 = box['xmax'], box['ymax']
# Couleur basée sur le score
if score > 0.8:
fill_color = (255, 0, 0, 100) # Rouge
border_color = (255, 0, 0, 255)
elif score > 0.6:
fill_color = (255, 165, 0, 100) # Orange
border_color = (255, 165, 0, 255)
else:
fill_color = (255, 255, 0, 100) # Jaune
border_color = (255, 255, 0, 255)
# Rectangle semi-transparent
draw.rectangle([x1, y1, x2, y2], fill=fill_color)
# Bordure
draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
return overlay
def draw_boxes(image, predictions):
result_image = image.copy().convert('RGBA')
for pred in predictions:
box = pred['box']
score = pred['score']
# Création de l'overlay
overlay = create_heatmap_overlay(image, box, score)
result_image = Image.alpha_composite(result_image, overlay)
# Ajout du texte
draw = ImageDraw.Draw(result_image)
temp = 36.5 + (score * 2.5)
label = f"{translate_label(pred['label'])} ({score:.1%}{temp:.1f}°C)"
# Fond noir pour le texte
text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
# Texte en blanc
draw.text(
(box['xmin'], box['ymin']-20),
label,
fill=(255, 255, 255, 255)
)
return result_image
def main():
models = load_models()
with st.container():
st.write("### 📤 Röntgenbild hochladen")
uploaded_file = st.file_uploader("Bild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed")
col1, col2 = st.columns([2, 1])
with col1:
conf_threshold = st.slider(
"Konfidenzschwelle",
min_value=0.0, max_value=1.0,
value=0.60, step=0.05,
label_visibility="visible"
)
with col2:
analyze_button = st.button("Analysieren")
if uploaded_file and analyze_button:
with st.spinner("Bild wird analysiert..."):
image = Image.open(uploaded_file)
results_container = st.container()
predictions_watcher = models["KnochenWächter"](image)
predictions_master = models["RöntgenMeister"](image)
predictions_locator = models["KnochenAuge"](image)
has_fracture = False
max_fracture_score = 0
filtered_locations = [p for p in predictions_locator
if p['score'] >= conf_threshold]
for pred in predictions_watcher:
if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower():
has_fracture = True
max_fracture_score = max(max_fracture_score, pred['score'])
with results_container:
st.write("### 🔍 Analyse Ergebnisse")
col1, col2 = st.columns(2)
with col1:
st.write("#### 🤖 KI-Diagnose")
st.markdown("#### 🛡️ KnochenWächter")
# Afficher tous les résultats de KnochenWächter
for pred in predictions_watcher:
confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
label_lower = pred['label'].lower()
# Mettre à jour max_fracture_score seulement pour les fractures
if pred['score'] >= conf_threshold and 'fracture' in label_lower:
has_fracture = True
max_fracture_score = max(max_fracture_score, pred['score'])
# Afficher tous les résultats
st.markdown(f"""
<div class="result-box" style="color: #1a1a1a;">
<span style="color: {confidence_color}; font-weight: 500;">
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
st.markdown("#### 🎓 RöntgenMeister")
# Afficher tous les résultats de RöntgenMeister
for pred in predictions_master:
confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
st.markdown(f"""
<div class="result-box" style="color: #1a1a1a;">
<span style="color: {confidence_color}; font-weight: 500;">
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
if max_fracture_score > 0:
st.write("#### 📊 Wahrscheinlichkeit")
no_fracture_prob = 1 - max_fracture_score
st.markdown(f"""
<div class="result-box" style="color: #1a1a1a;">
Knochenbruch: <strong style="color: #0066cc">{max_fracture_score:.1%}</strong><br>
Kein Knochenbruch: <strong style="color: #ffa500">{no_fracture_prob:.1%}</strong>
</div>
""", unsafe_allow_html=True)
with col2:
predictions = models["KnochenAuge"](image)
filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
if filtered_preds:
st.write("#### 🎯 Fraktur Lokalisation")
result_image = draw_boxes(image, filtered_preds)
st.image(result_image, use_container_width=True)
else:
st.write("#### 🖼️ Röntgenbild")
st.image(image, use_container_width=True)
if __name__ == "__main__":
main()