File size: 1,475 Bytes
8f0212f 311c890 8f0212f 311c890 193150f 8f0212f 193150f df53c44 193150f df53c44 193150f df53c44 193150f 8f0212f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import logging
import logging
from transformers import BartForConditionalGeneration
logger = logging.getLogger("ModelLogger")
class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
def forward(self, *args, **kwargs):
# Log input information
input_ids = kwargs.get('input_ids', args[0] if args else None)
logger.info(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}")
# Call the parent's forward method
outputs = super().forward(*args, **kwargs)
# Log output information
logger.info(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}")
return outputs
def generate(self, *args, **kwargs):
if not hasattr(super(), 'generate'):
logger.warning("Generate method is not available in the parent class.")
return None
# Log input information
input_ids = kwargs.get('input_ids', args[0] if args else None)
logger.info(f"Generate method called with input shape: {input_ids.shape if input_ids is not None else 'None'}")
# Call the parent's generate method
outputs = super().generate(*args, **kwargs)
# Log output information
logger.info(f"Generate method completed. Output shape: {outputs.shape}")
return outputs
|