radpid / app.py
yassonee's picture
Update app.py
edcb3d3 verified
raw
history blame
10.8 kB
import streamlit as st
from transformers import pipeline
from PIL import Image, ImageDraw
import numpy as np
import colorsys
import os
# Configuration des en-têtes pour WebSocket
os.environ['STREAMLIT_SERVER_WEBSOCKET_HEADERS'] = '{"Access-Control-Allow-Origin": "*"}'
os.environ['STREAMLIT_SERVER_ENABLE_CORS'] = 'true'
# Configuration de la page
st.set_page_config(
page_title="Fraktur Detektion",
layout="wide",
initial_sidebar_state="collapsed"
)
# Script pour gérer les WebSockets sur Edge
st.markdown("""
<script>
if (window.WebSocket && navigator.userAgent.indexOf("Edge") > -1) {
const originalWebSocket = window.WebSocket;
window.WebSocket = function(url, protocols) {
if (url.includes('_stcore/stream')) {
url = url.replace('wss://', 'ws://');
}
return new originalWebSocket(url, protocols);
};
}
</script>
""", unsafe_allow_html=True)
st.markdown("""
<style>
.stApp {
background: #f0f2f5 !important;
}
.block-container {
padding-top: 0 !important;
padding-bottom: 0 !important;
max-width: 1400px !important;
}
.upload-container {
background: white;
padding: 1.5rem;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
margin-bottom: 1rem;
text-align: center;
}
.results-container {
background: white;
padding: 1.5rem;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.result-box {
background: #f8f9fa;
padding: 0.75rem;
border-radius: 8px;
margin: 0.5rem 0;
border: 1px solid #e9ecef;
}
h1, h2, h3, h4, p {
color: #1a1a1a !important;
margin: 0.5rem 0 !important;
}
.stImage {
background: white;
padding: 0.5rem;
border-radius: 8px;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
.stImage > img {
max-height: 300px !important;
width: auto !important;
margin: 0 auto !important;
display: block !important;
}
[data-testid="stFileUploader"] {
width: 100% !important;
}
.stFileUploaderFileName {
color: #1a1a1a !important;
}
.stButton > button {
width: 200px;
background-color: #f8f9fa !important;
color: #1a1a1a !important;
border: 1px solid #e9ecef !important;
padding: 0.5rem 1rem !important;
border-radius: 5px !important;
transition: all 0.3s ease !important;
}
.stButton > button:hover {
background-color: #e9ecef !important;
transform: translateY(-1px);
}
#MainMenu, footer, header, [data-testid="stToolbar"] {
display: none !important;
}
/* Hide deprecation warning */
[data-testid="stExpander"], .element-container:has(>.stAlert) {
display: none !important;
}
/* Fix for WebSocket connection issues */
iframe {
visibility: hidden;
}
</style>
<script>
// Fix for WebSocket connection
window.addEventListener('load', function() {
setTimeout(function() {
const frames = document.getElementsByTagName('iframe');
for (let frame of frames) {
frame.style.visibility = 'visible';
}
}, 1000);
});
</script>
""", unsafe_allow_html=True)
# Configuration du cache pour les modèles
@st.cache_resource(show_spinner=False)
def load_models():
return {
"KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"RöntgenMeister": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
def translate_label(label):
translations = {
"fracture": "Knochenbruch",
"no fracture": "Kein Knochenbruch",
"normal": "Normal",
"abnormal": "Auffällig",
"F1": "Knochenbruch",
"NF": "Kein Knochenbruch"
}
return translations.get(label.lower(), label)
def create_heatmap_overlay(image, box, score):
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
x1, y1 = box['xmin'], box['ymin']
x2, y2 = box['xmax'], box['ymax']
if score > 0.8:
fill_color = (255, 0, 0, 100)
border_color = (255, 0, 0, 255)
elif score > 0.6:
fill_color = (255, 165, 0, 100)
border_color = (255, 165, 0, 255)
else:
fill_color = (255, 255, 0, 100)
border_color = (255, 255, 0, 255)
draw.rectangle([x1, y1, x2, y2], fill=fill_color)
draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
return overlay
def draw_boxes(image, predictions):
result_image = image.copy().convert('RGBA')
for pred in predictions:
box = pred['box']
score = pred['score']
overlay = create_heatmap_overlay(image, box, score)
result_image = Image.alpha_composite(result_image, overlay)
draw = ImageDraw.Draw(result_image)
temp = 36.5 + (score * 2.5)
label = f"{translate_label(pred['label'])} ({score:.1%}{temp:.1f}°C)"
text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
draw.text(
(box['xmin'], box['ymin']-20),
label,
fill=(255, 255, 255, 255)
)
return result_image
def main():
try:
models = load_models()
with st.container():
st.write("### 📤 Röntgenbild hochladen")
uploaded_file = st.file_uploader("Bild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed")
col1, col2 = st.columns([2, 1])
with col1:
conf_threshold = st.slider(
"Konfidenzschwelle",
min_value=0.0, max_value=1.0,
value=0.60, step=0.05,
label_visibility="visible"
)
with col2:
analyze_button = st.button("Analysieren")
if uploaded_file and analyze_button:
with st.spinner("Bild wird analysiert..."):
image = Image.open(uploaded_file)
results_container = st.container()
predictions_watcher = models["KnochenWächter"](image)
predictions_master = models["RöntgenMeister"](image)
predictions_locator = models["KnochenAuge"](image)
has_fracture = False
max_fracture_score = 0
filtered_locations = [p for p in predictions_locator
if p['score'] >= conf_threshold]
for pred in predictions_watcher:
if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower():
has_fracture = True
max_fracture_score = max(max_fracture_score, pred['score'])
with results_container:
st.write("### 🔍 Analyse Ergebnisse")
col1, col2 = st.columns(2)
with col1:
st.write("#### 🤖 KI-Diagnose")
st.markdown("#### 🛡️ KnochenWächter")
for pred in predictions_watcher:
confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
label_lower = pred['label'].lower()
if pred['score'] >= conf_threshold and 'fracture' in label_lower:
has_fracture = True
max_fracture_score = max(max_fracture_score, pred['score'])
st.markdown(f"""
<div class="result-box" style="color: #1a1a1a;">
<span style="color: {confidence_color}; font-weight: 500;">
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
st.markdown("#### 🎓 RöntgenMeister")
for pred in predictions_master:
confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
st.markdown(f"""
<div class="result-box" style="color: #1a1a1a;">
<span style="color: {confidence_color}; font-weight: 500;">
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
if max_fracture_score > 0:
st.write("#### 📊 Wahrscheinlichkeit")
no_fracture_prob = 1 - max_fracture_score
st.markdown(f"""
<div class="result-box" style="color: #1a1a1a;">
Knochenbruch: <strong style="color: #0066cc">{max_fracture_score:.1%}</strong><br>
Kein Knochenbruch: <strong style="color: #ffa500">{no_fracture_prob:.1%}</strong>
</div>
""", unsafe_allow_html=True)
with col2:
predictions = models["KnochenAuge"](image)
filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
if filtered_preds:
st.write("#### 🎯 Fraktur Lokalisation")
result_image = draw_boxes(image, filtered_preds)
st.image(result_image, use_container_width=True)
else:
st.write("#### 🖼️ Röntgenbild")
st.image(image, use_container_width=True)
except Exception as e:
st.error(f"Ein Fehler ist aufgetreten: {str(e)}")
if __name__ == "__main__":
main()