A19grey's picture
Added ability to save raw depth output in meters and updated title plot
182cf21
raw
history blame
3.76 kB
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import tempfile
import os
# Run the script to get pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model and preprocessing transform
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()
def resize_image(image_path, max_size=1024):
with Image.open(image_path) as img:
# Calculate the new size while maintaining aspect ratio
ratio = max_size / max(img.size)
new_size = tuple([int(x * ratio) for x in img.size])
# Resize the image
img = img.resize(new_size, Image.LANCZOS)
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
img.save(temp_file, format="PNG")
return temp_file.name
@spaces.GPU(duration=20)
def predict_depth(input_image):
temp_file = None
try:
# Resize the input image
temp_file = resize_image(input_image)
# Preprocess the image
result = depth_pro.load_rgb(temp_file)
image = result[0]
f_px = result[-1] # Assuming f_px is the last item in the returned tuple
image = transform(image)
image = image.to(device)
# Run inference
prediction = model.infer(image, f_px=f_px)
depth = prediction["depth"] # Depth in [m]
focallength_px = prediction["focallength_px"] # Focal length in pixels
# Convert depth to numpy array if it's a torch tensor
if isinstance(depth, torch.Tensor):
depth = depth.cpu().numpy()
# Ensure depth is a 2D numpy array
if depth.ndim != 2:
depth = depth.squeeze()
# Normalize depth for visualization
# agk - No never normalize depth. It is already in meters. EMBRACE REALITY. TOUCH GRASS.
depth_min = np.min(depth)
depth_max = np.max(depth)
depth_normalized = depth #it is normal to have depth in meters. Normalize reality.
# Create a color map
plt.figure(figsize=(10, 10))
plt.imshow(depth_normalized, cmap='viridis')
plt.colorbar(label='Depth [m]')
plt.title('Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m')
plt.axis('off')
# Save the plot to a file
output_path = "depth_map.png"
plt.savefig(output_path)
plt.close()
# Save raw depth data as CSV
raw_depth_path = "raw_depth_map.csv"
np.savetxt(raw_depth_path, depth, delimiter=',')
return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path
except Exception as e:
return None, f"An error occurred: {str(e)}", None
finally:
# Clean up the temporary file
if temp_file and os.path.exists(temp_file):
os.remove(temp_file)
# Create Gradio interface
iface = gr.Interface(
fn=predict_depth,
inputs=gr.Image(type="filepath"),
outputs=[
gr.Image(type="filepath", label="Depth Map"),
gr.Textbox(label="Focal Length or Error Message"),
gr.File(label="Download Raw Depth Map (CSV)")
],
title="DepthPro Demo in Meters",
description="[DepthPro](https://huggingface.co/apple/DepthPro) is a fast metric depth prediction model. Simply upload an image to predict its depth map and focal length. Large images will be automatically resized. You can also download the raw depth map data as a CSV file."
)
# Launch the interface
iface.launch()