Spaces:
Sleeping
Sleeping
Fix: App issue
Browse files
app.py
CHANGED
@@ -115,7 +115,15 @@ class GPT(nn.Module):
|
|
115 |
def load_model():
|
116 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
117 |
model = GPT(GPTConfig())
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
model.to(device)
|
120 |
model.eval()
|
121 |
return model, device
|
|
|
115 |
def load_model():
|
116 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
117 |
model = GPT(GPTConfig())
|
118 |
+
|
119 |
+
# Load the state dict with weights_only=True and remove the _orig_mod prefix
|
120 |
+
state_dict = torch.load('nano_gpt_model.pt', map_location=device, weights_only=True)
|
121 |
+
new_state_dict = {}
|
122 |
+
for key in state_dict.keys():
|
123 |
+
new_key = key.replace('_orig_mod.', '')
|
124 |
+
new_state_dict[new_key] = state_dict[key]
|
125 |
+
|
126 |
+
model.load_state_dict(new_state_dict)
|
127 |
model.to(device)
|
128 |
model.eval()
|
129 |
return model, device
|