import gradio as gr from PIL import Image import torch import numpy as np import requests from io import BytesIO from torchvision import transforms import onnxruntime as ort # ====================== # Model & Preprocessing # ====================== MODEL_ONNX_URL = "https://huggingface.co/Adilbai/bone-age-resnet-80m/resolve/main/resnet_bone_age_80m.onnx" def download_model(url, filename): if not os.path.exists(filename): print(f"Downloading model from {url}") r = requests.get(url) with open(filename, "wb") as f: f.write(r.content) import os MODEL_PATH = "resnet_bone_age_80m.onnx" download_model(MODEL_ONNX_URL, MODEL_PATH) # Set up ONNX session ort_session = ort.InferenceSession(MODEL_PATH) # Define image preprocessing (must match training) transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3) ]) # ====================== # Inference Function # ====================== def predict_bone_age(image, gender): """ image: PIL.Image gender: string ("Male" or "Female") """ # Preprocess image img_tensor = transform(image).unsqueeze(0).numpy() # Gender: 0=male, 1=female gender_val = 0.0 if gender.lower() == "male" else 1.0 gender_tensor = np.array([gender_val], dtype=np.float32) # ONNX inference outputs = ort_session.run(None, {"image": img_tensor, "gender": gender_tensor}) pred_age = outputs[0][0] # Display as years and months years = int(pred_age // 12) months = int(pred_age % 12) result_str = ( f"Predicted Bone Age: **{pred_age:.1f} months** \n" f"≈ {years} years, {months} months" ) return result_str # ====================== # Gradio UI # ====================== app_title = "Bone Age Prediction from Hand X-ray" app_desc = """ Upload a hand X-ray image and select the patient's gender. This app will predict the bone age (in months) using a powerful deep learning model. - Model: [bone-age-resnet-80m](https://huggingface.co/Adilbai/bone-age-resnet-80m) - Data: RSNA Pediatric Bone Age Challenge - **For research/educational use only.** """ iface = gr.Interface( fn=predict_bone_age, inputs=[ gr.Image(type="pil", label="Hand X-ray Image"), gr.Radio(["Male", "Female"], label="Gender") ], outputs=gr.Markdown(label="Prediction"), title=app_title, description=app_desc, allow_flagging="never" ) if __name__ == "__main__": iface.launch()