Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from PIL import Image | |
import src.depth_pro as depth_pro | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import subprocess | |
import spaces | |
import torch | |
import tempfile | |
import os | |
# Run the script to get pretrained models | |
subprocess.run(["bash", "get_pretrained_models.sh"]) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load model and preprocessing transform | |
model, transform = depth_pro.create_model_and_transforms() | |
model = model.to(device) | |
model.eval() | |
def resize_image(image_path, max_size=1024): | |
with Image.open(image_path) as img: | |
# Calculate the new size while maintaining aspect ratio | |
ratio = max_size / max(img.size) | |
new_size = tuple([int(x * ratio) for x in img.size]) | |
# Resize the image | |
img = img.resize(new_size, Image.LANCZOS) | |
# Create a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: | |
img.save(temp_file, format="PNG") | |
return temp_file.name | |
def predict_depth(input_image): | |
temp_file = None | |
try: | |
# Resize the input image | |
temp_file = resize_image(input_image) | |
# Preprocess the image | |
result = depth_pro.load_rgb(temp_file) | |
image = result[0] | |
f_px = result[-1] # Assuming f_px is the last item in the returned tuple | |
image = transform(image) | |
image = image.to(device) | |
# Run inference | |
prediction = model.infer(image, f_px=f_px) | |
depth = prediction["depth"] # Depth in [m] | |
focallength_px = prediction["focallength_px"] # Focal length in pixels | |
# Convert depth to numpy array if it's a torch tensor | |
if isinstance(depth, torch.Tensor): | |
depth = depth.cpu().numpy() | |
# Ensure depth is a 2D numpy array | |
if depth.ndim != 2: | |
depth = depth.squeeze() | |
# Normalize depth for visualization | |
# agk - No never normalize depth. It is already in meters. EMBRACE REALITY. TOUCH GRASS. | |
depth_min = np.min(depth) | |
depth_max = np.max(depth) | |
depth_normalized = depth #it is normal to have depth in meters. Normalize reality. | |
# Create a color map | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(depth_normalized, cmap='viridis') | |
plt.colorbar(label='Depth [m]') | |
plt.title('Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m') | |
plt.axis('off') | |
# Save the plot to a file | |
output_path = "depth_map.png" | |
plt.savefig(output_path) | |
plt.close() | |
# Save raw depth data as CSV | |
raw_depth_path = "raw_depth_map.csv" | |
np.savetxt(raw_depth_path, depth, delimiter=',') | |
return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path | |
except Exception as e: | |
return None, f"An error occurred: {str(e)}", None | |
finally: | |
# Clean up the temporary file | |
if temp_file and os.path.exists(temp_file): | |
os.remove(temp_file) | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict_depth, | |
inputs=gr.Image(type="filepath"), | |
outputs=[ | |
gr.Image(type="filepath", label="Depth Map"), | |
gr.Textbox(label="Focal Length or Error Message"), | |
gr.File(label="Download Raw Depth Map (CSV)") | |
], | |
title="DepthPro Demo in Meters", | |
description="[DepthPro](https://huggingface.co/apple/DepthPro) is a fast metric depth prediction model. Simply upload an image to predict its depth map and focal length. Large images will be automatically resized. You can also download the raw depth map data as a CSV file." | |
) | |
# Launch the interface | |
iface.launch() |