radpid / app.py
yassonee's picture
Update app.py
ebc0e7a verified
raw
history blame
4.95 kB
import streamlit as st
from transformers import pipeline
from PIL import Image, ImageDraw
st.set_page_config(
page_title="Fraktur Detektion",
layout="wide",
initial_sidebar_state="collapsed"
)
st.markdown("""
<style>
.stApp {
padding: 0 !important;
height: 100vh !important;
overflow: hidden !important;
}
.block-container {
padding: 0.25rem !important;
max-width: 100% !important;
}
.stImage > img {
width: 70% !important;
height: auto !important;
max-height: 150px !important;
object-fit: contain !important;
}
h2, h3 {
font-size: 0.9rem !important;
}
.result-box {
font-size: 0.8rem !important;
margin: 0.2rem 0 !important;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_models():
return {
"KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"RöntgenMeister": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
def translate_label(label):
translations = {
"fracture": "Knochenbruch",
"no fracture": "Kein Bruch",
"normal": "Normal",
"abnormal": "Auffällig"
}
return translations.get(label.lower(), label)
def draw_boxes(image, predictions):
draw = ImageDraw.Draw(image)
for pred in predictions:
box = pred['box']
label = f"{translate_label(pred['label'])} ({pred['score']:.2%})"
color = "#2563eb" if pred['score'] > 0.7 else "#eab308"
draw.rectangle(
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
outline=color,
width=2
)
# Ajouter des points de "chaleur" aux fractures détectées
center_x = (box['xmin'] + box['xmax']) / 2
center_y = (box['ymin'] + box['ymax']) / 2
radius = 5
draw.ellipse(
[(center_x - radius, center_y - radius), (center_x + radius, center_y + radius)],
fill=color
)
# Label plus compact
text_bbox = draw.textbbox((box['xmin'], box['ymin']-15), label)
draw.rectangle(text_bbox, fill=color)
draw.text((box['xmin'], box['ymin']-15), label, fill="white")
return image
def main():
models = load_models()
# Disposition en deux colonnes principales
col1, col2 = st.columns([1, 2])
with col1:
st.markdown("### 📤 Röntgenbild Upload")
uploaded_file = st.file_uploader("", type=['png', 'jpg', 'jpeg'])
if uploaded_file:
conf_threshold = st.slider(
"Konfidenzschwelle",
min_value=0.0, max_value=1.0,
value=0.60, step=0.05
)
with col2:
if uploaded_file:
image = Image.open(uploaded_file)
# Toujours afficher les résultats des autres modèles
st.markdown("### 🎯 KI-Analyse")
st.markdown("**🛡️ Der KnochenWächter**")
predictions_wachter = models["KnochenWächter"](image)
for pred in predictions_wachter:
score_color = "#22c55e" if pred['score'] > 0.7 else "#eab308"
st.markdown(f"""
<div class='result-box'>
<span style='color: {score_color}; font-weight: 500;'>
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
st.markdown("**🎓 Der RöntgenMeister**")
predictions_meister = models["RöntgenMeister"](image)
for pred in predictions_meister:
score_color = "#22c55e" if pred['score'] > 0.7 else "#eab308"
st.markdown(f"""
<div class='result-box'>
<span style='color: {score_color}; font-weight: 500;'>
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
# Analyse avec KnochenAuge (localisation)
predictions_auge = models["KnochenAuge"](image)
filtered_preds = [p for p in predictions_auge if p['score'] >= conf_threshold]
if filtered_preds:
st.markdown("#### 👁️ Das KnochenAuge - Lokalisation")
result_image = image.copy()
result_image = draw_boxes(result_image, filtered_preds)
st.image(result_image, use_container_width=True)
else:
st.info("Bitte laden Sie ein Röntgenbild hoch (JPEG, PNG)")
if __name__ == "__main__":
main()