|
import logging |
|
from transformers import BartForSequenceClassification |
|
|
|
logger = logging.getLogger("ModelLogger") |
|
|
|
class ModifiedBartForSequenceClassificationWithHook(BartForSequenceClassification): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.register_forward_hook(self.forward_hook) |
|
|
|
@staticmethod |
|
def forward_hook(module, inputs, outputs): |
|
logger.info(f"Called forward method of {module.__class__.__name__}") |
|
print(f"Called forward method of {module.__class__.__name__}") |
|
|
|
|
|
import logging |
|
from transformers import AutoModel |
|
|
|
logger = logging.getLogger("ModelLogger") |
|
|
|
class ModifiedAutoModelWithHook(AutoModel): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.register_forward_hook(self.forward_hook) |
|
|
|
@staticmethod |
|
def forward_hook(module, inputs, outputs): |
|
logger.info(f"Called forward method of {module.__class__.__name__}") |
|
print(f"Called forward method of {module.__class__.__name__}") |
|
|
|
|