Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NEW Feature] 新增基于hook的refined_recompute支持 #9396

Merged
merged 20 commits into from
Nov 27, 2024

Conversation

JunnYu
Copy link
Member

@JunnYu JunnYu commented Nov 8, 2024

PR types

New features

PR changes

APIs

Description

  • 当前仅支持llamaqwenqwen2模型
  • refined_recompute 支持 mlp_row_ln,attention_row_ln,attention_column_ln,mlp_column_ln,flash_attn这些算子,其中LoRA训练的时候不支持*_ln, 仅支持flash_attn
  • 测试llama模型:meta-llama/Meta-Llama-3-8B.
  • 当前refined_recompute,仅限在recompute_use_reentrant=False的时候生效,其他情况不生效。
1. 简单测试refined_recompute代码
paddle.seed(2024)
from paddle.nn.functional.flash_attention import flashmask_attention

dtype = "float16"
paddle.set_default_dtype(dtype)

in_weight_shape = (32, 3 * 2 * 32)
linear1 = paddle.nn.Linear(
    in_weight_shape[0],
    in_weight_shape[-1],
)
paddle.seed(2024)
in_weight = paddle.create_parameter(shape=in_weight_shape, dtype=dtype, name="in_weight")
in_weight.set_value(paddle.normal(0, 0.02, in_weight_shape))
in_weight.main_grad = paddle.normal(0, 0.02, in_weight.shape).cast("float32")
linear1.weight.set_value(in_weight)
in_bias = paddle.create_parameter(shape=(in_weight.shape[-1],), dtype=dtype, name="in_bias", is_bias=True)
in_bias.main_grad = paddle.normal(0, 0.02, in_bias.shape).cast("float32")
linear1.bias.set_value(in_bias)
linear1.weight.main_grad = in_weight.main_grad
linear1.bias.main_grad = in_bias.main_grad

out_weight_shape = (2 * 32, 32)
out_weight = paddle.create_parameter(shape=out_weight_shape, dtype=dtype, name="out_weight")
out_weight.set_value(paddle.normal(0, 0.02, out_weight_shape))
out_weight.main_grad = paddle.normal(0, 0.02, out_weight.shape).cast("float32")

class cus_multiply(paddle.autograd.PyLayer):
    @staticmethod
    def forward(ctx, a, b):
        y = paddle.multiply(a, b)
        ctx.save_for_backward(a, b)
        return y

    @staticmethod
    def backward(ctx, dy):
        a, b = ctx.saved_tensor()
        grad_a = dy * a
        grad_b = dy * b
        return grad_a, grad_b

multiply = cus_multiply.apply

def fwd(x, startend_row_indices, enable=True):
    def fwd_linear(x):
        weight = multiply(linear1.weight, linear1.weight * 0.1)
        bias = multiply(linear1.bias, linear1.bias * 0.1)
        qkv = paddle.nn.functional.silu(paddle.nn.functional.linear(x, weight, bias))
        q, k, v = paddle.chunk(qkv, 3, axis=-1)
        q = q.reshape([q.shape[0], q.shape[1], 2, q.shape[2] // 2])
        k = k.reshape([k.shape[0], k.shape[1], 2, v.shape[2] // 2])
        v = v.reshape([v.shape[0], k.shape[1], 2, v.shape[2] // 2])
        return q, k, v

    q, k, v = no_recompute(fwd_linear, x, enable=enable)

    q, k, v = q * q, k * k, v * v
    out = no_recompute(
        flashmask_attention,
        q,
        k,
        v,
        startend_row_indices=startend_row_indices,
        causal=True,
        enable=enable,
    )
    out = out.flatten(-2, -1)
    out = paddle.matmul(out, out_weight)
    return out

x = paddle.normal(0, 0.02, (1, 128, 32))
x.stop_gradient = False
x_input = x
startend_row_indices = paddle.randint(0, 128, (1, 2, 128, 1), dtype="int32")

enable = True
# 第一层
o1 = recompute(fwd, x, startend_row_indices, enable=enable)
# 第二层
o2 = recompute(fwd, o1 + x, startend_row_indices, enable=enable)
# 第三层
o3 = recompute(fwd, o2 + x, startend_row_indices, enable=enable)

o3.sum().backward()
print(x_input.grad.mean())
print(linear1.weight.grad.mean())
print(out_weight.grad.mean())
2. llama模型SFT 8k下精度对比(开启rr和关闭rr) 结论:10步的loss完全一致,精度一致,符合预期。

2.1 【精度】关闭 refined_recompute
image

2.2 【精度】开启 refined_recompute "flash_attn:-1"
image

3. llama模型SFT 8k下性能对比(开启rr和关闭rr) 结论:第二步ips, 1.1894 / 1.1636 = 1.022,约有 2.21%的提速

3.1 【性能】关闭 refined_recompute
image

3.2 【性能】开启 refined_recompute "flash_attn:-1"
image

4.测试PP精度, 对比不开recompute, 标准recompute,RR的recompute

image

image

export NVIDIA_TF32_OVERRIDE=0
export FLAGS_embedding_deterministic=1
export FLAGS_cudnn_deterministic=1

recompute=1
output_dir=ckpt_pp_with_rr

python -u  -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir ${output_dir}/logs run_finetune.py \
    ./config/llama/sft_argument.json \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --max_length 1024 \
    --tensor_parallel_degree 2 \
    --pipeline_parallel_degree 4 \
    --virtual_pp_degree 4 \
    --sharding stage1 \
    --sharding_parallel_degree 1 \
    --pipeline_parallel_config "disable_partial_send_recv enable_delay_scale_loss enable_sharding_comm_overlap enable_overlap_p2p_comm" \
    --recompute ${recompute} \
    --recompute_granularity full \
    --refined_recompute "mlp_row_ln:-1,attention_row_ln:-1,attention_column_ln:-1,mlp_column_ln:-1,flash_attn:-1" \
    --sequence_parallel 1 \
    --zero_padding 1 \
    --use_flash_attention 1 \
    --max_steps 20 \
    --autotuner_benchmark 0 \
    --save_strategy no \
    --evaluation_strategy no \
    --output_dir ${output_dir}
5. llama 16k PostPretrain性能对比代码

速度提升约6~7%,当前由于没有添加fused head loss, 导致无法训练32k,64k配置,理论上提速能更多(超过10%)。
image
image

export NVIDIA_TF32_OVERRIDE=0

wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.bin
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.idx
mkdir -p data
mv llama_openwebtext_100k.bin ./data
mv llama_openwebtext_100k.idx ./data
cd ../slm/model_zoo/gpt-3/external_ops/ && python3 setup.py install && cd -

recompute=1
output_dir=ppt/ckpt_pp_w_rr_recompute


python -u  -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir ${output_dir}/logs run_pretrain.py \
    ./config/llama/pretrain_argument.json \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --max_seq_length 16384 \
    --tensor_parallel_degree 1 \
    --pipeline_parallel_degree 1 \
    --virtual_pp_degree 1 \
    --sharding stage2 \
    --sharding_parallel_degree 8 \
    --recompute 1 \
    --recompute_granularity full \
    --dataloader_num_workers 4 \
    --recompute ${recompute} \
    --recompute_granularity full \
    --refined_recompute "flash_attn:-1" \
    --use_flash_attention 1 \
    --max_steps 20 \
    --autotuner_benchmark 0 \
    --save_strategy no \
    --evaluation_strategy no \
    --output_dir ${output_dir}
6 新增tp+sp与tp的对比

image
image

Copy link

paddle-bot bot commented Nov 8, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented Nov 8, 2024

CLA assistant check
All committers have signed the CLA.

Copy link

codecov bot commented Nov 8, 2024

Codecov Report

Attention: Patch coverage is 49.88290% with 214 lines in your changes missing coverage. Please review.

Project coverage is 52.93%. Comparing base (d6d181b) to head (418a259).
Report is 4 commits behind head on develop.

Current head 418a259 differs from pull request most recent head 9f5e306

Please upload reports for the commit 9f5e306 to get more accurate results.

Files with missing lines Patch % Lines
paddlenlp/transformers/refined_recompute.py 56.41% 153 Missing ⚠️
paddlenlp/transformers/qwen/modeling.py 13.04% 20 Missing ⚠️
paddlenlp/transformers/llama/modeling.py 17.39% 19 Missing ⚠️
paddlenlp/transformers/qwen2/modeling.py 17.39% 19 Missing ⚠️
paddlenlp/transformers/llama/fusion_ops.py 25.00% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9396      +/-   ##
===========================================
+ Coverage    52.84%   52.93%   +0.09%     
===========================================
  Files          688      689       +1     
  Lines       109378   109796     +418     
===========================================
+ Hits         57801    58121     +320     
- Misses       51577    51675      +98     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

pylayer_matmul = PyLayerMatmul.apply


class BertConfig:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么不直接在bert中搞

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

主要为了测试,在bert中没有flash attn

@JunnYu JunnYu changed the title [Draft] 新增refined_recompute支持 [Draft] 新增基于hook的refined_recompute支持 Nov 18, 2024
@DesmonDay DesmonDay marked this pull request as ready for review November 19, 2024 03:03
@JunnYu JunnYu changed the title [Draft] 新增基于hook的refined_recompute支持 [NEW Feature] 新增基于hook的refined_recompute支持 Nov 19, 2024
@@ -51,6 +51,7 @@ def swiglu(x, y=None):
except:
flash_attention = None

from paddlenlp.transformers.refined_recompute import no_recompute
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要叫no_recompute,感觉怪怪的

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要么改成skip_recompute也行

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recompute(func, xxxxx) vs no_recompute(func, xxxxxx)

@ZHUI
Copy link
Collaborator

ZHUI commented Nov 25, 2024

再适配一下qwen模型吧。

@JunnYu
Copy link
Member Author

JunnYu commented Nov 25, 2024

@ZHUI 已经支持qwen和qwen2

@@ -268,6 +268,14 @@ class LlmMetaConfig:
"Recompute granularity, Choose among ['full', 'core_attn', 'full_attn']",
),
("recompute_use_reentrant", bool, False, "recompute_use_reentrant"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的配置信息会传到下游任务里面吗?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要 _set_unsavable_keys 吗?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要,这个zhonghui比较清楚用法,我看了一下实现可以满足需求。1是加了llmmetaclass,2是LlmMetaConfig.set_llm_config(model_config, training_args)
@DataClass
@llmmetaclass
@add_start_docstrings(TrainingArguments.doc)
class TrainingArguments(TrainingArguments):

return output


class RRRowSequenceParallelLinear(RowSequenceParallelLinear):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于RowParallelLinear不用重写代码,但是RRRowSequenceParallelLinear需要重新写代码了?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前没有支持非SequenceParallel的并行,当然也可以支持看看

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DONE,已经支持

wawltor
wawltor previously approved these changes Nov 26, 2024
Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit b1466d7 into PaddlePaddle:develop Nov 27, 2024
9 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants