Wendy
commited on
Upload trainer.py
Browse files- LLAVA-Cherry/trainer.py +37 -3
LLAVA-Cherry/trainer.py
CHANGED
@@ -17,6 +17,11 @@
|
|
17 |
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
|
18 |
"""
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
import contextlib
|
21 |
import copy
|
22 |
import functools
|
@@ -2762,6 +2767,13 @@ class Trainer:
|
|
2762 |
Return:
|
2763 |
`torch.Tensor`: The tensor with training loss on this batch.
|
2764 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2765 |
model.train()
|
2766 |
inputs = self._prepare_inputs(inputs)
|
2767 |
|
@@ -2776,16 +2788,31 @@ class Trainer:
|
|
2776 |
del inputs['dataset_id']
|
2777 |
del inputs['data_info']
|
2778 |
#######################################################
|
|
|
2779 |
|
2780 |
with self.compute_loss_context_manager():
|
2781 |
-
loss = self.compute_loss(model, inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2782 |
|
2783 |
#######################################################
|
2784 |
import json
|
2785 |
for i in range(len(data_info_temp)):
|
2786 |
-
data_info_temp[i]['loss'] = float(loss[0][i])
|
|
|
|
|
2787 |
|
2788 |
-
|
|
|
|
|
|
|
2789 |
with open(file_path, 'a', encoding='utf-8') as file:
|
2790 |
# json.dump(data_info_temp[0], file, ensure_ascii=False, indent=4)
|
2791 |
for content in data_info_temp:
|
@@ -2825,6 +2852,13 @@ class Trainer:
|
|
2825 |
else:
|
2826 |
labels = None
|
2827 |
outputs = model(**inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2828 |
# Save past state if it exists
|
2829 |
# TODO: this needs to be fixed and made cleaner later.
|
2830 |
if self.args.past_index >= 0:
|
|
|
17 |
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
|
18 |
"""
|
19 |
|
20 |
+
#########################################################
|
21 |
+
from datetime import datetime
|
22 |
+
has_run = False
|
23 |
+
#########################################################
|
24 |
+
|
25 |
import contextlib
|
26 |
import copy
|
27 |
import functools
|
|
|
2767 |
Return:
|
2768 |
`torch.Tensor`: The tensor with training loss on this batch.
|
2769 |
"""
|
2770 |
+
|
2771 |
+
# #######################################################
|
2772 |
+
# # import pdb; pdb.set_trace()
|
2773 |
+
# import pprint
|
2774 |
+
# pprint.pprint(inputs)
|
2775 |
+
# #######################################################
|
2776 |
+
|
2777 |
model.train()
|
2778 |
inputs = self._prepare_inputs(inputs)
|
2779 |
|
|
|
2788 |
del inputs['dataset_id']
|
2789 |
del inputs['data_info']
|
2790 |
#######################################################
|
2791 |
+
|
2792 |
|
2793 |
with self.compute_loss_context_manager():
|
2794 |
+
# loss = self.compute_loss(model, inputs)
|
2795 |
+
(loss, outputs) = self.compute_loss(model, inputs,return_outputs=True)
|
2796 |
+
|
2797 |
+
|
2798 |
+
import pprint
|
2799 |
+
# pprint.pprint(outputs)
|
2800 |
+
# import pdb; pdb.set_trace()
|
2801 |
+
last_token_logits_yes = outputs.logits[:, -1, :]
|
2802 |
+
yes_target_token_id = 4874
|
2803 |
+
yes_target_logprob = torch.log_softmax(last_token_logits_yes, dim=-1)[0, yes_target_token_id].item()
|
2804 |
|
2805 |
#######################################################
|
2806 |
import json
|
2807 |
for i in range(len(data_info_temp)):
|
2808 |
+
# data_info_temp[i]['loss'] = float(loss[0][i])
|
2809 |
+
data_info_temp[i]['yes_target_logprob'] = yes_target_logprob
|
2810 |
+
data_info_temp[i]['logits_shape'] = outputs.logits.shape
|
2811 |
|
2812 |
+
from datetime import datetime
|
2813 |
+
current_time = datetime.now().strftime('%Y_%m_%d')
|
2814 |
+
|
2815 |
+
file_path = '/data/zbz5349/ICLR_2024/ACL_2025/LLaVA_Fliter/inference_demo/cherry_AskLLM_infer_result_' + current_time + '.jsonl'
|
2816 |
with open(file_path, 'a', encoding='utf-8') as file:
|
2817 |
# json.dump(data_info_temp[0], file, ensure_ascii=False, indent=4)
|
2818 |
for content in data_info_temp:
|
|
|
2852 |
else:
|
2853 |
labels = None
|
2854 |
outputs = model(**inputs)
|
2855 |
+
|
2856 |
+
# #######################################################
|
2857 |
+
# import pdb; pdb.set_trace()
|
2858 |
+
# import pprint
|
2859 |
+
# pprint.pprint(outputs)
|
2860 |
+
# #######################################################
|
2861 |
+
|
2862 |
# Save past state if it exists
|
2863 |
# TODO: this needs to be fixed and made cleaner later.
|
2864 |
if self.args.past_index >= 0:
|