Update modeling_modified.py
Browse files- modeling_modified.py +14 -6
modeling_modified.py
CHANGED
@@ -6,11 +6,19 @@ from transformers import BartForConditionalGeneration
|
|
6 |
logger = logging.getLogger("ModelLogger")
|
7 |
|
8 |
class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration):
|
9 |
-
def __init__(self,
|
10 |
-
super().__init__(
|
11 |
-
self.register_forward_hook(self.forward_hook)
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
|
|
6 |
logger = logging.getLogger("ModelLogger")
|
7 |
|
8 |
class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration):
|
9 |
+
def __init__(self, config):
|
10 |
+
super().__init__(config)
|
|
|
11 |
|
12 |
+
def forward(self, *args, **kwargs):
|
13 |
+
# Log input information
|
14 |
+
input_ids = kwargs.get('input_ids', args[0] if args else None)
|
15 |
+
print(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}")
|
16 |
+
|
17 |
+
# Call the parent's forward method
|
18 |
+
outputs = super().forward(*args, **kwargs)
|
19 |
+
|
20 |
+
# Log output information
|
21 |
+
print(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}")
|
22 |
+
|
23 |
+
return outputs
|
24 |
|