bart-large-mnli-logged / modeling_bart.py
Boriscii's picture
Create modeling_bart.py
bc22ed1 verified
import logging
from transformers import BartForSequenceClassification
logger = logging.getLogger("ModelLogger")
class ModifiedBartForSequenceClassificationWithHook(BartForSequenceClassification):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_forward_hook(self.forward_hook)
@staticmethod
def forward_hook(module, inputs, outputs):
logger.info(f"Called forward method of {module.__class__.__name__}")
print(f"Called forward method of {module.__class__.__name__}")
import logging
from transformers import AutoModel
logger = logging.getLogger("ModelLogger")
class ModifiedAutoModelWithHook(AutoModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_forward_hook(self.forward_hook)
@staticmethod
def forward_hook(module, inputs, outputs):
logger.info(f"Called forward method of {module.__class__.__name__}")
print(f"Called forward method of {module.__class__.__name__}")