File size: 2,522 Bytes
80af471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c08496
80af471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from PIL import Image
import torch
import numpy as np
import requests
from io import BytesIO
from torchvision import transforms
import onnxruntime as ort

# ======================
# Model & Preprocessing
# ======================
MODEL_ONNX_URL = "https://huggingface.co/Adilbai/bone-age-resnet-80m/resolve/main/resnet_bone_age_80m.onnx"

def download_model(url, filename):
    if not os.path.exists(filename):
        print(f"Downloading model from {url}")
        r = requests.get(url)
        with open(filename, "wb") as f:
            f.write(r.content)

import os
MODEL_PATH = "resnet_bone_age_80m.onnx"
download_model(MODEL_ONNX_URL, MODEL_PATH)

# Set up ONNX session
ort_session = ort.InferenceSession(MODEL_PATH)

# Define image preprocessing (must match training)
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# ======================
# Inference Function
# ======================
def predict_bone_age(image, gender):
    """
    image: PIL.Image
    gender: string ("Male" or "Female")
    """
    # Preprocess image
    img_tensor = transform(image).unsqueeze(0).numpy()
    # Gender: 0=male, 1=female
    gender_val = 0.0 if gender.lower() == "male" else 1.0
    gender_tensor = np.array([gender_val], dtype=np.float32)

    # ONNX inference
    outputs = ort_session.run(None, {"image": img_tensor, "gender": gender_tensor})
    pred_age = outputs[0][0]

    # Display as years and months
    years = int(pred_age // 12)
    months = int(pred_age % 12)
    result_str = (
        f"Predicted Bone Age: **{pred_age:.1f} months**  \n"
        f"≈ {years} years, {months} months"
    )
    return result_str

# ======================
# Gradio UI
# ======================
app_title = "Bone Age Prediction from Hand X-ray"
app_desc = """
Upload a hand X-ray image and select the patient's gender. This app will predict the bone age (in months) using a powerful deep learning model.
- Model: [bone-age-resnet-80m](https://huggingface.co/Adilbai/bone-age-resnet-80m)
- Data: RSNA Pediatric Bone Age Challenge
- **For research/educational use only.**
"""
iface = gr.Interface(
    fn=predict_bone_age,
    inputs=[
        gr.Image(type="pil", label="Hand X-ray Image"),
        gr.Radio(["Male", "Female"], label="Gender")
    ],
    outputs=gr.Markdown(label="Prediction"),
    title=app_title,
    description=app_desc,
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()