File size: 2,877 Bytes
cf8c487
 
41a69fa
cf8c487
 
dc9d69c
 
404967f
501d06f
024a2b8
71014b1
dc9d69c
cf8c487
 
 
 
 
501d06f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc9d69c
cf8c487
404967f
501d06f
 
 
404967f
501d06f
404967f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf8c487
 
 
 
 
404967f
cf8c487
501d06f
cf8c487
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import io

# Run the script to get pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])

# Load model and preprocessing transform
model, transform = depth_pro.create_model_and_transforms()
model.eval()

def resize_image(image_path, max_size=1024):
    with Image.open(image_path) as img:
        # Calculate the new size while maintaining aspect ratio
        ratio = max_size / max(img.size)
        new_size = tuple([int(x * ratio) for x in img.size])
        
        # Resize the image
        img = img.resize(new_size, Image.LANCZOS)
        
        # Save the resized image to a bytes buffer
        buffer = io.BytesIO()
        img.save(buffer, format="PNG")
        buffer.seek(0)
        
        return buffer

@spaces.GPU(duration=120)
def predict_depth(input_image):
    try:
        # Resize the input image
        resized_image = resize_image(input_image)
        
        # Preprocess the image
        result = depth_pro.load_rgb(resized_image)
        image = result[0]
        f_px = result[-1]  # Assuming f_px is the last item in the returned tuple
        image = transform(image)

        # Run inference
        prediction = model.infer(image, f_px=f_px)
        depth = prediction["depth"]  # Depth in [m]
        focallength_px = prediction["focallength_px"]  # Focal length in pixels

        # Convert depth to numpy array if it's a torch tensor
        if isinstance(depth, torch.Tensor):
            depth = depth.cpu().numpy()

        # Ensure depth is a 2D numpy array
        if depth.ndim != 2:
            depth = depth.squeeze()

        # Normalize depth for visualization
        depth_min = np.min(depth)
        depth_max = np.max(depth)
        depth_normalized = (depth - depth_min) / (depth_max - depth_min)
        
        # Create a color map
        plt.figure(figsize=(10, 10))
        plt.imshow(depth_normalized, cmap='viridis')
        plt.colorbar(label='Depth')
        plt.title('Predicted Depth Map')
        plt.axis('off')
        
        # Save the plot to a file
        output_path = "depth_map.png"
        plt.savefig(output_path)
        plt.close()

        return output_path, f"Focal length: {focallength_px:.2f} pixels"
    except Exception as e:
        return None, f"An error occurred: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=predict_depth,
    inputs=gr.Image(type="filepath"),
    outputs=[gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length or Error Message")],
    title="Depth Prediction Demo",
    description="Upload an image to predict its depth map and focal length. Large images will be automatically resized."
)

# Launch the interface
iface.launch()