|
import streamlit as st |
|
from transformers import pipeline |
|
from PIL import Image, ImageDraw |
|
import torch |
|
|
|
st.set_page_config(page_title="Multi-Model Fracture Detection", layout="wide") |
|
|
|
@st.cache_resource |
|
def load_models(): |
|
models = { |
|
"D3STRON (Object Detection)": pipeline("object-detection", model="D3STRON/bone-fracture-detr"), |
|
"Heem2 (Classification)": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"), |
|
"Akhileshav8 (Classification)": pipeline("image-classification", model="akhileshav8/image_classification_for_fracture"), |
|
"Nandodeomkar (Classification)": pipeline("image-classification", model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388"), |
|
"Anirban22 (Object Detection)": pipeline("object-detection", model="anirban22/detr-resnet-50-med_fracture") |
|
} |
|
return models |
|
|
|
def draw_boxes(image, predictions): |
|
draw = ImageDraw.Draw(image) |
|
for pred in predictions: |
|
box = pred['box'] |
|
label = f"{pred['label']} ({pred['score']:.2%})" |
|
|
|
draw.rectangle( |
|
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])], |
|
outline="red", |
|
width=3 |
|
) |
|
|
|
text_bbox = draw.textbbox((box['xmin'], box['ymin']), label) |
|
draw.rectangle(text_bbox, fill="red") |
|
draw.text((box['xmin'], box['ymin']), label, fill="white") |
|
return image |
|
|
|
def process_classification(model, image, conf_threshold): |
|
predictions = model(image) |
|
results = [] |
|
for pred in predictions: |
|
if pred['score'] >= conf_threshold: |
|
results.append(f"{pred['label']}: {pred['score']:.2%}") |
|
return results |
|
|
|
def process_detection(model, image, conf_threshold): |
|
predictions = model(image) |
|
return [pred for pred in predictions if pred['score'] >= conf_threshold] |
|
|
|
def main(): |
|
st.title("🦴 Multi-Model Fracture Detection") |
|
|
|
models = load_models() |
|
|
|
uploaded_file = st.file_uploader("Upload X-ray image", type=['png', 'jpg', 'jpeg']) |
|
|
|
conf_threshold = st.slider( |
|
"Confidence threshold", |
|
min_value=0.0, |
|
max_value=1.0, |
|
value=0.3, |
|
step=0.01 |
|
) |
|
|
|
if uploaded_file: |
|
image = Image.open(uploaded_file) |
|
max_size = (400, 400) |
|
image.thumbnail(max_size, Image.Resampling.LANCZOS) |
|
|
|
st.image(image, caption="Original Image", width=400) |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.subheader("Classification Models") |
|
for name, model in models.items(): |
|
if "Classification" in name: |
|
st.write(f"**{name}**") |
|
with st.spinner(f"Running {name}..."): |
|
results = process_classification(model, image, conf_threshold) |
|
for result in results: |
|
st.write(f"• {result}") |
|
|
|
with col2: |
|
st.subheader("Object Detection Models") |
|
for name, model in models.items(): |
|
if "Object Detection" in name: |
|
st.write(f"**{name}**") |
|
with st.spinner(f"Running {name}..."): |
|
detections = process_detection(model, image, conf_threshold) |
|
if detections: |
|
result_image = image.copy() |
|
result_image = draw_boxes(result_image, detections) |
|
st.image(result_image, caption=f"Results from {name}") |
|
for det in detections: |
|
st.write(f"• {det['label']}: {det['score']:.2%}") |
|
else: |
|
st.write("No detections above threshold") |
|
|
|
if __name__ == "__main__": |
|
main() |