Wendy commited on
Commit
cdcda28
·
verified ·
1 Parent(s): 822087d

Upload trainer.py

Browse files
Files changed (1) hide show
  1. LLAVA-Cherry/trainer.py +37 -3
LLAVA-Cherry/trainer.py CHANGED
@@ -17,6 +17,11 @@
17
  The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
18
  """
19
 
 
 
 
 
 
20
  import contextlib
21
  import copy
22
  import functools
@@ -2762,6 +2767,13 @@ class Trainer:
2762
  Return:
2763
  `torch.Tensor`: The tensor with training loss on this batch.
2764
  """
 
 
 
 
 
 
 
2765
  model.train()
2766
  inputs = self._prepare_inputs(inputs)
2767
 
@@ -2776,16 +2788,31 @@ class Trainer:
2776
  del inputs['dataset_id']
2777
  del inputs['data_info']
2778
  #######################################################
 
2779
 
2780
  with self.compute_loss_context_manager():
2781
- loss = self.compute_loss(model, inputs)
 
 
 
 
 
 
 
 
 
2782
 
2783
  #######################################################
2784
  import json
2785
  for i in range(len(data_info_temp)):
2786
- data_info_temp[i]['loss'] = float(loss[0][i])
 
 
2787
 
2788
- file_path = '/data/zbz5349/ICLR_2024/ACL_2025/LLaVA_Fliter/inference_demo/cherry_loss_infer_result.jsonl'
 
 
 
2789
  with open(file_path, 'a', encoding='utf-8') as file:
2790
  # json.dump(data_info_temp[0], file, ensure_ascii=False, indent=4)
2791
  for content in data_info_temp:
@@ -2825,6 +2852,13 @@ class Trainer:
2825
  else:
2826
  labels = None
2827
  outputs = model(**inputs)
 
 
 
 
 
 
 
2828
  # Save past state if it exists
2829
  # TODO: this needs to be fixed and made cleaner later.
2830
  if self.args.past_index >= 0:
 
17
  The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
18
  """
19
 
20
+ #########################################################
21
+ from datetime import datetime
22
+ has_run = False
23
+ #########################################################
24
+
25
  import contextlib
26
  import copy
27
  import functools
 
2767
  Return:
2768
  `torch.Tensor`: The tensor with training loss on this batch.
2769
  """
2770
+
2771
+ # #######################################################
2772
+ # # import pdb; pdb.set_trace()
2773
+ # import pprint
2774
+ # pprint.pprint(inputs)
2775
+ # #######################################################
2776
+
2777
  model.train()
2778
  inputs = self._prepare_inputs(inputs)
2779
 
 
2788
  del inputs['dataset_id']
2789
  del inputs['data_info']
2790
  #######################################################
2791
+
2792
 
2793
  with self.compute_loss_context_manager():
2794
+ # loss = self.compute_loss(model, inputs)
2795
+ (loss, outputs) = self.compute_loss(model, inputs,return_outputs=True)
2796
+
2797
+
2798
+ import pprint
2799
+ # pprint.pprint(outputs)
2800
+ # import pdb; pdb.set_trace()
2801
+ last_token_logits_yes = outputs.logits[:, -1, :]
2802
+ yes_target_token_id = 4874
2803
+ yes_target_logprob = torch.log_softmax(last_token_logits_yes, dim=-1)[0, yes_target_token_id].item()
2804
 
2805
  #######################################################
2806
  import json
2807
  for i in range(len(data_info_temp)):
2808
+ # data_info_temp[i]['loss'] = float(loss[0][i])
2809
+ data_info_temp[i]['yes_target_logprob'] = yes_target_logprob
2810
+ data_info_temp[i]['logits_shape'] = outputs.logits.shape
2811
 
2812
+ from datetime import datetime
2813
+ current_time = datetime.now().strftime('%Y_%m_%d')
2814
+
2815
+ file_path = '/data/zbz5349/ICLR_2024/ACL_2025/LLaVA_Fliter/inference_demo/cherry_AskLLM_infer_result_' + current_time + '.jsonl'
2816
  with open(file_path, 'a', encoding='utf-8') as file:
2817
  # json.dump(data_info_temp[0], file, ensure_ascii=False, indent=4)
2818
  for content in data_info_temp:
 
2852
  else:
2853
  labels = None
2854
  outputs = model(**inputs)
2855
+
2856
+ # #######################################################
2857
+ # import pdb; pdb.set_trace()
2858
+ # import pprint
2859
+ # pprint.pprint(outputs)
2860
+ # #######################################################
2861
+
2862
  # Save past state if it exists
2863
  # TODO: this needs to be fixed and made cleaner later.
2864
  if self.args.past_index >= 0: