chatlawv1 / trlx /examples /summarize_daily_cnn /t5_summarize_daily_cnn.py
teachyourselfcoding's picture
Upload 245 files
fa6856c
raw
history blame
4.26 kB
from typing import List
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer
import trlx
from trlx.data.configs import (
ModelConfig,
OptimizerConfig,
SchedulerConfig,
TokenizerConfig,
TrainConfig,
TRLConfig,
)
from trlx.models.modeling_ppo import PPOConfig
try:
import evaluate
except ImportError:
raise ImportError(
"To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`"
)
config = TRLConfig(
train=TrainConfig(
seq_length=612,
epochs=100,
total_steps=100000,
batch_size=12,
checkpoint_interval=10000,
eval_interval=500,
pipeline="PromptPipeline",
trainer="AcceleratePPOTrainer",
),
model=ModelConfig(
model_path="google/flan-t5-large",
model_arch_type="seq2seq",
num_layers_unfrozen=2,
),
tokenizer=TokenizerConfig(
tokenizer_path="google/flan-t5-large",
truncation_side="right",
),
optimizer=OptimizerConfig(
name="adamw",
kwargs={
"lr": 1.0e-5,
"betas": [0.9, 0.999],
"eps": 1.0e-8,
"weight_decay": 1.0e-6,
},
),
scheduler=SchedulerConfig(
name="cosine_annealing",
kwargs={
"T_max": 10000,
"eta_min": 1.0e-6,
},
),
method=PPOConfig(
name="PPOConfig",
num_rollouts=512,
chunk_size=12,
ppo_epochs=4,
init_kl_coef=0.05,
target=6,
horizon=10000,
gamma=0.99,
lam=0.95,
cliprange=0.2,
cliprange_value=0.2,
vf_coef=1.0,
scale_reward=None,
ref_mean=None,
ref_std=None,
cliprange_reward=10,
gen_kwargs={
"max_new_tokens": 100,
},
gen_experience_kwargs={
"max_new_tokens": 100,
"do_sample": True,
"temperature": 1.0,
"top_k": 50,
"top_p": 0.95,
},
),
)
meteor = evaluate.load("meteor") # use meteor as the reward function
if __name__ == "__main__":
def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):
original_summaries = [prompt_label[prompt.strip()] for prompt in prompts]
scores = [
meteor.compute(predictions=[output.strip()], references=[original])["meteor"]
for (original, output) in zip(original_summaries, outputs)
]
return scores
dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data")
# take 20,000 samples from the training set as prompts for training
prompts = dataset["train"]["article"][0:20000]
summaries = dataset["train"]["highlights"][0:20000]
prompts = ["Summarize: " + prompt for prompt in prompts]
# take 1,000 samples from the validation set as prompts for evaluation
val_prompts = ["Summarize: " + prompt for prompt in dataset["validation"]["article"][0:1000]]
val_summaries = dataset["validation"]["highlights"][0:1000]
# make dictionary of prompts and labels to use for reward function
tokenizer = AutoTokenizer.from_pretrained(config.model.model_path)
tokenizer.padding_side = "left"
tokenizer.truncation_side = "right"
tokenizer.sep_token = "<sep>"
prompt_label = {}
max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"]
for i in tqdm(range(len(prompts))):
key = tokenizer.decode(
tokenizer(prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"],
skip_special_tokens=True,
) # get prompt like trlx's prompt
prompt_label[key.strip()] = summaries[i]
for i in tqdm(range(len(val_prompts))):
key = tokenizer.decode(
tokenizer(val_prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"],
skip_special_tokens=True,
) # get prompt like trlx's prompt
prompt_label[key.strip()] = val_summaries[i]
trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=val_prompts,
config=config,
)