Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| try: | |
| from ...MiDaS.midas.dpt_depth import DPTDepthModel | |
| except ImportError: | |
| print('Please pull the MiDaS submodule via "git submodule update --init --recursive"!') | |
| class MidasDetector(nn.Module): | |
| def __init__(self, model_path="./models/dpt_hybrid-midas-501f0c75.pt"): | |
| super().__init__() | |
| self.model = DPTDepthModel(path=model_path, backbone="vitb_rn50_384", non_negative=True) | |
| self.model.requires_grad_(False) | |
| self.model.eval() | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |
| def device(self): | |
| return next(self.parameters()).device | |
| def forward(self, images: torch.Tensor): | |
| """ | |
| Input: [b, c, h, w] | |
| """ | |
| return self.model(images) | |