import gradio as gr import torch from PIL import Image import cv2 import numpy as np from geobench.modeling.archs.dam.dam import DepthAnything from geobench.utils.image_util import colorize_depth_maps from geobench.midas.transforms import Resize, NormalizeImage, PrepareForNet from torchvision.transforms import Compose import os # Helper function to load model (same as your original code) def load_model_by_name(arch_name, checkpoint_path, device): if arch_name == 'depthanything': if '.safetensors' in checkpoint_path: model = DepthAnything.from_pretrained(os.path.dirname(checkpoint_path)).to(device) else: raise NotImplementedError("Model architecture not implemented.") else: raise NotImplementedError(f"Unknown architecture: {arch_name}") return model # Image processing function (same as your original code, modified for Gradio) def process_image(image, model, device, mode='rel_depth'): # Preprocess the image image_np = np.array(image)[..., ::-1] / 255 transform = Compose([ Resize(512, 512, resize_target=None, keep_aspect_ratio=False, ensure_multiple_of=32, image_interpolation_method=cv2.INTER_CUBIC), NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), PrepareForNet() ]) image_tensor = transform({'image': image_np})['image'] image_tensor = torch.from_numpy(image_tensor).unsqueeze(0).to(device) with torch.no_grad(): # Disable autograd since we don't need gradients on CPU pred_disp, _ = model(image_tensor) pred_disp_np = pred_disp.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0) pred_disp = (pred_disp_np - pred_disp_np.min()) / (pred_disp_np.max() - pred_disp_np.min()) # Colorize depth map cmap = "Spectral_r" if mode != 'metric' else 'Spectral_r' depth_colored = colorize_depth_maps(pred_disp[None, ...], 0, 1, cmap=cmap).squeeze() depth_colored = (depth_colored * 255).astype(np.uint8) depth_image = Image.fromarray(depth_colored) return depth_image # Gradio interface function def gradio_interface(image, mode='rel_depth'): # Set device to CPU explicitly device = torch.device("cpu") # Force using CPU model = load_model_by_name("depthanything", "your_checkpoint_path_here", device) # Process image and return output return process_image(image, model, device, mode) # Create Gradio interface iface = gr.Interface( fn=gradio_interface, inputs=[gr.Image(type="pil"), gr.Dropdown(choices=['rel_depth', 'metric_depth', 'disparity'], label="Mode")], outputs=gr.Image(type="pil"), title="Depth Estimation Demo", description="Upload an image to see the depth estimation results." ) # Launch the Gradio interface iface.launch()