radpid / app.py
yassonee's picture
Update app.py
9aecd9e verified
raw
history blame
7.75 kB
import streamlit as st
from transformers import pipeline
from PIL import Image, ImageDraw
import torch
# Configuration de la page
st.set_page_config(
page_title="Knochenbrucherkennung",
layout="wide",
initial_sidebar_state="collapsed"
)
# CSS amélioré pour une meilleure intégration
st.markdown("""
<style>
/* Reset et base */
.stApp {
background-color: transparent !important;
padding: 0 !important;
}
.block-container {
padding: 1rem !important;
max-width: 100% !important;
}
/* En-tête compact */
h1 {
font-size: 1.5rem !important;
margin-bottom: 1rem !important;
}
/* Conteneurs */
.main > div {
padding: 1rem !important;
background: transparent !important;
border-radius: 0.5rem !important;
box-shadow: none !important;
}
/* Upload plus compact */
.uploadedFile {
border: 1px dashed #ccc;
border-radius: 0.5rem;
padding: 0.5rem;
background: rgba(255, 255, 255, 0.05);
}
/* Tabs style */
.stTabs [data-baseweb="tab-list"] {
gap: 1rem;
background-color: transparent;
}
.stTabs [data-baseweb="tab"] {
padding: 0.5rem 1rem;
border-radius: 0.5rem;
background: rgba(255, 255, 255, 0.1);
}
/* Résultats */
.result-box {
padding: 0.5rem;
border-radius: 0.375rem;
margin: 0.25rem 0;
background: rgba(255, 255, 255, 0.05);
border: 1px solid rgba(255, 255, 255, 0.1);
}
/* Images */
.stImage img {
max-height: 300px !important;
width: auto !important;
border-radius: 0.375rem;
margin: 0 auto;
}
/* Spinner plus petit */
.stSpinner > div {
height: 2rem !important;
width: 2rem !important;
}
/* Cacher éléments Streamlit */
#MainMenu {display: none;}
footer {display: none;}
header {display: none;}
/* Dark mode support */
@media (prefers-color-scheme: dark) {
.main > div {
background: rgba(0, 0, 0, 0.2) !important;
}
.uploadedFile {
border-color: #4a5568;
background: rgba(255, 255, 255, 0.05);
}
.stTabs [data-baseweb="tab"] {
background: rgba(255, 255, 255, 0.05);
}
.result-box {
background: rgba(255, 255, 255, 0.05);
border-color: rgba(255, 255, 255, 0.2);
}
}
/* Expander plus compact */
.streamlit-expanderHeader {
padding: 0.5rem !important;
background: rgba(255, 255, 255, 0.05) !important;
border-radius: 0.375rem !important;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_models():
return {
"D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"Nandodeomkar": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
def translate_label(label):
translations = {
"fracture": "Knochenbruch",
"no fracture": "Kein Bruch",
"normal": "Normal",
"abnormal": "Abnormal"
}
return translations.get(label.lower(), label)
def draw_boxes(image, predictions):
draw = ImageDraw.Draw(image)
for pred in predictions:
box = pred['box']
label = f"{translate_label(pred['label'])} ({pred['score']:.2%})"
# Box avec couleur basée sur le score
color = "#FF6B6B" if pred['score'] > 0.7 else "#FFA500"
draw.rectangle(
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
outline=color,
width=2
)
# Label plus compact
text_bbox = draw.textbbox((box['xmin'], box['ymin']-15), label)
draw.rectangle(text_bbox, fill=color)
draw.text((box['xmin'], box['ymin']-15), label, fill="white")
return image
def main():
st.markdown("<h1>🦴 KI-Fraktur Detektion</h1>", unsafe_allow_html=True)
models = load_models()
# Settings dans un expander compact
with st.expander("⚙️ Einstellungen", expanded=False):
conf_threshold = st.slider(
"Konfidenzschwelle",
min_value=0.0,
max_value=1.0,
value=0.60,
step=0.05
)
# Upload plus compact
uploaded_file = st.file_uploader(
"",
type=['png', 'jpg', 'jpeg'],
help="Unterstützte Formate: JPEG, PNG | Max: 5MB"
)
if uploaded_file:
# Layout en colonnes
col1, col2 = st.columns([1, 1])
with col1:
image = Image.open(uploaded_file)
max_size = (300, 300)
image.thumbnail(max_size, Image.Resampling.LANCZOS)
st.image(image, caption="Originalbild", use_column_width=True)
with col2:
tab1, tab2 = st.tabs(["📊 Klassifizierung", "🔍 Lokalisierung"])
with tab1:
for name in ["Heem2", "Nandodeomkar"]:
predictions = models[name](image)
for pred in predictions:
if pred['score'] >= conf_threshold:
score_color = "green" if pred['score'] > 0.7 else "orange"
st.markdown(f"""
<div class='result-box'>
<span style='color: {score_color}; font-weight: bold;'>
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
with tab2:
predictions = models["D3STRON"](image)
filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
if filtered_preds:
result_image = image.copy()
result_image = draw_boxes(result_image, filtered_preds)
st.image(result_image, use_column_width=True)
for pred in filtered_preds:
st.markdown(f"""
<div class='result-box'>
{translate_label(pred['label'])}: {pred['score']:.1%}
</div>
""", unsafe_allow_html=True)
else:
st.info("Keine Erkennungen über dem Schwellenwert")
else:
# Message d'instruction
st.markdown("""
<div style='padding: 1rem; background: rgba(59, 130, 246, 0.1); border-radius: 0.5rem;'>
<h4 style='margin: 0 0 0.5rem 0; font-size: 1rem;'>📤 Röntgenbild hochladen</h4>
<ul style='margin: 0; padding-left: 1rem; font-size: 0.875rem;'>
<li>Unterstützte Formate: JPEG, PNG</li>
<li>Maximale Größe: 5 MB</li>
<li>Optimale Auflösung: 512x512 Pixel</li>
</ul>
</div>
""", unsafe_allow_html=True)
# Script pour gérer le thème
st.markdown("""
<script>
window.addEventListener('message', function(e) {
if (e.data.type === 'theme-change') {
document.body.classList.toggle('dark', e.data.theme === 'dark');
}
});
</script>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()