import gradio as gr import numpy as np import matplotlib.pyplot as plt from PIL import Image from segmentation import predict as segmentation_predict from depth_estimation import predict as depth_estimation_predict def predict(image, color_map): # inference mask_image = segmentation_predict(image) segmented_image = Image.composite( image, Image.new("RGB", image.size, (0, 0, 0)), mask_image.convert("L") ) depth_image = depth_estimation_predict(segmented_image) # apply matplotlib colormap (e.g., viridis) depth_array = np.array(depth_image) # Convert PIL image to NumPy array colormap = plt.get_cmap(color_map) # Choose a colormap depth_colored = colormap(depth_array / 255.0) # Normalize and apply colormap depth_colored = (depth_colored * 255).astype(np.uint8) # Convert to RGB (discard alpha) depth_colored = Image.fromarray(depth_colored) return depth_colored color_maps = [ 'viridis', 'plasma', 'inferno', 'magma', 'cividis', 'Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds', 'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu', 'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn', 'binary', 'gist_yarg', 'gist_gray', 'gray', 'bone', 'pink', 'spring', 'summer', 'autumn', 'winter', 'cool', 'Wistia', 'hot', 'afmhot', 'gist_heat', 'copper', 'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn', 'Spectral', 'coolwarm', 'bwr', 'seismic', 'twilight', 'twilight_shifted', 'hsv', 'Pastel1', 'Pastel2', 'Paired', 'Accent', 'Dark2', 'Set1', 'Set2', 'Set3', 'tab10', 'tab20', 'tab20b', 'tab20c', 'flag', 'prism', 'ocean', 'gist_earth', 'terrain', 'gist_stern', 'gnuplot', 'gnuplot2', 'CMRmap', 'cubehelix', 'brg', 'gist_rainbow', 'rainbow', 'jet', 'turbo', 'nipy_spectral', 'gist_ncar', ] examples = [ ["assets/examples/myself.jpeg", "afmhot"], ["assets/examples/myself.jpeg", "inferno"], ] interface = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil"), gr.Dropdown(choices=color_maps), ], outputs=gr.Image(type="pil"), title="DepthPro: Colorify", description="Applies segmentation on the input image, then creates the depth map and finally colorizes it.", examples=examples, ) if __name__ == "__main__": interface.launch()