File size: 1,041 Bytes
bc22ed1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import logging
from transformers import BartForSequenceClassification
logger = logging.getLogger("ModelLogger")
class ModifiedBartForSequenceClassificationWithHook(BartForSequenceClassification):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_forward_hook(self.forward_hook)
@staticmethod
def forward_hook(module, inputs, outputs):
logger.info(f"Called forward method of {module.__class__.__name__}")
print(f"Called forward method of {module.__class__.__name__}")
import logging
from transformers import AutoModel
logger = logging.getLogger("ModelLogger")
class ModifiedAutoModelWithHook(AutoModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_forward_hook(self.forward_hook)
@staticmethod
def forward_hook(module, inputs, outputs):
logger.info(f"Called forward method of {module.__class__.__name__}")
print(f"Called forward method of {module.__class__.__name__}")
|