theia / theia /models /vfm.py
Brandon May
Add theia
26791f7
raw
history blame
8.1 kB
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
from typing import Any, Optional
import torch
import torch.nn as nn
from theia.foundation_models import get_clip_model, get_deit_model, get_dinov2_model, get_sam_model, get_vit_model
from transformers import AutoImageProcessor, AutoModel
from theia.models.utils import handle_feature_output
class VFMEncoder(nn.Module):
"""Wrapper class of an individual VFM Encoder for feature extraction.
Attrs:
model_name (str): name of the model.
feature_reduce_method (str): how to select the output feature token and shape.
processor (AutoProcessor): input pre-processor.
"""
def __init__(self, model_name: str, feature_reduce_method: Optional[str] = None, **kwargs: Any):
"""Instanciate a (off-the-shelf) VFM encoder.
Args:
model_name (str): name of the encoder
feature_reduce_method (Optional[str]): how to select the output feature token and shape. Defaults to None.
**kwargs (Any): anything not needed got pass-through
"""
super().__init__()
self.model_name = model_name
if "google/vit" in model_name:
model, processor = get_vit_model(model_name, device="cpu")
elif "facebook/dino" in model_name:
model, processor = get_dinov2_model(model_name, device="cpu")
elif "facebook/sam" in model_name:
model, processor = get_sam_model(model_name, device="cpu")
elif "openai/clip" in model_name:
model, processor = get_clip_model(model_name, device="cpu")
elif "facebook/deit" in model_name:
model, processor = get_deit_model(model_name, device="cpu")
elif "nvidia" in model_name:
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
processor = AutoImageProcessor.from_pretrained(model_name)
elif "mvp" in model_name:
import mvp
model_name_mvp = model_name.replace("mvp-", "")
model = mvp.load(model_name_mvp)
processor = None
elif "vip" in model_name:
from vip import load_vip
model = load_vip()
processor = None
elif "r3m" in model_name:
from r3m import load_r3m
model_name_r3m = model_name.replace("r3m-", "")
model = load_r3m(model_name_r3m)
processor = None
else:
raise NotImplementedError(f"{model_name} is not supported in theia.models.vfm.VFM")
self.model = model
self.processor = processor
self.feature_reduce_method = feature_reduce_method
if "image_size" in kwargs:
self.image_size = kwargs["image_size"]
if "final_spatial" in kwargs:
self.final_spatial = kwargs["final_spatial"]
def get_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Return the feature from the VFM.
Args:
x (torch.Tensor): input image.
kwargs: any arguments pass-through (mainly for processor currently).
For example, `do_rescale`, `do_resize`, `interpolate_pos_encoding`
to control image preprocessing pipeline.
Returns:
torch.Tensor: feature.
"""
if (
"google/vit" in self.model_name
or "facebook/dinov2" in self.model_name
or "facebook/deit" in self.model_name
):
inputs = self.processor(x, return_tensors="pt", **kwargs).to(self.model.device)
feature = self.model(**inputs).last_hidden_state
elif "openai/clip" in self.model_name:
inputs = self.processor(images=x, return_tensors="pt", **kwargs).to(self.model.device)
feature = self.model(**inputs).last_hidden_state
elif "facebook/sam" in self.model_name:
inputs = self.processor(x, return_tensors="pt", **kwargs).to(self.model.device)
feature = self.model(**inputs).image_embeddings
elif "nvidia" in self.model_name:
inputs = (
self.processor(images=x, return_tensors="pt", **kwargs)
.pixel_values.to(torch.bfloat16)
.to(self.model.device)
)
summary, feature = self.model(inputs)
if self.feature_reduce_method == "cls_identity":
feature = summary.to(torch.float32)
else:
feature = feature.to(torch.float32)
elif "mvp" in self.model_name:
feature = self.model(x)
elif "vip" in self.model_name:
feature = self.model(x)
elif "r3m" in self.model_name:
feature = self.model(x)
return feature
def forward(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Forward method, including getting the feature and handle the output token / shape.
Args:
x (torch.Tensor): input image.
Returns:
torch.Tensor: output feature with token or shape handled.
"""
feature = self.get_feature(x, **kwargs) # [B, 1+H*W, C]
return handle_feature_output(feature, self.feature_reduce_method)
def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Alias of forward() to accommandate some downstream usage.
Args:
x (torch.Tensor): input image.
Returns:
torch.Tensor: output feature with token or shape handled.
"""
return self.forward(x, **kwargs)
class ConcatVFMEncoder(nn.Module):
"""Wrapper class that combines features from multiple VFM Encoders. The combination is channel-wise concatenation.
Attrs:
model_names (list[str]): names of the models.
feature_reduce_method (Optional[str]): how to select the output feature token and shape.
model (nn.ModuleDict): a dict to hold different VFM encoders.
"""
def __init__(self, model_names: list[str], feature_reduce_method: Optional[str] = None, **kwargs: Any):
"""Instanciate a (off-the-shelf) VFM encoder.
Args:
model_names (list[str]): name of the encoder
feature_reduce_method (str, optional): how to select the output feature token and shape. Defaults to None.
**kwargs (Any): anything not needed got pass-through
"""
super().__init__()
self.model_names = model_names
self.model = {}
for model_name in model_names:
model = VFMEncoder(model_name, feature_reduce_method=feature_reduce_method, **kwargs)
self.model[model_name] = model
self.model = nn.ModuleDict(self.model)
self.feature_reduce_method = feature_reduce_method
def get_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Get different features from VFMs.
Args:
x (torch.Tensor): input image.
Returns:
torch.Tensor: features concatenated at channel dimension.
"""
features = []
for model_name in self.model_names:
features.append(self.model[model_name](x, **kwargs))
features = torch.cat(features, dim=-1)
return features
def forward(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Forward method, including getting the feature and handle the output token / shape.
Args:
x (torch.Tensor): input image.
Returns:
torch.Tensor: output feature with token or shape handled.
"""
feature = self.get_feature(x, **kwargs) # [B, 1+H*W, C]
return handle_feature_output(feature, self.feature_reduce_method)
def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Alias of forward() to accommandate some downstream usage.
Args:
x (torch.Tensor): input image.
Returns:
torch.Tensor: output feature with token or shape handled.
"""
return self.forward(x, **kwargs)