import logging from transformers import BartForSequenceClassification logger = logging.getLogger("ModelLogger") class ModifiedBartForSequenceClassificationWithHook(BartForSequenceClassification): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.register_forward_hook(self.forward_hook) @staticmethod def forward_hook(module, inputs, outputs): logger.info(f"Called forward method of {module.__class__.__name__}") print(f"Called forward method of {module.__class__.__name__}") import logging from transformers import AutoModel logger = logging.getLogger("ModelLogger") class ModifiedAutoModelWithHook(AutoModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.register_forward_hook(self.forward_hook) @staticmethod def forward_hook(module, inputs, outputs): logger.info(f"Called forward method of {module.__class__.__name__}") print(f"Called forward method of {module.__class__.__name__}")