Boriscii commited on
Commit
df53c44
·
verified ·
1 Parent(s): 193150f

Update modeling_modified.py

Browse files
Files changed (1) hide show
  1. modeling_modified.py +22 -5
modeling_modified.py CHANGED
@@ -12,13 +12,30 @@ class ModifiedBartForConditionalGenerationWithHook(BartForConditionalGeneration)
12
  def forward(self, *args, **kwargs):
13
  # Log input information
14
  input_ids = kwargs.get('input_ids', args[0] if args else None)
15
- print(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}")
16
-
17
  # Call the parent's forward method
18
  outputs = super().forward(*args, **kwargs)
19
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # Log output information
21
- print(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}")
22
-
23
  return outputs
24
 
 
12
  def forward(self, *args, **kwargs):
13
  # Log input information
14
  input_ids = kwargs.get('input_ids', args[0] if args else None)
15
+ logger.info(f"Forward pass initiated with input shape: {input_ids.shape if input_ids is not None else 'None'}")
16
+
17
  # Call the parent's forward method
18
  outputs = super().forward(*args, **kwargs)
19
+
20
+ # Log output information
21
+ logger.info(f"Forward pass completed. Output shape: {outputs.logits.shape if hasattr(outputs, 'logits') else 'N/A'}")
22
+
23
+ return outputs
24
+
25
+ def generate(self, *args, **kwargs):
26
+ if not hasattr(super(), 'generate'):
27
+ logger.warning("Generate method is not available in the parent class.")
28
+ return None
29
+
30
+ # Log input information
31
+ input_ids = kwargs.get('input_ids', args[0] if args else None)
32
+ logger.info(f"Generate method called with input shape: {input_ids.shape if input_ids is not None else 'None'}")
33
+
34
+ # Call the parent's generate method
35
+ outputs = super().generate(*args, **kwargs)
36
+
37
  # Log output information
38
+ logger.info(f"Generate method completed. Output shape: {outputs.shape}")
39
+
40
  return outputs
41