File size: 3,761 Bytes
cf8c487
 
41a69fa
cf8c487
 
dc9d69c
 
404967f
ee04d83
 
024a2b8
71014b1
dc9d69c
cf8c487
764b436
 
cf8c487
 
764b436
cf8c487
 
501d06f
 
 
 
 
 
 
 
 
ee04d83
 
 
 
501d06f
764b436
cf8c487
ee04d83
404967f
501d06f
ee04d83
501d06f
404967f
ee04d83
404967f
 
 
764b436
404967f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182cf21
404967f
 
182cf21
404967f
 
 
 
182cf21
 
404967f
 
 
 
 
 
 
182cf21
 
 
 
 
404967f
182cf21
ee04d83
 
 
 
cf8c487
 
 
 
 
182cf21
 
 
 
 
 
 
cf8c487
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import tempfile
import os

# Run the script to get pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load model and preprocessing transform
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()

def resize_image(image_path, max_size=1024):
    with Image.open(image_path) as img:
        # Calculate the new size while maintaining aspect ratio
        ratio = max_size / max(img.size)
        new_size = tuple([int(x * ratio) for x in img.size])
        
        # Resize the image
        img = img.resize(new_size, Image.LANCZOS)
        
        # Create a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
            img.save(temp_file, format="PNG")
            return temp_file.name

@spaces.GPU(duration=20)
def predict_depth(input_image):
    temp_file = None
    try:
        # Resize the input image
        temp_file = resize_image(input_image)
        
        # Preprocess the image
        result = depth_pro.load_rgb(temp_file)
        image = result[0]
        f_px = result[-1]  # Assuming f_px is the last item in the returned tuple
        image = transform(image)
        image = image.to(device)

        # Run inference
        prediction = model.infer(image, f_px=f_px)
        depth = prediction["depth"]  # Depth in [m]
        focallength_px = prediction["focallength_px"]  # Focal length in pixels

        # Convert depth to numpy array if it's a torch tensor
        if isinstance(depth, torch.Tensor):
            depth = depth.cpu().numpy()

        # Ensure depth is a 2D numpy array
        if depth.ndim != 2:
            depth = depth.squeeze()

        # Normalize depth for visualization
        # agk - No never normalize depth. It is already in meters. EMBRACE REALITY. TOUCH GRASS.
        depth_min = np.min(depth)
        depth_max = np.max(depth)
        depth_normalized = depth #it is normal to have depth in meters. Normalize reality.
        
        # Create a color map
        plt.figure(figsize=(10, 10))
        plt.imshow(depth_normalized, cmap='viridis')
        plt.colorbar(label='Depth [m]')
        plt.title('Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m')
        plt.axis('off')
        
        # Save the plot to a file
        output_path = "depth_map.png"
        plt.savefig(output_path)
        plt.close()

        # Save raw depth data as CSV
        raw_depth_path = "raw_depth_map.csv"
        np.savetxt(raw_depth_path, depth, delimiter=',')

        return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path
    except Exception as e:
        return None, f"An error occurred: {str(e)}", None
    finally:
        # Clean up the temporary file
        if temp_file and os.path.exists(temp_file):
            os.remove(temp_file)

# Create Gradio interface
iface = gr.Interface(
    fn=predict_depth,
    inputs=gr.Image(type="filepath"),
    outputs=[
        gr.Image(type="filepath", label="Depth Map"),
        gr.Textbox(label="Focal Length or Error Message"),
        gr.File(label="Download Raw Depth Map (CSV)")
    ],
    title="DepthPro Demo in Meters",
    description="[DepthPro](https://huggingface.co/apple/DepthPro) is a fast metric depth prediction model. Simply upload an image to predict its depth map and focal length. Large images will be automatically resized. You can also download the raw depth map data as a CSV file."
)

# Launch the interface
iface.launch()