lisa-on-cuda / train_ds.py
Xin Lai
Fix initialize_lisa_module during validation in train_ds.py
95dfca2
raw
history blame
21.4 kB
import argparse
import os
import shutil
import sys
import time
from functools import partial
import deepspeed
import numpy as np
import torch
import tqdm
import transformers
from peft import LoraConfig, get_peft_model
from torch.utils.tensorboard import SummaryWriter
from model.LISA import LISAForCausalLM
from model.llava import conversation as conversation_lib
from utils.dataset import HybridDataset, ValDataset, collate_fn
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
AverageMeter, ProgressMeter, Summary, dict_to_cuda,
intersectionAndUnionGPU)
def parse_args(args):
parser = argparse.ArgumentParser(description="LISA Model Training")
parser.add_argument("--local_rank", default=0, type=int, help="node rank")
parser.add_argument(
"--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview"
)
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
parser.add_argument(
"--precision",
default="bf16",
type=str,
choices=["fp32", "bf16", "fp16"],
help="precision for inference",
)
parser.add_argument("--image_size", default=1024, type=int, help="image size")
parser.add_argument("--model_max_length", default=512, type=int)
parser.add_argument("--lora_r", default=8, type=int)
parser.add_argument(
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
)
parser.add_argument("--load_in_8bit", action="store_true", default=False)
parser.add_argument("--load_in_4bit", action="store_true", default=False)
parser.add_argument(
"--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str
)
parser.add_argument("--sample_rates", default="9,3,3,1", type=str)
parser.add_argument(
"--sem_seg_data",
default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary",
type=str,
)
parser.add_argument(
"--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
)
parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
parser.add_argument("--dataset_dir", default="./dataset", type=str)
parser.add_argument("--log_base_dir", default="./runs", type=str)
parser.add_argument("--exp_name", default="lisa", type=str)
parser.add_argument("--epochs", default=10, type=int)
parser.add_argument("--steps_per_epoch", default=500, type=int)
parser.add_argument(
"--batch_size", default=2, type=int, help="batch size per device per step"
)
parser.add_argument(
"--grad_accumulation_steps",
default=10,
type=int,
)
parser.add_argument("--val_batch_size", default=1, type=int)
parser.add_argument("--workers", default=4, type=int)
parser.add_argument("--lr", default=0.0003, type=float)
parser.add_argument("--ce_loss_weight", default=1.0, type=float)
parser.add_argument("--dice_loss_weight", default=0.5, type=float)
parser.add_argument("--bce_loss_weight", default=2.0, type=float)
parser.add_argument("--lora_alpha", default=16, type=int)
parser.add_argument("--lora_dropout", default=0.05, type=float)
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
parser.add_argument("--explanatory", default=0.1, type=float)
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.95, type=float)
parser.add_argument("--num_classes_per_sample", default=3, type=int)
parser.add_argument("--exclude_val", action="store_true", default=False)
parser.add_argument("--no_eval", action="store_true", default=False)
parser.add_argument("--eval_only", action="store_true", default=False)
parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
parser.add_argument("--out_dim", default=256, type=int)
parser.add_argument("--resume", default="", type=str)
parser.add_argument("--print_freq", default=1, type=int)
parser.add_argument("--start_epoch", default=0, type=int)
parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
parser.add_argument("--train_mask_decoder", action="store_true", default=True)
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
parser.add_argument("--auto_resume", action="store_true", default=True)
parser.add_argument(
"--conv_type",
default="llava_v1",
type=str,
choices=["llava_v1", "llava_llama_2"],
)
return parser.parse_args(args)
def main(args):
args = parse_args(args)
args.log_dir = os.path.join(args.log_base_dir, args.exp_name)
if args.local_rank == 0:
os.makedirs(args.log_dir, exist_ok=True)
writer = SummaryWriter(args.log_dir)
else:
writer = None
# Create model
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.version,
cache_dir=None,
model_max_length=args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
num_added_tokens = tokenizer.add_tokens("[SEG]")
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
if args.use_mm_start_end:
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
model_args = {
"train_mask_decoder": args.train_mask_decoder,
"out_dim": args.out_dim,
"ce_loss_weight": args.ce_loss_weight,
"dice_loss_weight": args.dice_loss_weight,
"bce_loss_weight": args.bce_loss_weight,
"seg_token_idx": args.seg_token_idx,
"vision_pretrained": args.vision_pretrained,
"vision_tower": args.vision_tower,
"use_mm_start_end": args.use_mm_start_end,
}
torch_dtype = torch.float32
if args.precision == "bf16":
torch_dtype = torch.bfloat16
elif args.precision == "fp16":
torch_dtype = torch.half
model = LISAForCausalLM.from_pretrained(
args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args
)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch_dtype, device=args.local_rank)
if not args.eval_only:
model.get_model().initialize_lisa_modules(model.get_model().config)
for p in vision_tower.parameters():
p.requires_grad = False
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
conversation_lib.default_conversation = conversation_lib.conv_templates[
args.conv_type
]
lora_r = args.lora_r
if lora_r > 0:
def find_linear_layers(model, lora_target_modules):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, cls)
and all(
[
x not in name
for x in [
"visual_model",
"vision_tower",
"mm_projector",
"text_hidden_fcs",
]
]
)
and any([x in name for x in lora_target_modules])
):
lora_module_names.add(name)
return sorted(list(lora_module_names))
lora_alpha = args.lora_alpha
lora_dropout = args.lora_dropout
lora_target_modules = find_linear_layers(
model, args.lora_target_modules.split(",")
)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model.resize_token_embeddings(len(tokenizer))
# make text_hidden_fcs, mask_decoder, lm_head, embed_tokens trainable
for n, p in model.named_parameters():
if any(
[
x in n
for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"]
]
):
print("n: ", n, "p.shape: ", p.shape)
p.requires_grad = True
world_size = torch.cuda.device_count()
args.distributed = world_size > 1
train_dataset = HybridDataset(
args.dataset_dir,
tokenizer,
args.vision_tower,
samples_per_epoch=args.batch_size
* args.grad_accumulation_steps
* args.steps_per_epoch
* world_size,
precision=args.precision,
image_size=args.image_size,
num_classes_per_sample=args.num_classes_per_sample,
exclude_val=args.exclude_val,
dataset=args.dataset,
sample_rate=[float(x) for x in args.sample_rates.split(",")],
sem_seg_data=args.sem_seg_data,
refer_seg_data=args.refer_seg_data,
vqa_data=args.vqa_data,
reason_seg_data=args.reason_seg_data,
explanatory=args.explanatory,
)
if args.no_eval == False:
val_dataset = ValDataset(
args.dataset_dir,
tokenizer,
args.vision_tower,
args.val_dataset,
args.image_size,
)
print(
f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples."
)
else:
val_dataset = None
print(f"Training with {len(train_dataset)} examples.")
ds_config = {
"train_micro_batch_size_per_gpu": args.batch_size,
"gradient_accumulation_steps": args.grad_accumulation_steps,
"optimizer": {
"type": "AdamW",
"params": {
"lr": args.lr,
"weight_decay": 0.0,
"betas": (args.beta1, args.beta2),
},
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"total_num_steps": args.epochs * args.steps_per_epoch,
"warmup_min_lr": 0,
"warmup_max_lr": args.lr,
"warmup_num_steps": 100,
"warmup_type": "linear",
},
},
"fp16": {
"enabled": args.precision == "fp16",
},
"bf16": {
"enabled": args.precision == "bf16",
},
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8,
},
}
model_engine, optimizer, train_loader, scheduler = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
training_data=train_dataset,
collate_fn=partial(
collate_fn,
tokenizer=tokenizer,
conv_type=args.conv_type,
use_mm_start_end=args.use_mm_start_end,
local_rank=args.local_rank,
),
config=ds_config,
)
# resume deepspeed checkpoint
if args.auto_resume and len(args.resume) == 0:
resume = os.path.join(args.log_dir, "ckpt_model")
if os.path.exists(resume):
args.resume = resume
if args.resume:
load_path, client_state = model_engine.load_checkpoint(args.resume)
with open(os.path.join(args.resume, "latest"), "r") as f:
ckpt_dir = f.readlines()[0].strip()
args.start_epoch = (
int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch
)
print(
"resume training from {}, start from epoch {}".format(
args.resume, args.start_epoch
)
)
# validation dataset
if val_dataset is not None:
assert args.val_batch_size == 1
val_sampler = torch.utils.data.distributed.DistributedSampler(
val_dataset, shuffle=False, drop_last=False
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.val_batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=False,
sampler=val_sampler,
collate_fn=partial(
collate_fn,
tokenizer=tokenizer,
conv_type=args.conv_type,
use_mm_start_end=args.use_mm_start_end,
local_rank=args.local_rank,
),
)
train_iter = iter(train_loader)
best_score, cur_ciou = 0.0, 0.0
if args.eval_only:
giou, ciou = validate(val_loader, model_engine, 0, writer, args)
exit()
for epoch in range(args.start_epoch, args.epochs):
# train for one epoch
train_iter = train(
train_loader,
model_engine,
epoch,
scheduler,
writer,
train_iter,
args,
)
if args.no_eval == False:
giou, ciou = validate(val_loader, model_engine, epoch, writer, args)
is_best = giou > best_score
best_score = max(giou, best_score)
cur_ciou = ciou if is_best else cur_ciou
if args.no_eval or is_best:
save_dir = os.path.join(args.log_dir, "ckpt_model")
if args.local_rank == 0:
torch.save(
{"epoch": epoch},
os.path.join(
args.log_dir,
"meta_log_giou{:.3f}_ciou{:.3f}.pth".format(
best_score, cur_ciou
),
),
)
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
torch.distributed.barrier()
model_engine.save_checkpoint(save_dir)
def train(
train_loader,
model,
epoch,
scheduler,
writer,
train_iter,
args,
):
"""Main training loop."""
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":.4f")
ce_losses = AverageMeter("CeLoss", ":.4f")
mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f")
mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f")
mask_losses = AverageMeter("MaskLoss", ":.4f")
progress = ProgressMeter(
args.steps_per_epoch,
[
batch_time,
losses,
ce_losses,
mask_losses,
mask_bce_losses,
mask_dice_losses,
],
prefix="Epoch: [{}]".format(epoch),
)
# switch to train mode
model.train()
end = time.time()
for global_step in range(args.steps_per_epoch):
for i in range(args.grad_accumulation_steps):
try:
input_dict = next(train_iter)
except:
train_iter = iter(train_loader)
input_dict = next(train_iter)
data_time.update(time.time() - end)
input_dict = dict_to_cuda(input_dict)
if args.precision == "fp16":
input_dict["images"] = input_dict["images"].half()
input_dict["images_clip"] = input_dict["images_clip"].half()
elif args.precision == "bf16":
input_dict["images"] = input_dict["images"].bfloat16()
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
else:
input_dict["images"] = input_dict["images"].float()
input_dict["images_clip"] = input_dict["images_clip"].float()
output_dict = model(**input_dict)
loss = output_dict["loss"]
ce_loss = output_dict["ce_loss"]
mask_bce_loss = output_dict["mask_bce_loss"]
mask_dice_loss = output_dict["mask_dice_loss"]
mask_loss = output_dict["mask_loss"]
losses.update(loss.item(), input_dict["images"].size(0))
ce_losses.update(ce_loss.item(), input_dict["images"].size(0))
mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0))
mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0))
mask_losses.update(mask_loss.item(), input_dict["images"].size(0))
model.backward(loss)
model.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if global_step % args.print_freq == 0:
if args.distributed:
batch_time.all_reduce()
data_time.all_reduce()
losses.all_reduce()
ce_losses.all_reduce()
mask_bce_losses.all_reduce()
mask_dice_losses.all_reduce()
mask_losses.all_reduce()
if args.local_rank == 0:
progress.display(global_step + 1)
writer.add_scalar("train/loss", losses.avg, global_step)
writer.add_scalar("train/ce_loss", ce_losses.avg, global_step)
writer.add_scalar(
"train/mask_bce_loss", mask_bce_losses.avg, global_step
)
writer.add_scalar(
"train/mask_dice_loss", mask_dice_losses.avg, global_step
)
writer.add_scalar("train/mask_loss", mask_losses.avg, global_step)
writer.add_scalar(
"metrics/total_secs_per_batch", batch_time.avg, global_step
)
writer.add_scalar(
"metrics/data_secs_per_batch", data_time.avg, global_step
)
batch_time.reset()
data_time.reset()
losses.reset()
ce_losses.reset()
mask_bce_losses.reset()
mask_dice_losses.reset()
mask_losses.reset()
if global_step != 0:
curr_lr = scheduler.get_last_lr()
if args.local_rank == 0:
writer.add_scalar("train/lr", curr_lr[0], global_step)
return train_iter
def validate(val_loader, model_engine, epoch, writer, args):
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
model_engine.eval()
for input_dict in tqdm.tqdm(val_loader):
torch.cuda.empty_cache()
input_dict = dict_to_cuda(input_dict)
if args.precision == "fp16":
input_dict["images"] = input_dict["images"].half()
input_dict["images_clip"] = input_dict["images_clip"].half()
elif args.precision == "bf16":
input_dict["images"] = input_dict["images"].bfloat16()
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
else:
input_dict["images"] = input_dict["images"].float()
input_dict["images_clip"] = input_dict["images_clip"].float()
with torch.no_grad():
output_dict = model_engine(**input_dict)
pred_masks = output_dict["pred_masks"]
masks_list = output_dict["gt_masks"][0].int()
output_list = (pred_masks[0] > 0).int()
assert len(pred_masks) == 1
intersection, union, acc_iou = 0.0, 0.0, 0.0
for mask_i, output_i in zip(masks_list, output_list):
intersection_i, union_i, _ = intersectionAndUnionGPU(
output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255
)
intersection += intersection_i
union += union_i
acc_iou += intersection_i / (union_i + 1e-5)
acc_iou[union_i == 0] += 1.0 # no-object target
intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0]
intersection_meter.update(intersection), union_meter.update(
union
), acc_iou_meter.update(acc_iou, n=masks_list.shape[0])
intersection_meter.all_reduce()
union_meter.all_reduce()
acc_iou_meter.all_reduce()
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
ciou = iou_class[1]
giou = acc_iou_meter.avg[1]
if args.local_rank == 0:
writer.add_scalar("val/giou", giou, epoch)
writer.add_scalar("val/ciou", ciou, epoch)
print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou))
return giou, ciou
if __name__ == "__main__":
main(sys.argv[1:])