radpid / app.py
yassonee's picture
Update app.py
5db7880 verified
raw
history blame
5.49 kB
import streamlit as st
from transformers import pipeline
from PIL import Image, ImageDraw
st.set_page_config(
page_title="Fraktur Detektion",
layout="wide",
initial_sidebar_state="collapsed"
)
st.markdown("""
<style>
.stApp {
padding: 0 !important;
height: 100vh !important;
overflow: hidden !important;
}
.block-container {
padding: 0.25rem !important;
max-width: 100% !important;
}
.stImage > img {
max-height: 150px !important;
object-fit: contain !important;
}
h2, h3 {
font-size: 0.9rem !important;
}
.result-box {
font-size: 0.8rem !important;
margin: 0.2rem 0 !important;
}
.center-container {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100%;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_models():
return {
"KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"RöntgenMeister": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
def translate_label(label):
translations = {
"fracture": "Knochenbruch",
"no fracture": "Kein Bruch",
"normal": "Normal",
"abnormal": "Auffällig"
}
return translations.get(label.lower(), label)
def draw_boxes(image, predictions):
draw = ImageDraw.Draw(image)
for pred in predictions:
box = pred['box']
label = f"{translate_label(pred['label'])} ({pred['score']:.2%})"
color = "#2563eb" if pred['score'] > 0.7 else "#eab308"
draw.rectangle(
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
outline=color,
width=2
)
# Ajouter des points de "chaleur" aux fractures détectées
center_x = (box['xmin'] + box['xmax']) / 2
center_y = (box['ymin'] + box['ymax']) / 2
radius = 5
draw.ellipse(
[(center_x - radius, center_y - radius), (center_x + radius, center_y + radius)],
fill=color
)
# Label plus compact
draw.text((box['xmin'], box['ymin'] - 15), label, fill="white")
return image
def main():
models = load_models()
if "uploaded" not in st.session_state:
st.session_state["uploaded"] = False
if not st.session_state["uploaded"]:
st.markdown("""
<div class="center-container">
<h2>📤 Röntgenbild Hochladen</h2>
<p>Bitte laden Sie ein Röntgenbild hoch, um die Analyse zu starten.</p>
</div>
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader("Röntgenbild auswählen", type=['png', 'jpg', 'jpeg'], label_visibility="collapsed")
if uploaded_file:
st.session_state["uploaded"] = True
st.session_state["file"] = uploaded_file
st.session_state["analyze"] = False
else:
uploaded_file = st.session_state["file"]
if not st.session_state.get("analyze", False):
if st.button("🔍 Analyse starten"):
st.session_state["analyze"] = True
if st.session_state["analyze"]:
col1, col2, col3 = st.columns([1, 1.5, 1])
with col1:
st.markdown("### 🎯 KI-Analyse")
st.markdown("**🛡️ Der KnochenWächter**")
image = Image.open(uploaded_file)
predictions_wachter = models["KnochenWächter"](image)
for pred in predictions_wachter:
score_color = "#22c55e" if pred['score'] > 0.7 else "#eab308"
st.markdown(f"""
<div class='result-box'>
<span style='color: {score_color}; font-weight: 500;'>
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
st.markdown("**🎓 Der RöntgenMeister**")
predictions_meister = models["RöntgenMeister"](image)
for pred in predictions_meister:
score_color = "#22c55e" if pred['score'] > 0.7 else "#eab308"
st.markdown(f"""
<div class='result-box'>
<span style='color: {score_color}; font-weight: 500;'>
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
""", unsafe_allow_html=True)
with col2:
st.image(image, use_container_width=True)
predictions_auge = models["KnochenAuge"](image)
filtered_preds = [p for p in predictions_auge if p['score'] >= 0.6]
if filtered_preds:
with col3:
st.markdown("### 👁️ Das KnochenAuge - Lokalisation")
result_image = image.copy()
result_image = draw_boxes(result_image, filtered_preds)
st.image(result_image, use_container_width=True)
if __name__ == "__main__":
main()