CHEMISTral7Bv0.3 / train.py
Clemspace's picture
Initial model upload
cb9e677
raw
history blame
9.48 kB
import dataclasses
import logging
import os
import pprint
from contextlib import ExitStack
from pathlib import Path
from typing import TYPE_CHECKING
import fire
import torch.cuda
import torch.distributed as dist
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from torch.optim import AdamW, lr_scheduler
from finetune.args import TrainArgs
from finetune.checkpointing import Checkpointer
from finetune.data.data_loader import build_data_loader
from finetune.distributed import (
BACKEND,
avg_aggregate,
get_rank,
get_world_size,
is_torchrun,
set_device,
)
from finetune.eval import evaluate
from finetune.loss import compute_loss_with_mask
from finetune.mixed_precision import (
downcast_mixed_precision,
prepare_mixed_precision,
upcast_mixed_precision,
)
from finetune.monitoring.metrics_logger import (
MetricsLogger,
eval_log_msg,
get_eval_logs,
get_train_logs,
train_log_msg,
)
from finetune.monitoring.utils import set_logger
from finetune.utils import (
TrainState,
logged_closing,
set_random_seed,
)
from finetune.wrapped_model import load_model
if TYPE_CHECKING:
from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
logger = logging.getLogger("train")
def main_logger_info(message: str) -> None:
if get_rank() == 0:
logger.info(message)
#wandb.log({"info": message})
def train(config: str):
args: TrainArgs = TrainArgs.load(config, drop_extra_fields=False)
print(f"args: {args}")
set_logger(logging.INFO)
#if get_rank() == 0:
# wandb.init(project="CHEMISTral7b-ft",entity = "oops")
# wandb.config.update(dataclasses.asdict(args))
with ExitStack() as exit_stack:
_train(args, exit_stack)
logger.info("Closed everything!")
def _train(
args: TrainArgs,
exit_stack: ExitStack,
):
# 1. Initial setup and checks
set_random_seed(args.seed)
# Init NCCL
if "LOCAL_RANK" in os.environ:
set_device()
logger.info("Going to init comms...")
dist.init_process_group(backend=BACKEND)
else:
logger.error(
"PyTorch environment is not correctly initialized. This message should only be displayed when testing."
)
# 2. Init run dir
main_logger_info(f"Run dir: {args.run_dir}")
run_dir = Path(args.run_dir)
if is_torchrun():
if run_dir.exists():
raise RuntimeError(
f"Run dir {run_dir} already exists. Make sure to either rename `run_dir` or remove {run_dir}."
)
dist.barrier()
run_dir.mkdir(exist_ok=True, parents=True)
args_path = run_dir / "args.yaml"
if not args_path.exists():
args.save(args_path)
main_logger_info(f"TrainArgs: {pprint.pformat(dataclasses.asdict(args))}")
# 3. Get loggers
metrics_logger: MetricsLogger = MetricsLogger(
run_dir,
tag="train",
is_master=get_rank() == 0,
wandb_args=args.wandb,
mlflow_args=args.mlflow,
config=dataclasses.asdict(args),
)
exit_stack.enter_context(logged_closing(metrics_logger, "metrics_logger"))
eval_logger: MetricsLogger = MetricsLogger(
run_dir,
tag="eval",
is_master=get_rank() == 0,
wandb_args=args.wandb,
mlflow_args=args.mlflow,
config=dataclasses.asdict(args),
)
exit_stack.enter_context(logged_closing(eval_logger, "eval_logger"))
# 5. Potentially download model
if Path(args.model_id_or_path).is_dir():
model_folder = Path(args.model_id_or_path)
else:
raise ValueError(
"Invalid folder path. Please set `args.initial_model` to a valid folder path."
)
# 6. Load function calling instruct tokenizer
instruct_tokenizer: InstructTokenizerBase = MistralTokenizer.v3().instruct_tokenizer # type: ignore
# 7. Load data loaders
data_loader = build_data_loader(
instruct_tokenizer=instruct_tokenizer,
args=args.data,
seq_len=args.seq_len,
batch_size=args.batch_size,
seed=args.seed,
rank=get_rank(), # DDP rank
world_size=get_world_size(), # DDP world_size
is_eval=False,
)
if not args.no_eval:
assert (
args.data.eval_instruct_data != ""
), "Either set `no_eval` to True or provide evaluation samples under `data.eval_instruct_data`"
eval_data_loader = build_data_loader(
instruct_tokenizer=instruct_tokenizer,
args=args.data,
seq_len=args.seq_len,
batch_size=args.batch_size,
seed=None,
rank=get_rank(), # DDP rank
world_size=get_world_size(), # DDP world_size
is_eval=True,
)
# pre-load all eval tokens
eval_batches = list(eval_data_loader)
# 8. Load model
# Define mixed precision
param_dtype = torch.bfloat16
optim_dtype = torch.float32
assert args.lora is not None, "`args.lora` should be set to a valid value."
model = load_model(
folder=model_folder,
lora=args.lora,
checkpoint=args.checkpoint,
param_dtype=param_dtype,
)
# 9. Load optimizer
optimizer = AdamW(
model.parameters(),
lr=args.optim.lr,
betas=(0.9, 0.95),
eps=1e-08,
weight_decay=args.optim.weight_decay,
)
scheduler = lr_scheduler.OneCycleLR(
optimizer,
max_lr=args.optim.lr,
total_steps=args.max_steps,
pct_start=args.optim.pct_start,
)
state = TrainState(args.max_steps)
# 10. Initialize checkpointer
checkpointer = Checkpointer(
model=model,
state=state,
run_dir=run_dir,
optimizer=optimizer,
num_ckpt_keep=args.num_ckpt_keep,
)
# 11. Prepare mixed precision
prepare_mixed_precision(
model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype
)
# 12. train!
model.train()
torch.cuda.empty_cache()
while state.step < args.max_steps:
state.start_step()
is_last_step = state.step == args.max_steps
optimizer.zero_grad()
loss = torch.tensor([0.0], device="cuda")
n_batch_tokens: int = 0
for i in range(args.num_microbatches):
# batch
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda(non_blocking=True)
y = torch.from_numpy(batch.y).cuda(non_blocking=True)
y_mask = (
torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
if batch.y_mask is not None
else None
)
# forward / backward
output = model(
input_ids=x,
seqlens=batch.sizes,
)
mb_loss = compute_loss_with_mask(output, y, y_mask)
mb_loss.backward()
loss += mb_loss.detach()
n_batch_tokens += x.numel()
if i < args.num_microbatches - 1:
# synchronize CUDA to re-run backward
assert args.num_microbatches > 1 # should not happen
torch.cuda.synchronize()
if args.num_microbatches > 1:
loss /= args.num_microbatches
for p in model.parameters():
if p.requires_grad:
assert p.grad is not None
p.grad.div_(args.num_microbatches)
# upcast params for optimizer update
upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype)
# clip gra d norm
model.clip_grad_norm_(max_norm=args.max_norm)
# optimizer step
optimizer.step()
# downcast params for forward & backward
downcast_mixed_precision(model.parameters(), param_dtype=param_dtype)
last_lr = scheduler.get_last_lr()[0]
scheduler.step()
# Host sync
loss_item = loss.item()
avg_loss = avg_aggregate(loss_item)
if not args.no_eval and (
(args.eval_freq > 0 and state.step % args.eval_freq == 0) or is_last_step
):
# write perplexity to state
evaluate(model, eval_batches, state)
eval_logs = get_eval_logs(
state.step, avg_loss, state.this_eval_perplexity, state.this_eval_loss
)
main_logger_info(eval_log_msg(eval_logs))
eval_logger.log(eval_logs, step=state.step)
# Timing
state.end_step(n_batch_tokens)
if state.step % args.log_freq == 0:
train_logs = get_train_logs(
state,
avg_loss,
last_lr,
torch.cuda.max_memory_allocated(),
torch.cuda.memory_allocated(),
args,
)
main_logger_info(train_log_msg(state, logs=train_logs, loss=avg_loss))
metrics_logger.log(train_logs, step=state.step)
if not args.no_ckpt and (
(args.ckpt_freq > 0 and state.step % args.ckpt_freq == 0) or is_last_step
):
checkpointer.save_checkpoint(
save_only_lora=args.ckpt_only_lora,
dtype=param_dtype,
instruct_tokenizer=instruct_tokenizer,
)
main_logger_info("done!")
if __name__ == "__main__":
"""See README.md for usage."""
fire.Fire(train)