|
import logging |
|
|
|
import logging |
|
from transformers import BartForConditionalGeneration |
|
|
|
logger = logging.getLogger("ModelLogger") |
|
|
|
class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
input_ids = kwargs.get('input_ids', args[0] if args else None) |
|
logger.info(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}") |
|
|
|
|
|
outputs = super().forward(*args, **kwargs) |
|
|
|
|
|
logger.info(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}") |
|
|
|
return outputs |
|
|
|
def generate(self, *args, **kwargs): |
|
if not hasattr(super(), 'generate'): |
|
logger.warning("Generate method is not available in the parent class.") |
|
return None |
|
|
|
|
|
input_ids = kwargs.get('input_ids', args[0] if args else None) |
|
logger.info(f"Generate method called with input shape: {input_ids.shape if input_ids is not None else 'None'}") |
|
|
|
|
|
outputs = super().generate(*args, **kwargs) |
|
|
|
|
|
logger.info(f"Generate method completed. Output shape: {outputs.shape}") |
|
|
|
return outputs |
|
|
|
|