Boriscii commited on
Commit
193150f
·
verified ·
1 Parent(s): 311c890

Update modeling_modified.py

Browse files
Files changed (1) hide show
  1. modeling_modified.py +14 -6
modeling_modified.py CHANGED
@@ -6,11 +6,19 @@ from transformers import BartForConditionalGeneration
6
  logger = logging.getLogger("ModelLogger")
7
 
8
  class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration):
9
- def __init__(self, *args, **kwargs):
10
- super().__init__(*args, **kwargs)
11
- self.register_forward_hook(self.forward_hook)
12
 
13
- @staticmethod
14
- def forward_hook(module, inputs, outputs):
15
- logger.info(f"Called forward method of {module.__class__.__name__}")
 
 
 
 
 
 
 
 
 
16
 
 
6
  logger = logging.getLogger("ModelLogger")
7
 
8
  class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration):
9
+ def __init__(self, config):
10
+ super().__init__(config)
 
11
 
12
+ def forward(self, *args, **kwargs):
13
+ # Log input information
14
+ input_ids = kwargs.get('input_ids', args[0] if args else None)
15
+ print(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}")
16
+
17
+ # Call the parent's forward method
18
+ outputs = super().forward(*args, **kwargs)
19
+
20
+ # Log output information
21
+ print(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}")
22
+
23
+ return outputs
24