Skip to content

Commit

Permalink
[NEW Feature] 新增基于hook的refined_recompute支持 (#9396)
Browse files Browse the repository at this point in the history
* 代码实现refined_recompute

* 更新rr的实现,新增单测测试pp和非pp

* update llama and support refined recompute

* update rr

* update

* update create_skip_config_for_refined_recompute config.num_hidden_layers

* update llama pp recompute

* refined recompute only support recompute_use_reentrant=False

* LOD_TENSOR

* typo

* rr 支持qwen模型

* support RRColumnParallelLinear & RRRowParallelLinear

* fix

* update llm test

* fix

* update test_refined_recompute
  • Loading branch information
JunnYu authored Nov 27, 2024
1 parent f5ca96e commit b1466d7
Show file tree
Hide file tree
Showing 14 changed files with 1,585 additions and 15 deletions.
5 changes: 5 additions & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.transformers.refined_recompute import update_refined_recompute
from paddlenlp.trl import SFTTrainer
from paddlenlp.trl.llm_utils import (
ZeroPaddingIterDatasetCallback,
Expand Down Expand Up @@ -146,6 +147,10 @@ def main():
)

LlmMetaConfig.set_llm_config(model_config, training_args)
model_config.refined_recompute = update_refined_recompute(
training_args.refined_recompute,
model_args.lora,
)
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm

# Config for model using dropout, such as GPT.
Expand Down
4 changes: 4 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass
from paddlenlp.transformers.refined_recompute import update_refined_recompute
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device
Expand Down Expand Up @@ -413,6 +414,9 @@ def main():
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
# set all llm config
LlmMetaConfig.set_llm_config(config, training_args)
config.refined_recompute = update_refined_recompute(
training_args.refined_recompute,
)
config.use_fast_layer_norm = model_args.use_fast_layer_norm

config.seq_length = data_args.max_seq_length
Expand Down
8 changes: 8 additions & 0 deletions paddlenlp/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ class LlmMetaConfig:
"Recompute granularity, Choose among ['full', 'core_attn', 'full_attn']",
),
("recompute_use_reentrant", bool, False, "recompute_use_reentrant"),
# refined_recompute attributes
(
"refined_recompute",
str,
"",
"refined_recompute, Choose from 'mlp_row_ln', 'mlp_column_ln', 'attention_row_ln', 'attention_column_ln', 'flash_attn']",
),
("skip_recompute_ops", Optional[Dict[str, int]], None, "skip_recompute_ops"),
]

@classmethod
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def swiglu(x, y=None):
except:
flash_attention = None

from paddlenlp.transformers.refined_recompute import no_recompute
from paddlenlp.transformers.ring_flash_attention import RingFlashAttention


Expand Down Expand Up @@ -174,6 +175,7 @@ def fusion_flash_attention(
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
skip_recompute=False,
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape
Expand Down Expand Up @@ -257,28 +259,34 @@ def fusion_flash_attention(
attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1)

if hasattr(F, "flashmask_attention"):
attn_output = F.flashmask_attention(
attn_output = no_recompute(
F.flashmask_attention,
query_states,
key_states,
value_states,
startend_row_indices=attn_mask_startend_row_indices.unsqueeze(-1),
causal=True,
enable=skip_recompute,
)
else:
attn_output = F.flash_attention_with_sparse_mask(
attn_output = no_recompute(
F.flash_attention_with_sparse_mask,
query_states,
key_states,
value_states,
attn_mask_start_row_indices=attn_mask_startend_row_indices,
is_causal=True,
enable=skip_recompute,
)
else:
attn_output = F.scaled_dot_product_attention(
attn_output = no_recompute(
F.scaled_dot_product_attention,
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=query_states.shape[1] != 1,
enable=skip_recompute,
)
attn_weights = None

Expand Down
54 changes: 52 additions & 2 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@
from paddle.autograd import PyLayer
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.refined_recompute import (
RRColumnParallelLinear,
RRColumnSequenceParallelLinear,
RRRowParallelLinear,
RRRowSequenceParallelLinear,
create_skip_config_for_refined_recompute,
recompute,
)

try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
Expand Down Expand Up @@ -216,6 +224,7 @@ def scaled_dot_product_attention(
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
skip_recompute=False,
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape
Expand All @@ -233,6 +242,7 @@ def scaled_dot_product_attention(
sequence_parallel,
reshard_layer,
npu_is_casual,
skip_recompute=skip_recompute,
)

# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
Expand Down Expand Up @@ -605,10 +615,24 @@ def __init__(self, config):
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
if config.fuse_attention_ffn:
self.gate_up_fused_proj = ColumnParallelLinear(
Expand Down Expand Up @@ -719,9 +743,22 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
if self.fuse_attention_qkv:
Expand Down Expand Up @@ -821,6 +858,14 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):

self.attn_func = scaled_dot_product_attention

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if (
config.recompute
and not config.recompute_use_reentrant
and config.skip_recompute_ops.get("flash_attn", False)
):
self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True)

def _init_rope(self):
if (
hasattr(self.config, "rope_scaling")
Expand Down Expand Up @@ -1471,7 +1516,12 @@ def __init__(self, config: LlamaConfig):
)

self.layers = nn.LayerList(
[LlamaDecoderLayer(config, i not in self.no_recompute_layers) for i in range(config.num_hidden_layers)]
[
LlamaDecoderLayer(
create_skip_config_for_refined_recompute(i, config), i not in self.no_recompute_layers
)
for i in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config)

Expand Down
11 changes: 9 additions & 2 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
PipelineLayer,
SharedLayerDesc,
)
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.model_utils import PipelinePretrainedModel
from paddlenlp.transformers.refined_recompute import (
create_skip_config_for_refined_recompute,
recompute,
)
from paddlenlp.utils.tools import get_env_device

from .modeling import (
Expand Down Expand Up @@ -371,7 +374,11 @@ def get_hcg():

for i in range(config.num_hidden_layers):
self.add_sequential_layer(
LayerDesc(LlamaDecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers),
LayerDesc(
LlamaDecoderLayerPipe,
config=create_skip_config_for_refined_recompute(i, config),
layerwise_recompute=i not in self.no_recompute_layers,
),
f"llama.layers.{i}",
)
self.add_sequential_layer(LayerDesc(LlamaRMSNormPipe, config=config), "llama")
Expand Down
48 changes: 45 additions & 3 deletions paddlenlp/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@
from paddle import Tensor, nn
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker
from paddle.distributed.fleet.utils import recompute
from paddle.utils import try_import

from paddlenlp.transformers.refined_recompute import (
RRColumnParallelLinear,
RRColumnSequenceParallelLinear,
RRRowParallelLinear,
RRRowSequenceParallelLinear,
create_skip_config_for_refined_recompute,
no_recompute,
recompute,
)

try:
from paddle.incubate.nn.functional import swiglu
except ImportError:
Expand Down Expand Up @@ -154,9 +163,22 @@ def __init__(self, config):
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
if config.num_attention_heads % config.tensor_parallel_degree != 0:
Expand Down Expand Up @@ -227,12 +249,19 @@ def _attn(self, query, key, value, attention_mask=None):
return_softmax=self.config.attn_dropout_prob > 0.0,
)
else:
attn_output = F.scaled_dot_product_attention(
skip_recompute = (
self.config.recompute
and not self.config.recompute_use_reentrant
and self.config.skip_recompute_ops.get("flash_attn", False)
)
attn_output = no_recompute(
F.scaled_dot_product_attention,
query,
key,
value,
attn_mask=attention_mask,
is_causal=attention_mask is None,
enable=skip_recompute,
)
attn_weights = None

Expand Down Expand Up @@ -388,9 +417,22 @@ def __init__(self, config):
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
if self.fuse_attention_ffn:
Expand Down Expand Up @@ -684,7 +726,7 @@ def __init__(self, config):
self.h = nn.LayerList(
[
QWenBlock(
config,
create_skip_config_for_refined_recompute(i, config),
)
for i in range(config.num_hidden_layers)
]
Expand Down
5 changes: 4 additions & 1 deletion paddlenlp/transformers/qwen/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer

from paddlenlp.transformers.model_utils import PipelinePretrainedModel
from paddlenlp.transformers.refined_recompute import (
create_skip_config_for_refined_recompute,
)

from .modeling import (
QWenBlock,
Expand Down Expand Up @@ -170,7 +173,7 @@ def get_hcg():
self.add_sequential_layer(LayerDesc(QWenEmbeddingPipe, config=config), "qwen")
for i in range(config.num_hidden_layers):
self.add_sequential_layer(
LayerDesc(QWenBlockPipe, config=config),
LayerDesc(QWenBlockPipe, config=create_skip_config_for_refined_recompute(i, config)),
f"qwen.h.{i}",
)
self.add_sequential_layer(LayerDesc(QWenRMSNormPipe, config=config), "qwen.ln_f")
Expand Down
Loading

0 comments on commit b1466d7

Please sign in to comment.