from typing import List from datasets import load_dataset from tqdm import tqdm from transformers import AutoTokenizer import trlx from trlx.data.configs import ( ModelConfig, OptimizerConfig, SchedulerConfig, TokenizerConfig, TrainConfig, TRLConfig, ) from trlx.models.modeling_ppo import PPOConfig try: import evaluate except ImportError: raise ImportError( "To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`" ) config = TRLConfig( train=TrainConfig( seq_length=612, epochs=100, total_steps=100000, batch_size=12, checkpoint_interval=10000, eval_interval=500, pipeline="PromptPipeline", trainer="AcceleratePPOTrainer", ), model=ModelConfig( model_path="google/flan-t5-large", model_arch_type="seq2seq", num_layers_unfrozen=2, ), tokenizer=TokenizerConfig( tokenizer_path="google/flan-t5-large", truncation_side="right", ), optimizer=OptimizerConfig( name="adamw", kwargs={ "lr": 1.0e-5, "betas": [0.9, 0.999], "eps": 1.0e-8, "weight_decay": 1.0e-6, }, ), scheduler=SchedulerConfig( name="cosine_annealing", kwargs={ "T_max": 10000, "eta_min": 1.0e-6, }, ), method=PPOConfig( name="PPOConfig", num_rollouts=512, chunk_size=12, ppo_epochs=4, init_kl_coef=0.05, target=6, horizon=10000, gamma=0.99, lam=0.95, cliprange=0.2, cliprange_value=0.2, vf_coef=1.0, scale_reward=None, ref_mean=None, ref_std=None, cliprange_reward=10, gen_kwargs={ "max_new_tokens": 100, }, gen_experience_kwargs={ "max_new_tokens": 100, "do_sample": True, "temperature": 1.0, "top_k": 50, "top_p": 0.95, }, ), ) meteor = evaluate.load("meteor") # use meteor as the reward function if __name__ == "__main__": def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): original_summaries = [prompt_label[prompt.strip()] for prompt in prompts] scores = [ meteor.compute(predictions=[output.strip()], references=[original])["meteor"] for (original, output) in zip(original_summaries, outputs) ] return scores dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data") # take 20,000 samples from the training set as prompts for training prompts = dataset["train"]["article"][0:20000] summaries = dataset["train"]["highlights"][0:20000] prompts = ["Summarize: " + prompt for prompt in prompts] # take 1,000 samples from the validation set as prompts for evaluation val_prompts = ["Summarize: " + prompt for prompt in dataset["validation"]["article"][0:1000]] val_summaries = dataset["validation"]["highlights"][0:1000] # make dictionary of prompts and labels to use for reward function tokenizer = AutoTokenizer.from_pretrained(config.model.model_path) tokenizer.padding_side = "left" tokenizer.truncation_side = "right" tokenizer.sep_token = "" prompt_label = {} max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] for i in tqdm(range(len(prompts))): key = tokenizer.decode( tokenizer(prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], skip_special_tokens=True, ) # get prompt like trlx's prompt prompt_label[key.strip()] = summaries[i] for i in tqdm(range(len(val_prompts))): key = tokenizer.decode( tokenizer(val_prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], skip_special_tokens=True, ) # get prompt like trlx's prompt prompt_label[key.strip()] = val_summaries[i] trlx.train( reward_fn=reward_fn, prompts=prompts, eval_prompts=val_prompts, config=config, )