Update modeling_modified.py
Browse files- modeling_modified.py +22 -5
modeling_modified.py
CHANGED
@@ -12,13 +12,30 @@ class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration)
|
|
12 |
def forward(self, *args, **kwargs):
|
13 |
# Log input information
|
14 |
input_ids = kwargs.get('input_ids', args[0] if args else None)
|
15 |
-
|
16 |
-
|
17 |
# Call the parent's forward method
|
18 |
outputs = super().forward(*args, **kwargs)
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
# Log output information
|
21 |
-
|
22 |
-
|
23 |
return outputs
|
24 |
|
|
|
12 |
def forward(self, *args, **kwargs):
|
13 |
# Log input information
|
14 |
input_ids = kwargs.get('input_ids', args[0] if args else None)
|
15 |
+
logger.info(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}")
|
16 |
+
|
17 |
# Call the parent's forward method
|
18 |
outputs = super().forward(*args, **kwargs)
|
19 |
+
|
20 |
+
# Log output information
|
21 |
+
logger.info(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}")
|
22 |
+
|
23 |
+
return outputs
|
24 |
+
|
25 |
+
def generate(self, *args, **kwargs):
|
26 |
+
if not hasattr(super(), 'generate'):
|
27 |
+
logger.warning("Generate method is not available in the parent class.")
|
28 |
+
return None
|
29 |
+
|
30 |
+
# Log input information
|
31 |
+
input_ids = kwargs.get('input_ids', args[0] if args else None)
|
32 |
+
logger.info(f"Generate method called with input shape: {input_ids.shape if input_ids is not None else 'None'}")
|
33 |
+
|
34 |
+
# Call the parent's generate method
|
35 |
+
outputs = super().generate(*args, **kwargs)
|
36 |
+
|
37 |
# Log output information
|
38 |
+
logger.info(f"Generate method completed. Output shape: {outputs.shape}")
|
39 |
+
|
40 |
return outputs
|
41 |
|