File size: 616 Bytes
9bb0389
 
 
 
 
 
12b7fe4
9bb0389
12b7fe4
29a11ce
9bb0389
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

import torch
import torchvision
from torch import nn
from helper import setAllSeeds
from ViT import ViT
import spaces


@spaces.GPU(duration=5)
def getViT(seed,classNames,DEVICE):
  setAllSeeds(seed)
  ViTModel = ViT(3,768,16,224,3072,12,0.1,12,len(classNames)).to(DEVICE)
  vitWeights = torchvision.models.ViT_B_16_Weights.DEFAULT
  vitTransforms = vitWeights.transforms()
  vit = torchvision.models.vit_b_16(weights=vitWeights).to(DEVICE)
  for param in vit.parameters():
    param.requires_grad = False
  vit.heads = nn.Linear(in_features=768, out_features=len(classNames)).to(DEVICE)
  return vit,vitTransforms