bart-large-mnli-logged / modeling_modified.py
Boriscii's picture
Upload modified model with logging
8f0212f verified
raw
history blame
930 Bytes
import logging
import logging
from transformers import AutoModelForSequenceClassification
logger = logging.getLogger("ModelLogger")
class ModifiedAutoModelForSequenceClassificationWithHook(AutoModelForSequenceClassification):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_forward_hook(self.forward_hook)
@staticmethod
def forward_hook(module, inputs, outputs):
logger.info(f"Called forward method of {module.__class__.__name__}")
import logging
from transformers import AutoModel
logger = logging.getLogger("ModelLogger")
class ModifiedAutoModelWithHook(AutoModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_forward_hook(self.forward_hook)
@staticmethod
def forward_hook(module, inputs, outputs):
logger.info(f"Called forward method of {module.__class__.__name__}")