From 5217a3b79524a30ea37c75d37eda1a257052e22d Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:34:29 +0800 Subject: [PATCH] [Inference] Append attn FP8 quant (#9328) * add fp8 gen files to gitignore * append_attn support fp8 quant * Unified FP8 Network * include cuda_fp8.h * simplify qwen2 network and FusedBlockMultiTransformerFP8 * simplify llama network and code check * check fp8 params * code check * check * default config for fp8 gemm --- .gitignore | 3 + csrc/gpu/append_attention.cu | 287 ++++-- .../append_attn/append_attention_c16_impl.cuh | 34 + .../append_attn/append_attention_c4_impl.cuh | 34 + .../append_attn/append_attention_c8_impl.cuh | 34 + .../gpu/append_attn/append_attention_func.cuh | 49 +- .../gpu/append_attn/append_attention_kernel.h | 14 + ..._attention_c16_bfloat16_bfloat16_kernel.cu | 2 + ...ppend_attention_c16_bfloat16_fp8_kernel.cu | 58 ++ ...pend_attention_c16_bfloat16_int8_kernel.cu | 2 + ...nd_attention_c16_float16_float16_kernel.cu | 2 + ...append_attention_c16_float16_fp8_kernel.cu | 58 ++ ...ppend_attention_c16_float16_int8_kernel.cu | 2 + ...d_attention_c4_bfloat16_bfloat16_kernel.cu | 2 + ...append_attention_c4_bfloat16_fp8_kernel.cu | 58 ++ ...ppend_attention_c4_bfloat16_int8_kernel.cu | 2 + ...end_attention_c4_float16_float16_kernel.cu | 2 + .../append_attention_c4_float16_fp8_kernel.cu | 58 ++ ...append_attention_c4_float16_int8_kernel.cu | 2 + ...d_attention_c8_bfloat16_bfloat16_kernel.cu | 2 + ...append_attention_c8_bfloat16_fp8_kernel.cu | 58 ++ ...ppend_attention_c8_bfloat16_int8_kernel.cu | 2 + ...end_attention_c8_float16_float16_kernel.cu | 2 + .../append_attention_c8_float16_fp8_kerne.cu | 58 ++ .../append_attention_c8_float16_int8_kerne.cu | 2 + csrc/gpu/append_attn/utils.cuh | 8 +- csrc/setup_cuda.py | 7 +- ...uto_gen_fp8_fp8_dual_gemm_fused_kernels.py | 52 +- .../auto_gen_fp8_fp8_gemm_fused_kernels.py | 52 +- .../transformers/fused_transformer_layers.py | 658 ++++--------- .../transformers/llama/modeling.py | 878 +++++++---------- .../transformers/qwen2/modeling.py | 896 ++++++------------ 32 files changed, 1656 insertions(+), 1722 deletions(-) create mode 100644 csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_fp8_kernel.cu create mode 100644 csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu create mode 100644 csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu create mode 100644 csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu create mode 100644 csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu create mode 100644 csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu diff --git a/.gitignore b/.gitignore index 8ac817c65d68..386183b26a39 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,6 @@ FETCH_HEAD csrc/third_party/ dataset/ output/ + +# gen codes +autogen/ \ No newline at end of file diff --git a/csrc/gpu/append_attention.cu b/csrc/gpu/append_attention.cu index badf28d7676c..f80f8cee5d3d 100644 --- a/csrc/gpu/append_attention.cu +++ b/csrc/gpu/append_attention.cu @@ -56,6 +56,8 @@ std::vector AppendAttentionKernel( const std::string& cache_quant_type_str, const bool use_neox_rotary_style, const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, const float out_linear_in_scale, const int encoder_block_shape_q, const int decoder_block_shape_q, @@ -77,7 +79,7 @@ std::vector AppendAttentionKernel( auto main_stream = qkv.stream(); static cudaEvent_t main_event; - static cudaEvent_t decoder_event; + static cudaEvent_t decoder_event; static cudaStream_t decoder_stream; static bool init_flag = false; if (max_enc_len_this_time_data > 0 && max_dec_len_this_time_data > 0 && @@ -96,10 +98,20 @@ std::vector AppendAttentionKernel( } paddle::Tensor fmha_out; if (out_linear_in_scale > 0.0) { - fmha_out = GetEmptyTensor( + if (fabs(quant_max_bound - 127.0f) < 0.000001) { + fmha_out = GetEmptyTensor( {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, paddle::DataType::INT8, qkv.place()); + } + else if (fabs(quant_max_bound - 448.0f) < 0.000001) { + fmha_out = GetEmptyTensor( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + paddle::DataType::FLOAT8_E4M3FN, + qkv.place()); + }else{ + PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); + } } else { fmha_out = GetEmptyTensor( {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, @@ -167,40 +179,90 @@ std::vector AppendAttentionKernel( const_cast(&value_cache)); } if (out_linear_in_scale > 0.0) { - CascadeAppendAttentionKernel( - meta_data, - qkv_out, - key_cache, - value_cache, - attn_mask, - cache_k_dequant_scales, - cache_v_dequant_scales, - cache_k_zp, - cache_v_zp, - out_linear_shifts, - out_linear_smooths, - seq_lens_this_time, - seq_lens_decoder, - seq_lens_encoder, - padding_offsets, - cum_offsets, - block_tables, - encoder_batch_ids, - encoder_tile_ids_per_batch, - cache_quant_type_str, - encoder_num_blocks_data, - encoder_block_shape_q, - max_input_length, - max_enc_len_this_time_data, - out_linear_in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - false, - true, - main_stream, - &fmha_out); + switch (fmha_out.dtype()) { + case paddle::DataType::INT8:{ + CascadeAppendAttentionKernel( + meta_data, + qkv_out, + key_cache, + value_cache, + attn_mask, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + cache_quant_type_str, + encoder_num_blocks_data, + encoder_block_shape_q, + max_input_length, + max_enc_len_this_time_data, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + false, + true, + main_stream, + &fmha_out); + break; + } + case paddle::DataType::FLOAT8_E4M3FN:{ + CascadeAppendAttentionKernel( + meta_data, + qkv_out, + key_cache, + value_cache, + attn_mask, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + cache_quant_type_str, + encoder_num_blocks_data, + encoder_block_shape_q, + max_input_length, + max_enc_len_this_time_data, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + false, + true, + main_stream, + &fmha_out); + break; + } + default:{ + PD_THROW("Only supported output fmha_out of quant dtype in ['int8', 'fp8_e4m3']."); + break; + } + } } else { CascadeAppendAttentionKernel( meta_data, @@ -227,6 +289,8 @@ std::vector AppendAttentionKernel( encoder_block_shape_q, max_input_length, max_enc_len_this_time_data, + quant_max_bound, + quant_min_bound, out_linear_in_scale, max_partition_size, encoder_max_partition_size, @@ -346,40 +410,91 @@ std::vector AppendAttentionKernel( } if (out_linear_in_scale > 0.0) { - CascadeAppendAttentionKernel( - meta_data, - qkv_out, - key_cache, - value_cache, - attn_mask, - cache_k_dequant_scales, - cache_v_dequant_scales, - cache_k_zp, - cache_v_zp, - out_linear_shifts, - out_linear_smooths, - seq_lens_this_time, - seq_lens_decoder, - seq_lens_encoder, - padding_offsets, - cum_offsets, - block_tables, - decoder_batch_ids, - decoder_tile_ids_per_batch, - cache_quant_type_str, - decoder_num_blocks_data, - decoder_block_shape_q, - max_input_length, - max_len_kv_data, - out_linear_in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - !speculate_decoder, - !speculate_decoder, - exec_stream, - &fmha_out); + switch (fmha_out.dtype()) { + case paddle::DataType::INT8:{ + CascadeAppendAttentionKernel( + meta_data, + qkv_out, + key_cache, + value_cache, + attn_mask, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + decoder_batch_ids, + decoder_tile_ids_per_batch, + cache_quant_type_str, + decoder_num_blocks_data, + decoder_block_shape_q, + max_input_length, + max_len_kv_data, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + !speculate_decoder, + !speculate_decoder, + exec_stream, + &fmha_out); + break; + } + case paddle::DataType::FLOAT8_E4M3FN:{ + CascadeAppendAttentionKernel( + meta_data, + qkv_out, + key_cache, + value_cache, + attn_mask, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + decoder_batch_ids, + decoder_tile_ids_per_batch, + cache_quant_type_str, + decoder_num_blocks_data, + decoder_block_shape_q, + max_input_length, + max_len_kv_data, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + !speculate_decoder, + !speculate_decoder, + exec_stream, + &fmha_out); + break; + } + default:{ + PD_THROW("Only supported output fmha_out of quant dtype in ['int8', 'fp8_e4m3']."); + break; + } + } + } else { CascadeAppendAttentionKernel( meta_data, @@ -406,6 +521,8 @@ std::vector AppendAttentionKernel( decoder_block_shape_q, max_input_length, max_len_kv_data, + quant_max_bound, + quant_min_bound, out_linear_in_scale, max_partition_size, encoder_max_partition_size, @@ -463,6 +580,8 @@ std::vector AppendAttention( const std::string& cache_quant_type_str, const bool use_neox_rotary_style, const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, const float out_linear_in_scale, const int encoder_block_shape_q, const int decoder_block_shape_q, @@ -526,6 +645,8 @@ std::vector AppendAttention( cache_quant_type_str, use_neox_rotary_style, max_input_length, + quant_max_bound, + quant_min_bound, out_linear_in_scale, encoder_block_shape_q, decoder_block_shape_q, @@ -574,6 +695,8 @@ std::vector AppendAttention( cache_quant_type_str, use_neox_rotary_style, max_input_length, + quant_max_bound, + quant_min_bound, out_linear_in_scale, encoder_block_shape_q, decoder_block_shape_q, @@ -623,6 +746,8 @@ std::vector AppendAttention( cache_quant_type_str, use_neox_rotary_style, max_input_length, + quant_max_bound, + quant_min_bound, out_linear_in_scale, encoder_block_shape_q, decoder_block_shape_q, @@ -670,6 +795,8 @@ std::vector AppendAttention( cache_quant_type_str, use_neox_rotary_style, max_input_length, + quant_max_bound, + quant_min_bound, out_linear_in_scale, encoder_block_shape_q, decoder_block_shape_q, @@ -773,6 +900,8 @@ std::vector AppendAttentionInferDtype( const std::string& cache_quant_type_str, const bool use_neox_rotary_style, const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, const float out_linear_in_scale, const int encoder_block_shape_q, const int decoder_block_shape_q, @@ -783,13 +912,25 @@ std::vector AppendAttentionInferDtype( const bool speculate_decoder) { if (compute_dtype == "bf16") { if (out_linear_in_scale > 0.0) { - return {paddle::DataType::INT8, paddle::DataType::BFLOAT16}; + if (fabs(quant_max_bound - 127.0f) < 0.000001) { + return {paddle::DataType::INT8, paddle::DataType::BFLOAT16}; + } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { + return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::BFLOAT16}; + }else{ + PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); + } } else { return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16}; } } else if (compute_dtype == "fp16") { if (out_linear_in_scale > 0.0) { - return {paddle::DataType::INT8, paddle::DataType::FLOAT16}; + if (fabs(quant_max_bound - 127.0f) < 0.000001) { + return {paddle::DataType::INT8, paddle::DataType::FLOAT16}; + } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { + return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT16}; + }else{ + PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); + } } else { return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16}; } @@ -839,6 +980,8 @@ PD_BUILD_OP(append_attention) "cache_quant_type: std::string", "use_neox_rotary_style: bool", "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", "out_linear_in_scale: float", "encoder_block_shape_q: int", "decoder_block_shape_q: int", diff --git a/csrc/gpu/append_attn/append_attention_c16_impl.cuh b/csrc/gpu/append_attn/append_attention_c16_impl.cuh index 51fa56934dc7..ed181836d73c 100644 --- a/csrc/gpu/append_attn/append_attention_c16_impl.cuh +++ b/csrc/gpu/append_attn/append_attention_c16_impl.cuh @@ -47,6 +47,8 @@ __global__ void multi_query_append_attention_kernel( const int max_dec_len, const int max_block_num_per_seq, const float scale, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const uint32_t chunk_size, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, @@ -321,6 +323,8 @@ __global__ void multi_query_append_attention_kernel( smooth_weight, q_base_seq_id_this_block, q_head_idx, + quant_max_bound, + quant_min_bound, in_scale, q_len, partition_kv ? q_n_stride * num_chunks : q_n_stride, @@ -337,6 +341,8 @@ __global__ void multi_query_append_attention_kernel( smooth_weight, q_base_seq_id_this_block, q_head_idx, + quant_max_bound, + quant_min_bound, in_scale, q_len, partition_kv ? q_n_stride * num_chunks : q_n_stride, @@ -405,6 +411,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const int max_dec_len, const int max_block_num_per_seq, const float scale, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const uint32_t chunk_size, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, @@ -688,6 +696,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel( smooth_weight, q_base_seq_id_this_block, q_head_idx, + quant_max_bound, + quant_min_bound, in_scale, q_len, q_n_stride, @@ -704,6 +714,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel( smooth_weight, q_base_seq_id_this_block, q_head_idx, + quant_max_bound, + quant_min_bound, in_scale, q_len, q_n_stride * num_chunks, @@ -771,6 +783,8 @@ void MultiQueryAppendAttention( const int num_blocks_x_cpu, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, @@ -874,6 +888,8 @@ void MultiQueryAppendAttention( max_dec_len, max_block_num_per_seq, scale, + quant_max_bound, + quant_min_bound, in_scale, chunk_size, nullptr, @@ -929,6 +945,8 @@ void MultiQueryAppendAttention( max_dec_len, max_block_num_per_seq, scale, + quant_max_bound, + quant_min_bound, in_scale, chunk_size, reinterpret_cast(tmp_workspace->ptr()), @@ -964,6 +982,8 @@ void MultiQueryAppendAttention( smooth_weight.get().data())) : nullptr, reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, in_scale, max_seq_len, num_chunks, @@ -997,6 +1017,8 @@ void MultiQueryAppendAttention( smooth_weight.get().data())) : nullptr, reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, in_scale, max_seq_len, num_chunks, @@ -1087,6 +1109,8 @@ void MultiQueryAppendAttention( max_dec_len, max_block_num_per_seq, scale, + quant_max_bound, + quant_min_bound, in_scale, chunk_size, nullptr, @@ -1153,6 +1177,8 @@ void MultiQueryAppendAttention( max_dec_len, max_block_num_per_seq, scale, + quant_max_bound, + quant_min_bound, in_scale, chunk_size, reinterpret_cast(tmp_workspace->ptr()), @@ -1189,6 +1215,8 @@ void MultiQueryAppendAttention( smooth_weight.get().data())) : nullptr, reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, in_scale, max_seq_len, num_chunks, @@ -1222,6 +1250,8 @@ void MultiQueryAppendAttention( smooth_weight.get().data())) : nullptr, reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, in_scale, max_seq_len, num_chunks, @@ -1268,6 +1298,8 @@ void CascadeAppendAttentionC16Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, @@ -1328,6 +1360,8 @@ void CascadeAppendAttentionC16Kernel( num_blocks, max_seq_len, max_dec_len, + quant_max_bound, + quant_min_bound, in_scale, max_partition_size, encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/append_attention_c4_impl.cuh b/csrc/gpu/append_attn/append_attention_c4_impl.cuh index 43e063e90cff..586bde4dc741 100644 --- a/csrc/gpu/append_attn/append_attention_c4_impl.cuh +++ b/csrc/gpu/append_attn/append_attention_c4_impl.cuh @@ -52,6 +52,8 @@ __global__ void multi_query_append_attention_c4_kernel( const int max_dec_len, const int max_block_num_per_seq, const float scale, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const uint32_t chunk_size, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, @@ -416,6 +418,8 @@ __global__ void multi_query_append_attention_c4_kernel( smooth_weight, q_base_seq_id_this_block, q_head_idx, + quant_max_bound, + quant_min_bound, in_scale, q_len, partition_kv ? q_n_stride * num_chunks : q_n_stride, @@ -432,6 +436,8 @@ __global__ void multi_query_append_attention_c4_kernel( smooth_weight, q_base_seq_id_this_block, q_head_idx, + quant_max_bound, + quant_min_bound, in_scale, q_len, partition_kv ? q_n_stride * num_chunks : q_n_stride, @@ -504,6 +510,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( const int max_dec_len, const int max_block_num_per_seq, const float scale, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const uint32_t chunk_size, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, @@ -872,6 +880,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( smooth_weight, q_base_seq_id_this_block, q_head_idx, + quant_max_bound, + quant_min_bound, in_scale, q_len, q_n_stride, @@ -888,6 +898,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( smooth_weight, q_base_seq_id_this_block, q_head_idx, + quant_max_bound, + quant_min_bound, in_scale, q_len, q_n_stride * num_chunks, @@ -958,6 +970,8 @@ void MultiQueryAppendC4Attention( const int num_blocks_x_cpu, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, @@ -1080,6 +1094,8 @@ void MultiQueryAppendC4Attention( max_dec_len, max_block_num_per_seq, scale, + quant_max_bound, + quant_min_bound, in_scale, chunk_size, nullptr, @@ -1141,6 +1157,8 @@ void MultiQueryAppendC4Attention( max_dec_len, max_block_num_per_seq, scale, + quant_max_bound, + quant_min_bound, in_scale, chunk_size, reinterpret_cast(tmp_workspace->ptr()), @@ -1176,6 +1194,8 @@ void MultiQueryAppendC4Attention( smooth_weight.get().data())) : nullptr, reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, in_scale, max_seq_len, num_chunks, @@ -1209,6 +1229,8 @@ void MultiQueryAppendC4Attention( smooth_weight.get().data())) : nullptr, reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, in_scale, max_seq_len, num_chunks, @@ -1317,6 +1339,8 @@ void MultiQueryAppendC4Attention( max_dec_len, max_block_num_per_seq, scale, + quant_max_bound, + quant_min_bound, in_scale, chunk_size, nullptr, @@ -1391,6 +1415,8 @@ void MultiQueryAppendC4Attention( max_dec_len, max_block_num_per_seq, scale, + quant_max_bound, + quant_min_bound, in_scale, chunk_size, reinterpret_cast(tmp_workspace->ptr()), @@ -1426,6 +1452,8 @@ void MultiQueryAppendC4Attention( smooth_weight.get().data())) : nullptr, reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, in_scale, max_seq_len, num_chunks, @@ -1459,6 +1487,8 @@ void MultiQueryAppendC4Attention( smooth_weight.get().data())) : nullptr, reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, in_scale, max_seq_len, num_chunks, @@ -1505,6 +1535,8 @@ void CascadeAppendAttentionC4Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, @@ -1569,6 +1601,8 @@ void CascadeAppendAttentionC4Kernel( num_blocks, max_seq_len, max_dec_len, + quant_max_bound, + quant_min_bound, in_scale, max_partition_size, encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/append_attention_c8_impl.cuh b/csrc/gpu/append_attn/append_attention_c8_impl.cuh index bf4403dc041b..d5d1cc38e1b4 100644 --- a/csrc/gpu/append_attn/append_attention_c8_impl.cuh +++ b/csrc/gpu/append_attn/append_attention_c8_impl.cuh @@ -50,6 +50,8 @@ __global__ void multi_query_append_attention_c8_kernel( const int max_dec_len, const int max_block_num_per_seq, const float scale, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const uint32_t chunk_size, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, @@ -359,6 +361,8 @@ __global__ void multi_query_append_attention_c8_kernel( smooth_weight, q_base_seq_id_this_block, q_head_idx, + quant_max_bound, + quant_min_bound, in_scale, q_len, partition_kv ? q_n_stride * num_chunks : q_n_stride, @@ -375,6 +379,8 @@ __global__ void multi_query_append_attention_c8_kernel( smooth_weight, q_base_seq_id_this_block, q_head_idx, + quant_max_bound, + quant_min_bound, in_scale, q_len, partition_kv ? q_n_stride * num_chunks : q_n_stride, @@ -446,6 +452,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( const int max_dec_len, const int max_block_num_per_seq, const float scale, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const uint32_t chunk_size, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, @@ -760,6 +768,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( smooth_weight, q_base_seq_id_this_block, q_head_idx, + quant_max_bound, + quant_min_bound, in_scale, q_len, q_n_stride, @@ -776,6 +786,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( smooth_weight, q_base_seq_id_this_block, q_head_idx, + quant_max_bound, + quant_min_bound, in_scale, q_len, q_n_stride * num_chunks, @@ -845,6 +857,8 @@ void MultiQueryAppendC8Attention( const int num_blocks_x_cpu, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, @@ -952,6 +966,8 @@ void MultiQueryAppendC8Attention( max_dec_len, max_block_num_per_seq, scale, + quant_max_bound, + quant_min_bound, in_scale, chunk_size, nullptr, @@ -1007,6 +1023,8 @@ void MultiQueryAppendC8Attention( max_dec_len, max_block_num_per_seq, scale, + quant_max_bound, + quant_min_bound, in_scale, chunk_size, reinterpret_cast(tmp_workspace->ptr()), @@ -1042,6 +1060,8 @@ void MultiQueryAppendC8Attention( smooth_weight.get().data())) : nullptr, reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, in_scale, max_seq_len, num_chunks, @@ -1075,6 +1095,8 @@ void MultiQueryAppendC8Attention( smooth_weight.get().data())) : nullptr, reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, in_scale, max_seq_len, num_chunks, @@ -1167,6 +1189,8 @@ void MultiQueryAppendC8Attention( max_dec_len, max_block_num_per_seq, scale, + quant_max_bound, + quant_min_bound, in_scale, chunk_size, nullptr, @@ -1235,6 +1259,8 @@ void MultiQueryAppendC8Attention( max_dec_len, max_block_num_per_seq, scale, + quant_max_bound, + quant_min_bound, in_scale, chunk_size, reinterpret_cast(tmp_workspace->ptr()), @@ -1265,6 +1291,8 @@ void MultiQueryAppendC8Attention( smooth_weight.get().data())) : nullptr, reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, in_scale, max_seq_len, num_chunks, @@ -1298,6 +1326,8 @@ void MultiQueryAppendC8Attention( smooth_weight.get().data())) : nullptr, reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, in_scale, max_seq_len, num_chunks, @@ -1344,6 +1374,8 @@ void CascadeAppendAttentionC8Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, @@ -1406,6 +1438,8 @@ void CascadeAppendAttentionC8Kernel( num_blocks, max_seq_len, max_dec_len, + quant_max_bound, + quant_min_bound, in_scale, max_partition_size, encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/append_attention_func.cuh b/csrc/gpu/append_attn/append_attention_func.cuh index c9f7a1a0c097..0cb6c14f2b7d 100644 --- a/csrc/gpu/append_attn/append_attention_func.cuh +++ b/csrc/gpu/append_attn/append_attention_func.cuh @@ -1433,6 +1433,8 @@ struct StoreFunc { const AlignedVector& shift_bias_vec, const AlignedVector& smooth_weight_vec, AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int i) { out_vec[i] = static_cast(ori_out_vec[i]); @@ -1447,20 +1449,41 @@ struct StoreFunc { const AlignedVector& shift_bias_vec, const AlignedVector& smooth_weight_vec, AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int i) { float quant_value = - 127.0f * + quant_max_bound * static_cast((ori_out_vec[i] + shift_bias_vec[i]) * smooth_weight_vec[i]) * in_scale; quant_value = rintf(quant_value); - quant_value = quant_value > 127.0f ? 127.0f : quant_value; - quant_value = quant_value < -127.0f ? -127.0f : quant_value; + quant_value = quant_value > quant_max_bound ? quant_max_bound : quant_value; + quant_value = quant_value < quant_min_bound ? quant_min_bound : quant_value; out_vec[i] = static_cast(quant_value); } }; +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector<__nv_fp8_e4m3, VEC_SIZE>& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + float quant_value = + quant_max_bound * static_cast(ori_out_vec[i]) * in_scale; + quant_value = quant_value > quant_max_bound ? quant_max_bound : quant_value; + quant_value = quant_value < quant_min_bound ? quant_min_bound : quant_value; + out_vec[i] = static_cast<__nv_fp8_e4m3>(quant_value); + } +}; + template struct StoreFunc { __device__ __forceinline__ void operator()( @@ -1468,6 +1491,8 @@ struct StoreFunc { const AlignedVector& shift_bias_vec, const AlignedVector& smooth_weight_vec, AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int i) { out_vec[i] = ori_out_vec[i]; @@ -1488,6 +1513,8 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant( const T* smooth_weight, uint32_t o_idx_base, const uint32_t q_head_idx_base, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const uint32_t qo_upper_bound, const uint32_t qo_n_stride, @@ -1565,6 +1592,8 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant( shift_bias_vec, smooth_weight_vec, out_vec, + quant_max_bound, + quant_min_bound, in_scale, i); } @@ -1598,6 +1627,8 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( const T* smooth_weight, uint32_t o_idx_base, const uint32_t q_head_idx_base, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const uint32_t qo_upper_bound, const uint32_t qo_n_stride, @@ -1668,6 +1699,8 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( shift_bias_vec, smooth_weight_vec, out_vec, + quant_max_bound, + quant_min_bound, in_scale, i); } @@ -1791,6 +1824,8 @@ __global__ void merge_multi_chunks_kernel( const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] T* __restrict__ out, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_seq_len, const int num_chunks, @@ -2048,6 +2083,8 @@ __global__ void merge_multi_chunks_decoder_kernel( const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] OutT *__restrict__ out, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_seq_len, const int num_chunks, @@ -2149,7 +2186,7 @@ __global__ void merge_multi_chunks_decoder_kernel( #pragma unroll for (int i = 0; i < vec_size; ++i) { StoreFunc()( - st.o, shift_bias_vec, smooth_weight_vec, out_vec, in_scale, i); + st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i); } Store( out_vec, @@ -2175,6 +2212,8 @@ __global__ void merge_multi_chunks_v2_kernel( const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] OutT *__restrict__ out, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_seq_len, const int num_chunks, @@ -2299,7 +2338,7 @@ __global__ void merge_multi_chunks_v2_kernel( #pragma unroll for (int i = 0; i < vec_size; ++i) { StoreFunc()( - st.o, shift_bias_vec, smooth_weight_vec, out_vec, in_scale, i); + st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i); } Store( out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); diff --git a/csrc/gpu/append_attn/append_attention_kernel.h b/csrc/gpu/append_attn/append_attention_kernel.h index e8738af74296..59532b2400c5 100644 --- a/csrc/gpu/append_attn/append_attention_kernel.h +++ b/csrc/gpu/append_attn/append_attention_kernel.h @@ -49,6 +49,8 @@ void CascadeAppendAttentionC16Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, @@ -92,6 +94,8 @@ void CascadeAppendAttentionC8Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, @@ -135,6 +139,8 @@ void CascadeAppendAttentionC4Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, @@ -179,6 +185,8 @@ void CascadeAppendAttentionKernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, @@ -212,6 +220,8 @@ void CascadeAppendAttentionKernel( block_shape_q, max_seq_len, max_dec_len, + quant_max_bound, + quant_min_bound, in_scale, max_partition_size, encoder_max_partition_size, @@ -245,6 +255,8 @@ void CascadeAppendAttentionKernel( block_shape_q, max_seq_len, max_dec_len, + quant_max_bound, + quant_min_bound, in_scale, max_partition_size, encoder_max_partition_size, @@ -278,6 +290,8 @@ void CascadeAppendAttentionKernel( block_shape_q, max_seq_len, max_dec_len, + quant_max_bound, + quant_min_bound, in_scale, max_partition_size, encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu index e2930b517896..7dafef74ba88 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu @@ -46,6 +46,8 @@ template void CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); \ No newline at end of file diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_int8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_int8_kernel.cu index 425a5f14478d..deadb933db0f 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_int8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_int8_kernel.cu @@ -45,6 +45,8 @@ template void CascadeAppendAttentionC16Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu index d9420b287ff0..806eecbb529d 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu @@ -45,6 +45,8 @@ template void CascadeAppendAttentionC16Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu new file mode 100644 index 000000000000..c677686d68aa --- /dev/null +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "../append_attention_c16_impl.cuh" + +template void CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); \ No newline at end of file diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_int8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_int8_kernel.cu index 32beb39f5951..9d20f15f8c29 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_int8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_int8_kernel.cu @@ -45,6 +45,8 @@ template void CascadeAppendAttentionC16Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu index 794c2338dca8..75c6e80c3056 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu @@ -45,6 +45,8 @@ template void CascadeAppendAttentionC4Kernel const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu new file mode 100644 index 000000000000..065834d6d0d8 --- /dev/null +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "../append_attention_c4_impl.cuh" + +template void CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); \ No newline at end of file diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_int8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_int8_kernel.cu index ca905d9bc52a..8c699265d116 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_int8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_int8_kernel.cu @@ -45,6 +45,8 @@ template void CascadeAppendAttentionC4Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu index b87c37006409..3a2b13a89045 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu @@ -46,6 +46,8 @@ template void CascadeAppendAttentionC4Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu new file mode 100644 index 000000000000..4f5dedb15dc5 --- /dev/null +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "../append_attention_c4_impl.cuh" + +template void CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); \ No newline at end of file diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_int8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_int8_kernel.cu index 3a53f7189275..a747080ef49c 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_int8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_int8_kernel.cu @@ -45,6 +45,8 @@ template void CascadeAppendAttentionC4Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu index f1d9ba006667..606c9128a973 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu @@ -47,6 +47,8 @@ CascadeAppendAttentionC8Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu new file mode 100644 index 000000000000..efc54738fafc --- /dev/null +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "../append_attention_c8_impl.cuh" + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); \ No newline at end of file diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_int8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_int8_kernel.cu index f680c1c64510..baff43d49424 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_int8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_int8_kernel.cu @@ -45,6 +45,8 @@ template void CascadeAppendAttentionC8Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu index 2613164b8d9f..83728df8d409 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu @@ -45,6 +45,8 @@ template void CascadeAppendAttentionC8Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu new file mode 100644 index 000000000000..35267a59f55b --- /dev/null +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "../append_attention_c8_impl.cuh" + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out); \ No newline at end of file diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_int8_kerne.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_int8_kerne.cu index 5159c7944e21..9b489ff96001 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_int8_kerne.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_int8_kerne.cu @@ -45,6 +45,8 @@ template void CascadeAppendAttentionC8Kernel( const int block_shape_q, const int max_seq_len, const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, const float in_scale, const int max_partition_size, const int encoder_max_partition_size, diff --git a/csrc/gpu/append_attn/utils.cuh b/csrc/gpu/append_attn/utils.cuh index a1ada6d97eef..d152e2c32add 100644 --- a/csrc/gpu/append_attn/utils.cuh +++ b/csrc/gpu/append_attn/utils.cuh @@ -14,9 +14,10 @@ #pragma once #include +#include #include #include "mem_util.cuh" - + struct AppendAttnMetaData { int batch_size; int block_size; @@ -50,6 +51,11 @@ struct cascade_attn_type_traits { using type = half; }; +template <> +struct cascade_attn_type_traits { + using type = __nv_fp8_e4m3; +}; + template struct cascade_attn_nv_type2_traits { using type = T; diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index a42072bd4232..3978f8c17de8 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -149,16 +149,17 @@ def get_gencode_flags(): ] cc = get_sm_version() +cuda_version = float(paddle.version.cuda()) if cc >= 80: sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"] -if cc >= 89: +if cc >= 89 and cuda_version >= 12.4: + os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") + os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") sources += find_end_files("gpu/cutlass_kernels/fp8_gemm_fused/autogen", ".cu") sources += [ "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu", - "gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.cu", "gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu", - "gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.cu", ] setup( diff --git a/csrc/utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py b/csrc/utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py index 9404d880c6aa..12f802980a22 100644 --- a/csrc/utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py +++ b/csrc/utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py @@ -136,7 +136,9 @@ def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, ma code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit. #include -#include "fp8_fp8_dual_gemm_scale_bias_act.h" +#include +#include +#include "fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h" #include "launch_dual_gemm_kernel.h" COMMON_DECLARE_string(use_cutlass_device_best_config_path); @@ -173,6 +175,35 @@ def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, ma return false; } +template +T get_relative_best(nlohmann::json* json_data, + const std::string& target_key, + const std::string& regex_key, + const int& m, + const T& default_value) { + if (json_data->contains(target_key)) { + return json_data->at(target_key); + } else { + std::regex pattern(regex_key); + std::string closest_key; + int closest_diff = std::numeric_limits::max(); + T closest_value = default_value; + + for (const auto& [key, value] : json_data->items()) { + std::smatch matches; + if (std::regex_search(key, matches, pattern)) { + int relative_m = std::stoi(matches[1].str()); + int diff = std::abs(relative_m - m); + if (diff < closest_diff) { + closest_diff = diff; + closest_value = value; + } + } + } + json_data->push_back({target_key, closest_value}); + return closest_value; + } +} bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) { if (dual_gemm_type_map.find(params.fuse_gemm_config) == dual_gemm_type_map.end()) { @@ -185,25 +216,22 @@ def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, ma int K = params.K; std::string mnk_string = "dual_gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">"; + std::string regex_mnk_string = "dual_gemm<(\\d+), " + std::to_string(N) + ", " + std::to_string(K) + ">"; std::string mnk_split_k_string = "dual_gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k"; + std::string regex_mnk_split_k_string = "dual_gemm<(\\d+), " + std::to_string(N) + ", " + std::to_string(K) + ">, split_k"; int split_k; int kernel_id; std::string best_config; CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance(); if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path"); - nlohmann::json* config_json = best_config_mannager.get_gemm_best_configs(config_file_path); - if (config_json->contains(mnk_string)) { - best_config = config_json->at(mnk_string); - } else { - std::cerr << "Can not find the config for this gemm shape, please tune this shape: " << mnk_string <contains(mnk_split_k_string)) { - split_k = config_json->at(mnk_split_k_string); - } else { - std::cerr << "Can not find the config(split_k) for this gemm shape, please tune this shape: " << mnk_string <(config_json, mnk_string, regex_mnk_string, M, "<64, 64, 64>, <32, 32, 64>, <16, 8, 32>, 3"); + split_k = get_relative_best(config_json, mnk_split_k_string, regex_mnk_split_k_string, M, 1); if (dual_gemm_config_map.find(best_config) == dual_gemm_config_map.end()) { throw std::runtime_error("This config'kernel not be generate, please check generate_code_gemm_fused_kernels.py and re-generate."); @@ -601,7 +629,7 @@ def generate_dispatch_dual_gemm_cu( max_split_k, ) # hard code for act_tag - file_name = "gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.cu" + file_name = "gpu/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_dual_gemm_scale_bias_act.cu" all_code = generate_dispatch_dual_gemm_cu( inputs_type, outputs_type, diff --git a/csrc/utils/auto_gen_fp8_fp8_gemm_fused_kernels.py b/csrc/utils/auto_gen_fp8_fp8_gemm_fused_kernels.py index 35df16761859..63fa839e18b8 100644 --- a/csrc/utils/auto_gen_fp8_fp8_gemm_fused_kernels.py +++ b/csrc/utils/auto_gen_fp8_fp8_gemm_fused_kernels.py @@ -153,7 +153,9 @@ def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages): code_part0 = """// Generated by generate_code_gemm_fused_kernels.py - Do not edit. #include -#include "fp8_fp8_gemm_scale_bias_act.h" +#include +#include +#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h" #include "launch_gemm_kernel.h" COMMON_DECLARE_string(use_cutlass_device_best_config_path); @@ -190,6 +192,35 @@ def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages): return false; } +template +T get_relative_best(nlohmann::json* json_data, + const std::string& target_key, + const std::string& regex_key, + const int& m, + const T& default_value) { + if (json_data->contains(target_key)) { + return json_data->at(target_key); + } else { + std::regex pattern(regex_key); + std::string closest_key; + int closest_diff = std::numeric_limits::max(); + T closest_value = default_value; + + for (const auto& [key, value] : json_data->items()) { + std::smatch matches; + if (std::regex_search(key, matches, pattern)) { + int relative_m = std::stoi(matches[1].str()); + int diff = std::abs(relative_m - m); + if (diff < closest_diff) { + closest_diff = diff; + closest_value = value; + } + } + } + json_data->push_back({target_key, closest_value}); + return closest_value; + } +} bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) { if (gemm_type_map.find(params.fuse_gemm_config) == gemm_type_map.end()) { @@ -202,25 +233,22 @@ def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages): int K = params.K; std::string mnk_string = "gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">"; + std::string regex_mnk_string = "gemm<(\\d+), " + std::to_string(N) + ", " + std::to_string(K) + ">"; std::string mnk_split_k_string = "gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k"; + std::string regex_mnk_split_k_string = "gemm<(\\d+), " + std::to_string(N) + ", " + std::to_string(K) + ">, split_k"; int split_k; int kernel_id; std::string best_config; CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance(); if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path"); - nlohmann::json* config_json = best_config_mannager.get_gemm_best_configs(config_file_path); - if (config_json->contains(mnk_string)) { - best_config = config_json->at(mnk_string); - } else { - std::cerr << "Can not find the config for this gemm shape, please tune this shape: " << mnk_string <contains(mnk_split_k_string)) { - split_k = config_json->at(mnk_split_k_string); - } else { - std::cerr << "Can not find the config(split_k) for this gemm shape, please tune this shape: " << mnk_string <(config_json, mnk_string, regex_mnk_string, M, "<64, 64, 64>, <32, 32, 64>, <16, 8, 32>, 3"); + split_k = get_relative_best(config_json, mnk_split_k_string, regex_mnk_split_k_string, M, 1); if (gemm_config_map.find(best_config) == gemm_config_map.end()) { throw std::runtime_error("This config'kernel not be generate, please check generate_code_gemm_fused_kernels.py and re-generate."); @@ -591,7 +619,7 @@ def generate_dispatch_gemm_cu( # hard code for act_tag - file_name = "gpu/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.cu" + file_name = "gpu/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu" all_code = generate_dispatch_gemm_cu( inputs_type, outputs_type, diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 7c1dadc43ce6..5f7def16c24a 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -238,7 +238,6 @@ def __init__( self.ffn1_1_weight_attrs = ffn1_1_weight_attrs self.ffn1_0_bias_attrs = ffn1_0_bias_attrs self.ffn1_1_bias_attrs = ffn1_1_bias_attrs - self.use_dynamic_cachekv_quant = cachekv_int8_type == "dynamic" self.ffn2_weight_attrs = ffn2_weight_attrs self.ffn2_weight_scale_attrs = ffn2_weight_scale_attrs @@ -2189,6 +2188,8 @@ def compute_attn( "none", # cache_quant_type self.use_neox_rotary_style, kwargs.get("max_input_length", -1), + 0.0, + 0.0, 0.0, # out_linear_in_scale kwargs.get("encoder_block_shape_q", 64), kwargs.get("decoder_block_shape_q", 16), @@ -2378,6 +2379,8 @@ def compute_attn( cache_quant_type_str, self.use_neox_rotary_style, kwargs.get("max_input_length", -1), + self.quant_max_bound, + self.quant_min_bound, self.act_scales["out_linear_in_scale"][i], kwargs.get("encoder_block_shape_q", 64), kwargs.get("decoder_block_shape_q", 16), @@ -2432,112 +2435,81 @@ def compute_attn( return out_linear_out -class FusedBlockMultiTransformerFP8(Layer): +class FusedBlockMultiTransformerFP8(FusedBlockMultiTransformer): def __init__(self, config: FusedMultiTransformerConfig): """""" - super().__init__() - self.config = config + super().__init__(config) self.act_scales = None self.weight_scales = None - assert config.embed_dim > 0, "Expected embed_dim to be greater than 0, " "but received {}".format( - config.embed_dim - ) - assert config.num_heads > 0, "Expected nhead to be greater than 0, " "but received {}".format(config.num_heads) - assert config.dim_feedforward > 0, "Expected dim_feedforward to be greater than 0, but received {}".format( - config.dim_feedforward - ) - - # self.normalize_before = normalize_before - self._dtype = self._helper.get_default_dtype() - self._epsilon = config.epsilon - self._residual_alpha = config.residual_alpha - self.nranks = config.nranks - self.norm_type = config.norm_type - if self.norm_type == "layernorm": - self.norm_func = fused_layer_norm - elif self.norm_type == "rmsnorm": - self.norm_func = fused_rms_norm - else: - raise NotImplementedError("Only support norm type of [layernorm, rmsnorm]") - self.use_neox_rotary_style = config.use_neox_rotary_style - self._norm_weight_dtype = "float32" if self.norm_type == "layernorm" else self._dtype - - self.activation = config.activation - - self.embed_dim = config.embed_dim - self.head_dim = config.embed_dim // config.num_heads - assert self.head_dim * config.num_heads == config.embed_dim, "embed_dim must be divisible by num_heads" - - # tensor model parallel - if config.nranks > 1: - assert config.ring_id != -1 - assert config.num_heads % config.nranks == 0 - assert config.dim_feedforward % config.nranks == 0 - assert config.moe_config.shared_expert_intermediate_size % config.nranks == 0 - self.num_heads = config.num_heads // config.nranks - self.kv_num_heads = config.kv_num_heads // config.nranks - self.dim_feedforward = config.dim_feedforward // config.nranks - shared_expert_intermediate_size = config.moe_config.shared_expert_intermediate_size // config.nranks - self.config.moe_config.shared_expert_intermediate_size = shared_expert_intermediate_size + self.quant_round_type = config.quant_round_type + self.quant_max_bound = config.quant_max_bound + self.quant_min_bound = config.quant_min_bound - self.num_layers = config.num_layers - assert self.num_layers > 0 - if isinstance(config.qkv_weight_attrs, (list, tuple)): - assert self.num_layers == len(config.qkv_weight_attrs) + self.ffn1_0_biases = [] + self.ffn1_1_biases = [] - self.weight_dtype = self._dtype - self.create_params_type = self.get_weight_create_dype() + self.qkv_out_scales = [] + self.linear_out_scales = [] + self.ffn1_0_out_scales = [] + self.ffn1_1_out_scales = [] + self.ffn2_out_scales = [] - self.ln_scales, self.ln_biases = [], [] - self.qkv_weights, self.qkv_biases = [], [] - self.linear_weights, self.linear_biases = [], [] - self.ffn_ln_scales, self.ffn_ln_biases = [], [] - self.ffn1_0_weights, self.ffn1_0_biases = [], [] - self.ffn1_1_weights, self.ffn1_1_biases = [], [] - self.ffn2_weights, self.ffn2_biases = [], [] - self.cache_k_scales, self.cache_v_scales = [], [] - self.cache_k_out_scales, self.cache_v_out_scales = [], [] + self.init_weight_shape(config) for i in range(self.num_layers): - ln_scale_attr = self.get_attr(config.ln_scale_attrs, i) - ln_bias_attr = self.get_attr(config.ln_bias_attrs, i) - qkv_weight_attr = self.get_attr(config.qkv_weight_attrs, i) + self.qkv_out_scales.append(-1.0) + self.linear_out_scales.append(-1.0) + self.ffn1_0_out_scales.append(-1.0) + self.ffn1_1_out_scales.append(-1.0) + self.ffn2_out_scales.append(-1.0) - qkv_bias_attr = self.get_attr(config.qkv_bias_attrs, i) - linear_weight_attr = self.get_attr(config.linear_weight_attrs, i) - linear_bias_attr = self.get_attr(config.linear_bias_attrs, i) - - ffn_ln_scale_attr = self.get_attr(config.ffn_ln_scale_attrs, i) - ffn_ln_bias_attr = self.get_attr(config.ffn_ln_bias_attrs, i) - ffn1_0_weight_attr = self.get_attr(config.ffn1_0_weight_attrs, i) - ffn1_1_weight_attr = self.get_attr(config.ffn1_1_weight_attrs, i) ffn1_0_bias_attr = self.get_attr(config.ffn1_0_bias_attrs, i) ffn1_1_bias_attr = self.get_attr(config.ffn1_1_bias_attrs, i) - ffn2_weight_attr = self.get_attr(config.ffn2_weight_attrs, i) - ffn2_bias_attr = self.get_attr(config.ffn2_bias_attrs, i) - cache_k_scale_attr = self.get_attr(config.cache_k_scale_attrs, i) - cache_v_scale_attr = self.get_attr(config.cache_v_scale_attrs, i) - cache_k_out_scale_attr = self.get_attr(config.cache_k_out_scale_attrs, i) - cache_v_out_scale_attr = self.get_attr(config.cache_v_out_scale_attrs, i) + ffn1_0_bias = None + if ffn1_0_bias_attr: + ffn1_0_bias = self.create_parameter( + shape=[self.dim_feedforward], + attr=ffn1_0_bias_attr, + dtype=self._dtype, + is_bias=True, + ) - ln_scale = self.create_parameter( - attr=ln_scale_attr, - shape=[config.embed_dim], - default_initializer=Constant(value=1.0), - dtype=self._norm_weight_dtype, - ) - ln_bias = None - if ln_bias_attr: - ln_bias = self.create_parameter( - attr=ln_bias_attr, - shape=[config.embed_dim], + ffn1_1_bias = None + if ffn1_1_bias_attr: + ffn1_1_bias = self.create_parameter( + shape=[self.dim_feedforward], + attr=ffn1_1_bias_attr, + dtype=self._dtype, is_bias=True, - dtype=self._norm_weight_dtype, ) - self.init_weight_shape(config) + # tensor model parallel + if config.nranks > 1: + # column parallel + _set_var_distributed(ffn1_0_bias) + _set_var_distributed(ffn1_1_bias) + + self.ffn1_0_biases.append(ffn1_0_bias) + self.ffn1_1_biases.append(ffn1_1_bias) + + self._add_parameter(ffn1_0_bias) + self._add_parameter(ffn1_1_bias) + + def init_weight(self): + self.qkv_weights = [] + self.linear_weights = [] + self.ffn1_0_weights = [] + self.ffn1_1_weights = [] + self.ffn2_weights = [] + + for i in range(self.num_layers): + qkv_weight_attr = self.get_attr(self.config.qkv_weight_attrs, i) + linear_weight_attr = self.get_attr(self.config.linear_weight_attrs, i) + ffn1_0_weight_attr = self.get_attr(self.config.ffn1_0_weight_attrs, i) + ffn1_1_weight_attr = self.get_attr(self.config.ffn1_1_weight_attrs, i) + ffn2_weight_attr = self.get_attr(self.config.ffn2_weight_attrs, i) qkv_weight = self.create_parameter( shape=self.qkv_weight_shape, @@ -2545,49 +2517,12 @@ def __init__(self, config: FusedMultiTransformerConfig): dtype=self.create_params_type, is_bias=False, ) - - qkv_bias = None - if qkv_bias_attr: - qkv_bias = self.create_parameter( - shape=[(self.num_heads + 2 * self.kv_num_heads) * self.head_dim], - attr=qkv_bias_attr, - dtype=self._dtype, - is_bias=True, - ) - linear_weight = self.create_parameter( shape=self.linear_weight_shape, attr=linear_weight_attr, dtype=self.create_params_type, is_bias=False, ) - - linear_bias = None - if linear_bias_attr: - linear_bias = self.create_parameter( - shape=[config.embed_dim], - attr=linear_bias_attr, - dtype=self._dtype, - is_bias=True, - ) - - ffn_ln_scale = self.create_parameter( - shape=[config.embed_dim], - attr=ffn_ln_scale_attr, - is_bias=False, - default_initializer=Constant(1.0), - dtype=self._norm_weight_dtype, - ) - - ffn_ln_bias = None - if ffn_ln_bias_attr: - ffn_ln_bias = self.create_parameter( - shape=[config.embed_dim], - attr=ffn_ln_bias_attr, - is_bias=True, - dtype=self._norm_weight_dtype, - ) - ffn1_0_weight = self.create_parameter( shape=self.ffn1_0_weight_shape, attr=ffn1_0_weight_attr, @@ -2600,25 +2535,6 @@ def __init__(self, config: FusedMultiTransformerConfig): dtype=self.create_params_type, is_bias=False, ) - - ffn1_0_bias = None - if ffn1_0_bias_attr: - ffn1_0_bias = self.create_parameter( - shape=[self.dim_feedforward], - attr=ffn1_0_bias_attr, - dtype=self._dtype, - is_bias=True, - ) - - ffn1_1_bias = None - if ffn1_1_bias_attr: - ffn1_1_bias = self.create_parameter( - shape=[self.dim_feedforward], - attr=ffn1_1_bias_attr, - dtype=self._dtype, - is_bias=True, - ) - ffn2_weight = self.create_parameter( shape=self.ffn2_weight_shape, attr=ffn2_weight_attr, @@ -2626,135 +2542,30 @@ def __init__(self, config: FusedMultiTransformerConfig): is_bias=False, ) - ffn2_bias = None - if ffn2_bias_attr: - ffn2_bias = self.create_parameter( - shape=[config.embed_dim], - attr=ffn2_bias_attr, - dtype=self._dtype, - is_bias=True, - ) - - cache_scale_dtype = "float32" - if self.config.append_attn: - cache_scale_dtype = paddle.get_default_dtype() - - cache_k_scale = None - if cache_k_scale_attr: - cache_k_scale = self.create_parameter( - shape=[self.kv_num_heads], - attr=cache_k_scale_attr, - dtype=cache_scale_dtype, - is_bias=False, - ) - - cache_v_scale = None - if cache_v_scale_attr: - cache_v_scale = self.create_parameter( - shape=[self.kv_num_heads], - attr=cache_v_scale_attr, - dtype=cache_scale_dtype, - is_bias=False, - ) - - cache_k_out_scale = None - if cache_k_out_scale_attr: - cache_k_out_scale = self.create_parameter( - shape=[self.kv_num_heads], - attr=cache_k_out_scale_attr, - dtype=cache_scale_dtype, - is_bias=False, - ) - - cache_v_out_scale = None - if cache_v_out_scale_attr: - cache_v_out_scale = self.create_parameter( - shape=[self.kv_num_heads], - attr=cache_v_out_scale_attr, - dtype=cache_scale_dtype, - is_bias=False, - ) - # tensor model parallel - if config.nranks > 1: + if self.config.nranks > 1: # column parallel _set_var_distributed(qkv_weight) - _set_var_distributed(qkv_bias) _set_var_distributed(ffn1_0_weight) _set_var_distributed(ffn1_1_weight) - _set_var_distributed(ffn1_0_bias) - _set_var_distributed(ffn1_1_bias) # row parallel _set_var_distributed(linear_weight) _set_var_distributed(ffn2_weight) - self.ln_scales.append(ln_scale) - self.ln_biases.append(ln_bias) self.qkv_weights.append(qkv_weight) - self.qkv_biases.append(qkv_bias) self.linear_weights.append(linear_weight) - self.linear_biases.append(linear_bias) - self.ffn_ln_scales.append(ffn_ln_scale) - self.ffn_ln_biases.append(ffn_ln_bias) self.ffn1_0_weights.append(ffn1_0_weight) self.ffn1_1_weights.append(ffn1_1_weight) - self.ffn1_0_biases.append(ffn1_0_bias) - self.ffn1_1_biases.append(ffn1_1_bias) - self.ffn2_weights.append(ffn2_weight) - self.ffn2_biases.append(ffn2_bias) - - self.cache_k_scales.append(cache_k_scale) - self.cache_v_scales.append(cache_v_scale) - self.cache_k_out_scales.append(cache_k_out_scale) - self.cache_v_out_scales.append(cache_v_out_scale) - - self._add_parameter(ln_scale) - self._add_parameter(ln_bias) - self._add_parameter(qkv_weight) - self._add_parameter(qkv_bias) - self._add_parameter(linear_weight) - self._add_parameter(linear_bias) - - self._add_parameter(ffn_ln_scale) - self._add_parameter(ffn_ln_bias) - self._add_parameter(ffn1_0_weight) - self._add_parameter(ffn1_1_weight) - self._add_parameter(ffn1_0_bias) - self._add_parameter(ffn1_1_bias) - self._add_parameter(ffn2_weight) - self._add_parameter(ffn2_bias) - - self._add_parameter(cache_k_scale) - self._add_parameter(cache_v_scale) - self._add_parameter(cache_k_out_scale) - self._add_parameter(cache_v_out_scale) - - self.dropout_rate = config.dropout_rate - - from paddle.incubate.nn.functional import fused_linear - - self.linear = fused_linear - - def get_attr(self, attrs, idx): - """ - For fake parameter - """ - if isinstance(attrs, (list, tuple)): - assert ( - len(attrs) == self.num_layers - ), f"length of attrs is {len(attrs)} is not equal to self.num_layers {self.num_layers}" - return attrs[idx] - return attrs - def _add_parameter(self, param): - """ - For fake parameter - """ - if param is None: - return - assert param.name not in self._parameters - self._parameters[param.name] = param + self.ffn2_weights.append(ffn2_weight) + + self._add_parameter(qkv_weight) + self._add_parameter(linear_weight) + + self._add_parameter(ffn1_0_weight) + self._add_parameter(ffn1_1_weight) + self._add_parameter(ffn2_weight) def init_weight_shape(self, config): """ @@ -2770,10 +2581,13 @@ def init_weight_shape(self, config): self.ffn1_1_weight_shape = [self.dim_feedforward, self.embed_dim] self.ffn2_weight_shape = [self.embed_dim, self.dim_feedforward] - def get_weight_create_dype(self): + def get_weight_create_dype(self, layer_name=None, layer_idx=None): """ For fake parameter """ + if layer_name is not None and layer_idx is not None: + if self.weight_scales[layer_name][layer_idx] == -1: + return self._dtype return "float8_e4m3fn" def compute_layernorm_before_qkv(self, src, i): @@ -2787,7 +2601,7 @@ def compute_layernorm_before_qkv(self, src, i): self.ln_biases[i], self._epsilon, begin_norm_axis=1, - quant_scale=self.act_scales.scale["qkv_in_scale"][i], # quant_in_scale + quant_scale=self.act_scales["qkv_in_scale"][i], # quant_in_scale quant_round_type=1, quant_max_bound=self.config.quant_max_bound, quant_min_bound=self.config.quant_min_bound, @@ -2815,22 +2629,13 @@ def compute_qkv_linear(self, ln_out, i): transpose_x=False, transpose_y=True, bias=self.qkv_biases[i], - scale=self.weight_scales.scale["qkv_weight_scale"][i] - / (self.act_scales.scale["qkv_in_scale"][i] * 448 * 448), + scale=self.weight_scales["qkv_weight_scale"][i] / (self.act_scales["qkv_in_scale"][i] * 448 * 448), output_dtype=self._dtype, act="identity", ) return qkv_out - def compute_qkv(self, src, residual_input, i): - """ - For fake parameter - """ - ln_out = self.compute_layernorm_before_qkv(src, i) - qkv_out = self.compute_qkv_linear(ln_out, i) - return qkv_out, residual_input - def compute_out_linear(self, fmha_out, i): """ For fake parameter @@ -2841,19 +2646,12 @@ def compute_out_linear(self, fmha_out, i): bias=None, transpose_x=False, transpose_y=True, - scale=self.weight_scales.scale["out_linear_weight_scale"][i] - / (self.act_scales.scale["out_linear_in_scale"][i] * 448 * 448), + scale=self.weight_scales["out_linear_weight_scale"][i] + / (self.act_scales["out_linear_in_scale"][i] * 448 * 448), output_dtype=self._dtype, act="identity", ) - def compute_max_len(self, seq_lens_encoder, seq_lens_decoder, cum_offsets): - if seq_lens_encoder is None or seq_lens_decoder is None or cum_offsets is None: - return None, None - return paddle.incubate.nn.functional.blha_get_max_len( - seq_lens_encoder, seq_lens_decoder, cum_offsets # cum_offsets.shape[0] used as bsz - ) - def compute_attn( self, time_step, @@ -2895,50 +2693,108 @@ def compute_attn( v_quant_scales = kwargs.get("v_quant_scales", None) k_dequant_scales = kwargs.get("k_dequant_scales", None) v_dequant_scales = kwargs.get("v_dequant_scales", None) + cache_k_zps = kwargs.get("cache_k_zp", None) + cache_v_zps = kwargs.get("cache_v_zp", None) - if not self.config.use_dynamic_cachekv_quant: + cache_quant_type_str = "none" + if self.config.cachekv_int8_type == "static": k_quant_scales = self.cache_k_scales v_quant_scales = self.cache_v_scales k_dequant_scales = self.cache_k_out_scales v_dequant_scales = self.cache_v_out_scales + cache_quant_type_str = "cache_int8" + + if self.config.append_attn: + from paddlenlp_ops import append_attention + + fmha_out = append_attention( + qkv_out, + caches[2 * i], + caches[2 * i + 1], + kwargs.get("seq_lens_encoder", None), + kwargs.get("seq_lens_decoder", None), + kwargs.get("seq_lens_this_time", None), + kwargs.get("padding_offsets", None), + kwargs.get("cum_offsets", None), + kwargs.get("block_tables", None), + kwargs.get("encoder_batch_ids", None), + kwargs.get("encoder_tile_ids_per_batch", None), + kwargs.get("encoder_num_blocks", None), + kwargs.get("kv_batch_ids", None), + kwargs.get("kv_tile_ids_per_batch", None), + kwargs.get("kv_num_blocks", None), + kwargs.get("decoder_batch_ids", None), + kwargs.get("decoder_tile_ids_per_batch", None), + kwargs.get("decoder_num_blocks", None), + kwargs.get("max_enc_len_this_time", None), + kwargs.get("max_dec_len_this_time", None), + kwargs.get("max_len_kv", None), + rotary_embs, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + k_quant_scales[i] if k_quant_scales is not None else None, + v_quant_scales[i] if v_quant_scales is not None else None, + k_dequant_scales[i] if k_dequant_scales is not None else None, + v_dequant_scales[i] if v_dequant_scales is not None else None, + cache_k_zps[i] if cache_k_zps is not None else None, + cache_v_zps[i] if cache_v_zps is not None else None, + None, # linear_shifts + None, # linear_smooths + self._fuse_kernel_compute_dtype, + cache_quant_type_str, + self.use_neox_rotary_style, + kwargs.get("max_input_length", -1), + self.quant_max_bound, + self.quant_min_bound, + self.act_scales["out_linear_in_scale"][i], + kwargs.get("encoder_block_shape_q", 64), + kwargs.get("decoder_block_shape_q", 16), + kwargs.get("max_partition_size", 32768), + kwargs.get("encoder_max_partition_size", 32768), + kwargs["speculate_max_draft_token_num"], # speculate_max_draft_token_num + True, # causal + False, # speculate_decoder + )[0] + else: + fmha_out = paddle.incubate.nn.functional.block_multihead_attention( + qkv_out, + caches[2 * i], + caches[2 * i + 1], + kwargs.get("seq_lens_encoder", None), + kwargs.get("seq_lens_decoder", None), + kwargs.get("seq_lens_this_time", None), + kwargs.get("padding_offsets", None), + kwargs.get("cum_offsets", None), + kwargs.get("cu_seqlens_q", None), + kwargs.get("cu_seqlens_k", None), + kwargs.get("block_tables", None), + pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache + pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache + k_quant_scales[i] if k_quant_scales is not None else None, + v_quant_scales[i] if v_quant_scales is not None else None, + k_dequant_scales[i] if k_dequant_scales is not None else None, + v_dequant_scales[i] if v_dequant_scales is not None else None, + None, # qkv_out_scales + None, # qkv_bias + None, # out_shifts + None, # out_smooths + kwargs.get("max_enc_len_this_time", None), + kwargs.get("max_dec_len_this_time", None), + rotary_embs, + attn_mask, + kwargs.get("tgt_mask", None), # tgt_mask + kwargs.get("max_input_length", -1), + kwargs.get("block_size", 64), + self.use_neox_rotary_style, + self.config.use_dynamic_cachekv_quant, + quant_round_type=self.config.quant_round_type, + quant_max_bound=self.config.quant_max_bound, + quant_min_bound=self.config.quant_min_bound, + out_scale=self.act_scales.scale["out_linear_in_scale"][i], + rope_theta=self.config.rope_theta, + )[0] - fmha_out = paddle.incubate.nn.functional.block_multihead_attention( - qkv_out, - caches[2 * i], - caches[2 * i + 1], - kwargs.get("seq_lens_encoder", None), - kwargs.get("seq_lens_decoder", None), - kwargs.get("seq_lens_this_time", None), - kwargs.get("padding_offsets", None), - kwargs.get("cum_offsets", None), - kwargs.get("cu_seqlens_q", None), - kwargs.get("cu_seqlens_k", None), - kwargs.get("block_tables", None), - pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache - pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache - k_quant_scales[i] if k_quant_scales is not None else None, - v_quant_scales[i] if v_quant_scales is not None else None, - k_dequant_scales[i] if k_dequant_scales is not None else None, - v_dequant_scales[i] if v_dequant_scales is not None else None, - None, # qkv_out_scales - None, # qkv_bias - None, # out_shifts - None, # out_smooths - kwargs.get("max_enc_len_this_time", None), - kwargs.get("max_dec_len_this_time", None), - rotary_embs, - attn_mask, - kwargs.get("tgt_mask", None), # tgt_mask - kwargs.get("max_input_length", -1), - kwargs.get("block_size", 64), - self.use_neox_rotary_style, - self.config.use_dynamic_cachekv_quant, - quant_round_type=self.config.quant_round_type, - quant_max_bound=self.config.quant_max_bound, - quant_min_bound=self.config.quant_min_bound, - out_scale=self.act_scales.scale["out_linear_in_scale"][i], - rope_theta=self.config.rope_theta, - )[0] out_linear_out = self.compute_out_linear(fmha_out, i) return out_linear_out @@ -2955,7 +2811,7 @@ def compute_ffn_layernorm(self, out_linear_out, residual_input, i): begin_norm_axis=1, bias=self.linear_biases[i], residual=residual_input, - quant_scale=self.act_scales.scale["ffn1_in_scale"][i], # quant_in_scale + quant_scale=self.act_scales["ffn1_in_scale"][i], # quant_in_scale quant_round_type=1, quant_max_bound=self.config.quant_max_bound, quant_min_bound=self.config.quant_min_bound, @@ -2977,11 +2833,11 @@ def compute_ffn1(self, tmp_out, i): transpose_y=True, bias0=self.ffn1_0_biases[i], bias1=self.ffn1_1_biases[i], - scale0=self.weight_scales.scale["ffn1_0_weight_scale"][i] - / (self.act_scales.scale["ffn1_in_scale"][i] * 448 * 448), - scale1=self.weight_scales.scale["ffn1_1_weight_scale"][i] - / (self.act_scales.scale["ffn1_in_scale"][i] * 448 * 448), - scale_out=self.act_scales.scale["ffn2_in_scale"][i] * 448, + scale0=self.weight_scales["ffn1_0_weight_scale"][i] + / (self.act_scales["ffn1_in_scale"][i] * 448 * 448), + scale1=self.weight_scales["ffn1_1_weight_scale"][i] + / (self.act_scales["ffn1_in_scale"][i] * 448 * 448), + scale_out=self.act_scales["ffn2_in_scale"][i] * 448, act="swiglu", ) return res @@ -2991,8 +2847,7 @@ def compute_ffn1(self, tmp_out, i): self.ffn1_0_weights[i], transpose_x=False, transpose_y=True, - scale=self.weight_scales.scale["ffn1_0_weight_scale"][i] - / (self.act_scales.scale["ffn1_in_scale"][i] * 448 * 448), + scale=self.weight_scales["ffn1_0_weight_scale"][i] / (self.act_scales["ffn1_in_scale"][i] * 448 * 448), bias=self.ffn1_0_biases[i], output_dtype=self._dtype, act="identity", @@ -3003,8 +2858,7 @@ def compute_ffn1(self, tmp_out, i): self.ffn1_1_weights[i], transpose_x=False, transpose_y=True, - scale=self.weight_scales.scale["ffn1_1_weight_scale"][i] - / (self.act_scales.scale["ffn1_in_scale"][i] * 448 * 448), + scale=self.weight_scales["ffn1_1_weight_scale"][i] / (self.act_scales["ffn1_in_scale"][i] * 448 * 448), bias=self.ffn1_1_biases[i], output_dtype=self._dtype, act="identity", @@ -3013,9 +2867,12 @@ def compute_ffn1(self, tmp_out, i): from paddle.incubate.nn.functional import swiglu tem = swiglu(paddle.cast(tem_0, "float32"), paddle.cast(tem_1, "float32")) - res = paddle.cast(tem * self.act_scales.scale["ffn2_in_scale"][i] * 448, "float8_e4m3fn") + res = paddle.cast(tem * self.act_scales["ffn2_in_scale"][i] * 448, "float8_e4m3fn") return res + def compute_activation(self, ffn1_out, i): + return ffn1_out + def compute_ffn2(self, ffn1_out, i): """ For fake parameter @@ -3026,8 +2883,7 @@ def compute_ffn2(self, ffn1_out, i): bias=None, transpose_x=False, transpose_y=True, - scale=self.weight_scales.scale["ffn2_weight_scale"][i] - / (self.act_scales.scale["ffn2_in_scale"][i] * 448 * 448), + scale=self.weight_scales["ffn2_weight_scale"][i] / (self.act_scales["ffn2_in_scale"][i] * 448 * 448), output_dtype=self._dtype, act="identity", ) @@ -3045,7 +2901,7 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer begin_norm_axis=1, bias=self.ffn2_biases[i], residual=residual_input, - quant_scale=self.act_scales.scale["qkv_in_scale"][i + 1], # quant_in_scale + quant_scale=self.act_scales["qkv_in_scale"][i + 1], # quant_in_scale quant_round_type=1, quant_max_bound=self.config.quant_max_bound, quant_min_bound=self.config.quant_min_bound, @@ -3062,141 +2918,3 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer residual=residual_input, )[0] return tmp_out, residual_input - - def pre_process(self, **kwargs): - """ - For fake parameter - """ - pass - - def post_process(self, **kwargs): - """ - For fake parameter - """ - multi_block_output = kwargs.get("multi_block_output", None) - cum_offsets = kwargs.get("cum_offsets", None) - seq_lens_encoder = kwargs.get("seq_lens_encoder", None) - seq_lens_decoder = kwargs.get("seq_lens_decoder", None) - max_input_length = kwargs.get("max_input_length", -1) - - out = rebuild_padding_v2(multi_block_output, cum_offsets, seq_lens_decoder, seq_lens_encoder, max_input_length) - - return out - - def forward( - self, - input_ids, - src, - cum_offsets=None, - padding_offset=None, - attn_mask=None, - caches=None, - pre_caches=None, - pre_caches_length=0, - rotary_embs=None, - rotary_emb_dims=0, - seq_lens=None, - time_step=None, - **kwargs, - ): - r""" - Applies multi transformer layers on the input. - Parameters: - src (Tensor): The input of Transformer layers. It is - a tensor with shape `[batch_size, sequence_length, d_model]`. - The data type should be float16 or float32. - attn_mask (Tensor, optional): A tensor used in multi-head attention - to prevents attention to some unwanted positions, usually the - paddings or the subsequent positions. It is a tensor with shape - `[batch_size, 1, sequence_length, sequence_length]`. It can be - None when nothing wanted or needed to be prevented attention to. - Default None. - caches (list(Tensor)|tuple(Tensor), optional): The cache structure - tensors for the inference generation model. It is only used for - inference and should be None for training. The shape is - `[2, batch_size, num_head, max_seq_len, head_dim]`. Default None. - pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches - for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. - Default None. - rotary_embs (Tensor optional): The RoPE embs for the rotary computation. - The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. - rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, - and it is 0 when rotary_embs is None, - 1 when rotary_embs is not None and pos_extra_ids is None, - 2 when rotary_embs and pos_extra_ids are both not None. Default 0. - seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. - Default None. - time_step (Tensor, optional): The time step tensor for the generation - model. Which used in decode stage, to represent the time step, - that is, the real seq_len of CacheKV. The shape is `[1]`, must be - in CPUPlace. Default None. - Returns: - Tensor|tuple: If `caches` is None, return a tensor that has - the same shape and data type with `src`, representing the output - of Transformer layers. If `caches` is not None, return the - tuple (output, caches), which output is the output of - Transformer layers, caches is inplace with input `caches`. - """ - self.pre_process(**kwargs) - kwargs["cum_offsets"] = cum_offsets - - if caches is not None: - assert len(caches) == len(self.qkv_weights) or len(caches) == 2 * len(self.qkv_weights) - - assert self.num_layers == len(self.qkv_weights) - - max_enc_len_this_time, max_dec_len_this_time = self.compute_max_len( - kwargs.get("seq_lens_encoder", None), kwargs.get("seq_lens_decoder", None), cum_offsets - ) - kwargs["max_enc_len_this_time"] = max_enc_len_this_time - kwargs["max_dec_len_this_time"] = max_dec_len_this_time - - residual_input = src - for i in range(self.num_layers): - qkv_out, residual_input = self.compute_qkv(src, residual_input, i) - - out_linear_out = self.compute_attn( - time_step, - qkv_out, - padding_offset, - seq_lens, - input_ids, - rotary_embs, - rotary_emb_dims, - caches, - pre_caches, - pre_caches_length, - attn_mask, - i, - **kwargs, - ) - # all_reduce - if self.nranks > 1: - dist.all_reduce(out_linear_out) - - # ffn layernorm - tmp_out, residual_input = self.compute_ffn_layernorm(out_linear_out, residual_input, i) - - # ffn1 matmul - ffn1_out = self.compute_ffn1(tmp_out, i) - - # ffn2 matmul - ffn2_out = self.compute_ffn2(ffn1_out, i) - - # all_reduce - if self.nranks > 1: - dist.all_reduce(ffn2_out) - - # norm + residual_add_bias - tmp_out, residual_input = self.compute_bias_residual_layernorm( - ffn2_out, residual_input, i, self.num_layers - ) - src = tmp_out - - kwargs["time_step"] = time_step - kwargs["multi_block_output"] = tmp_out - kwargs["seq_lens"] = seq_lens - kwargs["input_ids"] = input_ids - - out = self.post_process(**kwargs) - return out, caches diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index 8f163ea3e707..0c2dbd322994 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -13,19 +13,13 @@ # limitations under the License. from __future__ import annotations -import inspect import json -import logging import os from functools import partial import numpy as np import paddle from paddle import nn -from paddle.base import core -from paddle.base.executor import Executor, global_scope -from paddle.base.framework import _current_expected_place as _get_device -from paddle.base.framework import in_dygraph_mode from paddle.distributed import fleet from paddle.nn.quant import weight_quantize @@ -398,6 +392,7 @@ def __init__(self, config: LlamaConfig): self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads + self.head_size = self.hidden_size // self.num_attention_heads self.intermediate_size = config.intermediate_size self.num_layers = config.num_hidden_layers self.epsilon = config.rms_norm_eps @@ -453,6 +448,35 @@ def __init__(self, config: LlamaConfig): except: pass + qkv_weight_scale_attrs = None + out_proj_weight_scale_attrs = None + ffn1_weight_scale_attrs = None + ffn2_weight_scale_attrs = None + + qkv_out_scale_attrs = None + linear_out_scale_attrs = None + ffn1_out_scale_attrs = None + ffn2_out_scale_attrs = None + linear_shift_attrs = None + linear_smooth_attrs = None + ffn2_shift_attrs = None + ffn2_smooth_attrs = None + + ln_bias_attrs = None + qkv_bias_attrs = None + out_proj_bias_attrs = None + ffn_ln_bias_attrs = None + ffn1_bias_attrs = None + ffn2_bias_attrs = None + + ffn1_0_weight_attrs = None + ffn1_1_weight_attrs = None + ffn1_0_bias_attrs = None + ffn1_1_bias_attrs = None + + ffn1_weight_attrs = None + ffn2_weight_attrs = None + ln_scale_attrs = [paddle.ParamAttr(name="fusellama.{}.ln_scale".format(i)) for i in range(self.num_layers)] qkv_weight_attrs = [ paddle.ParamAttr( @@ -482,8 +506,6 @@ def __init__(self, config: LlamaConfig): ) for i in range(self.num_layers) ] - ffn1_0_bias_attrs = None - ffn1_1_bias_attrs = None else: ffn1_weight_attrs = [ paddle.ParamAttr( @@ -498,21 +520,6 @@ def __init__(self, config: LlamaConfig): for i in range(self.num_layers) ] - qkv_out_scale_attrs = None - linear_out_scale_attrs = None - ffn1_out_scale_attrs = None - ffn2_out_scale_attrs = None - linear_shift_attrs = None - linear_smooth_attrs = None - ffn2_shift_attrs = None - ffn2_smooth_attrs = None - ln_bias_attrs = None - qkv_bias_attrs = None - out_proj_bias_attrs = None - ffn_ln_bias_attrs = None - ffn1_bias_attrs = None - ffn2_bias_attrs = None - if "a8w8" in self.quant_type: qkv_out_scale_attrs = [ paddle.ParamAttr(name="fusellama.{}.qkv_out_scale".format(i)) for i in range(self.num_layers) @@ -600,91 +607,56 @@ def __init__(self, config: LlamaConfig): paddle.ParamAttr(name="fusellama.{}.cache_v_out_scale".format(i)) for i in range(self.num_layers) ] - if "fp8" in self.quant_type: - transformer_config = FusedMultiTransformerConfig( - embed_dim=self.hidden_size, - num_heads=self.num_attention_heads, - kv_num_heads=self.num_key_value_heads, - dim_feedforward=self.intermediate_size, - quant_type=self.quant_type, - activation="swiglu", - num_layers=config.num_hidden_layers, - nranks=config.tensor_parallel_degree, - ring_id=ring_id, - ln_scale_attrs=ln_scale_attrs, - ln_bias_attrs=ln_bias_attrs, - qkv_weight_attrs=qkv_weight_attrs, - qkv_bias_attrs=qkv_bias_attrs, - linear_weight_attrs=out_proj_weight_attrs, - linear_bias_attrs=out_proj_bias_attrs, - ffn_ln_scale_attrs=ffn_ln_scale_attrs, - ffn_ln_bias_attrs=ffn_ln_bias_attrs, - cache_k_scale_attrs=cache_k_scale_attrs, - cache_v_scale_attrs=cache_v_scale_attrs, - cache_k_out_scale_attrs=cache_k_out_scale_attrs, - cache_v_out_scale_attrs=cache_v_out_scale_attrs, - ffn1_0_weight_attrs=ffn1_0_weight_attrs, - ffn1_1_weight_attrs=ffn1_1_weight_attrs, - ffn1_0_bias_attrs=ffn1_0_bias_attrs, - ffn1_1_bias_attrs=ffn1_1_bias_attrs, - ffn2_weight_attrs=ffn2_weight_attrs, - ffn2_bias_attrs=ffn2_bias_attrs, - epsilon=self.epsilon, - rope_theta=self.rope_theta, - norm_type="rmsnorm", - use_neox_rotary_style=self.use_neox, - rank_id=config.tensor_parallel_rank, - append_attn=config.append_attn, - ) - - else: - transformer_config = FusedMultiTransformerConfig( - embed_dim=self.hidden_size, - num_heads=self.num_attention_heads, - kv_num_heads=self.num_key_value_heads, - dim_feedforward=self.intermediate_size, - quant_type=self.quant_type, - activation="swiglu", - num_layers=config.num_hidden_layers, - nranks=config.tensor_parallel_degree, - ring_id=ring_id, - ln_scale_attrs=ln_scale_attrs, - qkv_weight_attrs=qkv_weight_attrs, - qkv_weight_scale_attrs=qkv_weight_scale_attrs, - linear_weight_attrs=out_proj_weight_attrs, - linear_weight_scale_attrs=out_proj_weight_scale_attrs, - ffn_ln_scale_attrs=ffn_ln_scale_attrs, - ffn1_weight_attrs=ffn1_weight_attrs, - ffn1_weight_scale_attrs=ffn1_weight_scale_attrs, - ffn2_weight_attrs=ffn2_weight_attrs, - ffn2_weight_scale_attrs=ffn2_weight_scale_attrs, - qkv_out_scale_attrs=qkv_out_scale_attrs, - linear_out_scale_attrs=linear_out_scale_attrs, - ffn1_out_scale_attrs=ffn1_out_scale_attrs, - ffn2_out_scale_attrs=ffn2_out_scale_attrs, - linear_shift_attrs=linear_shift_attrs, - linear_smooth_attrs=linear_smooth_attrs, - ffn2_shift_attrs=ffn2_shift_attrs, - ffn2_smooth_attrs=ffn2_smooth_attrs, - ln_bias_attrs=ln_bias_attrs, - qkv_bias_attrs=qkv_bias_attrs, - linear_bias_attrs=out_proj_bias_attrs, - ffn_ln_bias_attrs=ffn_ln_bias_attrs, - ffn1_bias_attrs=ffn1_bias_attrs, - ffn2_bias_attrs=ffn2_bias_attrs, - cache_k_scale_attrs=cache_k_scale_attrs, - cache_v_scale_attrs=cache_v_scale_attrs, - cache_k_out_scale_attrs=cache_k_out_scale_attrs, - cache_v_out_scale_attrs=cache_v_out_scale_attrs, - epsilon=self.epsilon, - rope_theta=self.rope_theta, - norm_type="rmsnorm", - use_neox_rotary_style=self.use_neox, - cachekv_int8_type=config.cachekv_int8_type, - rank_id=config.tensor_parallel_rank, - trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True), - append_attn=config.append_attn, - ) + transformer_config = FusedMultiTransformerConfig( + embed_dim=self.hidden_size, + num_heads=self.num_attention_heads, + kv_num_heads=self.num_key_value_heads, + dim_feedforward=self.intermediate_size, + quant_type=self.quant_type, + activation="swiglu", + num_layers=config.num_hidden_layers, + nranks=config.tensor_parallel_degree, + ring_id=ring_id, + ln_scale_attrs=ln_scale_attrs, + qkv_weight_attrs=qkv_weight_attrs, + qkv_weight_scale_attrs=qkv_weight_scale_attrs, + linear_weight_attrs=out_proj_weight_attrs, + linear_weight_scale_attrs=out_proj_weight_scale_attrs, + ffn_ln_scale_attrs=ffn_ln_scale_attrs, + ffn1_weight_attrs=ffn1_weight_attrs, + ffn1_weight_scale_attrs=ffn1_weight_scale_attrs, + ffn1_0_weight_attrs=ffn1_0_weight_attrs, + ffn1_1_weight_attrs=ffn1_1_weight_attrs, + ffn2_weight_attrs=ffn2_weight_attrs, + ffn2_weight_scale_attrs=ffn2_weight_scale_attrs, + qkv_out_scale_attrs=qkv_out_scale_attrs, + linear_out_scale_attrs=linear_out_scale_attrs, + ffn1_out_scale_attrs=ffn1_out_scale_attrs, + ffn2_out_scale_attrs=ffn2_out_scale_attrs, + linear_shift_attrs=linear_shift_attrs, + linear_smooth_attrs=linear_smooth_attrs, + ffn2_shift_attrs=ffn2_shift_attrs, + ffn2_smooth_attrs=ffn2_smooth_attrs, + ln_bias_attrs=ln_bias_attrs, + qkv_bias_attrs=qkv_bias_attrs, + linear_bias_attrs=out_proj_bias_attrs, + ffn_ln_bias_attrs=ffn_ln_bias_attrs, + ffn1_bias_attrs=ffn1_bias_attrs, + ffn1_0_bias_attrs=ffn1_0_bias_attrs, + ffn1_1_bias_attrs=ffn1_1_bias_attrs, + ffn2_bias_attrs=ffn2_bias_attrs, + cache_k_scale_attrs=cache_k_scale_attrs, + cache_v_scale_attrs=cache_v_scale_attrs, + cache_k_out_scale_attrs=cache_k_out_scale_attrs, + cache_v_out_scale_attrs=cache_v_out_scale_attrs, + epsilon=self.epsilon, + norm_type="rmsnorm", + use_neox_rotary_style=self.use_neox, + cachekv_int8_type=config.cachekv_int8_type, + rank_id=config.tensor_parallel_rank, + trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True), + append_attn=config.append_attn, + ) self.set_transformer_block(transformer_config) self.norm = FusedLlamaRMSNorm(config) @@ -835,21 +807,52 @@ def forward( ) @paddle.no_grad() - def set_state_dict(self, state_dict): - if "a8w8" in self.quant_type: - current_work_dir = os.path.dirname(__file__) + def set_quant_scale(self): + current_work_dir = os.path.dirname(__file__) + if "fp8" in self.quant_type: + scale_map_file = f"{current_work_dir}/ptq_fp8_scales_map.json" + with open(scale_map_file) as json_file: + scale_map_dict = json.load(json_file) + act_scale_map_dict = scale_map_dict["act_scale"] + weight_scale_map_dict = scale_map_dict["weight_scale"] + cache_scale_map_dict = scale_map_dict["cachekv_scale"] + act_scale_json_path = resolve_file_path(self.quant_model_path, "act_scales.json") + weight_scale_json_path = resolve_file_path(self.quant_model_path, "weight_scales.json") + if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: + act_scale_json_path = resolve_file_path( + self.quant_model_path, f"act_scales_{self.config.tensor_parallel_rank}.json" + ) + weight_scale_json_path = resolve_file_path( + self.quant_model_path, f"weight_scales_{self.config.tensor_parallel_rank}.json" + ) + + act_scales = ActScalesLoader( + act_scale_json_path, act_scale_map_dict, num_of_layers=self.config.num_hidden_layers + ) + + weight_scales = PerTensorWeightScalesLoader( + weight_scale_json_path, + weight_scale_map_dict, + num_of_layers=self.config.num_hidden_layers, + ) + + for weight_name in weight_scales.scale: + weight_scales.scale[weight_name] = weight_scales.scale[weight_name].astype(np.float32) + for act_name in act_scales.scale: + act_scales.scale[act_name] = act_scales.scale[act_name].astype(np.float32) + self.transformer_block.weight_scales = weight_scales.scale + self.transformer_block.act_scales = act_scales.scale + elif "a8w8" in self.quant_type: scale_map_file = ( f"{current_work_dir}/ptq_scales_map.json" if not self.shift_smooth_all_linears else f"{current_work_dir}/ptq_scales_map_shift_smooth.json" ) - with open(scale_map_file) as json_file: scale_map_dict = json.load(json_file) act_scale_map_dict = scale_map_dict["act_scale"] weight_scale_map_dict = scale_map_dict["weight_scale"] cache_scale_map_dict = scale_map_dict["cachekv_scale"] - if not self.use_fake_parameter: act_scale_json_path = resolve_file_path(self.quant_model_path, "act_scales.json") weight_scale_json_path = resolve_file_path(self.quant_model_path, "weight_scales.json") @@ -886,11 +889,102 @@ def set_state_dict(self, state_dict): self.transformer_block.weight_scales = weight_scales_loader.scale self.transformer_block.act_scales = act_scale_loader.scale + for k, v in weight_scales_loader.scale.items(): + if "qkv_" in k: + for i_layer, weight_scale in enumerate(v): + if not np.all(weight_scale == -1): + tmp = paddle.to_tensor( + weight_scale + / ( + 127.0 * 127.0 * act_scale_loader.scale["qkv_in_scale"][i_layer] + ) # [3 * num_head * dim_head] + ).reshape([-1]) + if self.config.tensor_parallel_degree > 1 and self.config.single_card_ptq: + tmp = ( + tmp.reshape([3, self.num_attention_heads, self.head_size]) + .split(self.config.tensor_parallel_degree, axis=1)[ + self.config.tensor_parallel_rank + ] + .reshape([-1]) + ) + self.transformer_block.qkv_out_scales[i_layer].set_value(tmp) + elif "out_linear_" in k: + for i_layer, weight_scale in enumerate(v): + if not np.all(weight_scale == -1): + tmp = paddle.to_tensor( + weight_scale / (127.0 * 127.0 * act_scale_loader.scale["out_linear_in_scale"][i_layer]) + ) + self.transformer_block.linear_out_scales[i_layer].set_value(tmp) + elif "ffn1_weight_scale" in k: + for i_layer, weight_scale in enumerate(v): + if not np.all(weight_scale == -1): + tmp = paddle.to_tensor( + weight_scale / (127.0 * 127.0 * act_scale_loader.scale["ffn1_in_scale"][i_layer]) + ) + if self.config.tensor_parallel_degree > 1 and self.config.single_card_ptq: + tmp = paddle.split(tmp, self.config.tensor_parallel_degree * 2) + tmp = paddle.concat( + [ + tmp[self.config.tensor_parallel_rank], + tmp[self.config.tensor_parallel_rank + self.config.tensor_parallel_degree], + ], + axis=0, + ) + self.transformer_block.ffn1_out_scales[i_layer].set_value(tmp) + elif "ffn2" in k: + for i_layer, weight_scale in enumerate(v): + if not np.all(weight_scale == -1): + self.transformer_block.ffn2_out_scales[i_layer].set_value( + paddle.to_tensor( + weight_scale / (127.0 * 127.0 * act_scale_loader.scale["ffn2_in_scale"][i_layer]) + ) + ) + + if self.config.cachekv_int8_type == "static": + if not self.use_fake_parameter: + cache_scale_json_path = resolve_file_path(self.quant_model_path, "cachekv_scales.json") + if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: + cache_scale_json_path = resolve_file_path( + self.quant_model_path, f"cachekv_scales_{self.config.tensor_parallel_rank}.json" + ) + cache_scales_loader = CacheScaleLoader( + cache_scale_json_path, + cache_scale_map_dict, + num_of_layers=self.config.num_hidden_layers, + num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, + num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, + ) + else: + cache_scales_loader = EmptyCacheScale( + cache_scale_map_dict, + num_of_layers=self.config.num_hidden_layers, + num_heads=self.num_attention_heads, + dim_heads=self.hidden_size // self.num_attention_heads, + is_channel_wise=False, + num_key_value_heads=self.num_key_value_heads, + mp_size=self.config.tensor_parallel_degree, + ) + + for k, v in cache_scales_loader.scale.items(): + for i_layer, weight_scale in enumerate(v): + if self.config.append_attn: + weight_scale = paddle.to_tensor(weight_scale).cast(paddle.get_default_dtype()) + else: + weight_scale = weight_scale.astype("float32") + if k == "cache_k_scale": + self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) + elif k == "cache_v_scale": + self.transformer_block.cache_v_scales[i_layer].set_value(weight_scale) + elif k == "cache_k_out_scale": + self.transformer_block.cache_k_out_scales[i_layer].set_value(weight_scale) + else: + self.transformer_block.cache_v_out_scales[i_layer].set_value(weight_scale) + + @paddle.no_grad() + def set_state_dict(self, state_dict): + self.set_quant_scale() self.transformer_block.init_weight() - unfused_state_dict = {} - head_size = self.hidden_size // self.num_attention_heads split_fn = split_param_func() - self.embed_tokens.weight.set_value( paddle.to_tensor(state_dict["llama.embed_tokens.weight"]).cast(self.embed_tokens.weight.dtype) ) @@ -900,29 +994,50 @@ def set_state_dict(self, state_dict): for idx in range(self.config.num_hidden_layers): logger.info(f"set state for layer {idx}") + self.transformer_block.ln_scales[idx].set_value( + paddle.to_tensor(state_dict["llama.layers.{}.input_layernorm.weight".format(idx)]).cast( + self.transformer_block.ln_scales[idx].dtype + ) + ) if "llama.layers.{}.self_attn.qkv_proj.weight".format(idx) in state_dict.keys(): - concated_qkv_weight = np.concatenate( - split_fn( - state_dict["llama.layers.{}.self_attn.qkv_proj.weight".format(idx)], - is_qkv=True, - num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, - num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, - ), - axis=-1, - ).transpose(1, 0) + concated_qkv_weight = paddle.to_tensor( + np.concatenate( + split_fn( + state_dict["llama.layers.{}.self_attn.qkv_proj.weight".format(idx)], + is_qkv=True, + num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, + num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, + ), + axis=-1, + ).transpose(1, 0) + ) else: unfused_state_dict = {} - unfused_state_dict["self_attn.q_proj.weight"] = state_dict[ - "llama.layers.{}.self_attn.q_proj.weight".format(idx) - ] - unfused_state_dict["self_attn.k_proj.weight"] = state_dict[ - "llama.layers.{}.self_attn.k_proj.weight".format(idx) - ] - unfused_state_dict["self_attn.v_proj.weight"] = state_dict[ - "llama.layers.{}.self_attn.v_proj.weight".format(idx) - ] - if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type: - concated_qkv_weight = np.concatenate( + unfused_state_dict["self_attn.q_proj.weight"] = paddle.to_tensor( + state_dict["llama.layers.{}.self_attn.q_proj.weight".format(idx)] + ) + unfused_state_dict["self_attn.k_proj.weight"] = paddle.to_tensor( + state_dict["llama.layers.{}.self_attn.k_proj.weight".format(idx)] + ) + unfused_state_dict["self_attn.v_proj.weight"] = paddle.to_tensor( + state_dict["llama.layers.{}.self_attn.v_proj.weight".format(idx)] + ) + if "fp8" in self.quant_type: + q_wgt_scale = self.transformer_block.weight_scales["q_weight_scale"][idx] + k_wgt_scale = self.transformer_block.weight_scales["k_weight_scale"][idx] + v_wgt_scale = self.transformer_block.weight_scales["v_weight_scale"][idx] + qkv_wgt_scale = self.transformer_block.weight_scales["qkv_weight_scale"][idx] + unfused_state_dict["self_attn.q_proj.weight"] = ( + unfused_state_dict["self_attn.q_proj.weight"].cast("float32") * q_wgt_scale / qkv_wgt_scale + ) + unfused_state_dict["self_attn.k_proj.weight"] = ( + unfused_state_dict["self_attn.k_proj.weight"].cast("float32") * k_wgt_scale / qkv_wgt_scale + ) + unfused_state_dict["self_attn.v_proj.weight"] = ( + unfused_state_dict["self_attn.v_proj.weight"].cast("float32") * v_wgt_scale / qkv_wgt_scale + ) + if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type and "fp8" not in self.quant_type: + concated_qkv_weight = paddle.concat( [ unfused_state_dict["self_attn.q_proj.weight"], unfused_state_dict["self_attn.k_proj.weight"], @@ -930,16 +1045,18 @@ def set_state_dict(self, state_dict): ], axis=-1, ).reshape( - self.hidden_size, - ( - self.num_attention_heads // self.config.tensor_parallel_degree - + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree - ) - * (head_size), + [ + self.hidden_size, + ( + self.num_attention_heads // self.config.tensor_parallel_degree + + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree + ) + * (self.head_size), + ] ) else: concated_qkv_weight = ( - np.concatenate( + paddle.concat( [ unfused_state_dict["self_attn.q_proj.weight"], unfused_state_dict["self_attn.k_proj.weight"], @@ -947,29 +1064,18 @@ def set_state_dict(self, state_dict): ], axis=-1, ) - .transpose(1, 0) + .transpose([1, 0]) .reshape( - ( - self.num_attention_heads // self.config.tensor_parallel_degree - + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree - ) - * (head_size), - self.hidden_size, + [ + ( + self.num_attention_heads // self.config.tensor_parallel_degree + + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree + ) + * (self.head_size), + self.hidden_size, + ] ) ) - if "llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys(): - concated_ffn1_weight = np.concatenate( - split_fn(state_dict["llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]), axis=-1 - ) - else: - unfused_state_dict["mlp.gate_proj.weight"] = state_dict[ - "llama.layers.{}.mlp.gate_proj.weight".format(idx) - ] - unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)] - concated_ffn1_weight = np.concatenate( - [unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1 - ) - qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight).cast(paddle.get_default_dtype()) if self.use_weight_only: qkv_weight_tensor = paddle.transpose(qkv_weight_tensor, perm=[1, 0]) @@ -978,6 +1084,8 @@ def set_state_dict(self, state_dict): ) self.transformer_block.qkv_weights[idx].set_value(qkv_quanted_weight_tensor) self.transformer_block.qkv_weights_scale[idx].set_value(qkv_weight_scale_tensor) + elif "fp8" in self.quant_type: + self.transformer_block.qkv_weights[idx].copy_(paddle.cast(concated_qkv_weight, "float8_e4m3fn"), False) elif "a8w8" in self.quant_type and not self.transformer_block.skip_quant("qkv_weight_scale", idx): self.transformer_block.qkv_weights[idx].set_value( paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8") @@ -994,6 +1102,16 @@ def set_state_dict(self, state_dict): ) self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight_tensor) self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale_tensor) + elif "fp8" in self.quant_type: + self.transformer_block.linear_weights[idx].copy_( + paddle.cast( + paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)]).transpose( + (1, 0) + ), + "float8_e4m3fn", + ), + False, + ) elif "a8w8" in self.quant_type: w_dtype = ( paddle.get_default_dtype() @@ -1019,13 +1137,43 @@ def set_state_dict(self, state_dict): else: self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor) + self.transformer_block.ffn_ln_scales[idx].set_value( + paddle.to_tensor(state_dict["llama.layers.{}.post_attention_layernorm.weight".format(idx)]).cast( + self.transformer_block.ffn_ln_scales[idx].dtype + ) + ) + + if "llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys(): + concated_ffn1_weight = np.concatenate( + split_fn(state_dict["llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]), axis=-1 + ) + else: + unfused_state_dict["mlp.gate_proj.weight"] = state_dict[ + "llama.layers.{}.mlp.gate_proj.weight".format(idx) + ] + unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)] + concated_ffn1_weight = np.concatenate( + [unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1 + ) ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight).cast(paddle.get_default_dtype()) + if self.use_weight_only: ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize( ffn1_weight_tensor, algo=self.quant_algo ) self.transformer_block.ffn1_weights[idx].set_value(ffn1_quanted_weight_tensor) self.transformer_block.ffn1_weights_scale[idx].set_value(ffn1_weight_scale_tensor) + elif "fp8" in self.quant_type: + self.transformer_block.ffn1_0_weights[idx].copy_( + paddle.to_tensor(unfused_state_dict["mlp.gate_proj.weight"]) + .transpose((1, 0)) + .cast("float8_e4m3fn"), + False, + ) + self.transformer_block.ffn1_1_weights[idx].copy_( + paddle.to_tensor(unfused_state_dict["mlp.up_proj.weight"]).transpose((1, 0)).cast("float8_e4m3fn"), + False, + ) elif "a8w8" in self.quant_type: w_dtype = ( paddle.get_default_dtype() @@ -1052,6 +1200,13 @@ def set_state_dict(self, state_dict): ) self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight_tensor) self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale_tensor) + elif "fp8" in self.quant_type: + self.transformer_block.ffn2_weights[idx].copy_( + paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)]) + .transpose([1, 0]) + .cast("float8_e4m3fn"), + False, + ) elif "a8w8" in self.quant_type: w_dtype = ( paddle.get_default_dtype() @@ -1076,7 +1231,7 @@ def set_state_dict(self, state_dict): else: self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor) - if "a8w8" in self.quant_type: + if "fp8" not in self.quant_type and "a8w8" in self.quant_type: if self.shift_smooth_all_linears: if self.use_fake_parameter: if "llama.layers.{}.self_attn.o_proj.shift_bias".format(idx) not in state_dict: @@ -1207,391 +1362,6 @@ def set_state_dict(self, state_dict): paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.layer.bias".format(idx)]) ) - self.transformer_block.ln_scales[idx].set_value( - paddle.to_tensor(state_dict["llama.layers.{}.input_layernorm.weight".format(idx)]).cast( - self.transformer_block.ln_scales[idx].dtype - ) - ) - - self.transformer_block.ffn_ln_scales[idx].set_value( - paddle.to_tensor(state_dict["llama.layers.{}.post_attention_layernorm.weight".format(idx)]).cast( - self.transformer_block.ffn_ln_scales[idx].dtype - ) - ) - - if "a8w8" in self.quant_type: - if self.config.cachekv_int8_type == "static": - if not self.use_fake_parameter: - cache_scale_json_path = resolve_file_path(self.quant_model_path, "cachekv_scales.json") - if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: - cache_scale_json_path = resolve_file_path( - self.quant_model_path, f"cachekv_scales_{self.config.tensor_parallel_rank}.json" - ) - cache_scales_loader = CacheScaleLoader( - cache_scale_json_path, - cache_scale_map_dict, - num_of_layers=self.config.num_hidden_layers, - num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, - num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, - ) - else: - cache_scales_loader = EmptyCacheScale( - cache_scale_map_dict, - num_of_layers=self.config.num_hidden_layers, - num_heads=self.num_attention_heads, - dim_heads=self.hidden_size // self.num_attention_heads, - is_channel_wise=False, - num_key_value_heads=self.num_key_value_heads, - mp_size=self.config.tensor_parallel_degree, - ) - - for k, v in cache_scales_loader.scale.items(): - for i_layer, weight_scale in enumerate(v): - if self.config.append_attn: - weight_scale = paddle.to_tensor(weight_scale).cast(paddle.get_default_dtype()) - else: - weight_scale = weight_scale.astype("float32") - if k == "cache_k_scale": - self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) - elif k == "cache_v_scale": - self.transformer_block.cache_v_scales[i_layer].set_value(weight_scale) - elif k == "cache_k_out_scale": - self.transformer_block.cache_k_out_scales[i_layer].set_value(weight_scale) - else: - self.transformer_block.cache_v_out_scales[i_layer].set_value(weight_scale) - - for k, v in weight_scales_loader.scale.items(): - if "qkv_" in k: - for i_layer, weight_scale in enumerate(v): - if not np.all(weight_scale == -1): - tmp = paddle.to_tensor( - weight_scale - / ( - 127.0 * 127.0 * act_scale_loader.scale["qkv_in_scale"][i_layer] - ) # [3 * num_head * dim_head] - ).reshape([-1]) - - if self.config.tensor_parallel_degree > 1 and self.config.single_card_ptq: - tmp = ( - tmp.reshape([3, self.num_attention_heads, head_size]) - .split(self.config.tensor_parallel_degree, axis=1)[ - self.config.tensor_parallel_rank - ] - .reshape([-1]) - ) - self.transformer_block.qkv_out_scales[i_layer].set_value(tmp) - pass - elif "out_linear_" in k: - for i_layer, weight_scale in enumerate(v): - if not np.all(weight_scale == -1): - tmp = paddle.to_tensor( - weight_scale / (127.0 * 127.0 * act_scale_loader.scale["out_linear_in_scale"][i_layer]) - ) - self.transformer_block.linear_out_scales[i_layer].set_value(tmp) - elif "ffn1_weight_scale" in k: - for i_layer, weight_scale in enumerate(v): - if not np.all(weight_scale == -1): - tmp = paddle.to_tensor( - weight_scale / (127.0 * 127.0 * act_scale_loader.scale["ffn1_in_scale"][i_layer]) - ) - if self.config.tensor_parallel_degree > 1 and self.config.single_card_ptq: - tmp = paddle.split(tmp, self.config.tensor_parallel_degree * 2) - tmp = paddle.concat( - [ - tmp[self.config.tensor_parallel_rank], - tmp[self.config.tensor_parallel_rank + self.config.tensor_parallel_degree], - ], - axis=0, - ) - self.transformer_block.ffn1_out_scales[i_layer].set_value(tmp) - elif "ffn2" in k: - for i_layer, weight_scale in enumerate(v): - if not np.all(weight_scale == -1): - self.transformer_block.ffn2_out_scales[i_layer].set_value( - paddle.to_tensor( - weight_scale / (127.0 * 127.0 * act_scale_loader.scale["ffn2_in_scale"][i_layer]) - ) - ) - - def set_state_dict_fp8(self, state_dict: dict[str, np.ndarray | paddle.Tensor], use_structured_name=True): - """transpose qkv shape & cast dtype for layernorm - - Args: - state_dict (dict[str, np.ndarray | paddle.Tensor]): the state dict of model - use_structured_name (bool, optional): _description_. Defaults to True. - """ - current_work_dir = os.path.dirname(__file__) - scale_map_file = f"{current_work_dir}/ptq_fp8_scales_map.json" - with open(scale_map_file) as json_file: - scale_map_dict = json.load(json_file) - act_scale_map_dict = scale_map_dict["act_scale"] - weight_scale_map_dict = scale_map_dict["weight_scale"] - cache_scale_map_dict = scale_map_dict["cachekv_scale"] - act_scale_json_path = resolve_file_path(self.quant_model_path, "act_scales.json") - weight_scale_json_path = resolve_file_path(self.quant_model_path, "weight_scales.json") - if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: - act_scale_json_path = resolve_file_path( - self.quant_model_path, f"act_scales_{self.config.tensor_parallel_rank}.json" - ) - weight_scale_json_path = resolve_file_path( - self.quant_model_path, f"weight_scales_{self.config.tensor_parallel_rank}.json" - ) - - act_scales = ActScalesLoader( - act_scale_json_path, act_scale_map_dict, num_of_layers=self.config.num_hidden_layers - ) - - weight_scales = PerTensorWeightScalesLoader( - weight_scale_json_path, - weight_scale_map_dict, - num_of_layers=self.config.num_hidden_layers, - ) - - for weight_name in weight_scales.scale: - weight_scales.scale[weight_name] = weight_scales.scale[weight_name].astype(np.float32) - for act_name in act_scales.scale: - act_scales.scale[act_name] = act_scales.scale[act_name].astype(np.float32) - self.transformer_block.act_scales = act_scales - self.transformer_block.weight_scales = weight_scales - - if self.config.cachekv_int8_type == "static": - cache_scale_json_path = resolve_file_path(self.quant_model_path, "cachekv_scales.json") - if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: - cache_scale_json_path = resolve_file_path( - self.quant_model_path, f"cachekv_scales_{self.config.tensor_parallel_rank}.json" - ) - cache_scales_loader = CacheScaleLoader( - cache_scale_json_path, - cache_scale_map_dict, - num_of_layers=self.config.num_hidden_layers, - num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, - num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, - ) - for k, v in cache_scales_loader.scale.items(): - for i_layer, weight_scale in enumerate(v): - if self.config.append_attn: - weight_scale = paddle.to_tensor(weight_scale).cast(paddle.get_default_dtype()) - else: - weight_scale = weight_scale.astype("float32") - if k == "cache_k_scale": - self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) - elif k == "cache_v_scale": - self.transformer_block.cache_v_scales[i_layer].set_value(weight_scale) - elif k == "cache_k_out_scale": - self.transformer_block.cache_k_out_scales[i_layer].set_value(weight_scale) - else: - self.transformer_block.cache_v_out_scales[i_layer].set_value(weight_scale) - unfused_state_dict = {} - head_size = self.hidden_size // self.num_attention_heads - split_fn = split_param_func() - - self.embed_tokens.weight.set_value( - paddle.to_tensor(state_dict["llama.embed_tokens.weight"]).cast(self.embed_tokens.weight.dtype) - ) - self.norm.weight.set_value(paddle.to_tensor(state_dict["llama.norm.weight"]).cast(self.norm.weight.dtype)) - - for key in state_dict.keys(): - state_dict[key] = paddle.to_tensor(state_dict[key]) - - for key in list(state_dict.keys()): - if "llama.layers" in key: - state_dict[key.replace("llama.layers", "transformer_block.fusellama")] = state_dict.pop(key) - - for idx in range(self.config.num_hidden_layers): - if "transformer_block.fusellama.{}.self_attn.qkv_proj.weight".format(idx) in list(state_dict.keys()): - concated_qkv_weight = paddle.concat( - split_fn( - state_dict["transformer_block.fusellama.{}.self_attn.qkv_proj.weight".format(idx)], - is_qkv=True, - num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, - num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, - ), - axis=-1, - ).transpose([1, 0]) - else: - unfused_state_dict = {} - q_wgt_scale = self.transformer_block.weight_scales.scale["q_weight_scale"][idx] - k_wgt_scale = self.transformer_block.weight_scales.scale["k_weight_scale"][idx] - v_wgt_scale = self.transformer_block.weight_scales.scale["v_weight_scale"][idx] - qkv_wgt_scale = self.transformer_block.weight_scales.scale["qkv_weight_scale"][idx] - unfused_state_dict["self_attn.q_proj.weight"] = ( - state_dict["transformer_block.fusellama.{}.self_attn.q_proj.weight".format(idx)].cast("float32") - * q_wgt_scale - / qkv_wgt_scale - ) - unfused_state_dict["self_attn.k_proj.weight"] = ( - state_dict["transformer_block.fusellama.{}.self_attn.k_proj.weight".format(idx)].cast("float32") - * k_wgt_scale - / qkv_wgt_scale - ) - unfused_state_dict["self_attn.v_proj.weight"] = ( - state_dict["transformer_block.fusellama.{}.self_attn.v_proj.weight".format(idx)].cast("float32") - * v_wgt_scale - / qkv_wgt_scale - ) - concated_qkv_weight = ( - paddle.concat( - [ - unfused_state_dict["self_attn.q_proj.weight"], - unfused_state_dict["self_attn.k_proj.weight"], - unfused_state_dict["self_attn.v_proj.weight"], - ], - axis=-1, - ) - .transpose([1, 0]) - .reshape( - [ - ( - self.num_attention_heads // self.config.tensor_parallel_degree - + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree - ) - * (head_size), - self.hidden_size, - ] - ) - ) - state_dict[ - "transformer_block.fusellama.{}.self_attn.qkv_proj.weight".format(idx) - ] = concated_qkv_weight - - for key in list(state_dict.keys()): - if key.endswith(".input_layernorm.weight"): - state_dict[key.replace(".input_layernorm.weight", ".ln_scale")] = state_dict.pop(key).cast( - self.transformer_block.ln_scales[idx].dtype - ) - elif key.endswith(".post_attention_layernorm.weight"): - state_dict[key.replace(".post_attention_layernorm.weight", ".ffn_ln_scale")] = state_dict.pop( - key - ).cast(self.transformer_block.ffn_ln_scales[idx].dtype) - elif key.endswith(".self_attn.qkv_proj.weight"): - state_dict[key.replace(".self_attn.qkv_proj.weight", ".qkv_weight")] = state_dict.pop(key).cast( - "float8_e4m3fn" - ) - elif key.endswith(".self_attn.qkv_proj.bias"): - state_dict[key.replace(".self_attn.qkv_proj.bias", ".qkv_bias")] = state_dict.pop(key).cast( - self.transformer_block.qkv_biases[idx].dtype - ) - elif key.endswith(".self_attn.o_proj.weight"): - state_dict[key.replace(".self_attn.o_proj.weight", ".out_proj_weight")] = ( - state_dict.pop(key).transpose([1, 0]).cast("float8_e4m3fn") - ) - elif key.endswith(".mlp.gate_proj.weight"): - state_dict[key.replace(".mlp.gate_proj.weight", ".ffn1_0_weight")] = ( - state_dict.pop(key).transpose([1, 0]).cast("float8_e4m3fn") - ) - elif key.endswith(".mlp.up_proj.weight"): - state_dict[key.replace(".mlp.up_proj.weight", ".ffn1_1_weight")] = ( - state_dict.pop(key).transpose([1, 0]).cast("float8_e4m3fn") - ) - elif key.endswith(".mlp.down_proj.weight"): - state_dict[key.replace(".mlp.down_proj.weight", ".ffn2_weight")] = ( - state_dict.pop(key).transpose([1, 0]).cast("float8_e4m3fn") - ) - - self.set_state_dict_to_params(state_dict, True) - - return self - - def set_state_dict_to_params(self, state_dict: dict[str, np.ndarray | paddle.Tensor], use_structured_name=True): - """ - set_state_dict_to_params - """ - if in_dygraph_mode: - for k, v in self.state_dict(use_hook=False).items(): - if k in state_dict: - v_new = state_dict.pop(k) - if v_new.shape != v.shape: - logger.warning( - f"key {k} has diff shape between " - + f"state_dict and model params: {v_new.shape} vs {v.shape}." - ) - v.copy_(v_new, False) - else: - logger.warning(f"key {k} is not found in state_dict.") - else: - # static mode code copy from nn.layers.Layer.set_state_dict - logger.warning("set_state_dict_to_params in static mode.") - missing_keys = [] - match_keys = set() - unexpected_keys = [] - - def _check_match(key, param): - state = state_dict.get(key, None) - if state is None: - missing_keys.append(key) - raise ValueError(f"{key} is not found in the provided dict.") - if isinstance(state, (dict, list)): - if len(state) != len(param): - missing_keys.append(key) - raise ValueError( - "{} receieves the length of {}, " - "but the expected shape is {}".format(key, len(state), len(param)) - ) - else: - match_keys.add(key) - return param, state - else: - state_shape = state.shape() if inspect.ismethod(state.shape) else state.shape - - if list(state_shape) != list(param.shape): - missing_keys.append(key) - raise ValueError( - "{} receives a shape {}, but the expected shape is {}.".format( - key, list(state_shape), list(param.shape) - ) - ) - match_keys.add(key) - return param, state - - matched_param_state = [] - for key, param in self._state_dict_impl(use_hook=False).items(): - key_name = key if use_structured_name else param.name - try: - match_res = _check_match(key_name, param) - matched_param_state.append(match_res) - except ValueError as err: - logging.warning(f"Skip loading for {key}. " + str(err)) - for key in state_dict.keys(): - if key not in match_keys: - unexpected_keys.append(key) - - def _set_var(var, ndarray): - t = global_scope().find_var(var.name).get_tensor() - p = t._place() - if p.is_cpu_place(): - place = core.CPUPlace() - elif p.is_cuda_pinned_place(): - place = core.CUDAPinnedPlace() - elif p.is_xpu_place(): - p = core.Place() - p.set_place(t._place()) - place = core.XPUPlace(p.xpu_device_id()) - else: - p = core.Place() - p.set_place(t._place()) - place = core.CUDAPlace(p.gpu_device_id()) - t.set(ndarray, place) - - try: - executor = Executor(_get_device())._default_executor - # restore parameter states - core._create_loaded_parameter( - [param for param, state in matched_param_state], - global_scope(), - executor, - ) - for param, state in matched_param_state: - _set_var(param, state) - except ValueError: - raise ValueError( - "This error might happens in dy2static, " - + "while calling 'set_state_dict' dynamicly in 'forward', " - + "which is not supported. " - + "If you only need call 'set_state_dict' once, " - + "move it to '__init__'." - ) - return self - @register_base_model class LlamaBlockInferenceModel(LlamaInferenceModel): @@ -1906,10 +1676,7 @@ def set_state_dict(self, state_dict): self.lm_head.weight.set_value( paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype) ) - if "fp8" in self.llama.quant_type: - self.llama.set_state_dict_fp8({k: state_dict[k] for k in state_dict.keys()}) - else: - self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) + self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) class LlamaForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, LlamaPretrainedModel): @@ -2119,10 +1886,7 @@ def set_state_dict(self, state_dict): self.lm_head.weight.set_value( paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype) ) - if "fp8" in self.llama.quant_type: - self.llama.set_state_dict_fp8({k: state_dict[k] for k in state_dict.keys()}) - else: - self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) + self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) class LlamaForMiniGPT4InferenceModel(LlamaForCausalLMInferenceModel): diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index f945e244faf6..e4cb60515921 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -13,19 +13,13 @@ # limitations under the License. from __future__ import annotations -import inspect import json -import logging import os from functools import partial import numpy as np import paddle from paddle import nn -from paddle.base import core -from paddle.base.executor import Executor, global_scope -from paddle.base.framework import _current_expected_place as _get_device -from paddle.base.framework import in_dygraph_mode from paddle.distributed import fleet from paddle.nn.quant import weight_quantize @@ -98,6 +92,7 @@ def __init__(self, config: Qwen2Config): self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads + self.head_size = self.hidden_size // self.num_attention_heads self.intermediate_size = config.intermediate_size self.num_layers = config.num_hidden_layers self.rms_norm_eps = config.rms_norm_eps @@ -152,6 +147,35 @@ def __init__(self, config: Qwen2Config): except: pass + qkv_weight_scale_attrs = None + out_proj_weight_scale_attrs = None + ffn1_weight_scale_attrs = None + ffn2_weight_scale_attrs = None + + qkv_out_scale_attrs = None + linear_out_scale_attrs = None + ffn1_out_scale_attrs = None + ffn2_out_scale_attrs = None + linear_shift_attrs = None + linear_smooth_attrs = None + ffn2_shift_attrs = None + ffn2_smooth_attrs = None + + ln_bias_attrs = None + qkv_bias_attrs = None + out_proj_bias_attrs = None + ffn_ln_bias_attrs = None + ffn1_bias_attrs = None + ffn2_bias_attrs = None + + ffn1_0_weight_attrs = None + ffn1_1_weight_attrs = None + ffn1_0_bias_attrs = None + ffn1_1_bias_attrs = None + + ffn1_weight_attrs = None + ffn2_weight_attrs = None + ln_scale_attrs = [paddle.ParamAttr(name="fuseqwen2.{}.ln_scale".format(i)) for i in range(self.num_layers)] qkv_weight_attrs = [ paddle.ParamAttr( @@ -182,8 +206,6 @@ def __init__(self, config: Qwen2Config): ) for i in range(self.num_layers) ] - ffn1_0_bias_attrs = None - ffn1_1_bias_attrs = None else: ffn1_weight_attrs = [ paddle.ParamAttr( @@ -198,26 +220,6 @@ def __init__(self, config: Qwen2Config): for i in range(self.num_layers) ] - qkv_weight_scale_attrs = None - out_proj_weight_scale_attrs = None - ffn1_weight_scale_attrs = None - ffn2_weight_scale_attrs = None - - qkv_out_scale_attrs = None - linear_out_scale_attrs = None - ffn1_out_scale_attrs = None - ffn2_out_scale_attrs = None - linear_shift_attrs = None - linear_smooth_attrs = None - ffn2_shift_attrs = None - ffn2_smooth_attrs = None - - ln_bias_attrs = None - out_proj_bias_attrs = None - ffn_ln_bias_attrs = None - ffn1_bias_attrs = None - ffn2_bias_attrs = None - if "a8w8" in self.quant_type: qkv_out_scale_attrs = [ paddle.ParamAttr(name="fuseqwen2.{}.qkv_out_scale".format(i)) for i in range(self.num_layers) @@ -267,11 +269,6 @@ def __init__(self, config: Qwen2Config): paddle.ParamAttr(name="fuseqwen2.{}.ffn2_bias".format(i)) for i in range(self.num_layers) ] - qkv_weight_scale_attrs = None - out_proj_weight_scale_attrs = None - ffn1_weight_scale_attrs = None - ffn2_weight_scale_attrs = None - if self.use_weight_only: qkv_weight_scale_attrs = [ paddle.ParamAttr(name="fuseqwen2.{}.qkv_weight_scale".format(i)) for i in range(self.num_layers) @@ -304,91 +301,57 @@ def __init__(self, config: Qwen2Config): paddle.ParamAttr(name="fuseqwen2.{}.cache_v_out_scale".format(i)) for i in range(self.num_layers) ] - if "fp8" in self.quant_type: - transformer_config = FusedMultiTransformerConfig( - embed_dim=self.hidden_size, - num_heads=self.num_attention_heads, - kv_num_heads=self.num_key_value_heads, - dim_feedforward=self.intermediate_size, - quant_type=self.quant_type, - activation="swiglu", - num_layers=config.num_hidden_layers, - nranks=config.tensor_parallel_degree, - ring_id=ring_id, - ln_scale_attrs=ln_scale_attrs, - ln_bias_attrs=ln_bias_attrs, - qkv_weight_attrs=qkv_weight_attrs, - qkv_bias_attrs=qkv_bias_attrs, - linear_weight_attrs=out_proj_weight_attrs, - linear_bias_attrs=out_proj_bias_attrs, - ffn_ln_scale_attrs=ffn_ln_scale_attrs, - ffn_ln_bias_attrs=ffn_ln_bias_attrs, - cache_k_scale_attrs=cache_k_scale_attrs, - cache_v_scale_attrs=cache_v_scale_attrs, - cache_k_out_scale_attrs=cache_k_out_scale_attrs, - cache_v_out_scale_attrs=cache_v_out_scale_attrs, - ffn1_0_weight_attrs=ffn1_0_weight_attrs, - ffn1_1_weight_attrs=ffn1_1_weight_attrs, - ffn1_0_bias_attrs=ffn1_0_bias_attrs, - ffn1_1_bias_attrs=ffn1_1_bias_attrs, - ffn2_weight_attrs=ffn2_weight_attrs, - ffn2_bias_attrs=ffn2_bias_attrs, - epsilon=self.rms_norm_eps, - rope_theta=self.rope_theta, - norm_type="rmsnorm", - use_neox_rotary_style=self.use_neox, - rank_id=config.tensor_parallel_rank, - append_attn=config.append_attn, - ) - - else: - transformer_config = FusedMultiTransformerConfig( - embed_dim=self.hidden_size, - num_heads=self.num_attention_heads, - kv_num_heads=self.num_key_value_heads, - dim_feedforward=self.intermediate_size, - quant_type=self.quant_type, - activation="swiglu", - num_layers=config.num_hidden_layers, - nranks=config.tensor_parallel_degree, - ring_id=ring_id, - ln_scale_attrs=ln_scale_attrs, - qkv_weight_attrs=qkv_weight_attrs, - qkv_weight_scale_attrs=qkv_weight_scale_attrs, - linear_weight_attrs=out_proj_weight_attrs, - linear_weight_scale_attrs=out_proj_weight_scale_attrs, - ffn_ln_scale_attrs=ffn_ln_scale_attrs, - ffn1_weight_attrs=ffn1_weight_attrs, - ffn1_weight_scale_attrs=ffn1_weight_scale_attrs, - ffn2_weight_attrs=ffn2_weight_attrs, - ffn2_weight_scale_attrs=ffn2_weight_scale_attrs, - qkv_out_scale_attrs=qkv_out_scale_attrs, - linear_out_scale_attrs=linear_out_scale_attrs, - ffn1_out_scale_attrs=ffn1_out_scale_attrs, - ffn2_out_scale_attrs=ffn2_out_scale_attrs, - linear_shift_attrs=linear_shift_attrs, - linear_smooth_attrs=linear_smooth_attrs, - ffn2_shift_attrs=ffn2_shift_attrs, - ffn2_smooth_attrs=ffn2_smooth_attrs, - ln_bias_attrs=ln_bias_attrs, - qkv_bias_attrs=qkv_bias_attrs, - linear_bias_attrs=out_proj_bias_attrs, - ffn_ln_bias_attrs=ffn_ln_bias_attrs, - ffn1_bias_attrs=ffn1_bias_attrs, - ffn2_bias_attrs=ffn2_bias_attrs, - cache_k_scale_attrs=cache_k_scale_attrs, - cache_v_scale_attrs=cache_v_scale_attrs, - cache_k_out_scale_attrs=cache_k_out_scale_attrs, - cache_v_out_scale_attrs=cache_v_out_scale_attrs, - epsilon=self.rms_norm_eps, - rope_theta=self.rope_theta, - norm_type="rmsnorm", - use_neox_rotary_style=self.use_neox, - cachekv_int8_type=config.cachekv_int8_type, - rank_id=config.tensor_parallel_rank, - trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True), - append_attn=config.append_attn, - ) + transformer_config = FusedMultiTransformerConfig( + embed_dim=self.hidden_size, + num_heads=self.num_attention_heads, + kv_num_heads=self.num_key_value_heads, + dim_feedforward=self.intermediate_size, + quant_type=self.quant_type, + activation="swiglu", + num_layers=config.num_hidden_layers, + nranks=config.tensor_parallel_degree, + ring_id=ring_id, + ln_scale_attrs=ln_scale_attrs, + qkv_weight_attrs=qkv_weight_attrs, + qkv_weight_scale_attrs=qkv_weight_scale_attrs, + linear_weight_attrs=out_proj_weight_attrs, + linear_weight_scale_attrs=out_proj_weight_scale_attrs, + ffn_ln_scale_attrs=ffn_ln_scale_attrs, + ffn1_weight_attrs=ffn1_weight_attrs, + ffn1_weight_scale_attrs=ffn1_weight_scale_attrs, + ffn1_0_weight_attrs=ffn1_0_weight_attrs, + ffn1_1_weight_attrs=ffn1_1_weight_attrs, + ffn2_weight_attrs=ffn2_weight_attrs, + ffn2_weight_scale_attrs=ffn2_weight_scale_attrs, + qkv_out_scale_attrs=qkv_out_scale_attrs, + linear_out_scale_attrs=linear_out_scale_attrs, + ffn1_out_scale_attrs=ffn1_out_scale_attrs, + ffn2_out_scale_attrs=ffn2_out_scale_attrs, + linear_shift_attrs=linear_shift_attrs, + linear_smooth_attrs=linear_smooth_attrs, + ffn2_shift_attrs=ffn2_shift_attrs, + ffn2_smooth_attrs=ffn2_smooth_attrs, + ln_bias_attrs=ln_bias_attrs, + qkv_bias_attrs=qkv_bias_attrs, + linear_bias_attrs=out_proj_bias_attrs, + ffn_ln_bias_attrs=ffn_ln_bias_attrs, + ffn1_bias_attrs=ffn1_bias_attrs, + ffn1_0_bias_attrs=ffn1_0_bias_attrs, + ffn1_1_bias_attrs=ffn1_1_bias_attrs, + ffn2_bias_attrs=ffn2_bias_attrs, + cache_k_scale_attrs=cache_k_scale_attrs, + cache_v_scale_attrs=cache_v_scale_attrs, + cache_k_out_scale_attrs=cache_k_out_scale_attrs, + cache_v_out_scale_attrs=cache_v_out_scale_attrs, + epsilon=self.rms_norm_eps, + rope_theta=self.rope_theta, + norm_type="rmsnorm", + use_neox_rotary_style=self.use_neox, + cachekv_int8_type=config.cachekv_int8_type, + rank_id=config.tensor_parallel_rank, + trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True), + append_attn=config.append_attn, + ) self.set_transformer_block(transformer_config) @@ -412,9 +375,42 @@ def set_input_embeddings(self, value): self.embed_tokens = value @paddle.no_grad() - def set_state_dict(self, state_dict): - if "a8w8" in self.quant_type: - current_work_dir = os.path.dirname(__file__) + def set_quant_scale(self): + current_work_dir = os.path.dirname(__file__) + if "fp8" in self.quant_type: + scale_map_file = f"{current_work_dir}/ptq_fp8_scales_map.json" + with open(scale_map_file) as json_file: + scale_map_dict = json.load(json_file) + act_scale_map_dict = scale_map_dict["act_scale"] + weight_scale_map_dict = scale_map_dict["weight_scale"] + cache_scale_map_dict = scale_map_dict["cachekv_scale"] + act_scale_json_path = resolve_file_path(self.quant_model_path, "act_scales.json") + weight_scale_json_path = resolve_file_path(self.quant_model_path, "weight_scales.json") + if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: + act_scale_json_path = resolve_file_path( + self.quant_model_path, f"act_scales_{self.config.tensor_parallel_rank}.json" + ) + weight_scale_json_path = resolve_file_path( + self.quant_model_path, f"weight_scales_{self.config.tensor_parallel_rank}.json" + ) + + act_scales = ActScalesLoader( + act_scale_json_path, act_scale_map_dict, num_of_layers=self.config.num_hidden_layers + ) + + weight_scales = PerTensorWeightScalesLoader( + weight_scale_json_path, + weight_scale_map_dict, + num_of_layers=self.config.num_hidden_layers, + ) + + for weight_name in weight_scales.scale: + weight_scales.scale[weight_name] = weight_scales.scale[weight_name].astype(np.float32) + for act_name in act_scales.scale: + act_scales.scale[act_name] = act_scales.scale[act_name].astype(np.float32) + self.transformer_block.weight_scales = weight_scales.scale + self.transformer_block.act_scales = act_scales.scale + elif "a8w8" in self.quant_type: scale_map_file = ( f"{current_work_dir}/ptq_scales_map.json" if not self.shift_smooth_all_linears @@ -425,8 +421,6 @@ def set_state_dict(self, state_dict): act_scale_map_dict = scale_map_dict["act_scale"] weight_scale_map_dict = scale_map_dict["weight_scale"] cache_scale_map_dict = scale_map_dict["cachekv_scale"] - # TODO(RichardWooSJTU): support multi-cards - if not self.use_fake_parameter: act_scale_json_path = resolve_file_path(self.quant_model_path, "act_scales.json") weight_scale_json_path = resolve_file_path(self.quant_model_path, "weight_scales.json") @@ -463,8 +457,101 @@ def set_state_dict(self, state_dict): self.transformer_block.weight_scales = weight_scales_loader.scale self.transformer_block.act_scales = act_scale_loader.scale + for k, v in weight_scales_loader.scale.items(): + if "qkv_" in k: + for i_layer, weight_scale in enumerate(v): + if not np.all(weight_scale == -1): + tmp = paddle.to_tensor( + weight_scale + / ( + 127.0 * 127.0 * act_scale_loader.scale["qkv_in_scale"][i_layer] + ) # [3 * num_head * dim_head] + ).reshape([-1]) + if self.config.tensor_parallel_degree > 1 and self.config.single_card_ptq: + tmp = ( + tmp.reshape([3, self.num_attention_heads, self.head_size]) + .split(self.config.tensor_parallel_degree, axis=1)[ + self.config.tensor_parallel_rank + ] + .reshape([-1]) + ) + self.transformer_block.qkv_out_scales[i_layer].set_value(tmp) + elif "out_linear_" in k: + for i_layer, weight_scale in enumerate(v): + if not np.all(weight_scale == -1): + tmp = paddle.to_tensor( + weight_scale / (127.0 * 127.0 * act_scale_loader.scale["out_linear_in_scale"][i_layer]) + ) + self.transformer_block.linear_out_scales[i_layer].set_value(tmp) + elif "ffn1_weight_scale" in k: + for i_layer, weight_scale in enumerate(v): + if not np.all(weight_scale == -1): + tmp = paddle.to_tensor( + weight_scale / (127.0 * 127.0 * act_scale_loader.scale["ffn1_in_scale"][i_layer]) + ) + if self.config.tensor_parallel_degree > 1 and self.config.single_card_ptq: + tmp = paddle.split(tmp, self.config.tensor_parallel_degree * 2) + tmp = paddle.concat( + [ + tmp[self.config.tensor_parallel_rank], + tmp[self.config.tensor_parallel_rank + self.config.tensor_parallel_degree], + ], + axis=0, + ) + self.transformer_block.ffn1_out_scales[i_layer].set_value(tmp) + elif "ffn2" in k: + for i_layer, weight_scale in enumerate(v): + if not np.all(weight_scale == -1): + self.transformer_block.ffn2_out_scales[i_layer].set_value( + paddle.to_tensor( + weight_scale / (127.0 * 127.0 * act_scale_loader.scale["ffn2_in_scale"][i_layer]) + ) + ) + + if self.config.cachekv_int8_type == "static": + if not self.use_fake_parameter: + cache_scale_json_path = resolve_file_path(self.quant_model_path, "cachekv_scales.json") + if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: + cache_scale_json_path = resolve_file_path( + self.quant_model_path, f"cachekv_scales_{self.config.tensor_parallel_rank}.json" + ) + cache_scales_loader = CacheScaleLoader( + cache_scale_json_path, + cache_scale_map_dict, + num_of_layers=self.config.num_hidden_layers, + num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, + num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, + ) + else: + cache_scales_loader = EmptyCacheScale( + cache_scale_map_dict, + num_of_layers=self.config.num_hidden_layers, + num_heads=self.num_attention_heads, + dim_heads=self.hidden_size // self.num_attention_heads, + is_channel_wise=False, + num_key_value_heads=self.num_key_value_heads, + mp_size=self.config.tensor_parallel_degree, + ) + + for k, v in cache_scales_loader.scale.items(): + for i_layer, weight_scale in enumerate(v): + if self.config.append_attn: + weight_scale = paddle.to_tensor(weight_scale).cast(paddle.get_default_dtype()) + else: + weight_scale = weight_scale.astype("float32") + if k == "cache_k_scale": + self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) + elif k == "cache_v_scale": + self.transformer_block.cache_v_scales[i_layer].set_value(weight_scale) + elif k == "cache_k_out_scale": + self.transformer_block.cache_k_out_scales[i_layer].set_value(weight_scale) + else: + self.transformer_block.cache_v_out_scales[i_layer].set_value(weight_scale) + + @paddle.no_grad() + def set_state_dict(self, state_dict): + self.set_quant_scale() self.transformer_block.init_weight() - head_size = self.hidden_size // self.num_attention_heads split_fn = split_param_func() self.embed_tokens.weight.set_value( paddle.to_tensor(state_dict["qwen2.embed_tokens.weight"]).cast(self.embed_tokens.weight.dtype) @@ -473,35 +560,52 @@ def set_state_dict(self, state_dict): for idx in range(self.num_layers): logger.info(f"set state for layer {idx}") - unfused_state_dict = {} + ln_scale = paddle.to_tensor(state_dict["qwen2.layers.{}.input_layernorm.weight".format(idx)]).cast( self.transformer_block.ln_scales[idx].dtype ) self.transformer_block.ln_scales[idx].set_value(ln_scale) if "qwen2.layers.{}.self_attn.qkv_proj.weight".format(idx) in state_dict.keys(): - concated_qkv_weight = np.concatenate( - split_fn( - state_dict["qwen2.layers.{}.self_attn.qkv_proj.weight".format(idx)], - is_qkv=True, - num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, - num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, - ), - axis=-1, - ).transpose(1, 0) + concated_qkv_weight = paddle.to_tensor( + np.concatenate( + split_fn( + state_dict["qwen2.layers.{}.self_attn.qkv_proj.weight".format(idx)], + is_qkv=True, + num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, + num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, + ), + axis=-1, + ).transpose(1, 0) + ) else: unfused_state_dict = {} - unfused_state_dict["qwen2.self_attn.q_proj.weight"] = state_dict[ - "qwen2.layers.{}.self_attn.q_proj.weight".format(idx) - ] - unfused_state_dict["qwen2.self_attn.k_proj.weight"] = state_dict[ - "qwen2.layers.{}.self_attn.k_proj.weight".format(idx) - ] - unfused_state_dict["qwen2.self_attn.v_proj.weight"] = state_dict[ - "qwen2.layers.{}.self_attn.v_proj.weight".format(idx) - ] - if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type: - concated_qkv_weight = np.concatenate( + unfused_state_dict["self_attn.q_proj.weight"] = paddle.to_tensor( + state_dict["qwen2.layers.{}.self_attn.q_proj.weight".format(idx)] + ) + unfused_state_dict["self_attn.k_proj.weight"] = paddle.to_tensor( + state_dict["qwen2.layers.{}.self_attn.k_proj.weight".format(idx)] + ) + unfused_state_dict["self_attn.v_proj.weight"] = paddle.to_tensor( + state_dict["qwen2.layers.{}.self_attn.v_proj.weight".format(idx)] + ) + if "fp8" in self.quant_type: + q_wgt_scale = self.transformer_block.weight_scales["q_weight_scale"][idx] + k_wgt_scale = self.transformer_block.weight_scales["k_weight_scale"][idx] + v_wgt_scale = self.transformer_block.weight_scales["v_weight_scale"][idx] + qkv_wgt_scale = self.transformer_block.weight_scales["qkv_weight_scale"][idx] + unfused_state_dict["self_attn.q_proj.weight"] = ( + unfused_state_dict["self_attn.q_proj.weight"].cast("float32") * q_wgt_scale / qkv_wgt_scale + ) + unfused_state_dict["self_attn.k_proj.weight"] = ( + unfused_state_dict["self_attn.k_proj.weight"].cast("float32") * k_wgt_scale / qkv_wgt_scale + ) + unfused_state_dict["self_attn.v_proj.weight"] = ( + unfused_state_dict["self_attn.v_proj.weight"].cast("float32") * v_wgt_scale / qkv_wgt_scale + ) + + if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type and "fp8" not in self.quant_type: + concated_qkv_weight = paddle.concat( [ unfused_state_dict["self_attn.q_proj.weight"], unfused_state_dict["self_attn.k_proj.weight"], @@ -509,34 +613,38 @@ def set_state_dict(self, state_dict): ], axis=-1, ).reshape( - self.hidden_size, - ( - self.num_attention_heads // self.config.tensor_parallel_degree - + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree - ) - * (head_size), + [ + self.hidden_size, + ( + self.num_attention_heads // self.config.tensor_parallel_degree + + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree + ) + * (self.head_size), + ] ) + else: concated_qkv_weight = ( - np.concatenate( + paddle.concat( [ - unfused_state_dict["qwen2.self_attn.q_proj.weight"], - unfused_state_dict["qwen2.self_attn.k_proj.weight"], - unfused_state_dict["qwen2.self_attn.v_proj.weight"], + unfused_state_dict["self_attn.q_proj.weight"], + unfused_state_dict["self_attn.k_proj.weight"], + unfused_state_dict["self_attn.v_proj.weight"], ], axis=-1, ) - .transpose(1, 0) + .transpose([1, 0]) .reshape( - ( - self.num_attention_heads // self.config.tensor_parallel_degree - + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree - ) - * (head_size), - self.hidden_size, + [ + ( + self.num_attention_heads // self.config.tensor_parallel_degree + + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree + ) + * (self.head_size), + self.hidden_size, + ] ) ) - qkv_weight = paddle.to_tensor(concated_qkv_weight).cast(paddle.get_default_dtype()) if self.use_weight_only: @@ -544,6 +652,8 @@ def set_state_dict(self, state_dict): qkv_quanted_weight, qkv_weight_scale = weight_quantize(qkv_weight, algo=self.quant_algo) self.transformer_block.qkv_weights[idx].set_value(qkv_quanted_weight) self.transformer_block.qkv_weights_scale[idx].set_value(qkv_weight_scale) + elif "fp8" in self.quant_type: + self.transformer_block.qkv_weights[idx].copy_(paddle.cast(concated_qkv_weight, "float8_e4m3fn"), False) elif "a8w8" in self.quant_type and not self.transformer_block.skip_quant("qkv_weight_scale", idx): self.transformer_block.qkv_weights[idx].set_value( paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8") @@ -581,6 +691,16 @@ def set_state_dict(self, state_dict): linear_quanted_weight, linear_weight_scale = weight_quantize(linear_weight, algo=self.quant_algo) self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight) self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale) + elif "fp8" in self.quant_type: + self.transformer_block.linear_weights[idx].copy_( + paddle.cast( + paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.weight".format(idx)]).transpose( + (1, 0) + ), + "float8_e4m3fn", + ), + False, + ) elif "a8w8" in self.quant_type: w_dtype = ( paddle.get_default_dtype() @@ -634,6 +754,17 @@ def set_state_dict(self, state_dict): ffn1_quanted_weight, ffn1_weight_scale = weight_quantize(ffn1_weight, algo=self.quant_algo) self.transformer_block.ffn1_weights[idx].set_value(ffn1_quanted_weight) self.transformer_block.ffn1_weights_scale[idx].set_value(ffn1_weight_scale) + elif "fp8" in self.quant_type: + self.transformer_block.ffn1_0_weights[idx].copy_( + paddle.to_tensor(unfused_state_dict["mlp.gate_proj.weight"]) + .transpose((1, 0)) + .cast("float8_e4m3fn"), + False, + ) + self.transformer_block.ffn1_1_weights[idx].copy_( + paddle.to_tensor(unfused_state_dict["mlp.up_proj.weight"]).transpose((1, 0)).cast("float8_e4m3fn"), + False, + ) elif "a8w8" in self.quant_type: w_dtype = ( paddle.get_default_dtype() @@ -658,6 +789,13 @@ def set_state_dict(self, state_dict): ffn2_quanted_weight, ffn2_weight_scale = weight_quantize(ffn2_weight, algo=self.quant_algo) self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight) self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale) + elif "fp8" in self.quant_type: + self.transformer_block.ffn2_weights[idx].copy_( + paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.weight".format(idx)]) + .transpose([1, 0]) + .cast("float8_e4m3fn"), + False, + ) elif "a8w8" in self.quant_type: w_dtype = ( paddle.get_default_dtype() @@ -665,26 +803,15 @@ def set_state_dict(self, state_dict): else "int8" ) if paddle.is_compiled_with_rocm(): - self.transformer_block.ffn2_weights[idx].set_value( - paddle.cast( - paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.weight".format(idx)]), w_dtype - ) - ) + self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight.cast(w_dtype)) else: - self.transformer_block.ffn2_weights[idx].set_value( - paddle.cast( - paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.weight".format(idx)]).transpose( - (1, 0) - ), - w_dtype, - ) - ) + self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight.transpose([1, 0]).cast(w_dtype)) else: self.transformer_block.ffn2_weights[idx].set_value( ffn2_weight.cast(self.transformer_block.ffn2_weights[idx].dtype) ) - if "a8w8" in self.quant_type: + if "fp8" not in self.quant_type and "a8w8" in self.quant_type: if self.shift_smooth_all_linears: if self.use_fake_parameter: if "qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx) not in state_dict: @@ -812,411 +939,6 @@ def set_state_dict(self, state_dict): paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.layer.bias".format(idx)]) ) - if "a8w8" in self.quant_type: - if self.config.cachekv_int8_type == "static": - if not self.use_fake_parameter: - cache_scale_json_path = resolve_file_path(self.quant_model_path, "cachekv_scales.json") - if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: - cache_scale_json_path = resolve_file_path( - self.quant_model_path, f"cachekv_scales_{self.config.tensor_parallel_rank}.json" - ) - cache_scales_loader = CacheScaleLoader( - cache_scale_json_path, - cache_scale_map_dict, - num_of_layers=self.config.num_hidden_layers, - num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, - num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, - ) - else: - cache_scales_loader = EmptyCacheScale( - cache_scale_map_dict, - num_of_layers=self.config.num_hidden_layers, - num_heads=self.num_attention_heads, - dim_heads=self.hidden_size // self.num_attention_heads, - is_channel_wise=False, - num_key_value_heads=self.num_key_value_heads, - mp_size=self.config.tensor_parallel_degree, - ) - - for k, v in cache_scales_loader.scale.items(): - for i_layer, weight_scale in enumerate(v): - if self.config.append_attn: - weight_scale = paddle.to_tensor(weight_scale).cast(paddle.get_default_dtype()) - else: - weight_scale = weight_scale.astype("float32") - if k == "cache_k_scale": - self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) - elif k == "cache_v_scale": - self.transformer_block.cache_v_scales[i_layer].set_value(weight_scale) - elif k == "cache_k_out_scale": - self.transformer_block.cache_k_out_scales[i_layer].set_value(weight_scale) - else: - self.transformer_block.cache_v_out_scales[i_layer].set_value(weight_scale) - - for k, v in weight_scales_loader.scale.items(): - if "qkv_" in k: - for i_layer, weight_scale in enumerate(v): - if not np.all(weight_scale == -1): - tmp = paddle.to_tensor( - weight_scale - / ( - 127.0 * 127.0 * act_scale_loader.scale["qkv_in_scale"][i_layer] - ) # [3 * num_head * dim_head] - ).reshape([-1]) - if self.config.tensor_parallel_degree > 1 and self.config.single_card_ptq: - tmp = ( - tmp.reshape([3, self.num_attention_heads, head_size]) - .split(self.config.tensor_parallel_degree, axis=1)[ - self.config.tensor_parallel_rank - ] - .reshape([-1]) - ) - self.transformer_block.qkv_out_scales[i_layer].set_value(tmp) - pass - elif "out_linear_" in k: - for i_layer, weight_scale in enumerate(v): - if not np.all(weight_scale == -1): - tmp = paddle.to_tensor( - weight_scale / (127.0 * 127.0 * act_scale_loader.scale["out_linear_in_scale"][i_layer]) - ) - self.transformer_block.linear_out_scales[i_layer].set_value(tmp) - elif "ffn1_weight_scale" in k: - for i_layer, weight_scale in enumerate(v): - if not np.all(weight_scale == -1): - tmp = paddle.to_tensor( - weight_scale / (127.0 * 127.0 * act_scale_loader.scale["ffn1_in_scale"][i_layer]) - ) - if self.config.tensor_parallel_degree > 1 and self.config.single_card_ptq: - tmp = paddle.split(tmp, self.config.tensor_parallel_degree * 2) - tmp = paddle.concat( - [ - tmp[self.config.tensor_parallel_rank], - tmp[self.config.tensor_parallel_rank + self.config.tensor_parallel_degree], - ], - axis=0, - ) - self.transformer_block.ffn1_out_scales[i_layer].set_value(tmp) - elif "ffn2" in k: - for i_layer, weight_scale in enumerate(v): - if not np.all(weight_scale == -1): - self.transformer_block.ffn2_out_scales[i_layer].set_value( - paddle.to_tensor( - weight_scale / (127.0 * 127.0 * act_scale_loader.scale["ffn2_in_scale"][i_layer]) - ) - ) - - def set_state_dict_fp8(self, state_dict: dict[str, np.ndarray | paddle.Tensor], use_structured_name=True): - """transpose qkv shape & cast dtype for layernorm - - Args: - state_dict (dict[str, np.ndarray | paddle.Tensor]): the state dict of model - use_structured_name (bool, optional): _description_. Defaults to True. - """ - current_work_dir = os.path.dirname(__file__) - scale_map_file = f"{current_work_dir}/ptq_fp8_scales_map.json" - with open(scale_map_file) as json_file: - scale_map_dict = json.load(json_file) - act_scale_map_dict = scale_map_dict["act_scale"] - weight_scale_map_dict = scale_map_dict["weight_scale"] - cache_scale_map_dict = scale_map_dict["cachekv_scale"] - act_scale_json_path = resolve_file_path(self.quant_model_path, "act_scales.json") - weight_scale_json_path = resolve_file_path(self.quant_model_path, "weight_scales.json") - if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: - act_scale_json_path = resolve_file_path( - self.quant_model_path, f"act_scales_{self.config.tensor_parallel_rank}.json" - ) - weight_scale_json_path = resolve_file_path( - self.quant_model_path, f"weight_scales_{self.config.tensor_parallel_rank}.json" - ) - - act_scales = ActScalesLoader( - act_scale_json_path, act_scale_map_dict, num_of_layers=self.config.num_hidden_layers - ) - - weight_scales = PerTensorWeightScalesLoader( - weight_scale_json_path, - weight_scale_map_dict, - num_of_layers=self.config.num_hidden_layers, - ) - - for weight_name in weight_scales.scale: - weight_scales.scale[weight_name] = weight_scales.scale[weight_name].astype(np.float32) - for act_name in act_scales.scale: - act_scales.scale[act_name] = act_scales.scale[act_name].astype(np.float32) - self.transformer_block.act_scales = act_scales - self.transformer_block.weight_scales = weight_scales - - unfused_state_dict = {} - head_size = self.hidden_size // self.num_attention_heads - split_fn = split_param_func() - if self.config.cachekv_int8_type == "static": - cache_scale_json_path = resolve_file_path(self.quant_model_path, "cachekv_scales.json") - if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: - cache_scale_json_path = resolve_file_path( - self.quant_model_path, f"cachekv_scales_{self.config.tensor_parallel_rank}.json" - ) - cache_scales_loader = CacheScaleLoader( - cache_scale_json_path, - cache_scale_map_dict, - num_of_layers=self.config.num_hidden_layers, - num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, - num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, - ) - for k, v in cache_scales_loader.scale.items(): - for i_layer, weight_scale in enumerate(v): - if self.config.append_attn: - weight_scale = paddle.to_tensor(weight_scale).cast(paddle.get_default_dtype()) - else: - weight_scale = weight_scale.astype("float32") - if k == "cache_k_scale": - self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) - elif k == "cache_v_scale": - self.transformer_block.cache_v_scales[i_layer].set_value(weight_scale) - elif k == "cache_k_out_scale": - self.transformer_block.cache_k_out_scales[i_layer].set_value(weight_scale) - else: - self.transformer_block.cache_v_out_scales[i_layer].set_value(weight_scale) - unfused_state_dict = {} - head_size = self.hidden_size // self.num_attention_heads - split_fn = split_param_func() - - self.embed_tokens.weight.set_value( - paddle.to_tensor(state_dict["qwen2.embed_tokens.weight"]).cast(self.embed_tokens.weight.dtype) - ) - self.norm.weight.set_value(paddle.to_tensor(state_dict["qwen2.norm.weight"]).cast(self.norm.weight.dtype)) - - for key in state_dict.keys(): - state_dict[key] = paddle.to_tensor(state_dict[key]) - - for key in list(state_dict.keys()): - if "qwen2.layers" in key: - state_dict[key.replace("qwen2.layers", "transformer_block.fuseqwen2")] = state_dict.pop(key) - - for idx in range(self.config.num_hidden_layers): - if "transformer_block.fuseqwen2.{}.self_attn.qkv_proj.weight".format(idx) in list(state_dict.keys()): - concated_qkv_weight = paddle.concat( - split_fn( - state_dict["transformer_block.fuseqwen2.{}.self_attn.qkv_proj.weight".format(idx)], - is_qkv=True, - num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, - num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, - ), - axis=-1, - ).transpose([1, 0]) - else: - unfused_state_dict = {} - q_wgt_scale = self.transformer_block.weight_scales.scale["q_weight_scale"][idx] - k_wgt_scale = self.transformer_block.weight_scales.scale["k_weight_scale"][idx] - v_wgt_scale = self.transformer_block.weight_scales.scale["v_weight_scale"][idx] - qkv_wgt_scale = self.transformer_block.weight_scales.scale["qkv_weight_scale"][idx] - unfused_state_dict["self_attn.q_proj.weight"] = ( - state_dict["transformer_block.fuseqwen2.{}.self_attn.q_proj.weight".format(idx)].cast("float32") - * q_wgt_scale - / qkv_wgt_scale - ) - unfused_state_dict["self_attn.k_proj.weight"] = ( - state_dict["transformer_block.fuseqwen2.{}.self_attn.k_proj.weight".format(idx)].cast("float32") - * k_wgt_scale - / qkv_wgt_scale - ) - unfused_state_dict["self_attn.v_proj.weight"] = ( - state_dict["transformer_block.fuseqwen2.{}.self_attn.v_proj.weight".format(idx)].cast("float32") - * v_wgt_scale - / qkv_wgt_scale - ) - concated_qkv_weight = ( - paddle.concat( - [ - unfused_state_dict["self_attn.q_proj.weight"], - unfused_state_dict["self_attn.k_proj.weight"], - unfused_state_dict["self_attn.v_proj.weight"], - ], - axis=-1, - ) - .transpose([1, 0]) - .reshape( - [ - ( - self.num_attention_heads // self.config.tensor_parallel_degree - + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree - ) - * (head_size), - self.hidden_size, - ] - ) - ) - state_dict[ - "transformer_block.fuseqwen2.{}.self_attn.qkv_proj.weight".format(idx) - ] = concated_qkv_weight - if "transformer_block.fuseqwen2.{}.self_attn.qkv_proj.bias".format(idx) in list(state_dict.keys()): - concated_qkv_bias = paddle.concat( - split_fn( - state_dict["transformer_block.fuseqwen2.{}.self_attn.qkv_proj.bias".format(idx)], - is_qkv=True, - num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, - num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, - ), - axis=-1, - ) - else: - unfused_state_dict = {} - unfused_state_dict["self_attn.q_proj.bias"] = state_dict[ - "transformer_block.fuseqwen2.{}.self_attn.q_proj.bias".format(idx) - ] - unfused_state_dict["self_attn.k_proj.bias"] = state_dict[ - "transformer_block.fuseqwen2.{}.self_attn.k_proj.bias".format(idx) - ] - unfused_state_dict["self_attn.v_proj.bias"] = state_dict[ - "transformer_block.fuseqwen2.{}.self_attn.v_proj.bias".format(idx) - ] - concated_qkv_bias = paddle.concat( - [ - unfused_state_dict["self_attn.q_proj.bias"], - unfused_state_dict["self_attn.k_proj.bias"], - unfused_state_dict["self_attn.v_proj.bias"], - ], - axis=-1, - ) - state_dict["transformer_block.fuseqwen2.{}.self_attn.qkv_proj.bias".format(idx)] = concated_qkv_bias - - for key in list(state_dict.keys()): - if key.endswith(".input_layernorm.weight"): - state_dict[key.replace(".input_layernorm.weight", ".ln_scale")] = state_dict.pop(key).cast( - self.transformer_block.ln_scales[idx].dtype - ) - elif key.endswith(".post_attention_layernorm.weight"): - state_dict[key.replace(".post_attention_layernorm.weight", ".ffn_ln_scale")] = state_dict.pop( - key - ).cast(self.transformer_block.ffn_ln_scales[idx].dtype) - elif key.endswith(".self_attn.qkv_proj.weight"): - state_dict[key.replace(".self_attn.qkv_proj.weight", ".qkv_weight")] = state_dict.pop(key).cast( - "float8_e4m3fn" - ) - elif key.endswith(".self_attn.qkv_proj.bias"): - state_dict[key.replace(".self_attn.qkv_proj.bias", ".qkv_bias")] = state_dict.pop(key).cast( - self.transformer_block.qkv_biases[idx].dtype - ) - elif key.endswith(".self_attn.o_proj.weight"): - state_dict[key.replace(".self_attn.o_proj.weight", ".out_proj_weight")] = ( - state_dict.pop(key).transpose([1, 0]).cast("float8_e4m3fn") - ) - elif key.endswith(".mlp.gate_proj.weight"): - state_dict[key.replace(".mlp.gate_proj.weight", ".ffn1_0_weight")] = ( - state_dict.pop(key).transpose([1, 0]).cast("float8_e4m3fn") - ) - elif key.endswith(".mlp.up_proj.weight"): - state_dict[key.replace(".mlp.up_proj.weight", ".ffn1_1_weight")] = ( - state_dict.pop(key).transpose([1, 0]).cast("float8_e4m3fn") - ) - elif key.endswith(".mlp.down_proj.weight"): - state_dict[key.replace(".mlp.down_proj.weight", ".ffn2_weight")] = ( - state_dict.pop(key).transpose([1, 0]).cast("float8_e4m3fn") - ) - - self.set_state_dict_to_params(state_dict, True) - - return self - - def set_state_dict_to_params(self, state_dict: dict[str, np.ndarray | paddle.Tensor], use_structured_name=True): - """ - set_state_dict_to_params - """ - if in_dygraph_mode: - for k, v in self.state_dict(use_hook=False).items(): - if k in state_dict: - v_new = state_dict.pop(k) - if v_new.shape != v.shape: - logger.warning( - f"key {k} has diff shape between " - + f"state_dict and model params: {v_new.shape} vs {v.shape}." - ) - v.copy_(v_new, False) - else: - logger.warning(f"key {k} is not found in state_dict.") - else: - # static mode code copy from nn.layers.Layer.set_state_dict - logger.warning("set_state_dict_to_params in static mode.") - missing_keys = [] - match_keys = set() - unexpected_keys = [] - - def _check_match(key, param): - state = state_dict.get(key, None) - if state is None: - missing_keys.append(key) - raise ValueError(f"{key} is not found in the provided dict.") - if isinstance(state, (dict, list)): - if len(state) != len(param): - missing_keys.append(key) - raise ValueError( - "{} receieves the length of {}, " - "but the expected shape is {}".format(key, len(state), len(param)) - ) - else: - match_keys.add(key) - return param, state - else: - state_shape = state.shape() if inspect.ismethod(state.shape) else state.shape - - if list(state_shape) != list(param.shape): - missing_keys.append(key) - raise ValueError( - "{} receives a shape {}, but the expected shape is {}.".format( - key, list(state_shape), list(param.shape) - ) - ) - match_keys.add(key) - return param, state - - matched_param_state = [] - for key, param in self._state_dict_impl(use_hook=False).items(): - key_name = key if use_structured_name else param.name - try: - match_res = _check_match(key_name, param) - matched_param_state.append(match_res) - except ValueError as err: - logging.warning(f"Skip loading for {key}. " + str(err)) - for key in state_dict.keys(): - if key not in match_keys: - unexpected_keys.append(key) - - def _set_var(var, ndarray): - t = global_scope().find_var(var.name).get_tensor() - p = t._place() - if p.is_cpu_place(): - place = core.CPUPlace() - elif p.is_cuda_pinned_place(): - place = core.CUDAPinnedPlace() - elif p.is_xpu_place(): - p = core.Place() - p.set_place(t._place()) - place = core.XPUPlace(p.xpu_device_id()) - else: - p = core.Place() - p.set_place(t._place()) - place = core.CUDAPlace(p.gpu_device_id()) - t.set(ndarray, place) - - try: - executor = Executor(_get_device())._default_executor - # restore parameter states - core._create_loaded_parameter( - [param for param, state in matched_param_state], - global_scope(), - executor, - ) - for param, state in matched_param_state: - _set_var(param, state) - except ValueError: - raise ValueError( - "This error might happens in dy2static, " - + "while calling 'set_state_dict' dynamicly in 'forward', " - + "which is not supported. " - + "If you only need call 'set_state_dict' once, " - + "move it to '__init__'." - ) - return self - def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) @@ -1504,10 +1226,7 @@ def set_state_dict(self, state_dict): if "lm_head.weight" in state_dict: lm_head_weight = paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype) self.lm_head.weight.set_value(lm_head_weight) - if "fp8" in self.qwen2.quant_type: - self.qwen2.set_state_dict_fp8({k: state_dict[k] for k in state_dict.keys()}) - else: - self.qwen2.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) + self.qwen2.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) @register_base_model @@ -1773,7 +1492,4 @@ def set_state_dict(self, state_dict): self.lm_head.weight.set_value( paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype) ) - if "fp8" in self.qwen2.quant_type: - self.qwen2.set_state_dict_fp8({k: state_dict[k] for k in state_dict.keys()}) - else: - self.qwen2.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) + self.qwen2.set_state_dict({k: state_dict[k] for k in state_dict.keys()})