MohametSena commited on
Commit
e53f483
·
1 Parent(s): bbe990e

End of training

Browse files
Files changed (2) hide show
  1. config.json +4 -0
  2. index.py +49 -0
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "TimesheetEstimator"
4
  ],
 
 
 
 
5
  "encoder_model_name": "distilroberta-base",
6
  "hidden_size": 768,
7
  "id2label": {
 
2
  "architectures": [
3
  "TimesheetEstimator"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "index.TimesheetEstimatorConfig",
7
+ "AutoModel": "index.TimesheetEstimator"
8
+ },
9
  "encoder_model_name": "distilroberta-base",
10
  "hidden_size": 768,
11
  "id2label": {
index.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig, AutoModel
2
+ from transformers.modeling_outputs import ModelOutput
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ class TimesheetEstimatorConfig(PretrainedConfig):
7
+ def __init__(
8
+ self,
9
+ encoder_model_name = "bert-base-cased",
10
+ hidden_size=768,
11
+ **kwargs
12
+ ):
13
+ super().__init__(**kwargs)
14
+
15
+ self.num_labels = 1
16
+ self.hidden_size = hidden_size
17
+ self.encoder_model_name = encoder_model_name
18
+
19
+ class TimesheetEstimator(PreTrainedModel):
20
+ config_class = TimesheetEstimatorConfig
21
+
22
+ def __init__(self, config: TimesheetEstimatorConfig):
23
+ super().__init__(config)
24
+
25
+ self.encoder = AutoModel.from_pretrained(config.encoder_model_name)
26
+ self.estimate_layer = nn.Linear(config.hidden_size, config.num_labels)
27
+ self.loss = nn.MSELoss()
28
+
29
+ def forward(self, input_ids, attention_mask, labels=None):
30
+ encoder_outputs = self.encoder(
31
+ input_ids=input_ids,
32
+ attention_mask=attention_mask,
33
+ )
34
+
35
+ represent_vectors = encoder_outputs[0]
36
+
37
+ estimate = self.estimate_layer(represent_vectors[:, 0, :])
38
+
39
+ loss = None
40
+ if labels is not None:
41
+ loss = self.loss(estimate, labels.reshape(-1, 1))
42
+
43
+ return ModelOutput(
44
+ loss=loss,
45
+ logits=estimate,
46
+ )
47
+
48
+
49
+