|
import streamlit as st |
|
from transformers import pipeline |
|
from PIL import Image, ImageDraw |
|
|
|
st.set_page_config( |
|
page_title="Fraktur Detektion", |
|
layout="wide", |
|
initial_sidebar_state="collapsed" |
|
) |
|
|
|
st.markdown(""" |
|
<style> |
|
.stApp { |
|
padding: 0 !important; |
|
height: 100vh !important; |
|
overflow: hidden !important; |
|
} |
|
|
|
.block-container { |
|
padding: 0.25rem !important; |
|
max-width: 100% !important; |
|
} |
|
|
|
.stImage > img { |
|
max-height: 150px !important; |
|
object-fit: contain !important; |
|
} |
|
|
|
h2, h3 { |
|
font-size: 0.9rem !important; |
|
} |
|
|
|
.result-box { |
|
font-size: 0.8rem !important; |
|
margin: 0.2rem 0 !important; |
|
} |
|
|
|
.center-container { |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
justify-content: center; |
|
height: 100%; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
@st.cache_resource |
|
def load_models(): |
|
return { |
|
"KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"), |
|
"KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"), |
|
"RöntgenMeister": pipeline("image-classification", |
|
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388") |
|
} |
|
|
|
def translate_label(label): |
|
translations = { |
|
"fracture": "Knochenbruch", |
|
"no fracture": "Kein Bruch", |
|
"normal": "Normal", |
|
"abnormal": "Auffällig" |
|
} |
|
return translations.get(label.lower(), label) |
|
|
|
def draw_boxes(image, predictions): |
|
draw = ImageDraw.Draw(image) |
|
for pred in predictions: |
|
box = pred['box'] |
|
label = f"{translate_label(pred['label'])} ({pred['score']:.2%})" |
|
color = "#2563eb" if pred['score'] > 0.7 else "#eab308" |
|
|
|
draw.rectangle( |
|
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])], |
|
outline=color, |
|
width=2 |
|
) |
|
|
|
|
|
center_x = (box['xmin'] + box['xmax']) / 2 |
|
center_y = (box['ymin'] + box['ymax']) / 2 |
|
radius = 5 |
|
draw.ellipse( |
|
[(center_x - radius, center_y - radius), (center_x + radius, center_y + radius)], |
|
fill=color |
|
) |
|
|
|
|
|
draw.text((box['xmin'], box['ymin'] - 15), label, fill="white") |
|
return image |
|
|
|
def main(): |
|
models = load_models() |
|
|
|
if "uploaded" not in st.session_state: |
|
st.session_state["uploaded"] = False |
|
|
|
if not st.session_state["uploaded"]: |
|
st.markdown(""" |
|
<div class="center-container"> |
|
<h2>📤 Röntgenbild Hochladen</h2> |
|
<p>Bitte laden Sie ein Röntgenbild hoch, um die Analyse zu starten.</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
uploaded_file = st.file_uploader("Röntgenbild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed") |
|
|
|
if uploaded_file: |
|
st.session_state["uploaded"] = True |
|
st.session_state["file"] = uploaded_file |
|
st.session_state["analyze"] = False |
|
else: |
|
uploaded_file = st.session_state["file"] |
|
|
|
if not st.session_state.get("analyze", False): |
|
if st.button("🔍 Analyse starten"): |
|
st.session_state["analyze"] = True |
|
|
|
if st.session_state["analyze"]: |
|
col1, col2, col3 = st.columns([1, 1.5, 1]) |
|
|
|
with col1: |
|
st.markdown("### 🎯 KI-Analyse") |
|
|
|
st.markdown("**🛡️ Der KnochenWächter**") |
|
image = Image.open(uploaded_file) |
|
predictions_wachter = models["KnochenWächter"](image) |
|
for pred in predictions_wachter: |
|
score_color = "#22c55e" if pred['score'] > 0.7 else "#eab308" |
|
st.markdown(f""" |
|
<div class='result-box'> |
|
<span style='color: {score_color}; font-weight: 500;'> |
|
{pred['score']:.1%} |
|
</span> - {translate_label(pred['label'])} |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
st.markdown("**🎓 Der RöntgenMeister**") |
|
predictions_meister = models["RöntgenMeister"](image) |
|
for pred in predictions_meister: |
|
score_color = "#22c55e" if pred['score'] > 0.7 else "#eab308" |
|
st.markdown(f""" |
|
<div class='result-box'> |
|
<span style='color: {score_color}; font-weight: 500;'> |
|
{pred['score']:.1%} |
|
</span> - {translate_label(pred['label'])} |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
with col2: |
|
st.image(image, use_container_width=True) |
|
|
|
predictions_auge = models["KnochenAuge"](image) |
|
filtered_preds = [p for p in predictions_auge if p['score'] >= 0.6] |
|
|
|
if filtered_preds: |
|
with col3: |
|
st.markdown("### 👁️ Das KnochenAuge - Lokalisation") |
|
result_image = image.copy() |
|
result_image = draw_boxes(result_image, filtered_preds) |
|
st.image(result_image, use_container_width=True) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|