Boriscii commited on
Commit
bc22ed1
·
verified ·
1 Parent(s): d47dd75

Create modeling_bart.py

Browse files
Files changed (1) hide show
  1. modeling_bart.py +31 -0
modeling_bart.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from transformers import BartForSequenceClassification
3
+
4
+ logger = logging.getLogger("ModelLogger")
5
+
6
+ class ModifiedBartForSequenceClassificationWithHook(BartForSequenceClassification):
7
+ def __init__(self, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ self.register_forward_hook(self.forward_hook)
10
+
11
+ @staticmethod
12
+ def forward_hook(module, inputs, outputs):
13
+ logger.info(f"Called forward method of {module.__class__.__name__}")
14
+ print(f"Called forward method of {module.__class__.__name__}")
15
+
16
+
17
+ import logging
18
+ from transformers import AutoModel
19
+
20
+ logger = logging.getLogger("ModelLogger")
21
+
22
+ class ModifiedAutoModelWithHook(AutoModel):
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ self.register_forward_hook(self.forward_hook)
26
+
27
+ @staticmethod
28
+ def forward_hook(module, inputs, outputs):
29
+ logger.info(f"Called forward method of {module.__class__.__name__}")
30
+ print(f"Called forward method of {module.__class__.__name__}")
31
+