File size: 1,475 Bytes
8f0212f
 
 
311c890
8f0212f
 
 
311c890
193150f
 
8f0212f
193150f
 
 
df53c44
 
193150f
 
df53c44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193150f
df53c44
 
193150f
8f0212f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import logging

import logging
from transformers import BartForConditionalGeneration

logger = logging.getLogger("ModelLogger")

class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, *args, **kwargs):
        # Log input information
        input_ids = kwargs.get('input_ids', args[0] if args else None)
        logger.info(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}")

        # Call the parent's forward method
        outputs = super().forward(*args, **kwargs)

        # Log output information
        logger.info(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}")

        return outputs

    def generate(self, *args, **kwargs):
        if not hasattr(super(), 'generate'):
            logger.warning("Generate method is not available in the parent class.")
            return None

        # Log input information
        input_ids = kwargs.get('input_ids', args[0] if args else None)
        logger.info(f"Generate method called with input shape: {input_ids.shape if input_ids is not None else 'None'}")

        # Call the parent's generate method
        outputs = super().generate(*args, **kwargs)

        # Log output information
        logger.info(f"Generate method completed. Output shape: {outputs.shape}")

        return outputs