File size: 1,041 Bytes
bc22ed1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import logging
from transformers import BartForSequenceClassification

logger = logging.getLogger("ModelLogger")

class ModifiedBartForSequenceClassificationWithHook(BartForSequenceClassification):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_forward_hook(self.forward_hook)

    @staticmethod
    def forward_hook(module, inputs, outputs):
        logger.info(f"Called forward method of {module.__class__.__name__}")
        print(f"Called forward method of {module.__class__.__name__}")


import logging
from transformers import AutoModel

logger = logging.getLogger("ModelLogger")

class ModifiedAutoModelWithHook(AutoModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_forward_hook(self.forward_hook)

    @staticmethod
    def forward_hook(module, inputs, outputs):
        logger.info(f"Called forward method of {module.__class__.__name__}")
        print(f"Called forward method of {module.__class__.__name__}")