File size: 707 Bytes
9ba7005
 
 
 
 
 
 
 
 
 
 
 
 
 
f822986
9ba7005
 
 
 
 
f822986
 
9ba7005
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os
import yaml
import torch
from transformers import AlbertConfig, AlbertModel

class CustomAlbert(AlbertModel):
    def forward(self, *args, **kwargs):
        # Call the original forward method
        outputs = super().forward(*args, **kwargs)

        # Only return the last_hidden_state
        return outputs.last_hidden_state


def load_plbert(wights_path, config_path):
    plbert_config = yaml.safe_load(open(config_path))
    
    albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
    bert = CustomAlbert(albert_base_configuration)

    state_dict = torch.load(wights_path, map_location='cpu')
    bert.load_state_dict(state_dict, strict=False)
    
    return bert