Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update Utils/PLBERT/util.py
Browse files- Utils/PLBERT/util.py +19 -19
    	
        Utils/PLBERT/util.py
    CHANGED
    
    | @@ -3,7 +3,6 @@ import yaml | |
| 3 | 
             
            import torch
         | 
| 4 | 
             
            from transformers import AlbertConfig, AlbertModel
         | 
| 5 |  | 
| 6 | 
            -
             | 
| 7 | 
             
            class CustomAlbert(AlbertModel):
         | 
| 8 | 
             
                def forward(self, *args, **kwargs):
         | 
| 9 | 
             
                    # Call the original forward method
         | 
| @@ -16,34 +15,35 @@ class CustomAlbert(AlbertModel): | |
| 16 | 
             
            def load_plbert(log_dir):
         | 
| 17 | 
             
                config_path = os.path.join(log_dir, "config.yml")
         | 
| 18 | 
             
                plbert_config = yaml.safe_load(open(config_path))
         | 
| 19 | 
            -
             | 
| 20 | 
            -
                albert_base_configuration = AlbertConfig(**plbert_config[ | 
| 21 | 
             
                bert = CustomAlbert(albert_base_configuration)
         | 
| 22 |  | 
| 23 | 
             
                files = os.listdir(log_dir)
         | 
| 24 | 
             
                ckpts = []
         | 
| 25 | 
             
                for f in os.listdir(log_dir):
         | 
| 26 | 
            -
                    if f.startswith("step_"):
         | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
                iters = [
         | 
| 30 | 
            -
                    int(f.split("_")[-1].split(".")[0])
         | 
| 31 | 
            -
                    for f in ckpts
         | 
| 32 | 
            -
                    if os.path.isfile(os.path.join(log_dir, f))
         | 
| 33 | 
            -
                ]
         | 
| 34 | 
             
                iters = sorted(iters)[-1]
         | 
| 35 |  | 
| 36 | 
            -
                checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location= | 
| 37 | 
            -
                state_dict = checkpoint[ | 
| 38 | 
             
                from collections import OrderedDict
         | 
| 39 | 
            -
             | 
| 40 | 
             
                new_state_dict = OrderedDict()
         | 
| 41 | 
             
                for k, v in state_dict.items():
         | 
| 42 | 
            -
                    name = k[7:] | 
| 43 | 
            -
                    if name.startswith( | 
| 44 | 
            -
                        name = name[8:] | 
| 45 | 
             
                        new_state_dict[name] = v
         | 
| 46 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 47 | 
             
                bert.load_state_dict(new_state_dict, strict=False)
         | 
| 48 | 
            -
             | 
| 49 | 
             
                return bert
         | 
|  | |
|  | 
|  | |
| 3 | 
             
            import torch
         | 
| 4 | 
             
            from transformers import AlbertConfig, AlbertModel
         | 
| 5 |  | 
|  | |
| 6 | 
             
            class CustomAlbert(AlbertModel):
         | 
| 7 | 
             
                def forward(self, *args, **kwargs):
         | 
| 8 | 
             
                    # Call the original forward method
         | 
|  | |
| 15 | 
             
            def load_plbert(log_dir):
         | 
| 16 | 
             
                config_path = os.path.join(log_dir, "config.yml")
         | 
| 17 | 
             
                plbert_config = yaml.safe_load(open(config_path))
         | 
| 18 | 
            +
                
         | 
| 19 | 
            +
                albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
         | 
| 20 | 
             
                bert = CustomAlbert(albert_base_configuration)
         | 
| 21 |  | 
| 22 | 
             
                files = os.listdir(log_dir)
         | 
| 23 | 
             
                ckpts = []
         | 
| 24 | 
             
                for f in os.listdir(log_dir):
         | 
| 25 | 
            +
                    if f.startswith("step_"): ckpts.append(f)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 28 | 
             
                iters = sorted(iters)[-1]
         | 
| 29 |  | 
| 30 | 
            +
                checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
         | 
| 31 | 
            +
                state_dict = checkpoint['net']
         | 
| 32 | 
             
                from collections import OrderedDict
         | 
|  | |
| 33 | 
             
                new_state_dict = OrderedDict()
         | 
| 34 | 
             
                for k, v in state_dict.items():
         | 
| 35 | 
            +
                    name = k[7:] # remove `module.`
         | 
| 36 | 
            +
                    if name.startswith('encoder.'):
         | 
| 37 | 
            +
                        name = name[8:] # remove `encoder.`
         | 
| 38 | 
             
                        new_state_dict[name] = v
         | 
| 39 | 
            +
                
         | 
| 40 | 
            +
                # Check if 'embeddings.position_ids' exists before attempting to delete it
         | 
| 41 | 
            +
                if not hasattr(bert.embeddings, 'position_ids'):
         | 
| 42 | 
            +
                    del new_state_dict["embeddings.position_ids"]
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                
         | 
| 45 | 
             
                bert.load_state_dict(new_state_dict, strict=False)
         | 
| 46 | 
            +
                
         | 
| 47 | 
             
                return bert
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
 
			
