geetu040's picture
fix index
b95c986
raw
history blame
2.3 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from segmentation import predict as segmentation_predict
from depth_estimation import predict as depth_estimation_predict
def predict(image, color_map):
# inference
mask_image = segmentation_predict(image)
segmented_image = Image.composite(
image,
Image.new("RGB", image.size, (0, 0, 0)),
mask_image.convert("L")
)
depth_image = depth_estimation_predict(segmented_image)
# apply matplotlib colormap (e.g., viridis)
depth_array = np.array(depth_image) # Convert PIL image to NumPy array
colormap = plt.get_cmap(color_map) # Choose a colormap
depth_colored = colormap(depth_array / 255.0) # Normalize and apply colormap
depth_colored = (depth_colored * 255).astype(np.uint8) # Convert to RGB (discard alpha)
depth_colored = Image.fromarray(depth_colored)
return depth_colored
color_maps = [
'viridis', 'plasma', 'inferno', 'magma', 'cividis',
'Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds',
'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn',
'binary', 'gist_yarg', 'gist_gray', 'gray', 'bone',
'pink', 'spring', 'summer', 'autumn', 'winter', 'cool',
'Wistia', 'hot', 'afmhot', 'gist_heat', 'copper',
'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu',
'RdYlGn', 'Spectral', 'coolwarm', 'bwr', 'seismic',
'twilight', 'twilight_shifted', 'hsv',
'Pastel1', 'Pastel2', 'Paired', 'Accent', 'Dark2',
'Set1', 'Set2', 'Set3', 'tab10', 'tab20', 'tab20b', 'tab20c',
'flag', 'prism', 'ocean', 'gist_earth', 'terrain',
'gist_stern', 'gnuplot', 'gnuplot2', 'CMRmap',
'cubehelix', 'brg', 'gist_rainbow', 'rainbow', 'jet',
'turbo', 'nipy_spectral', 'gist_ncar',
]
examples = [
["assets/examples/myself.jpeg", "afmhot"],
["assets/examples/myself.jpeg", "inferno"],
]
interface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Dropdown(choices=color_maps),
],
outputs=gr.Image(type="pil"),
title="DepthPro: Colorify",
description="Applies segmentation on the input image, then creates the depth map and finally colorizes it.",
examples=examples,
)
if __name__ == "__main__":
interface.launch()