Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 2,019 Bytes
			
			| a0e3aec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import timm
import torch
from torch import nn
from loguru import logger
from torch.utils.checkpoint import checkpoint
# from sbp.nn.model_paths import MODEL_PATHS
class ImageEncoder(nn.Module):
    def __init__(self, output_dim, base_model='eva02_base_patch14_224.mim_in22k', layer_num=6, seq_len=3, device='cpu'):
        super().__init__()
        self.output_dim = output_dim
        if base_model == 'eva02_base_patch14_224.mim_in22k':
            self.img_seq = 257
        elif base_model == 'eva02_large_patch14_448.mim_in22k_ft_in1k':
            self.img_seq = 1025
        else:
            raise ValueError(f" unknown {base_model}, supported: {list(paths.keys())}")
        self.base_model = timm.create_model(base_model, pretrained=False)
        del self.base_model.norm, self.base_model.fc_norm, self.base_model.head, self.base_model.head_drop
        del self.base_model.blocks[layer_num:]
        self.project = nn.Linear(self.base_model.num_features, output_dim)
        self.final_norm = nn.LayerNorm(output_dim)
        self.seq_len = seq_len
        self.device = device
    def forward(self, image_list):
        splits = [len(lst) for lst in image_list]
        if sum(splits) == 0:
            return torch.zeros([len(splits), self.seq_len * self.img_seq, self.output_dim], device=self.device, dtype=torch.bfloat16)
        x = torch.concat(image_list, dim=0).to(device=self.device, dtype=torch.bfloat16)
        x = self.base_model.patch_embed(x)
        x, rot_pos_embed = self.base_model._pos_embed(x)
        for blk in self.base_model.blocks:
            x = blk(x, rope=rot_pos_embed)
        x = self.project(x)
        x = self.final_norm(x)
        b, seq_len, c= x.shape
        split_patches = torch.split(x, splits, dim=0)
        split_patches = [nn.functional.pad(sample, (0, 0, 0, 0, 0, self.seq_len - len(sample))) for sample in split_patches]
        x = torch.stack(split_patches, dim=0)
        x = x.reshape((len(splits), self.seq_len * seq_len, c))
        return x
     |