AdilzhanB
App change
9c08496
import gradio as gr
from PIL import Image
import torch
import numpy as np
import requests
from io import BytesIO
from torchvision import transforms
import onnxruntime as ort
# ======================
# Model & Preprocessing
# ======================
MODEL_ONNX_URL = "https://huggingface.co/Adilbai/bone-age-resnet-80m/resolve/main/resnet_bone_age_80m.onnx"
def download_model(url, filename):
if not os.path.exists(filename):
print(f"Downloading model from {url}")
r = requests.get(url)
with open(filename, "wb") as f:
f.write(r.content)
import os
MODEL_PATH = "resnet_bone_age_80m.onnx"
download_model(MODEL_ONNX_URL, MODEL_PATH)
# Set up ONNX session
ort_session = ort.InferenceSession(MODEL_PATH)
# Define image preprocessing (must match training)
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
# ======================
# Inference Function
# ======================
def predict_bone_age(image, gender):
"""
image: PIL.Image
gender: string ("Male" or "Female")
"""
# Preprocess image
img_tensor = transform(image).unsqueeze(0).numpy()
# Gender: 0=male, 1=female
gender_val = 0.0 if gender.lower() == "male" else 1.0
gender_tensor = np.array([gender_val], dtype=np.float32)
# ONNX inference
outputs = ort_session.run(None, {"image": img_tensor, "gender": gender_tensor})
pred_age = outputs[0][0]
# Display as years and months
years = int(pred_age // 12)
months = int(pred_age % 12)
result_str = (
f"Predicted Bone Age: **{pred_age:.1f} months** \n"
f"≈ {years} years, {months} months"
)
return result_str
# ======================
# Gradio UI
# ======================
app_title = "Bone Age Prediction from Hand X-ray"
app_desc = """
Upload a hand X-ray image and select the patient's gender. This app will predict the bone age (in months) using a powerful deep learning model.
- Model: [bone-age-resnet-80m](https://huggingface.co/Adilbai/bone-age-resnet-80m)
- Data: RSNA Pediatric Bone Age Challenge
- **For research/educational use only.**
"""
iface = gr.Interface(
fn=predict_bone_age,
inputs=[
gr.Image(type="pil", label="Hand X-ray Image"),
gr.Radio(["Male", "Female"], label="Gender")
],
outputs=gr.Markdown(label="Prediction"),
title=app_title,
description=app_desc,
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()