bart-large-mnli-logged / modeling_modified.py
Boriscii's picture
Update modeling_modified.py
193150f verified
raw
history blame
832 Bytes
import logging
import logging
from transformers import BartForConditionalGeneration
logger = logging.getLogger("ModelLogger")
class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
def forward(self, *args, **kwargs):
# Log input information
input_ids = kwargs.get('input_ids', args[0] if args else None)
print(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}")
# Call the parent's forward method
outputs = super().forward(*args, **kwargs)
# Log output information
print(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}")
return outputs