File size: 832 Bytes
8f0212f
 
 
311c890
8f0212f
 
 
311c890
193150f
 
8f0212f
193150f
 
 
 
 
 
 
 
 
 
 
 
8f0212f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import logging

import logging
from transformers import BartForConditionalGeneration

logger = logging.getLogger("ModelLogger")

class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, *args, **kwargs):
        # Log input information
        input_ids = kwargs.get('input_ids', args[0] if args else None)
        print(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}")
        
        # Call the parent's forward method
        outputs = super().forward(*args, **kwargs)
        
        # Log output information
        print(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}")
        
        return outputs