Skip to content

Sisha0/WARP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

WARP: On the Benefits of Weight Averaged Rewarded Policies

Original paper: https://arxiv.org/abs/2406.16768.

Report with rexperiments results: link.

Checkpoints folder: https://drive.google.com/drive/folders/1iZ7S603cg9yZ6q1kW0EgSuYKafsG-uCY.

The goal is to verify whether the algorithm works for IMDB reviews dataset. We will fine-tune the model to generate more positive reviews.

How to run

Environment

Python version: 3.12.4 (3.12.* versions should be fine). Dependencies can be install with the command:

pip install -r requirements.txt

To run the experiments, you need to have at least one GPU. All the scripts used for training log into WANDB, so either set the mode to offline (using environment variable WANDB_MODE = 'offline') or log in locally (run wandb login and provide the API key).

Reward model

distilbert-base-cased is used as a reward model. Details about reward modeling can be found in the report. To train a reward model, run the script from the project's root folder:

python -m scripts.train_reward

To see the full list of supported arguments, run:

python -m scripts.train_reward --help

During the later experiments, we did use the reward model from this run.

SFT model

lvwerra/gpt2-imdb is used as a reference model. Because we are going to use LoRA for training during RLHF stage, we can't use plain lvwerra/gpt2-imdb model (because it has no LoRA layers). So before RLHF stage, we fine-tune the model with LoRA for some iterations:

python -m scripts.train_sft_model --fp16 --per_device_train_batch_size=32 --max_steps=4000

To see the full list of supported arguments, run:

python -m scripts.train_reward --help

During the later experiments, we did use the SFT model from this run.

RLHF

WARP

lvwerra/gpt2-imdb is fine-tuned with RLHF to generate more positive reviews. To fine-tune the model, run the script:

python -m scripts.train_policy --learning_rate=0.0001 --max_new_tokens=128 --warmup_steps=20 --num_iterations=3

To see the full list of supported arguments, run:

python -m scripts.train_policy --help

If you did log in to WANDB, the script will log runs into the group specified with --group_name parameter (warp by default). If you run the script again with the same --group_name, all the new runs will go to the same group, so it's recommended to use new group name every time (or delete previous runs from the group). An example experiment is here.

WARP with RLOO

We use RLOO as an alternative to plain REINFORCE used in the original paper. This leads to a better KL/Reward trade-off (see report). The run is here. Run the script to test the results:

python -m scripts.train_policy --save_folder=warp_rloo --group_name=warp_rloo --learning_rate=0.0001 --max_new_tokens=128 --warmup_steps=20 --per_device_train_batch_size=16 --per_device_train_eval_size=4 --num_return_sequences=4

Results

Details about parameters choices and experiments' results are in the report. Plots can be reproduces with the notebook. An example of how to run the code in Kaggle's environment (with 2x Tesla T4) is here.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published