Shivam29rathore commited on
Commit
f0515e4
·
1 Parent(s): ac9c278

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -1
app.py CHANGED
@@ -2,11 +2,24 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import pickle
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  checkpoint = "my_t5.sav"
6
 
7
  #load model from drive
8
  with open(checkpoint, "rb") as f:
9
- model= pickle.load(f)
10
 
11
  #tokenizer = AutoTokenizer.from_pretrained(checkpoint)
12
  #model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
 
2
  import pickle
3
 
4
 
5
+ import io
6
+
7
+ class CPU_Unpickler(pickle.Unpickler):
8
+ def find_class(self, module, name):
9
+ if module == 'torch.storage' and name == '_load_from_bytes':
10
+ return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
11
+ else:
12
+ return super().find_class(module, name)
13
+
14
+ #contents = pickle.load(f) becomes...
15
+ #contents = CPU_Unpickler(f).load()
16
+
17
+
18
  checkpoint = "my_t5.sav"
19
 
20
  #load model from drive
21
  with open(checkpoint, "rb") as f:
22
+ model= CPU_Unpickler(f).load()
23
 
24
  #tokenizer = AutoTokenizer.from_pretrained(checkpoint)
25
  #model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)