import os import numpy as np from PIL import Image import tensorflow as tf from tensorflow.keras.applications.mobilenet_v2 import preprocess_input import gradio as gr # Try to load model from various sources model = None # Try local files first (for development) local_model_paths = ["saved_model", "best_model.h5", "final_model.h5"] for path in local_model_paths: if os.path.exists(path): try: model = tf.keras.models.load_model(path, compile=False) print(f"Loaded model from local path: {path}") break except Exception as e: print(f"Failed to load local model from {path}: {e}") # If no local model, try to download from Hugging Face Hub if model is None: HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "Sharris/age_detection_regression") try: from huggingface_hub import hf_hub_download # Try to download the .h5 model file model_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="best_model.h5") model = tf.keras.models.load_model(model_path, compile=False) print(f"Loaded model from HF Hub: {HF_MODEL_ID}/best_model.h5") except Exception as e: print(f"Failed to load model from HF Hub ({HF_MODEL_ID}): {e}") # Fallback: try to download entire repo and load from there try: from huggingface_hub import snapshot_download repo_dir = snapshot_download(repo_id=HF_MODEL_ID) model_file = os.path.join(repo_dir, "best_model.h5") if os.path.exists(model_file): model = tf.keras.models.load_model(model_file, compile=False) print(f"Loaded model from downloaded repo: {model_file}") except Exception as e2: print(f"Fallback download also failed: {e2}") if model is None: raise RuntimeError( "No model found. Ensure 'best_model.h5' exists locally or set HF_MODEL_ID env var to a Hugging Face model repo containing the model." ) INPUT_SIZE = (224, 224) def predict_age(image: Image.Image): if image.mode != 'RGB': image = image.convert('RGB') image = image.resize(INPUT_SIZE) arr = np.array(image).astype(np.float32) arr = preprocess_input(arr) arr = np.expand_dims(arr, 0) pred = model.predict(arr)[0] # Ensure scalar if hasattr(pred, '__len__'): pred = float(np.asarray(pred).squeeze()) else: pred = float(pred) return { "predicted_age": round(pred, 2), "raw_output": float(pred) } demo = gr.Interface( fn=predict_age, inputs=gr.Image(type='pil', label='Face image (crop to face for best results)'), outputs=[ gr.Number(label='Predicted age (years)'), gr.Number(label='Raw model output') ], examples=[], title='UTKFace Age Estimator', description='Upload a cropped face image and the model will predict age in years. For Spaces, set the HF_MODEL_ID environment variable to your Hugging Face model repo if you want the app to download a SavedModel from the Hub.' ) if __name__ == '__main__': demo.launch(server_name='0.0.0.0', server_port=int(os.environ.get('PORT', 7860)))