|
import json |
|
import os |
|
import pathlib |
|
|
|
import deepspeed |
|
import torch |
|
import transformers |
|
from deepspeed import get_accelerator |
|
from torch.utils.data import ConcatDataset |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig |
|
from transformers import Trainer |
|
from transformers.integrations.deepspeed import unset_hf_deepspeed_config, set_hf_deepspeed_config |
|
|
|
from callback import TuneTauCallback, MonitorCallback |
|
from ovis.model.configuration_ovis import OvisConfig |
|
from ovis.model.modeling_ovis import Ovis |
|
from ovis.train.arguments import ModelArguments, TrainingArguments |
|
from ovis.train.dataset.caption_dataset import CaptionDataset |
|
from ovis.train.dataset.conversation_dataset import ConversationDataset |
|
from ovis.train.dataset.multimodal_dataset import DataCollatorForMultimodalDataset |
|
from ovis.util.constants import BEGIN_LINE, END_LINE |
|
from ovis.util.utils import smart_unit, rank0_print |
|
|
|
|
|
def train(): |
|
|
|
parser = transformers.HfArgumentParser( |
|
(ModelArguments, TrainingArguments)) |
|
model_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
with training_args.main_process_first(local=False): |
|
if training_args.process_index == 0: |
|
def args2dict(args): |
|
return {k: str(v) for k, v in args.__dict__.items()} |
|
|
|
args_log = json.dumps(dict( |
|
model_args=args2dict(model_args), |
|
training_args=args2dict(training_args) |
|
), ensure_ascii=False, indent=2) |
|
print(args_log) |
|
os.makedirs(training_args.output_dir, exist_ok=True) |
|
with open(os.path.join(training_args.output_dir, 'model_training_args.json'), 'w', |
|
encoding='utf-8') as f: |
|
f.write(args_log + '\n') |
|
|
|
|
|
if not training_args.ovis_pretrained_path: |
|
|
|
ovis_config = OvisConfig( |
|
multimodal_max_length=model_args.multimodal_max_length, |
|
conversation_formatter_class=model_args.conversation_formatter_class, |
|
llm_attn_implementation=model_args.llm_attn_implementation |
|
) |
|
|
|
attn_kwargs = dict() |
|
if model_args.llm_attn_implementation: |
|
attn_kwargs['attn_implementation'] = model_args.llm_attn_implementation |
|
llm = AutoModelForCausalLM.from_pretrained(model_args.llm_name_or_path, **attn_kwargs) |
|
text_tokenizer = AutoTokenizer.from_pretrained(model_args.llm_name_or_path) |
|
if text_tokenizer.pad_token_id is None and model_args.pad_token_id is not None: |
|
text_tokenizer.pad_token_id = model_args.pad_token_id |
|
|
|
|
|
unset_hf_deepspeed_config() |
|
if training_args.visual_tokenizer_pretrained_path is not None: |
|
visual_tokenizer = AutoModel.from_pretrained( |
|
training_args.visual_tokenizer_pretrained_path, |
|
image_processor_name_or_path=training_args.visual_tokenizer_pretrained_path |
|
) |
|
else: |
|
visual_tokenizer_config = AutoConfig.for_model( |
|
model_type=model_args.visual_tokenizer_type + "_visual_tokenizer", |
|
vocab_size=model_args.visual_vocab_size, |
|
tokenize_function=model_args.visual_tokenize_function, |
|
tau=model_args.visual_tau, |
|
depths=model_args.visual_depths, |
|
drop_cls_token=model_args.visual_drop_cls_token, |
|
hidden_stride=model_args.visual_hidden_stride, |
|
) |
|
visual_tokenizer = AutoModel.from_config(visual_tokenizer_config, train_from_scratch=True) |
|
visual_tokenizer = visual_tokenizer.to( |
|
device=torch.device(get_accelerator().device_name(os.getenv("LOCAL_RANK")))) |
|
if getattr(training_args, 'hf_deepspeed_config', None) is not None: |
|
set_hf_deepspeed_config(training_args.hf_deepspeed_config) |
|
|
|
model = Ovis(ovis_config, llm=llm, text_tokenizer=text_tokenizer, visual_tokenizer=visual_tokenizer, |
|
train_from_scratch=True) |
|
else: |
|
model, loading_info = Ovis.from_pretrained(training_args.ovis_pretrained_path, |
|
multimodal_max_length=model_args.multimodal_max_length, |
|
output_loading_info=True) |
|
rank0_print(BEGIN_LINE) |
|
rank0_print(f'Loading info of Ovis:\n{loading_info}') |
|
rank0_print(END_LINE) |
|
training_args.vte_re_init = False |
|
|
|
model.get_llm().config.use_cache = False |
|
model.config.use_cache = False |
|
text_tokenizer = model.get_text_tokenizer() |
|
|
|
rank0_print(BEGIN_LINE) |
|
rank0_print(f'model.config:\n{model.config}') |
|
rank0_print(END_LINE) |
|
|
|
|
|
if training_args.vte_re_init: |
|
with deepspeed.zero.GatheredParameters([model.get_wte().weight]): |
|
mean = model.get_wte().weight.mean().item() |
|
std = model.get_wte().weight.std().item() |
|
rank0_print(f'Statistics of embedding table of LLM: {mean=}, {std=}') |
|
model.re_init_vte(mean, std) |
|
|
|
|
|
model.requires_grad_(False) |
|
for module in training_args.train_modules.split('|'): |
|
if module == 'all': |
|
model.requires_grad_(True) |
|
elif module == 'llm': |
|
model.get_llm().requires_grad_(True) |
|
elif module == 'visual_tokenizer': |
|
model.get_visual_tokenizer().requires_grad_(True) |
|
elif module == 'visual_tokenizer.backbone': |
|
model.get_visual_tokenizer().get_backbone().requires_grad_(True) |
|
elif module.startswith('visual_tokenizer.backbone.layer.'): |
|
layer_index = int(module[len('visual_tokenizer.backbone.layer.'):]) |
|
layer = model.get_visual_tokenizer().get_backbone_layer(layer_index) |
|
layer.requires_grad_(True) |
|
elif module == 'visual_tokenizer.head': |
|
model.get_visual_tokenizer().get_head().requires_grad_(True) |
|
elif module == 'vte': |
|
model.get_vte().requires_grad_(True) |
|
else: |
|
raise ValueError(f'Invalid train module name: {module}') |
|
|
|
rank0_print(BEGIN_LINE) |
|
rank0_print('Parameters to train:') |
|
for name, param in model.named_parameters(): |
|
if param.requires_grad: |
|
rank0_print(name) |
|
rank0_print(f'LLM\'s attn implementation: {model.get_llm().config._attn_implementation}') |
|
rank0_print(END_LINE) |
|
|
|
|
|
datasets = [] |
|
dataset_info_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), |
|
f'dataset/{training_args.dataset_info}.json') |
|
with open(dataset_info_path, 'r', encoding='utf-8') as f: |
|
dataset_info = json.load(f) |
|
for name in training_args.dataset_names.split('|'): |
|
info = dataset_info[name] |
|
data_format = info['data_format'] |
|
if data_format == 'caption': |
|
dataset = CaptionDataset(name, info, model, training_args) |
|
elif data_format == 'conversation': |
|
dataset = ConversationDataset(name, info, model, training_args) |
|
else: |
|
raise ValueError(f'Invalid data format `{data_format}` for dataset `{name}`') |
|
datasets.append(dataset) |
|
data_module = dict( |
|
train_dataset=ConcatDataset(datasets), |
|
data_collator=DataCollatorForMultimodalDataset(text_tokenizer) |
|
) |
|
|
|
|
|
train_callbacks = [MonitorCallback] |
|
if model_args.visual_tokenize_function == 'gumbel_argmax': |
|
train_callbacks.append(TuneTauCallback) |
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
callbacks=train_callbacks, |
|
**data_module |
|
) |
|
rank0_print(BEGIN_LINE) |
|
rank0_print('Dataset sample tensor:') |
|
rank0_print(data_module['train_dataset'][0]) |
|
rank0_print(END_LINE) |
|
rank0_print(BEGIN_LINE) |
|
rank0_print('Dataset sample input_ids decoding:') |
|
rank0_print(text_tokenizer.decode([x for x in data_module['train_dataset'][0]['input_ids'] if x >= 0])) |
|
rank0_print(END_LINE) |
|
rank0_print(BEGIN_LINE) |
|
rank0_print('Dataset sample labels decoding:') |
|
rank0_print(text_tokenizer.decode([x for x in data_module['train_dataset'][0]['labels'] if x >= 0])) |
|
rank0_print(END_LINE) |
|
rank0_print(BEGIN_LINE) |
|
rank0_print(f'#param of model: {smart_unit(model.num_parameters())}') |
|
rank0_print(f'#param of llm: {smart_unit(model.get_llm().num_parameters())}') |
|
rank0_print(f'#param of visual_tokenizer: {smart_unit(model.get_visual_tokenizer().num_parameters())}') |
|
rank0_print(f'#param of vte: {smart_unit(model.get_vte().weight.numel())}') |
|
rank0_print(END_LINE) |
|
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
|
trainer.train(resume_from_checkpoint=True) |
|
else: |
|
trainer.train() |
|
trainer.save_state() |
|
|
|
|
|
model.get_llm().config.use_cache = True |
|
model.config.use_cache = True |
|
trainer.save_model() |
|
|
|
|
|
if __name__ == '__main__': |
|
train() |
|
|