Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import torch | |
import numpy as np | |
import requests | |
from io import BytesIO | |
from torchvision import transforms | |
import onnxruntime as ort | |
# ====================== | |
# Model & Preprocessing | |
# ====================== | |
MODEL_ONNX_URL = "https://huggingface.co/Adilbai/bone-age-resnet-80m/resolve/main/resnet_bone_age_80m.onnx" | |
def download_model(url, filename): | |
if not os.path.exists(filename): | |
print(f"Downloading model from {url}") | |
r = requests.get(url) | |
with open(filename, "wb") as f: | |
f.write(r.content) | |
import os | |
MODEL_PATH = "resnet_bone_age_80m.onnx" | |
download_model(MODEL_ONNX_URL, MODEL_PATH) | |
# Set up ONNX session | |
ort_session = ort.InferenceSession(MODEL_PATH) | |
# Define image preprocessing (must match training) | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5]*3, [0.5]*3) | |
]) | |
# ====================== | |
# Inference Function | |
# ====================== | |
def predict_bone_age(image, gender): | |
""" | |
image: PIL.Image | |
gender: string ("Male" or "Female") | |
""" | |
# Preprocess image | |
img_tensor = transform(image).unsqueeze(0).numpy() | |
# Gender: 0=male, 1=female | |
gender_val = 0.0 if gender.lower() == "male" else 1.0 | |
gender_tensor = np.array([gender_val], dtype=np.float32) | |
# ONNX inference | |
outputs = ort_session.run(None, {"image": img_tensor, "gender": gender_tensor}) | |
pred_age = outputs[0][0] | |
# Display as years and months | |
years = int(pred_age // 12) | |
months = int(pred_age % 12) | |
result_str = ( | |
f"Predicted Bone Age: **{pred_age:.1f} months** \n" | |
f"≈ {years} years, {months} months" | |
) | |
return result_str | |
# ====================== | |
# Gradio UI | |
# ====================== | |
app_title = "Bone Age Prediction from Hand X-ray" | |
app_desc = """ | |
Upload a hand X-ray image and select the patient's gender. This app will predict the bone age (in months) using a powerful deep learning model. | |
- Model: [bone-age-resnet-80m](https://huggingface.co/Adilbai/bone-age-resnet-80m) | |
- Data: RSNA Pediatric Bone Age Challenge | |
- **For research/educational use only.** | |
""" | |
iface = gr.Interface( | |
fn=predict_bone_age, | |
inputs=[ | |
gr.Image(type="pil", label="Hand X-ray Image"), | |
gr.Radio(["Male", "Female"], label="Gender") | |
], | |
outputs=gr.Markdown(label="Prediction"), | |
title=app_title, | |
description=app_desc, | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
iface.launch() |