File size: 618 Bytes
9bb0389
 
 
 
 
 
12b7fe4
9bb0389
12b7fe4
71c2fd4
9bb0389
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

import torch
import torchvision
from torch import nn
from helper import setAllSeeds
from ViT import ViT
import spaces


# @spaces.GPU(duration=5)
def getViT(seed,classNames,DEVICE):
  setAllSeeds(seed)
  ViTModel = ViT(3,768,16,224,3072,12,0.1,12,len(classNames)).to(DEVICE)
  vitWeights = torchvision.models.ViT_B_16_Weights.DEFAULT
  vitTransforms = vitWeights.transforms()
  vit = torchvision.models.vit_b_16(weights=vitWeights).to(DEVICE)
  for param in vit.parameters():
    param.requires_grad = False
  vit.heads = nn.Linear(in_features=768, out_features=len(classNames)).to(DEVICE)
  return vit,vitTransforms