#!/usr/bin/env bash # # Copyright 2023-2024 PKU-Alignment Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== if [ -z "${BASH_VERSION}" ]; then echo "Please use bash to run this script." >&2 exit 1 fi set -x SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)" ROOT_DIR="$(dirname "${SCRIPT_DIR}")" export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}" export LOGLEVEL="${LOGLEVEL:-WARNING}" ACTOR_MODEL_NAME_OR_PATH="PKU-Alignment/alpaca-7b-reproduced" REWARD_MODEL_NAME_OR_PATH="${ROOT_DIR}/output/rm" COST_MODEL_NAME_OR_PATH="${ROOT_DIR}/output/cm" unset {REWARD,COST}_CRITIC_MODEL_NAME_OR_PATH OUTPUT_DIR="${ROOT_DIR}/output/ppo-lag" unset HOSTFILE ZERO_STAGE=1 OFFLOAD="none" while [[ "$#" -gt 0 ]]; do arg="$1" shift case "${arg}" in --actor_model_name_or_path) ACTOR_MODEL_NAME_OR_PATH="$1" shift ;; --actor_model_name_or_path=*) ACTOR_MODEL_NAME_OR_PATH="${arg#*=}" ;; --reward_model_name_or_path) REWARD_MODEL_NAME_OR_PATH="$1" shift ;; --reward_model_name_or_path=*) REWARD_MODEL_NAME_OR_PATH="${arg#*=}" ;; --reward_critic_model_name_or_path) REWARD_CRITIC_MODEL_NAME_OR_PATH="$1" shift ;; --reward_critic_model_name_or_path=*) REWARD_CRITIC_MODEL_NAME_OR_PATH="${arg#*=}" ;; --cost_model_name_or_path) COST_MODEL_NAME_OR_PATH="$1" shift ;; --cost_model_name_or_path=*) COST_MODEL_NAME_OR_PATH="${arg#*=}" ;; --cost_critic_model_name_or_path) COST_CRITIC_MODEL_NAME_OR_PATH="$1" shift ;; --cost_critic_model_name_or_path=*) COST_CRITIC_MODEL_NAME_OR_PATH="${arg#*=}" ;; --output_dir) OUTPUT_DIR="$1" shift ;; --output_dir=*) OUTPUT_DIR="${arg#*=}" ;; --hostfile) HOSTFILE="$1" shift ;; --hostfile=*) HOSTFILE="${arg#*=}" ;; --zero_stage) ZERO_STAGE="$1" shift ;; --zero_stage=*) ZERO_STAGE="${arg#*=}" ;; --offload) OFFLOAD="$1" shift ;; --offload=*) OFFLOAD="${arg#*=}" ;; *) echo "Unknown parameter passed: '${arg}'" >&2 exit 1 ;; esac done if [[ -z "${REWARD_CRITIC_MODEL_NAME_OR_PATH+x}" ]]; then REWARD_CRITIC_MODEL_NAME_OR_PATH="${REWARD_MODEL_NAME_OR_PATH}" fi if [[ -z "${COST_CRITIC_MODEL_NAME_OR_PATH+x}" ]]; then COST_CRITIC_MODEL_NAME_OR_PATH="${COST_MODEL_NAME_OR_PATH}" fi mkdir -p "${OUTPUT_DIR}" OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then echo '*' >"${OUTPUT_DIR}/.gitignore" fi cp -f "$0" "${OUTPUT_DIR}/script.sh" if [[ -z "${WANDB_API_KEY}" ]]; then export WANDB_MODE="offline" fi MASTER_PORT_START=10000 MASTER_PORT_END=65535 MASTER_PORT="$( comm -23 \ <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \ <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) | shuf | head -n 1 )" DEEPSPEED_ARGS=() if [[ -n "${HOSTFILE+x}" ]]; then DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}") fi DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}") exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2) deepspeed "${DEEPSPEED_ARGS[@]}" \ --master_port "${MASTER_PORT}" \ --module safe_rlhf.algorithms.ppo_lag \ --train_datasets PKU-SafeRLHF/train \ --ptx_datasets alpaca \ --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \ --reward_model_name_or_path "${REWARD_MODEL_NAME_OR_PATH}" \ --reward_critic_model_name_or_path "${REWARD_CRITIC_MODEL_NAME_OR_PATH}" \ --cost_model_name_or_path "${COST_MODEL_NAME_OR_PATH}" \ --cost_critic_model_name_or_path "${COST_CRITIC_MODEL_NAME_OR_PATH}" \ --max_length 512 \ --temperature 1.0 \ --num_return_sequences 1 \ --repetition_penalty 1.0 \ --trust_remote_code True \ --epochs 1 \ --update_iters 1 \ --per_device_prompt_batch_size 16 \ --per_device_train_batch_size 16 \ --gradient_accumulation_steps 1 \ --actor_lr 1e-5 \ --actor_weight_decay 0.01 \ --actor_lr_scheduler_type cosine \ --actor_lr_warmup_ratio 0.03 \ --actor_gradient_checkpointing \ --critic_lr 5e-6 \ --critic_weight_decay 0.0 \ --critic_lr_scheduler_type constant \ --critic_lr_warmup_ratio 0.03 \ --critic_gradient_checkpointing \ --normalize_reward False \ --normalize_cost False \ --seed 42 \ --threshold 0.0 \ --lambda_init 1.0 \ --lambda_lr 0.1 \ --lambda_max 5.0 \ --lambda_update_delay_steps 0 \ --episode_cost_window_size 128 \ --kl_coeff 0.01 \ --clip_range_ratio 0.2 \ --clip_range_score 50.0 \ --clip_range_value 5.0 \ --ptx_coeff 16.0 \ --output_dir "${OUTPUT_DIR}" \ --log_type wandb \ --log_project Safe-RLHF-PPO \ --zero_stage "${ZERO_STAGE}" \ --offload "${OFFLOAD}" \ --fp16 True # --bf16 True \ # --tf32 True