Spaces:
Paused
Paused
import argparse | |
import os | |
import shutil | |
import sys | |
import time | |
from functools import partial | |
import deepspeed | |
import numpy as np | |
import torch | |
import tqdm | |
import transformers | |
from peft import LoraConfig, get_peft_model | |
from torch.utils.tensorboard import SummaryWriter | |
from model.LISA import LISAForCausalLM | |
from model.llava import conversation as conversation_lib | |
from utils.dataset import HybridDataset, ValDataset, collate_fn | |
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, | |
AverageMeter, ProgressMeter, Summary, dict_to_cuda, | |
intersectionAndUnionGPU) | |
def parse_args(args): | |
parser = argparse.ArgumentParser(description="LISA Model Training") | |
parser.add_argument("--local_rank", default=0, type=int, help="node rank") | |
parser.add_argument( | |
"--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview" | |
) | |
parser.add_argument("--vis_save_path", default="./vis_output", type=str) | |
parser.add_argument( | |
"--precision", | |
default="bf16", | |
type=str, | |
choices=["fp32", "bf16", "fp16"], | |
help="precision for inference", | |
) | |
parser.add_argument("--image_size", default=1024, type=int, help="image size") | |
parser.add_argument("--model_max_length", default=512, type=int) | |
parser.add_argument("--lora_r", default=8, type=int) | |
parser.add_argument( | |
"--vision-tower", default="openai/clip-vit-large-patch14", type=str | |
) | |
parser.add_argument("--load_in_8bit", action="store_true", default=False) | |
parser.add_argument("--load_in_4bit", action="store_true", default=False) | |
parser.add_argument( | |
"--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str | |
) | |
parser.add_argument("--sample_rates", default="9,3,3,1", type=str) | |
parser.add_argument( | |
"--sem_seg_data", | |
default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary", | |
type=str, | |
) | |
parser.add_argument( | |
"--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str | |
) | |
parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str) | |
parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str) | |
parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str) | |
parser.add_argument("--dataset_dir", default="./dataset", type=str) | |
parser.add_argument("--log_base_dir", default="./runs", type=str) | |
parser.add_argument("--exp_name", default="lisa", type=str) | |
parser.add_argument("--epochs", default=10, type=int) | |
parser.add_argument("--steps_per_epoch", default=500, type=int) | |
parser.add_argument( | |
"--batch_size", default=2, type=int, help="batch size per device per step" | |
) | |
parser.add_argument( | |
"--grad_accumulation_steps", | |
default=10, | |
type=int, | |
) | |
parser.add_argument("--val_batch_size", default=1, type=int) | |
parser.add_argument("--workers", default=4, type=int) | |
parser.add_argument("--lr", default=0.0003, type=float) | |
parser.add_argument("--ce_loss_weight", default=1.0, type=float) | |
parser.add_argument("--dice_loss_weight", default=0.5, type=float) | |
parser.add_argument("--bce_loss_weight", default=2.0, type=float) | |
parser.add_argument("--lora_alpha", default=16, type=int) | |
parser.add_argument("--lora_dropout", default=0.05, type=float) | |
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str) | |
parser.add_argument("--explanatory", default=0.1, type=float) | |
parser.add_argument("--beta1", default=0.9, type=float) | |
parser.add_argument("--beta2", default=0.95, type=float) | |
parser.add_argument("--num_classes_per_sample", default=3, type=int) | |
parser.add_argument("--exclude_val", action="store_true", default=False) | |
parser.add_argument("--no_eval", action="store_true", default=False) | |
parser.add_argument("--eval_only", action="store_true", default=False) | |
parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str) | |
parser.add_argument("--out_dim", default=256, type=int) | |
parser.add_argument("--resume", default="", type=str) | |
parser.add_argument("--print_freq", default=1, type=int) | |
parser.add_argument("--start_epoch", default=0, type=int) | |
parser.add_argument("--gradient_checkpointing", action="store_true", default=True) | |
parser.add_argument("--train_mask_decoder", action="store_true", default=True) | |
parser.add_argument("--use_mm_start_end", action="store_true", default=True) | |
parser.add_argument("--auto_resume", action="store_true", default=True) | |
parser.add_argument( | |
"--conv_type", | |
default="llava_v1", | |
type=str, | |
choices=["llava_v1", "llava_llama_2"], | |
) | |
return parser.parse_args(args) | |
def main(args): | |
args = parse_args(args) | |
args.log_dir = os.path.join(args.log_base_dir, args.exp_name) | |
if args.local_rank == 0: | |
os.makedirs(args.log_dir, exist_ok=True) | |
writer = SummaryWriter(args.log_dir) | |
else: | |
writer = None | |
# Create model | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
args.version, | |
cache_dir=None, | |
model_max_length=args.model_max_length, | |
padding_side="right", | |
use_fast=False, | |
) | |
tokenizer.pad_token = tokenizer.unk_token | |
num_added_tokens = tokenizer.add_tokens("[SEG]") | |
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] | |
if args.use_mm_start_end: | |
tokenizer.add_tokens( | |
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True | |
) | |
model_args = { | |
"train_mask_decoder": args.train_mask_decoder, | |
"out_dim": args.out_dim, | |
"ce_loss_weight": args.ce_loss_weight, | |
"dice_loss_weight": args.dice_loss_weight, | |
"bce_loss_weight": args.bce_loss_weight, | |
"seg_token_idx": args.seg_token_idx, | |
"vision_pretrained": args.vision_pretrained, | |
"vision_tower": args.vision_tower, | |
"use_mm_start_end": args.use_mm_start_end, | |
} | |
torch_dtype = torch.float32 | |
if args.precision == "bf16": | |
torch_dtype = torch.bfloat16 | |
elif args.precision == "fp16": | |
torch_dtype = torch.half | |
model = LISAForCausalLM.from_pretrained( | |
args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args | |
) | |
model.config.eos_token_id = tokenizer.eos_token_id | |
model.config.bos_token_id = tokenizer.bos_token_id | |
model.config.pad_token_id = tokenizer.pad_token_id | |
model.enable_input_require_grads() | |
model.gradient_checkpointing_enable() | |
model.get_model().initialize_vision_modules(model.get_model().config) | |
vision_tower = model.get_model().get_vision_tower() | |
vision_tower.to(dtype=torch_dtype, device=args.local_rank) | |
if not args.eval_only: | |
model.get_model().initialize_lisa_modules(model.get_model().config) | |
for p in vision_tower.parameters(): | |
p.requires_grad = False | |
for p in model.get_model().mm_projector.parameters(): | |
p.requires_grad = False | |
conversation_lib.default_conversation = conversation_lib.conv_templates[ | |
args.conv_type | |
] | |
lora_r = args.lora_r | |
if lora_r > 0: | |
def find_linear_layers(model, lora_target_modules): | |
cls = torch.nn.Linear | |
lora_module_names = set() | |
for name, module in model.named_modules(): | |
if ( | |
isinstance(module, cls) | |
and all( | |
[ | |
x not in name | |
for x in [ | |
"visual_model", | |
"vision_tower", | |
"mm_projector", | |
"text_hidden_fcs", | |
] | |
] | |
) | |
and any([x in name for x in lora_target_modules]) | |
): | |
lora_module_names.add(name) | |
return sorted(list(lora_module_names)) | |
lora_alpha = args.lora_alpha | |
lora_dropout = args.lora_dropout | |
lora_target_modules = find_linear_layers( | |
model, args.lora_target_modules.split(",") | |
) | |
lora_config = LoraConfig( | |
r=lora_r, | |
lora_alpha=lora_alpha, | |
target_modules=lora_target_modules, | |
lora_dropout=lora_dropout, | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
model = get_peft_model(model, lora_config) | |
model.print_trainable_parameters() | |
model.resize_token_embeddings(len(tokenizer)) | |
# make text_hidden_fcs, mask_decoder, lm_head, embed_tokens trainable | |
for n, p in model.named_parameters(): | |
if any( | |
[ | |
x in n | |
for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"] | |
] | |
): | |
print("n: ", n, "p.shape: ", p.shape) | |
p.requires_grad = True | |
world_size = torch.cuda.device_count() | |
args.distributed = world_size > 1 | |
train_dataset = HybridDataset( | |
args.dataset_dir, | |
tokenizer, | |
args.vision_tower, | |
samples_per_epoch=args.batch_size | |
* args.grad_accumulation_steps | |
* args.steps_per_epoch | |
* world_size, | |
precision=args.precision, | |
image_size=args.image_size, | |
num_classes_per_sample=args.num_classes_per_sample, | |
exclude_val=args.exclude_val, | |
dataset=args.dataset, | |
sample_rate=[float(x) for x in args.sample_rates.split(",")], | |
sem_seg_data=args.sem_seg_data, | |
refer_seg_data=args.refer_seg_data, | |
vqa_data=args.vqa_data, | |
reason_seg_data=args.reason_seg_data, | |
explanatory=args.explanatory, | |
) | |
if args.no_eval == False: | |
val_dataset = ValDataset( | |
args.dataset_dir, | |
tokenizer, | |
args.vision_tower, | |
args.val_dataset, | |
args.image_size, | |
) | |
print( | |
f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples." | |
) | |
else: | |
val_dataset = None | |
print(f"Training with {len(train_dataset)} examples.") | |
ds_config = { | |
"train_micro_batch_size_per_gpu": args.batch_size, | |
"gradient_accumulation_steps": args.grad_accumulation_steps, | |
"optimizer": { | |
"type": "AdamW", | |
"params": { | |
"lr": args.lr, | |
"weight_decay": 0.0, | |
"betas": (args.beta1, args.beta2), | |
}, | |
}, | |
"scheduler": { | |
"type": "WarmupDecayLR", | |
"params": { | |
"total_num_steps": args.epochs * args.steps_per_epoch, | |
"warmup_min_lr": 0, | |
"warmup_max_lr": args.lr, | |
"warmup_num_steps": 100, | |
"warmup_type": "linear", | |
}, | |
}, | |
"fp16": { | |
"enabled": args.precision == "fp16", | |
}, | |
"bf16": { | |
"enabled": args.precision == "bf16", | |
}, | |
"gradient_clipping": 1.0, | |
"zero_optimization": { | |
"stage": 2, | |
"contiguous_gradients": True, | |
"overlap_comm": True, | |
"reduce_scatter": True, | |
"reduce_bucket_size": 5e8, | |
"allgather_bucket_size": 5e8, | |
}, | |
} | |
model_engine, optimizer, train_loader, scheduler = deepspeed.initialize( | |
model=model, | |
model_parameters=model.parameters(), | |
training_data=train_dataset, | |
collate_fn=partial( | |
collate_fn, | |
tokenizer=tokenizer, | |
conv_type=args.conv_type, | |
use_mm_start_end=args.use_mm_start_end, | |
local_rank=args.local_rank, | |
), | |
config=ds_config, | |
) | |
# resume deepspeed checkpoint | |
if args.auto_resume and len(args.resume) == 0: | |
resume = os.path.join(args.log_dir, "ckpt_model") | |
if os.path.exists(resume): | |
args.resume = resume | |
if args.resume: | |
load_path, client_state = model_engine.load_checkpoint(args.resume) | |
with open(os.path.join(args.resume, "latest"), "r") as f: | |
ckpt_dir = f.readlines()[0].strip() | |
args.start_epoch = ( | |
int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch | |
) | |
print( | |
"resume training from {}, start from epoch {}".format( | |
args.resume, args.start_epoch | |
) | |
) | |
# validation dataset | |
if val_dataset is not None: | |
assert args.val_batch_size == 1 | |
val_sampler = torch.utils.data.distributed.DistributedSampler( | |
val_dataset, shuffle=False, drop_last=False | |
) | |
val_loader = torch.utils.data.DataLoader( | |
val_dataset, | |
batch_size=args.val_batch_size, | |
shuffle=False, | |
num_workers=args.workers, | |
pin_memory=False, | |
sampler=val_sampler, | |
collate_fn=partial( | |
collate_fn, | |
tokenizer=tokenizer, | |
conv_type=args.conv_type, | |
use_mm_start_end=args.use_mm_start_end, | |
local_rank=args.local_rank, | |
), | |
) | |
train_iter = iter(train_loader) | |
best_score, cur_ciou = 0.0, 0.0 | |
if args.eval_only: | |
giou, ciou = validate(val_loader, model_engine, 0, writer, args) | |
exit() | |
for epoch in range(args.start_epoch, args.epochs): | |
# train for one epoch | |
train_iter = train( | |
train_loader, | |
model_engine, | |
epoch, | |
scheduler, | |
writer, | |
train_iter, | |
args, | |
) | |
if args.no_eval == False: | |
giou, ciou = validate(val_loader, model_engine, epoch, writer, args) | |
is_best = giou > best_score | |
best_score = max(giou, best_score) | |
cur_ciou = ciou if is_best else cur_ciou | |
if args.no_eval or is_best: | |
save_dir = os.path.join(args.log_dir, "ckpt_model") | |
if args.local_rank == 0: | |
torch.save( | |
{"epoch": epoch}, | |
os.path.join( | |
args.log_dir, | |
"meta_log_giou{:.3f}_ciou{:.3f}.pth".format( | |
best_score, cur_ciou | |
), | |
), | |
) | |
if os.path.exists(save_dir): | |
shutil.rmtree(save_dir) | |
torch.distributed.barrier() | |
model_engine.save_checkpoint(save_dir) | |
def train( | |
train_loader, | |
model, | |
epoch, | |
scheduler, | |
writer, | |
train_iter, | |
args, | |
): | |
"""Main training loop.""" | |
batch_time = AverageMeter("Time", ":6.3f") | |
data_time = AverageMeter("Data", ":6.3f") | |
losses = AverageMeter("Loss", ":.4f") | |
ce_losses = AverageMeter("CeLoss", ":.4f") | |
mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f") | |
mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f") | |
mask_losses = AverageMeter("MaskLoss", ":.4f") | |
progress = ProgressMeter( | |
args.steps_per_epoch, | |
[ | |
batch_time, | |
losses, | |
ce_losses, | |
mask_losses, | |
mask_bce_losses, | |
mask_dice_losses, | |
], | |
prefix="Epoch: [{}]".format(epoch), | |
) | |
# switch to train mode | |
model.train() | |
end = time.time() | |
for global_step in range(args.steps_per_epoch): | |
for i in range(args.grad_accumulation_steps): | |
try: | |
input_dict = next(train_iter) | |
except: | |
train_iter = iter(train_loader) | |
input_dict = next(train_iter) | |
data_time.update(time.time() - end) | |
input_dict = dict_to_cuda(input_dict) | |
if args.precision == "fp16": | |
input_dict["images"] = input_dict["images"].half() | |
input_dict["images_clip"] = input_dict["images_clip"].half() | |
elif args.precision == "bf16": | |
input_dict["images"] = input_dict["images"].bfloat16() | |
input_dict["images_clip"] = input_dict["images_clip"].bfloat16() | |
else: | |
input_dict["images"] = input_dict["images"].float() | |
input_dict["images_clip"] = input_dict["images_clip"].float() | |
output_dict = model(**input_dict) | |
loss = output_dict["loss"] | |
ce_loss = output_dict["ce_loss"] | |
mask_bce_loss = output_dict["mask_bce_loss"] | |
mask_dice_loss = output_dict["mask_dice_loss"] | |
mask_loss = output_dict["mask_loss"] | |
losses.update(loss.item(), input_dict["images"].size(0)) | |
ce_losses.update(ce_loss.item(), input_dict["images"].size(0)) | |
mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0)) | |
mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0)) | |
mask_losses.update(mask_loss.item(), input_dict["images"].size(0)) | |
model.backward(loss) | |
model.step() | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if global_step % args.print_freq == 0: | |
if args.distributed: | |
batch_time.all_reduce() | |
data_time.all_reduce() | |
losses.all_reduce() | |
ce_losses.all_reduce() | |
mask_bce_losses.all_reduce() | |
mask_dice_losses.all_reduce() | |
mask_losses.all_reduce() | |
if args.local_rank == 0: | |
progress.display(global_step + 1) | |
writer.add_scalar("train/loss", losses.avg, global_step) | |
writer.add_scalar("train/ce_loss", ce_losses.avg, global_step) | |
writer.add_scalar( | |
"train/mask_bce_loss", mask_bce_losses.avg, global_step | |
) | |
writer.add_scalar( | |
"train/mask_dice_loss", mask_dice_losses.avg, global_step | |
) | |
writer.add_scalar("train/mask_loss", mask_losses.avg, global_step) | |
writer.add_scalar( | |
"metrics/total_secs_per_batch", batch_time.avg, global_step | |
) | |
writer.add_scalar( | |
"metrics/data_secs_per_batch", data_time.avg, global_step | |
) | |
batch_time.reset() | |
data_time.reset() | |
losses.reset() | |
ce_losses.reset() | |
mask_bce_losses.reset() | |
mask_dice_losses.reset() | |
mask_losses.reset() | |
if global_step != 0: | |
curr_lr = scheduler.get_last_lr() | |
if args.local_rank == 0: | |
writer.add_scalar("train/lr", curr_lr[0], global_step) | |
return train_iter | |
def validate(val_loader, model_engine, epoch, writer, args): | |
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM) | |
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM) | |
acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM) | |
model_engine.eval() | |
for input_dict in tqdm.tqdm(val_loader): | |
torch.cuda.empty_cache() | |
input_dict = dict_to_cuda(input_dict) | |
if args.precision == "fp16": | |
input_dict["images"] = input_dict["images"].half() | |
input_dict["images_clip"] = input_dict["images_clip"].half() | |
elif args.precision == "bf16": | |
input_dict["images"] = input_dict["images"].bfloat16() | |
input_dict["images_clip"] = input_dict["images_clip"].bfloat16() | |
else: | |
input_dict["images"] = input_dict["images"].float() | |
input_dict["images_clip"] = input_dict["images_clip"].float() | |
with torch.no_grad(): | |
output_dict = model_engine(**input_dict) | |
pred_masks = output_dict["pred_masks"] | |
masks_list = output_dict["gt_masks"][0].int() | |
output_list = (pred_masks[0] > 0).int() | |
assert len(pred_masks) == 1 | |
intersection, union, acc_iou = 0.0, 0.0, 0.0 | |
for mask_i, output_i in zip(masks_list, output_list): | |
intersection_i, union_i, _ = intersectionAndUnionGPU( | |
output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255 | |
) | |
intersection += intersection_i | |
union += union_i | |
acc_iou += intersection_i / (union_i + 1e-5) | |
acc_iou[union_i == 0] += 1.0 # no-object target | |
intersection, union = intersection.cpu().numpy(), union.cpu().numpy() | |
acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0] | |
intersection_meter.update(intersection), union_meter.update( | |
union | |
), acc_iou_meter.update(acc_iou, n=masks_list.shape[0]) | |
intersection_meter.all_reduce() | |
union_meter.all_reduce() | |
acc_iou_meter.all_reduce() | |
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) | |
ciou = iou_class[1] | |
giou = acc_iou_meter.avg[1] | |
if args.local_rank == 0: | |
writer.add_scalar("val/giou", giou, epoch) | |
writer.add_scalar("val/ciou", ciou, epoch) | |
print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou)) | |
return giou, ciou | |
if __name__ == "__main__": | |
main(sys.argv[1:]) | |