eternalBlissard commited on
Commit
12b7fe4
·
verified ·
1 Parent(s): 0e341ec

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +3 -0
model.py CHANGED
@@ -4,7 +4,10 @@ import torchvision
4
  from torch import nn
5
  from helper import setAllSeeds
6
  from ViT import ViT
 
7
 
 
 
8
  def getViT(seed,classNames,DEVICE):
9
  setAllSeeds(seed)
10
  ViTModel = ViT(3,768,16,224,3072,12,0.1,12,len(classNames)).to(DEVICE)
 
4
  from torch import nn
5
  from helper import setAllSeeds
6
  from ViT import ViT
7
+ import spaces
8
 
9
+
10
+ @spaces.GPU
11
  def getViT(seed,classNames,DEVICE):
12
  setAllSeeds(seed)
13
  ViTModel = ViT(3,768,16,224,3072,12,0.1,12,len(classNames)).to(DEVICE)