Original paper: https://arxiv.org/abs/2406.16768.
Report with rexperiments results: link.
Checkpoints folder: https://drive.google.com/drive/folders/1iZ7S603cg9yZ6q1kW0EgSuYKafsG-uCY.
The goal is to verify whether the algorithm works for IMDB reviews dataset. We will fine-tune the model to generate more positive reviews.
Python version: 3.12.4 (3.12.* versions should be fine). Dependencies can be install with the command:
pip install -r requirements.txt
To run the experiments, you need to have at least one GPU. All the scripts used for training log into WANDB, so either set the mode to offline (using environment variable WANDB_MODE = 'offline'
) or log in locally (run wandb login
and provide the API key).
distilbert-base-cased is used as a reward model. Details about reward modeling can be found in the report. To train a reward model, run the script from the project's root folder:
python -m scripts.train_reward
To see the full list of supported arguments, run:
python -m scripts.train_reward --help
During the later experiments, we did use the reward model from this run.
lvwerra/gpt2-imdb is used as a reference model. Because we are going to use LoRA for training during RLHF stage, we can't use plain lvwerra/gpt2-imdb model (because it has no LoRA layers). So before RLHF stage, we fine-tune the model with LoRA for some iterations:
python -m scripts.train_sft_model --fp16 --per_device_train_batch_size=32 --max_steps=4000
To see the full list of supported arguments, run:
python -m scripts.train_reward --help
During the later experiments, we did use the SFT model from this run.
lvwerra/gpt2-imdb is fine-tuned with RLHF to generate more positive reviews. To fine-tune the model, run the script:
python -m scripts.train_policy --learning_rate=0.0001 --max_new_tokens=128 --warmup_steps=20 --num_iterations=3
To see the full list of supported arguments, run:
python -m scripts.train_policy --help
If you did log in to WANDB, the script will log runs into the group specified with --group_name
parameter (warp by default). If you run the script again with the same --group_name
, all the new runs will go to the same group, so it's recommended to use new group name every time (or delete previous runs from the group). An example experiment is here.
We use RLOO as an alternative to plain REINFORCE used in the original paper. This leads to a better KL/Reward trade-off (see report). The run is here. Run the script to test the results:
python -m scripts.train_policy --save_folder=warp_rloo --group_name=warp_rloo --learning_rate=0.0001 --max_new_tokens=128 --warmup_steps=20 --per_device_train_batch_size=16 --per_device_train_eval_size=4 --num_return_sequences=4
Details about parameters choices and experiments' results are in the report. Plots can be reproduces with the notebook. An example of how to run the code in Kaggle's environment (with 2x Tesla T4) is here.