Spaces:
Runtime error
Runtime error
File size: 618 Bytes
9bb0389 12b7fe4 9bb0389 12b7fe4 71c2fd4 9bb0389 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
import torchvision
from torch import nn
from helper import setAllSeeds
from ViT import ViT
import spaces
# @spaces.GPU(duration=5)
def getViT(seed,classNames,DEVICE):
setAllSeeds(seed)
ViTModel = ViT(3,768,16,224,3072,12,0.1,12,len(classNames)).to(DEVICE)
vitWeights = torchvision.models.ViT_B_16_Weights.DEFAULT
vitTransforms = vitWeights.transforms()
vit = torchvision.models.vit_b_16(weights=vitWeights).to(DEVICE)
for param in vit.parameters():
param.requires_grad = False
vit.heads = nn.Linear(in_features=768, out_features=len(classNames)).to(DEVICE)
return vit,vitTransforms
|