wchai commited on
Commit
fd32045
1 Parent(s): 8cd4417

Upload xtuner_config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. xtuner_config.py +288 -0
xtuner_config.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from mmengine.dataset import DefaultSampler
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+
6
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
7
+ BitsAndBytesConfig,
8
+ CLIPImageProcessor, CLIPVisionModel,
9
+ SiglipVisionModel, SiglipImageProcessor, AutoProcessor)
10
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
11
+
12
+ from peft import LoraConfig
13
+ from torch.optim import AdamW
14
+ from xtuner.dataset import LLaVADataset, CambrianDataset, ConcatDataset
15
+ from xtuner.dataset.collate_fns import default_collate_fn
16
+ from xtuner.dataset.map_fns import llava_map_fn, cambrian_map_fn, template_map_fn_factory
17
+ from xtuner.dataset.samplers import LengthGroupedSampler
18
+ from xtuner.engine import DatasetInfoHook, EvaluateChatHook
19
+ from xtuner.model import LLaVAModel, PikaModel
20
+ from xtuner.utils import PROMPT_TEMPLATE
21
+
22
+
23
+ #######################################################################
24
+ # PART 1 Settings #
25
+ #######################################################################
26
+ # Model
27
+ llm_name_or_path = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
28
+ visual_encoder_name_or_path = 'google/siglip-so400m-patch14-384'
29
+ # pretrained_pth = '/data/wenhao/projects/xtuner/work_dirs/final_siglip_llama31_P/projector'
30
+
31
+ prompt_template = PROMPT_TEMPLATE.llama3_chat
32
+ max_length = 4096
33
+ size = 378
34
+
35
+ batch_size = 8 # per_device
36
+ accumulative_counts = 2
37
+ lr = 1e-3
38
+ dataloader_num_workers = 0
39
+ max_epochs = 1
40
+ optim_type = AdamW
41
+ betas = (0.9, 0.999)
42
+ weight_decay = 0
43
+ max_norm = 1 # grad clip
44
+ warmup_ratio = 0.03
45
+
46
+ # Save
47
+ save_steps = 200
48
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
49
+
50
+ #######################################################################
51
+ # PART 2 Model & Tokenizer & Image Processor #
52
+ #######################################################################
53
+ tokenizer = dict(
54
+ type=AutoTokenizer.from_pretrained,
55
+ pretrained_model_name_or_path=llm_name_or_path,
56
+ trust_remote_code=True,
57
+ padding_side='right')
58
+
59
+ image_processor = dict(
60
+ type=CLIPImageProcessor.from_pretrained,
61
+ pretrained_model_name_or_path='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
62
+ trust_remote_code=True,
63
+ size=size,
64
+ crop_size=size)
65
+
66
+ model = dict(
67
+ type=PikaModel,
68
+ freeze_llm=True,
69
+ freeze_visual_encoder=True,
70
+ # pretrained_pth=pretrained_pth,
71
+ llm=dict(
72
+ type=AutoModelForCausalLM.from_pretrained,
73
+ pretrained_model_name_or_path=llm_name_or_path,
74
+ trust_remote_code=True,
75
+ torch_dtype=torch.float16,),
76
+ visual_encoder=dict(
77
+ type=SiglipVisionModel.from_pretrained,
78
+ pretrained_model_name_or_path=visual_encoder_name_or_path))
79
+
80
+ #######################################################################
81
+ # PART 3 Dataset & Dataloader #
82
+ #######################################################################
83
+ dense_data_root = '/data/wenhao/projects/xtuner/data/DenseFusion-1M/'
84
+ dense_data_path = dense_data_root + 'DenseFusion-1M/DenseFusion-1M-instruct.jsonl'
85
+ dense_image_folder = dense_data_root + '1M_data'
86
+ dense_processed_text_folder = dense_data_root + 'pre_token_llama3'
87
+ dense_dataset = dict(
88
+ type=CambrianDataset,
89
+ image_folder=dense_image_folder,
90
+ image_processor=image_processor,
91
+ # data_path=dense_data_path,
92
+ # tokenizer=tokenizer,
93
+ offline_processed_text_folder=dense_processed_text_folder,
94
+ dataset_map_fn=cambrian_map_fn,
95
+ template_map_fn=dict(
96
+ type=template_map_fn_factory, template=prompt_template),
97
+ max_length=max_length,
98
+ pad_image_to_square=True)
99
+
100
+ laion_data_root = '/data/wenhao/projects/xtuner/data/LLaVA-Pretrain/'
101
+ laion_data_path = laion_data_root + 'laion_558k.jsonl'
102
+ laion_image_folder = laion_data_root
103
+ laion_dataset = dict(
104
+ type=CambrianDataset,
105
+ offline_processed_text_folder='/data/wenhao/projects/xtuner/data/LLaVA-Pretrain/pre_token_llama31',
106
+ image_folder=laion_image_folder,
107
+ image_processor=image_processor,
108
+ dataset_map_fn=cambrian_map_fn,
109
+ template_map_fn=dict(
110
+ type=template_map_fn_factory, template=prompt_template),
111
+ max_length=max_length,
112
+ pad_image_to_square=True)
113
+
114
+ face_data_root = '/data/wenhao/projects/xtuner/data/FaceCaption-15M/'
115
+ face_data_path = face_data_root + 'FaceCaption-100K.jsonl'
116
+ face_image_folder = face_data_root + 'full_data'
117
+ face_processed_text_folder = face_data_root + 'pre_token_llama3'
118
+ face_dataset = dict(
119
+ type=CambrianDataset,
120
+ offline_processed_text_folder=face_processed_text_folder,
121
+ image_folder=face_image_folder,
122
+ image_processor=image_processor,
123
+ dataset_map_fn=cambrian_map_fn,
124
+ template_map_fn=dict(
125
+ type=template_map_fn_factory, template=prompt_template),
126
+ max_length=max_length,
127
+ pad_image_to_square=True)
128
+
129
+ allava_data_root = '/data/wenhao/projects/xtuner/data/ALLaVA-4V'
130
+ allava_cl_data_path = '/data/wenhao/projects/xtuner/data/ALLaVA-4V/ALLaVA-Caption-LAION-4V.jsonl'
131
+ allava_cl_image_folder = allava_data_root
132
+ allava_cl_dataset = dict(
133
+ type=CambrianDataset,
134
+ offline_processed_text_folder='/data/wenhao/projects/xtuner/data/ALLaVA-4V/pre_token_cl_llama31',
135
+ # tokenizer=tokenizer,
136
+ # data_path=allava_cl_data_path,
137
+ image_folder=allava_cl_image_folder,
138
+ image_processor=image_processor,
139
+ dataset_map_fn=cambrian_map_fn,
140
+ template_map_fn=dict(
141
+ type=template_map_fn_factory, template=prompt_template),
142
+ max_length=max_length,
143
+ pad_image_to_square=True)
144
+
145
+ allava_cv_data_path = '/data/wenhao/projects/xtuner/data/ALLaVA-4V/ALLaVA-Caption-VFLAN-4V.jsonl'
146
+ allava_image_folder = allava_data_root
147
+ allava_cv_dataset = dict(
148
+ type=CambrianDataset,
149
+ offline_processed_text_folder='/data/wenhao/projects/xtuner/data/ALLaVA-4V/pre_token_cv_llama31',
150
+ # tokenizer=tokenizer,
151
+ # data_path=allava_cv_data_path,
152
+ image_folder=allava_image_folder,
153
+ image_processor=image_processor,
154
+ dataset_map_fn=cambrian_map_fn,
155
+ template_map_fn=dict(
156
+ type=template_map_fn_factory, template=prompt_template),
157
+ max_length=max_length,
158
+ pad_image_to_square=True)
159
+
160
+ sharept_data_root = '/data/wenhao/projects/xtuner/data/ShareGPT4V/'
161
+ sharept_data_path = sharept_data_root + 'sharegpt4v_pt.jsonl'
162
+ sharept_image_folder = '/data/wenhao/projects/xtuner/data/'
163
+ sharept_dataset = dict(
164
+ type=CambrianDataset,
165
+ offline_processed_text_folder='/data/wenhao/projects/xtuner/data/ShareGPT4V/pre_token_llama31',
166
+ # tokenizer=tokenizer,
167
+ # data_path='/data/wenhao/projects/xtuner/data/ShareGPT4V/sharegpt4v_pt.jsonl',
168
+ image_folder=sharept_image_folder,
169
+ image_processor=image_processor,
170
+ dataset_map_fn=cambrian_map_fn,
171
+ template_map_fn=dict(
172
+ type=template_map_fn_factory, template=prompt_template),
173
+ max_length=max_length,
174
+ pad_image_to_square=True)
175
+
176
+ train_dataset = dict(
177
+ type=ConcatDataset,
178
+ datasets=[laion_dataset, dense_dataset, face_dataset, sharept_dataset, allava_cl_dataset, allava_cv_dataset],
179
+ )
180
+
181
+ train_dataloader = dict(
182
+ batch_size=batch_size,
183
+ num_workers=dataloader_num_workers,
184
+ dataset=train_dataset,
185
+ sampler=dict(type=DefaultSampler, shuffle=True),
186
+ collate_fn=dict(type=default_collate_fn))
187
+
188
+ #######################################################################
189
+ # PART 4 Scheduler & Optimizer #
190
+ #######################################################################
191
+ # optimizer
192
+ optim_wrapper = dict(
193
+ type=AmpOptimWrapper,
194
+ optimizer=dict(
195
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
196
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
197
+ accumulative_counts=accumulative_counts,
198
+ loss_scale='dynamic',
199
+ dtype='float16')
200
+
201
+ # learning policy
202
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
203
+ param_scheduler = [
204
+ dict(
205
+ type=LinearLR,
206
+ start_factor=1e-5,
207
+ by_epoch=True,
208
+ begin=0,
209
+ end=warmup_ratio * max_epochs,
210
+ convert_to_iter_based=True),
211
+ dict(
212
+ type=CosineAnnealingLR,
213
+ eta_min=0.0,
214
+ by_epoch=True,
215
+ begin=warmup_ratio * max_epochs,
216
+ T_max=max_epochs,
217
+ convert_to_iter_based=True)
218
+ ]
219
+
220
+ # train, val, test setting
221
+ train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
222
+
223
+ #######################################################################
224
+ # PART 5 Runtime #
225
+ #######################################################################
226
+ # Evaluate the generation performance during the training
227
+ evaluation_freq = 100
228
+ SYSTEM = ''
229
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
230
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
231
+
232
+
233
+ # Log the dialogue periodically during the training process, optional
234
+ custom_hooks = [
235
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
236
+ dict(
237
+ type=EvaluateChatHook,
238
+ tokenizer=tokenizer,
239
+ image_processor=image_processor,
240
+ every_n_iters=evaluation_freq,
241
+ evaluation_inputs=evaluation_inputs,
242
+ evaluation_images=evaluation_images,
243
+ system=SYSTEM,
244
+ prompt_template=prompt_template)
245
+ ]
246
+
247
+ # configure default hooks
248
+ default_hooks = dict(
249
+ # record the time of every iteration.
250
+ timer=dict(type=IterTimerHook),
251
+ # print log every 100 iterations.
252
+ logger=dict(type=LoggerHook, interval=10),
253
+ # enable the parameter scheduler.
254
+ param_scheduler=dict(type=ParamSchedulerHook),
255
+ # save checkpoint per epoch.
256
+ checkpoint=dict(
257
+ type=CheckpointHook,
258
+ by_epoch=False,
259
+ interval=save_steps,
260
+ max_keep_ckpts=save_total_limit),
261
+ # set sampler seed in distributed evrionment.
262
+ sampler_seed=dict(type=DistSamplerSeedHook),
263
+ )
264
+
265
+ # configure environment
266
+ env_cfg = dict(
267
+ # whether to enable cudnn benchmark
268
+ cudnn_benchmark=False,
269
+ # set multi process parameters
270
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
271
+ # set distributed parameters
272
+ dist_cfg=dict(backend='nccl'),
273
+ )
274
+
275
+ # set visualizer
276
+ visualizer = None
277
+
278
+ # set log level
279
+ log_level = 'INFO'
280
+
281
+ # load from which checkpoint
282
+ load_from = None
283
+
284
+ # whether to resume training from the loaded checkpoint
285
+ resume = False
286
+
287
+ # Defaults to use random seed and disable `deterministic`
288
+ randomness = dict(seed=None, deterministic=False)