import logging | |
import logging | |
from transformers import BartForConditionalGeneration | |
logger = logging.getLogger("ModelLogger") | |
class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration): | |
def __init__(self, config): | |
super().__init__(config) | |
def forward(self, *args, **kwargs): | |
# Log input information | |
input_ids = kwargs.get('input_ids', args[0] if args else None) | |
print(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}") | |
# Call the parent's forward method | |
outputs = super().forward(*args, **kwargs) | |
# Log output information | |
print(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}") | |
return outputs | |