theia / app.py
bmay's picture
Update app.py
024298c verified
raw
history blame
1.78 kB
import gradio as gr
import spaces
import torch
import torch.transforms
import numpy as np
from transformers import AutoModel
from theia.decoding import load_feature_stats, prepare_depth_decoder, prepare_mask_generator, decode_everything
@spaces.GPU
def run_theia(image):
theia_model = AutoModel.from_pretrained("theaiinstitute/theia-base-patch16-224-cdiv", trust_remote_code=True)
theia_model = theia_model.to('cuda')
target_model_names = [
"google/vit-huge-patch14-224-in21k",
"facebook/dinov2-large",
"openai/clip-vit-large-patch14",
"facebook/sam-vit-huge",
"LiheYoung/depth-anything-large-hf",
]
feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root="feature_stats")
mask_generator, sam_model = prepare_mask_generator('cuda')
depth_anything_model_name = "LiheYoung/depth-anything-large-hf"
depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, 'cuda')
image = torch.transforms.v2.Resize(size=(224, 224))(image)
images = [image]
theia_decode_results, gt_decode_results = decode_everything(
theia_model=theia_model,
feature_means=feature_means,
feature_vars=feature_vars,
images=images,
mask_generator=mask_generator,
sam_model=sam_model,
depth_anything_decoder=depth_anything_decoder,
pred_iou_thresh=0.5,
stability_score_thresh=0.7,
gt=True,
device='cuda',
)
vis_video = np.stack(
[np.vstack([tr, gtr]) for tr, gtr in zip(theia_decode_results, gt_decode_results, strict=False)]
)
return vis_video
demo = gr.Interface(fn=run_theia, inputs=gr.Image(type="pil"), outputs=gr.Image(type="numpy"))
demo.launch()