File size: 930 Bytes
8f0212f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import logging
import logging
from transformers import AutoModelForSequenceClassification
logger = logging.getLogger("ModelLogger")
class ModifiedAutoModelForSequenceClassificationWithHook(AutoModelForSequenceClassification):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_forward_hook(self.forward_hook)
@staticmethod
def forward_hook(module, inputs, outputs):
logger.info(f"Called forward method of {module.__class__.__name__}")
import logging
from transformers import AutoModel
logger = logging.getLogger("ModelLogger")
class ModifiedAutoModelWithHook(AutoModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_forward_hook(self.forward_hook)
@staticmethod
def forward_hook(module, inputs, outputs):
logger.info(f"Called forward method of {module.__class__.__name__}")
|