WinstonHu commited on
Commit
1dceb55
·
verified ·
1 Parent(s): 4fd1c52

Upload folder stage1_qwen25_both_hf to stage_1/token_merge_plus_resampler/stage1_qwen25_both_hf

Browse files
stage_1/token_merge_plus_resampler/stage1_qwen25_both_hf/projector/projector.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d844e6e365f6cb5d69687471a53f3f08d9ee1aca5456b681b31b2b9fc5d9e394
3
+ size 40384848
stage_1/token_merge_plus_resampler/stage1_qwen25_both_hf/resampler/resampler.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18512111537ba6526fa1c1f7891e5ac14b7e549de381ef962efb319da11c9ccc
3
+ size 129956960
stage_1/token_merge_plus_resampler/stage1_qwen25_both_hf/token_merger/merger.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86f24c13a9f8e7cc06aa2f8708df819386f9c6b9a6a0f10e716a8331f1282a9a
3
+ size 20784
stage_1/token_merge_plus_resampler/stage1_qwen25_both_hf/xtuner_config.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.visualization import Visualizer, WandbVisBackend
7
+ from mmengine.optim import AmpOptimWrapper, ConstantLR, LinearLR, CosineAnnealingLR
8
+ from torch.optim import AdamW
9
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
10
+ BitsAndBytesConfig, CLIPImageProcessor,
11
+ CLIPVisionModel)
12
+
13
+ from xtuner.dataset import LLaVADataset
14
+ from xtuner.dataset.collate_fns import default_collate_fn
15
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
16
+ from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHookResampler, HFCheckpointHook, ThroughputHook
17
+ from xtuner.engine.runner import TrainLoop
18
+ from xtuner.model.llava_no_longnet_simple_sampler import LLaVAModel
19
+ from xtuner.utils import PROMPT_TEMPLATE
20
+
21
+ #######################################################################
22
+ # PART 1 Settings #
23
+ #######################################################################
24
+ # Model
25
+ llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
26
+ # Data
27
+ data_path = '/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/stage1_morph2.json'
28
+ image_path_list = None
29
+
30
+ prompt_template = PROMPT_TEMPLATE.qwen_chat
31
+
32
+ # 长序列:保持 per_image_length == sample_num
33
+ max_length = 15836
34
+ per_image_length = 10240
35
+ sample_type = 'wsi' # 'wsi' or 'image'
36
+
37
+ # Scheduler & Optimizer (epoch-based)
38
+ batch_size = 1
39
+ accumulative_counts = 256 # 8 * 256 = 2048
40
+ dataloader_num_workers = 10
41
+ seed = 42
42
+ optim_type = AdamW
43
+ lr = 1e-3
44
+ betas = (0.9, 0.999)
45
+ weight_decay = 0.0 # 适度WD抑制漂移
46
+ max_norm = 1 # 更紧的梯度裁剪
47
+
48
+ # 以 epoch 为主
49
+ max_epochs = 2
50
+ warmup_ratio = 0.08 # 预热占比(相对 max_iters)
51
+
52
+ # Save
53
+ save_steps = 5120
54
+ save_total_limit = 8 # Maximum checkpoints to keep (-1 means unlimited)
55
+
56
+ # Evaluate the generation performance during the training
57
+ evaluation_freq = 512
58
+ SYSTEM = ''
59
+ evaluation_images = '/mnt/bn/xudong-va/meilong/datasets/Token_Compression/skcm_224x224_b20_t15/h5_files/TCGA-EB-A5UN-06Z-00-DX1.h5'
60
+ evaluation_inputs = ['Are the tumor cells organized in a lobulated pattern within the slide?']
61
+
62
+ #######################################################################
63
+ # PART 2 Model & Tokenizer & Image Processor #
64
+ #######################################################################
65
+ tokenizer = dict(
66
+ type=AutoTokenizer.from_pretrained,
67
+ pretrained_model_name_or_path=llm_name_or_path,
68
+ trust_remote_code=True,
69
+ padding_side='right')
70
+
71
+ bnb = dict(
72
+ type=BitsAndBytesConfig,
73
+ load_in_4bit=True,
74
+ load_in_8bit=False,
75
+ llm_int8_threshold=6.0,
76
+ llm_int8_has_fp16_weight=False,
77
+ bnb_4bit_compute_dtype=torch.bfloat16,
78
+ bnb_4bit_use_double_quant=True,
79
+ bnb_4bit_quant_type="nf4",
80
+ )
81
+
82
+ model = dict(
83
+ type=LLaVAModel,
84
+ freeze_llm=True,
85
+ train_stage='1',
86
+ llm=dict(
87
+ type=AutoModelForCausalLM.from_pretrained,
88
+ pretrained_model_name_or_path=llm_name_or_path,
89
+ trust_remote_code=True,
90
+ torch_dtype=torch.bfloat16,
91
+ attn_implementation='flash_attention_2',
92
+ quantization_config=bnb
93
+ ),
94
+
95
+ max_position_embeddings = None, # original 32000 +
96
+ enable_token_merge = True,
97
+ # 建议:前期用 resampler 更稳(也更省显存);如不需要可改回 False
98
+ use_resampler=True,
99
+ resampler_num_latents=100,
100
+ )
101
+
102
+ #######################################################################
103
+ # PART 3 Dataset & Dataloader #
104
+ #######################################################################
105
+ llava_dataset = dict(
106
+ type=LLaVADataset,
107
+ data_path=data_path,
108
+ image_folder='',
109
+ image_path_list=image_path_list,
110
+ tokenizer=tokenizer,
111
+ dataset_map_fn=llava_map_fn,
112
+ template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
113
+ max_length=max_length,
114
+ per_image_length=per_image_length,
115
+ pad_image_to_square=False,
116
+ sample_num=per_image_length,
117
+ image_feature_prefix='/mnt/bn/xudong-va/meilong/datasets/Token_Compression',
118
+ image_feature_suffix='.h5',
119
+ identifier='_224x224_b20_t15',
120
+ unwanted_prefix_csv='/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/missing_slides3.csv',
121
+ sample_strategy='linspace', #use linspace
122
+ )
123
+
124
+
125
+ # cying: add: per_image_length=per_image_length,
126
+
127
+ train_dataloader = dict(
128
+ batch_size=batch_size,
129
+ num_workers=dataloader_num_workers,
130
+ pin_memory=True,
131
+ persistent_workers=True,
132
+ prefetch_factor=4,
133
+ dataset=llava_dataset,
134
+ sampler=dict(type=DefaultSampler, shuffle=True),
135
+ collate_fn=dict(type=default_collate_fn)
136
+ )
137
+
138
+
139
+
140
+ #######################################################################
141
+ # PART 4 Scheduler & Optimizer #
142
+ #######################################################################
143
+ # optimizer
144
+ optim_wrapper = dict(
145
+ type=AmpOptimWrapper,
146
+ optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
147
+ paramwise_cfg = dict(
148
+ norm_decay_mult=0.0,
149
+ bias_decay_mult=0.0,
150
+ paramwise_cfg=dict(
151
+ custom_keys={'^projector\\.': dict(lr_mult=1.0)},
152
+ # 关键:明确只收集 projector,其他丢弃
153
+ # 有些实现没有这个开关;那就用 EnsureProjectorInOptimHook 热修
154
+ ),
155
+ ),
156
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True), # 遇到 NaN 立刻报错
157
+ accumulative_counts=accumulative_counts,
158
+ loss_scale='dynamic',
159
+ dtype='bfloat16',
160
+ )
161
+
162
+ param_scheduler = [
163
+ dict(
164
+ type=LinearLR,
165
+ start_factor=0.01, # 从 1% 的 lr 慢启动
166
+ by_epoch=True,
167
+ begin=0,
168
+ end=warmup_ratio * max_epochs,
169
+ convert_to_iter_based=True # 按 iter 计算
170
+ ),
171
+ dict(
172
+ type=CosineAnnealingLR,
173
+ eta_min=0.0,
174
+ by_epoch=True,
175
+ begin=warmup_ratio * max_epochs,
176
+ end=max_epochs,
177
+ convert_to_iter_based=True
178
+ )
179
+ ]
180
+
181
+ # train, val, test setting
182
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
183
+
184
+ #######################################################################
185
+ # PART 5 Runtime #
186
+ #######################################################################
187
+ # Log the dialogue periodically during the training process, optional
188
+ custom_hooks = [
189
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
190
+ dict(
191
+ type=EvaluateChatHookResampler,
192
+ tokenizer=tokenizer,
193
+ every_n_iters=evaluation_freq,
194
+ evaluation_inputs=evaluation_inputs,
195
+ evaluation_images=evaluation_images,
196
+ system=SYSTEM,
197
+ prompt_template=prompt_template),
198
+ dict(type = ThroughputHook)
199
+ ]
200
+
201
+ # configure default hooks
202
+ default_hooks = dict(
203
+ # record the time of every iteration.
204
+ timer=dict(type=IterTimerHook),
205
+ # print log every 10 iterations.
206
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
207
+ # enable the parameter scheduler.
208
+ param_scheduler=dict(type=ParamSchedulerHook),
209
+ # save checkpoint per `save_steps`.
210
+ checkpoint=dict(
211
+ type=CheckpointHook,
212
+ by_epoch=False,
213
+ interval=save_steps,
214
+ max_keep_ckpts=save_total_limit),
215
+ # set sampler seed in distributed evrionment.
216
+ sampler_seed=dict(type=DistSamplerSeedHook),
217
+ )
218
+
219
+ # configure environment
220
+ env_cfg = dict(
221
+ # whether to enable cudnn benchmark
222
+ cudnn_benchmark=False,
223
+ # set multi process parameters
224
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
225
+ # set distributed parameters
226
+ dist_cfg=dict(backend='nccl'),
227
+ )
228
+
229
+
230
+ # visualizer = dict(
231
+ # type=Visualizer,
232
+ # vis_backends=[
233
+ # dict(type=WandbVisBackend, init_kwargs=dict(project='stage1_no_longnet_simple_resampler_projector100'))])
234
+ visualizer = None
235
+ # set log level
236
+ log_level = 'INFO'
237
+
238
+ # load from which checkpoint
239
+ load_from = None
240
+
241
+ # whether to resume training from the loaded checkpoint
242
+ resume = False
243
+
244
+ # Defaults to use random seed and disable `deterministic`
245
+ randomness = dict(seed=seed, deterministic=False)
246
+
247
+ # set log processor
248
+ log_processor = dict(
249
+ by_epoch=False,
250
+ window_size=1,
251
+ mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
252
+ )