import logging import logging from transformers import BartForConditionalGeneration logger = logging.getLogger("ModelLogger") class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration): def __init__(self, config): super().__init__(config) def forward(self, *args, **kwargs): # Log input information input_ids = kwargs.get('input_ids', args[0] if args else None) print(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}") # Call the parent's forward method outputs = super().forward(*args, **kwargs) # Log output information print(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}") return outputs