Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
@@ -338,5 +338,5 @@ class MultiHeadClassification(nn.Module):
|
|
338 |
backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone'))
|
339 |
instance = cls(backbone, head_config, dropout, l2_reg)
|
340 |
instance.load(os.path.join(model_path, 'multi-head-sequence-classification-model-model.pth'))
|
341 |
-
instance.head_config = {k: v.
|
342 |
return instance
|
|
|
338 |
backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone'))
|
339 |
instance = cls(backbone, head_config, dropout, l2_reg)
|
340 |
instance.load(os.path.join(model_path, 'multi-head-sequence-classification-model-model.pth'))
|
341 |
+
instance.head_config = {k: v.weight.shape[1] for k, v in instance.heads.items()}
|
342 |
return instance
|