theia / app.py
bmay's picture
Update app.py
6669d02 verified
raw
history blame
1.92 kB
import gradio as gr
import spaces
import torch
import torchvision.transforms
import numpy as np
from transformers import AutoModel
from theia.decoding import load_feature_stats, prepare_depth_decoder, prepare_mask_generator, decode_everything
@spaces.GPU()
def run_theia(image):
theia_model = AutoModel.from_pretrained("theaiinstitute/theia-tiny-patch16-224-cddsv", trust_remote_code=True)
theia_model = theia_model.to('cuda')
target_model_names = [
"google/vit-huge-patch14-224-in21k",
"facebook/dinov2-large",
"openai/clip-vit-large-patch14",
"facebook/sam-vit-huge",
"LiheYoung/depth-anything-large-hf",
]
feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root="feature_stats")
mask_generator, sam_model = prepare_mask_generator('cuda')
depth_anything_model_name = "LiheYoung/depth-anything-large-hf"
depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, 'cuda')
image = torchvision.transforms.Resize(size=(224, 224))(image)
images = [image]
theia_decode_results, gt_decode_results = decode_everything(
theia_model=theia_model,
feature_means=feature_means,
feature_vars=feature_vars,
images=images,
mask_generator=mask_generator,
sam_model=sam_model,
depth_anything_decoder=depth_anything_decoder,
pred_iou_thresh=0.5,
stability_score_thresh=0.7,
gt=True,
device='cuda',
)
theia_decode_results = (255.0 * theia_decode_results[0]).astype(np.uint8)
gt_decode_results = (255.0 * gt_decode_results[0]).astype(np.uint8)
return [(theia_decode_results, "Theia Results"), (gt_decode_results, "Ground Truth")]
demo = gr.Interface(fn=run_theia, inputs=gr.Image(type="pil"), outputs=gr.Gallery(label="Input, DINOv2, SAM, Depth Anything", type="numpy"))
demo.launch()