theia / app.py
bmay's picture
Update app.py
2eca07c verified
raw
history blame
3.71 kB
import gradio as gr
import spaces
import torch
import torchvision.transforms
import numpy as np
from transformers import AutoModel
from theia.decoding import load_feature_stats, prepare_depth_decoder, prepare_mask_generator, decode_everything
def load_description(fp):
with open(fp, 'r', encoding='utf-8') as f:
content = f.read()
return content
@spaces.GPU(duration=90)
def run_theia(image):
theia_model = AutoModel.from_pretrained("theaiinstitute/theia-tiny-patch16-224-cddsv", trust_remote_code=True)
theia_model = theia_model.to('cuda')
target_model_names = [
"google/vit-huge-patch14-224-in21k",
"facebook/dinov2-large",
"openai/clip-vit-large-patch14",
"facebook/sam-vit-huge",
"LiheYoung/depth-anything-large-hf",
]
feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root="feature_stats")
mask_generator, sam_model = prepare_mask_generator('cuda')
depth_anything_model_name = "LiheYoung/depth-anything-large-hf"
depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, 'cuda')
image = torchvision.transforms.Resize(size=(224, 224))(image)
images = [image]
theia_decode_results, gt_results = decode_everything(
theia_model=theia_model,
feature_means=feature_means,
feature_vars=feature_vars,
images=images,
mask_generator=mask_generator,
sam_model=sam_model,
depth_anything_decoder=depth_anything_decoder,
pred_iou_thresh=0.5,
stability_score_thresh=0.7,
gt=True,
device='cuda',
)
_, width, _ = theia_decode_results[0].shape
theia_decode_dino = (255.0 * theia_decode_results[0]).astype(np.uint8)[:, width // 4 : 2 * width // 4, :]
theia_decode_sam = (255.0 * theia_decode_results[0]).astype(np.uint8)[:, 2 * width // 4 : 3 * width // 4, :]
theia_decode_depth = (255.0 * theia_decode_results[0]).astype(np.uint8)[:, 3 * width // 4 :, :]
gt_dino = (255.0 * gt_results[0]).astype(np.uint8)[:, width // 4 : 2 * width // 4, :]
gt_sam = (255.0 * gt_results[0]).astype(np.uint8)[:, 2 * width // 4 : 3 * width // 4, :]
gt_depth = (255.0 * gt_results[0]).astype(np.uint8)[:, 3 * width // 4 :, :]
theia_output = [(theia_decode_dino, "DINOv2"), (theia_decode_sam, "SAM"), (theia_decode_depth, "Depth-Anything")]
gt_output = [(gt_dino, "DINOv2"), (gt_sam, "SAM"), (gt_depth, "Depth-Anything")]
return theia_output, gt_output
with gr.Blocks() as demo:
gr.HTML(load_description("gradio_title.md"))
gr.Markdown("This space demonstrates decoding Theia-predicted VFM representations to their original teacher model outputs. For DINOv2 we apply the PCA visualization, for SAM we use its decoder to generate segmentation masks (but with SAM's pipeline of prompting), and for Depth-Anything we use its decoder head to do depth prediction.")
with gr.Row():
with gr.Column():
gr.Markdown("### Input Image")
input_image = gr.Image(type="pil", label=None)
submit_button = gr.Button("Submit")
with gr.Column():
gr.Markdown("### Theia Results")
theia_output = gr.Gallery(label=None, type="numpy")
gr.Markdown("### Ground Truth")
gt_output = gr.Gallery(label=None, type="numpy")
submit_button.click(run_theia, inputs=input_image, outputs=[theia_output, gt_output])
demo.launch()