aa / ovis /train /train.py
root
Ajout du module Ovis
b4942cf
raw
history blame
9.34 kB
import json
import os
import pathlib
import deepspeed
import torch
import transformers
from deepspeed import get_accelerator
from torch.utils.data import ConcatDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig
from transformers import Trainer
from transformers.integrations.deepspeed import unset_hf_deepspeed_config, set_hf_deepspeed_config
from callback import TuneTauCallback, MonitorCallback
from ovis.model.configuration_ovis import OvisConfig
from ovis.model.modeling_ovis import Ovis
from ovis.train.arguments import ModelArguments, TrainingArguments
from ovis.train.dataset.caption_dataset import CaptionDataset
from ovis.train.dataset.conversation_dataset import ConversationDataset
from ovis.train.dataset.multimodal_dataset import DataCollatorForMultimodalDataset
from ovis.util.constants import BEGIN_LINE, END_LINE
from ovis.util.utils import smart_unit, rank0_print
def train():
# parse args
parser = transformers.HfArgumentParser(
(ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()
# save args to checkpoint dir
with training_args.main_process_first(local=False):
if training_args.process_index == 0:
def args2dict(args):
return {k: str(v) for k, v in args.__dict__.items()}
args_log = json.dumps(dict(
model_args=args2dict(model_args),
training_args=args2dict(training_args)
), ensure_ascii=False, indent=2)
print(args_log)
os.makedirs(training_args.output_dir, exist_ok=True)
with open(os.path.join(training_args.output_dir, 'model_training_args.json'), 'w',
encoding='utf-8') as f:
f.write(args_log + '\n')
# construct or load ovis model
if not training_args.ovis_pretrained_path: # construct model (S1)
# 1. construct ovis config
ovis_config = OvisConfig(
multimodal_max_length=model_args.multimodal_max_length,
conversation_formatter_class=model_args.conversation_formatter_class,
llm_attn_implementation=model_args.llm_attn_implementation
)
# 2. load pretrained llm and text tokenizer
attn_kwargs = dict()
if model_args.llm_attn_implementation:
attn_kwargs['attn_implementation'] = model_args.llm_attn_implementation
llm = AutoModelForCausalLM.from_pretrained(model_args.llm_name_or_path, **attn_kwargs)
text_tokenizer = AutoTokenizer.from_pretrained(model_args.llm_name_or_path)
if text_tokenizer.pad_token_id is None and model_args.pad_token_id is not None:
text_tokenizer.pad_token_id = model_args.pad_token_id
# 3. construct visual tokenizer
# deepspeed zero.Init with bfloat16 fail for visual_tokenizer, so temporarily disable zero.Init here
unset_hf_deepspeed_config()
if training_args.visual_tokenizer_pretrained_path is not None:
visual_tokenizer = AutoModel.from_pretrained(
training_args.visual_tokenizer_pretrained_path,
image_processor_name_or_path=training_args.visual_tokenizer_pretrained_path
)
else:
visual_tokenizer_config = AutoConfig.for_model(
model_type=model_args.visual_tokenizer_type + "_visual_tokenizer",
vocab_size=model_args.visual_vocab_size,
tokenize_function=model_args.visual_tokenize_function,
tau=model_args.visual_tau,
depths=model_args.visual_depths,
drop_cls_token=model_args.visual_drop_cls_token,
hidden_stride=model_args.visual_hidden_stride,
)
visual_tokenizer = AutoModel.from_config(visual_tokenizer_config, train_from_scratch=True)
visual_tokenizer = visual_tokenizer.to(
device=torch.device(get_accelerator().device_name(os.getenv("LOCAL_RANK"))))
if getattr(training_args, 'hf_deepspeed_config', None) is not None:
set_hf_deepspeed_config(training_args.hf_deepspeed_config)
# 4. construct ovis model
model = Ovis(ovis_config, llm=llm, text_tokenizer=text_tokenizer, visual_tokenizer=visual_tokenizer,
train_from_scratch=True)
else: # load pretrained ovis model
model, loading_info = Ovis.from_pretrained(training_args.ovis_pretrained_path,
multimodal_max_length=model_args.multimodal_max_length,
output_loading_info=True)
rank0_print(BEGIN_LINE)
rank0_print(f'Loading info of Ovis:\n{loading_info}')
rank0_print(END_LINE)
training_args.vte_re_init = False
model.get_llm().config.use_cache = False
model.config.use_cache = False
text_tokenizer = model.get_text_tokenizer()
rank0_print(BEGIN_LINE)
rank0_print(f'model.config:\n{model.config}')
rank0_print(END_LINE)
# maybe re-init vte
if training_args.vte_re_init:
with deepspeed.zero.GatheredParameters([model.get_wte().weight]):
mean = model.get_wte().weight.mean().item()
std = model.get_wte().weight.std().item()
rank0_print(f'Statistics of embedding table of LLM: {mean=}, {std=}')
model.re_init_vte(mean, std)
# select train modules
model.requires_grad_(False)
for module in training_args.train_modules.split('|'):
if module == 'all':
model.requires_grad_(True)
elif module == 'llm':
model.get_llm().requires_grad_(True)
elif module == 'visual_tokenizer':
model.get_visual_tokenizer().requires_grad_(True)
elif module == 'visual_tokenizer.backbone':
model.get_visual_tokenizer().get_backbone().requires_grad_(True)
elif module.startswith('visual_tokenizer.backbone.layer.'):
layer_index = int(module[len('visual_tokenizer.backbone.layer.'):])
layer = model.get_visual_tokenizer().get_backbone_layer(layer_index)
layer.requires_grad_(True)
elif module == 'visual_tokenizer.head':
model.get_visual_tokenizer().get_head().requires_grad_(True)
elif module == 'vte':
model.get_vte().requires_grad_(True)
else:
raise ValueError(f'Invalid train module name: {module}')
rank0_print(BEGIN_LINE)
rank0_print('Parameters to train:')
for name, param in model.named_parameters():
if param.requires_grad:
rank0_print(name)
rank0_print(f'LLM\'s attn implementation: {model.get_llm().config._attn_implementation}')
rank0_print(END_LINE)
# construct data module
datasets = []
dataset_info_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
f'dataset/{training_args.dataset_info}.json')
with open(dataset_info_path, 'r', encoding='utf-8') as f:
dataset_info = json.load(f)
for name in training_args.dataset_names.split('|'):
info = dataset_info[name]
data_format = info['data_format']
if data_format == 'caption':
dataset = CaptionDataset(name, info, model, training_args)
elif data_format == 'conversation':
dataset = ConversationDataset(name, info, model, training_args)
else:
raise ValueError(f'Invalid data format `{data_format}` for dataset `{name}`')
datasets.append(dataset)
data_module = dict(
train_dataset=ConcatDataset(datasets),
data_collator=DataCollatorForMultimodalDataset(text_tokenizer)
)
# train
train_callbacks = [MonitorCallback]
if model_args.visual_tokenize_function == 'gumbel_argmax':
train_callbacks.append(TuneTauCallback)
trainer = Trainer(
model=model,
args=training_args,
callbacks=train_callbacks,
**data_module
)
rank0_print(BEGIN_LINE)
rank0_print('Dataset sample tensor:')
rank0_print(data_module['train_dataset'][0])
rank0_print(END_LINE)
rank0_print(BEGIN_LINE)
rank0_print('Dataset sample input_ids decoding:')
rank0_print(text_tokenizer.decode([x for x in data_module['train_dataset'][0]['input_ids'] if x >= 0]))
rank0_print(END_LINE)
rank0_print(BEGIN_LINE)
rank0_print('Dataset sample labels decoding:')
rank0_print(text_tokenizer.decode([x for x in data_module['train_dataset'][0]['labels'] if x >= 0]))
rank0_print(END_LINE)
rank0_print(BEGIN_LINE)
rank0_print(f'#param of model: {smart_unit(model.num_parameters())}')
rank0_print(f'#param of llm: {smart_unit(model.get_llm().num_parameters())}')
rank0_print(f'#param of visual_tokenizer: {smart_unit(model.get_visual_tokenizer().num_parameters())}')
rank0_print(f'#param of vte: {smart_unit(model.get_vte().weight.numel())}')
rank0_print(END_LINE)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
# save model
model.get_llm().config.use_cache = True
model.config.use_cache = True
trainer.save_model()
if __name__ == '__main__':
train()