ClemSummer commited on
Commit
88b5781
·
1 Parent(s): 8c88d40

using /tmp, should be writable

Browse files
Files changed (1) hide show
  1. vit_captioning/generate.py +6 -6
vit_captioning/generate.py CHANGED
@@ -24,23 +24,23 @@ class CaptionGenerator:
24
  print("No GPU found, falling back to CPU.")
25
 
26
  # Load tokenizer
27
- self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
28
  #HF needs all model downloads to a special read-write cache dir
29
- #self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', cache_dir="/data")
30
 
31
  # Select encoder, processor, output dim
32
  if model_type == "ViTEncoder":
33
  self.encoder = ViTEncoder().to(self.device)
34
  self.encoder_dim = 768
35
- self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
36
  #HF needs all model downloads to a special read-write cache dir
37
- #self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k", cache_dir="/data")
38
  elif model_type == "CLIPEncoder":
39
  self.encoder = CLIPEncoder().to(self.device)
40
  self.encoder_dim = 512
41
- self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
42
  #HF needs all model downloads to a special read-write cache dir
43
- #self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir="/data")
44
  else:
45
  raise ValueError("Unknown model type")
46
 
 
24
  print("No GPU found, falling back to CPU.")
25
 
26
  # Load tokenizer
27
+ #self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
28
  #HF needs all model downloads to a special read-write cache dir
29
+ self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', cache_dir="/tmp")
30
 
31
  # Select encoder, processor, output dim
32
  if model_type == "ViTEncoder":
33
  self.encoder = ViTEncoder().to(self.device)
34
  self.encoder_dim = 768
35
+ #self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
36
  #HF needs all model downloads to a special read-write cache dir
37
+ self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k", cache_dir="/tmp")
38
  elif model_type == "CLIPEncoder":
39
  self.encoder = CLIPEncoder().to(self.device)
40
  self.encoder_dim = 512
41
+ #self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
42
  #HF needs all model downloads to a special read-write cache dir
43
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir="/tmp")
44
  else:
45
  raise ValueError("Unknown model type")
46