Shilpaj commited on
Commit
100d65e
·
1 Parent(s): 70388af

Fix: App issue

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -115,7 +115,15 @@ class GPT(nn.Module):
115
  def load_model():
116
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
117
  model = GPT(GPTConfig())
118
- model.load_state_dict(torch.load('nano_gpt_model.pt', map_location=device))
 
 
 
 
 
 
 
 
119
  model.to(device)
120
  model.eval()
121
  return model, device
 
115
  def load_model():
116
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
117
  model = GPT(GPTConfig())
118
+
119
+ # Load the state dict with weights_only=True and remove the _orig_mod prefix
120
+ state_dict = torch.load('nano_gpt_model.pt', map_location=device, weights_only=True)
121
+ new_state_dict = {}
122
+ for key in state_dict.keys():
123
+ new_key = key.replace('_orig_mod.', '')
124
+ new_state_dict[new_key] = state_dict[key]
125
+
126
+ model.load_state_dict(new_state_dict)
127
  model.to(device)
128
  model.eval()
129
  return model, device