Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,196 Bytes
cf8c487 41a69fa cf8c487 dc9d69c 404967f ee04d83 024a2b8 71014b1 dc9d69c cf8c487 764b436 cf8c487 764b436 cf8c487 501d06f ee04d83 501d06f 764b436 cf8c487 ee04d83 404967f 501d06f ee04d83 501d06f 404967f ee04d83 404967f 764b436 404967f ee04d83 cf8c487 404967f cf8c487 501d06f cf8c487 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import tempfile
import os
# Run the script to get pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model and preprocessing transform
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()
def resize_image(image_path, max_size=1024):
with Image.open(image_path) as img:
# Calculate the new size while maintaining aspect ratio
ratio = max_size / max(img.size)
new_size = tuple([int(x * ratio) for x in img.size])
# Resize the image
img = img.resize(new_size, Image.LANCZOS)
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
img.save(temp_file, format="PNG")
return temp_file.name
@spaces.GPU(duration=20)
def predict_depth(input_image):
temp_file = None
try:
# Resize the input image
temp_file = resize_image(input_image)
# Preprocess the image
result = depth_pro.load_rgb(temp_file)
image = result[0]
f_px = result[-1] # Assuming f_px is the last item in the returned tuple
image = transform(image)
image = image.to(device)
# Run inference
prediction = model.infer(image, f_px=f_px)
depth = prediction["depth"] # Depth in [m]
focallength_px = prediction["focallength_px"] # Focal length in pixels
# Convert depth to numpy array if it's a torch tensor
if isinstance(depth, torch.Tensor):
depth = depth.cpu().numpy()
# Ensure depth is a 2D numpy array
if depth.ndim != 2:
depth = depth.squeeze()
# Normalize depth for visualization
depth_min = np.min(depth)
depth_max = np.max(depth)
depth_normalized = (depth - depth_min) / (depth_max - depth_min)
# Create a color map
plt.figure(figsize=(10, 10))
plt.imshow(depth_normalized, cmap='viridis')
plt.colorbar(label='Depth')
plt.title('Predicted Depth Map')
plt.axis('off')
# Save the plot to a file
output_path = "depth_map.png"
plt.savefig(output_path)
plt.close()
return output_path, f"Focal length: {focallength_px:.2f} pixels"
except Exception as e:
return None, f"An error occurred: {str(e)}"
finally:
# Clean up the temporary file
if temp_file and os.path.exists(temp_file):
os.remove(temp_file)
# Create Gradio interface
iface = gr.Interface(
fn=predict_depth,
inputs=gr.Image(type="filepath"),
outputs=[gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length or Error Message")],
title="Depth Prediction Demo",
description="Upload an image to predict its depth map and focal length. Large images will be automatically resized."
)
# Launch the interface
iface.launch() |