Commit 
							
							·
						
						162074d
	
1
								Parent(s):
							
							c1c87bf
								
Add HAT implementation files
Browse files- modelling_hat.py +4 -1
    	
        modelling_hat.py
    CHANGED
    
    | @@ -1186,6 +1186,7 @@ class HATModelForDocumentRepresentation(HATPreTrainedModel): | |
| 1186 | 
             
                    super().__init__(config)
         | 
| 1187 | 
             
                    self.num_labels = config.num_labels
         | 
| 1188 | 
             
                    self.config = config
         | 
|  | |
| 1189 |  | 
| 1190 | 
             
                    self.hi_transformer = HATModel(config)
         | 
| 1191 | 
             
                    self.pooler = HATPooler(config, pooling=pooling)
         | 
| @@ -1233,7 +1234,7 @@ class HATModelForDocumentRepresentation(HATPreTrainedModel): | |
| 1233 | 
             
                        return_dict=return_dict,
         | 
| 1234 | 
             
                    )
         | 
| 1235 | 
             
                    sequence_output = outputs[0]
         | 
| 1236 | 
            -
                    pooled_outputs = self.pooler(sequence_output)
         | 
| 1237 |  | 
| 1238 | 
             
                    drp_loss = None
         | 
| 1239 | 
             
                    if labels is not None:
         | 
| @@ -1832,6 +1833,7 @@ class HATForSequenceClassification(HATPreTrainedModel): | |
| 1832 | 
             
                    super().__init__(config)
         | 
| 1833 | 
             
                    self.num_labels = config.num_labels
         | 
| 1834 | 
             
                    self.config = config
         | 
|  | |
| 1835 | 
             
                    self.pooling = pooling
         | 
| 1836 |  | 
| 1837 | 
             
                    self.hi_transformer = HATModel(config)
         | 
| @@ -2043,6 +2045,7 @@ class HATForMultipleChoice(HATPreTrainedModel): | |
| 2043 | 
             
                    super().__init__(config)
         | 
| 2044 |  | 
| 2045 | 
             
                    self.pooling = pooling
         | 
|  | |
| 2046 | 
             
                    self.hi_transformer = HATModel(config)
         | 
| 2047 | 
             
                    classifier_dropout = (
         | 
| 2048 | 
             
                        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
         | 
|  | |
| 1186 | 
             
                    super().__init__(config)
         | 
| 1187 | 
             
                    self.num_labels = config.num_labels
         | 
| 1188 | 
             
                    self.config = config
         | 
| 1189 | 
            +
                    self.max_sentence_length = config.max_sentence_length
         | 
| 1190 |  | 
| 1191 | 
             
                    self.hi_transformer = HATModel(config)
         | 
| 1192 | 
             
                    self.pooler = HATPooler(config, pooling=pooling)
         | 
|  | |
| 1234 | 
             
                        return_dict=return_dict,
         | 
| 1235 | 
             
                    )
         | 
| 1236 | 
             
                    sequence_output = outputs[0]
         | 
| 1237 | 
            +
                    pooled_outputs = self.pooler(sequence_output[:, ::self.max_sentence_length])
         | 
| 1238 |  | 
| 1239 | 
             
                    drp_loss = None
         | 
| 1240 | 
             
                    if labels is not None:
         | 
|  | |
| 1833 | 
             
                    super().__init__(config)
         | 
| 1834 | 
             
                    self.num_labels = config.num_labels
         | 
| 1835 | 
             
                    self.config = config
         | 
| 1836 | 
            +
                    self.max_sentence_length = config.max_sentence_length
         | 
| 1837 | 
             
                    self.pooling = pooling
         | 
| 1838 |  | 
| 1839 | 
             
                    self.hi_transformer = HATModel(config)
         | 
|  | |
| 2045 | 
             
                    super().__init__(config)
         | 
| 2046 |  | 
| 2047 | 
             
                    self.pooling = pooling
         | 
| 2048 | 
            +
                    self.max_sentence_length = config.max_sentence_length
         | 
| 2049 | 
             
                    self.hi_transformer = HATModel(config)
         | 
| 2050 | 
             
                    classifier_dropout = (
         | 
| 2051 | 
             
                        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
         | 
