import pprint from functools import partial from tqdm import tqdm, trange import numpy as np import mlxu import jax import jax.numpy as jnp from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as PS from flax.training.train_state import TrainState from EasyLM.data import DatasetFactory from EasyLM.checkpoint import StreamingCheckpointer from EasyLM.optimizers import OptimizerFactory from EasyLM.jax_utils import ( JaxRNG, next_rng, match_partition_rules, cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name, set_random_seed, average_metrics, get_weight_decay_mask, make_shard_and_gather_fns, with_sharding_constraint, ) from EasyLM.models.llama.llama_model import ( LLaMAConfig, FlaxLLaMAForCausalLMModule ) FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( seed=42, initialize_jax_distributed=False, mesh_dim='1,-1,1', dtype='fp32', param_dtype='fp32', total_steps=10000, load_llama_config='', update_llama_config='', load_checkpoint='', load_dataset_state='', log_freq=50, save_model_freq=0, save_milestone_freq=0, eval_freq=0, tokenizer=LLaMAConfig.get_tokenizer_config(), train_dataset=DatasetFactory.get_default_config(), eval_dataset=DatasetFactory.get_default_config(), optimizer=OptimizerFactory.get_default_config(), checkpointer=StreamingCheckpointer.get_default_config(), llama=LLaMAConfig.get_default_config(), logger=mlxu.WandBLogger.get_default_config(), log_all_worker=False, ) def main(argv): if FLAGS.initialize_jax_distributed: jax.distributed.initialize() variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF) flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF) logger = mlxu.WandBLogger( config=FLAGS.logger, variant=variant, enable=FLAGS.log_all_worker or (jax.process_index() == 0), ) set_random_seed(FLAGS.seed) tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer) dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer) if FLAGS.load_dataset_state != '': dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state)) if FLAGS.eval_freq > 0: eval_dataset = DatasetFactory.load_dataset( FLAGS.eval_dataset, dataset.tokenizer, eval_dataset=True ) seq_length = dataset.seq_length if FLAGS.load_llama_config != '': llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config) else: llama_config = LLaMAConfig(**FLAGS.llama) if FLAGS.update_llama_config != '': llama_config.update(dict(eval(FLAGS.update_llama_config))) llama_config.update(dict( bos_token_id=dataset.tokenizer.bos_token_id, eos_token_id=dataset.tokenizer.eos_token_id, )) if llama_config.vocab_size < dataset.vocab_size: llama_config.update(dict(vocab_size=dataset.vocab_size)) model = FlaxLLaMAForCausalLMModule( llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype), param_dtype=get_float_dtype_by_name(FLAGS.param_dtype) ) optimizer, optimizer_info = OptimizerFactory.get_optimizer( FLAGS.optimizer, get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions()) ) def create_trainstate_from_params(params): return TrainState.create(params=params, tx=optimizer, apply_fn=None) def init_fn(rng): rng_generator = JaxRNG(rng) params = model.init( input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32), position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32), attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32), rngs=rng_generator(llama_config.rng_keys()), ) return TrainState.create(params=params, tx=optimizer, apply_fn=None) def train_step(train_state, rng, batch): rng_generator = JaxRNG(rng) batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'))) def loss_and_accuracy(params): logits = model.apply( params, batch['input_tokens'], deterministic=False, rngs=rng_generator(llama_config.rng_keys()), ).logits return cross_entropy_loss_and_accuracy( logits, batch['target_tokens'], batch['loss_masks'] ) grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True) (loss, accuracy), grads = grad_fn(train_state.params) train_state = train_state.apply_gradients(grads=grads) metrics = dict( loss=loss, accuracy=accuracy, learning_rate=optimizer_info['learning_rate_schedule'](train_state.step), gradient_norm=global_norm(grads), param_norm=global_norm(train_state.params), ) return train_state, rng_generator(), metrics def eval_step(train_state, rng, batch): rng_generator = JaxRNG(rng) batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'))) logits = model.apply( train_state.params, batch['input_tokens'], deterministic=True, rngs=rng_generator(llama_config.rng_keys()), ).logits loss, accuracy = cross_entropy_loss_and_accuracy( logits, batch['target_tokens'], batch['loss_masks'] ) metrics = dict( eval_loss=loss, eval_accuracy=accuracy, ) return rng_generator(), metrics train_state_shapes = jax.eval_shape(init_fn, next_rng()) train_state_partition = match_partition_rules( LLaMAConfig.get_partition_rules(), train_state_shapes ) shard_fns, gather_fns = make_shard_and_gather_fns( train_state_partition, train_state_shapes ) checkpointer = StreamingCheckpointer( FLAGS.checkpointer, logger.output_dir, enable=jax.process_index() == 0, ) sharded_init_fn = pjit( init_fn, in_shardings=PS(), out_shardings=train_state_partition ) sharded_create_trainstate_from_params = pjit( create_trainstate_from_params, in_shardings=(train_state_partition.params, ), out_shardings=train_state_partition, donate_argnums=(0, ), ) sharded_train_step = pjit( train_step, in_shardings=(train_state_partition, PS(), PS()), out_shardings=(train_state_partition, PS(), PS()), donate_argnums=(0, 1), ) sharded_eval_step = pjit( eval_step, in_shardings=(train_state_partition, PS(), PS()), out_shardings=(PS(), PS()), donate_argnums=(1,), ) def save_checkpoint(train_state, milestone=False): step = int(jax.device_get(train_state.step)) metadata = dict( step=step, variant=variant, flags=flags_config_dict, llama_config=llama_config.to_dict(), ) checkpointer.save_all( train_state=train_state, gather_fns=gather_fns, metadata=metadata, dataset=dataset.get_state_dict(), milestone=milestone, ) mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim) with mesh: train_state, restored_params = None, None if FLAGS.load_checkpoint != '': train_state, restored_params = checkpointer.load_trainstate_checkpoint( FLAGS.load_checkpoint, train_state_shapes, shard_fns ) if train_state is None and restored_params is None: # Initialize from scratch train_state = sharded_init_fn(next_rng()) elif train_state is None and restored_params is not None: # Restore from params but initialize train_state train_state = sharded_create_trainstate_from_params(restored_params) del restored_params start_step = int(jax.device_get(train_state.step)) if FLAGS.save_model_freq > 0: save_checkpoint(train_state) sharded_rng = next_rng() step_counter = trange(start_step, FLAGS.total_steps, ncols=0) for step, (batch, dataset_metrics) in zip(step_counter, dataset): train_state, sharded_rng, metrics = sharded_train_step( train_state, sharded_rng, batch ) if FLAGS.eval_freq > 0 and (step + 1) % FLAGS.eval_freq == 0: eval_metric_list = [] eval_iterator = iter(eval_dataset) for eval_batch, _ in eval_iterator: sharded_rng, eval_metrics = sharded_eval_step( train_state, sharded_rng, eval_batch ) eval_metric_list.append(eval_metrics) metrics.update(average_metrics(eval_metric_list)) if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0: log_metrics = {"step": step} log_metrics.update(metrics) log_metrics.update(dataset_metrics) log_metrics = jax.device_get(log_metrics) logger.log(log_metrics) tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0: save_checkpoint(train_state, milestone=True) elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0: save_checkpoint(train_state) if FLAGS.save_model_freq > 0: save_checkpoint(train_state) if __name__ == "__main__": mlxu.run(main)