Boriscii commited on
Commit
8f0212f
·
verified ·
1 Parent(s): 39358e3

Upload modified model with logging

Browse files
__pycache__/modeling_modified.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
config.json CHANGED
@@ -5,13 +5,13 @@
5
  "activation_function": "gelu",
6
  "add_final_layer_norm": false,
7
  "architectures": [
8
- "ModifiedBartForSequenceClassificationWithHook"
9
  ],
10
  "attention_dropout": 0.0,
11
  "auto_map": {
12
- "AutoModel": "modeling_bart.ModifiedBartForSequenceClassificationWithHook",
13
- "AutoModelForSequenceClassification": "modeling_bart.ModifiedBartForSequenceClassificationWithHook",
14
- "AutoConfig": "config.BartConfig"
15
  },
16
  "bos_token_id": 0,
17
  "classif_dropout": 0.0,
 
5
  "activation_function": "gelu",
6
  "add_final_layer_norm": false,
7
  "architectures": [
8
+ "BartForSequenceClassification"
9
  ],
10
  "attention_dropout": 0.0,
11
  "auto_map": {
12
+ "AutoConfig": "configuring_modified.BartConfig",
13
+ "AutoForSequenceClassification": "modeling_modified.ModifiedAutoModelForSequenceClassificationWithHook",
14
+ "AutoModel": "modeling_modified.ModifiedAutoModelWithHook"
15
  },
16
  "bos_token_id": 0,
17
  "classif_dropout": 0.0,
configuring_modified.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from transformers import BartConfig
modeling_modified.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import logging
4
+ from transformers import AutoModelForSequenceClassification
5
+
6
+ logger = logging.getLogger("ModelLogger")
7
+
8
+ class ModifiedAutoModelForSequenceClassificationWithHook(AutoModelForSequenceClassification):
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+ self.register_forward_hook(self.forward_hook)
12
+
13
+ @staticmethod
14
+ def forward_hook(module, inputs, outputs):
15
+ logger.info(f"Called forward method of {module.__class__.__name__}")
16
+
17
+
18
+ import logging
19
+ from transformers import AutoModel
20
+
21
+ logger = logging.getLogger("ModelLogger")
22
+
23
+ class ModifiedAutoModelWithHook(AutoModel):
24
+ def __init__(self, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+ self.register_forward_hook(self.forward_hook)
27
+
28
+ @staticmethod
29
+ def forward_hook(module, inputs, outputs):
30
+ logger.info(f"Called forward method of {module.__class__.__name__}")
31
+