diff --git a/csrc/README.md b/csrc/README.md index 21cf3e9bc90f..710ea28abd67 100644 --- a/csrc/README.md +++ b/csrc/README.md @@ -15,19 +15,6 @@ pip install -r requirements.txt python setup_cuda.py install ``` -### 手动安装 Cutlass 库 -1. 访问 Cutlass 仓库: [NVIDIA/cutlass](https://github.com/NVIDIA/cutlass) - -2. 拉取代码: - git clone -b v3.5.0 --single-branch https://github.com/NVIDIA/cutlass.git - -3. 将下载的 `cutlass` 目录放在 `csrc/third_party/cutlass`下 - -4. 重新编译 Cuda 算子 -```shell -python setup_cuda.py install -``` - ### FP8 GEMM 自动调优 确保 `cutlass` 库已经安装,然后执行以下命令进行自动调优。 diff --git a/csrc/cpu/src/stop_generation_multi_ends.cc b/csrc/cpu/src/stop_generation_multi_ends.cc index 73aaf688afea..cae2704f88a8 100644 --- a/csrc/cpu/src/stop_generation_multi_ends.cc +++ b/csrc/cpu/src/stop_generation_multi_ends.cc @@ -15,20 +15,9 @@ #include #include -#include "paddle/extension.h" +#include "helper.h" #include - -bool is_in_end(const int64_t id, const int64_t* end_ids, int length) { - bool flag = false; - for (int i = 0; i < length; i++) { - if (id == end_ids[i]) { - return true; - } - } - return flag; -} - void set_value_by_flags(const bool* stop_flags, const int64_t* end_ids, int64_t* topk_ids, diff --git a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu new file mode 100644 index 000000000000..06295cd62207 --- /dev/null +++ b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu @@ -0,0 +1,191 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fp8_fp8_half_cuda_core_gemm.h" +#include "cutlass/numeric_conversion.h" + +template +__global__ void cudaCoreGemm(InputType const* __restrict__ act, + InputType const* __restrict__ weight, + OutputType const* __restrict__ bias, + OutputType* __restrict__ output, + int32_t m, + int32_t n, + int32_t k, + float alpha) { + using VecType = int4; + static constexpr int32_t kStepK = + static_cast(128 / (8 * sizeof(InputType))); + static constexpr int32_t kTileK = kStepK * BLOCK_SIZE; + auto tileIdM = static_cast(blockIdx.x * TILE_M); + auto tileIdN = static_cast(blockIdx.y * TILE_N); + auto tid = static_cast(threadIdx.x); + float tile_a[kStepK], tile_w[TILE_N * kStepK]; + float acc[TILE_M * TILE_N]; + + static_assert(kStepK % 4 == 0); + using Converter = cutlass::NumericArrayConverter; + using CvtSrcType = typename Converter::source_type; + using CvtResType = typename Converter::result_type; + + static constexpr int32_t kCvtCount = + static_cast(sizeof(VecType) / sizeof(CvtSrcType)); + +#pragma unroll + for (int32_t i = 0; i < TILE_M * TILE_N; ++i) { + acc[i] = 0; + } + act += tileIdM * k; + weight += tileIdN * k; + output += tileIdM * n + tileIdN; + if constexpr (UseBias) { + bias += tileIdN; + } + for (int32_t idxK = tid * kStepK; idxK < k; idxK += kTileK) { + for (int32_t i = 0; i < TILE_N; ++i) { + auto tile_w_quantized = + reinterpret_cast(weight + i * k + idxK)[0]; +#pragma unroll + for (int32_t cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) { + reinterpret_cast(tile_w)[i * kCvtCount + cvtIdx] = + Converter::convert( + reinterpret_cast(&tile_w_quantized)[cvtIdx]); + } + } +#pragma unroll + for (int32_t i = 0; i < TILE_M; ++i) { + auto tile_a_quantized = + reinterpret_cast(act + i * k + idxK)[0]; +#pragma unroll + for (int32_t cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) { + reinterpret_cast(tile_a)[cvtIdx] = Converter::convert( + reinterpret_cast(&tile_a_quantized)[cvtIdx]); + } +#pragma unroll + for (int32_t j = 0; j < TILE_N; ++j) { +#pragma unroll + for (int32_t l = 0; l < kStepK; ++l) { + acc[i * TILE_N + j] = + fma(tile_a[l], tile_w[j * kStepK + l], acc[i * TILE_N + j]); + } + } + } + } + + typedef cub::WarpReduce WarpReduce; + + static constexpr int32_t kWarpSize = 32; + static constexpr int32_t kWarpNum = BLOCK_SIZE / kWarpSize; + int32_t warpId = tid / kWarpSize, laneId = tid % kWarpSize; + __shared__ float shmem[TILE_M * TILE_N * kWarpNum]; + __shared__ typename WarpReduce::TempStorage tempStorage[kWarpNum]; +#pragma unroll + for (int32_t mi = 0; mi < TILE_M; ++mi) { +#pragma unroll + for (int32_t ni = 0; ni < TILE_N; ++ni) { + float val = WarpReduce(tempStorage[warpId]).Sum(acc[mi * TILE_N + ni]); + if (laneId == 0) { + shmem[mi * TILE_N + ni + warpId * TILE_M * TILE_N] = val; + } + } + } + + __syncthreads(); + for (int32_t ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) { + int32_t mid = ii / TILE_N, nid = ii % TILE_N; + float val = 0; +#pragma unroll + for (int32_t jj = 0; jj < kWarpNum; ++jj) { + val += shmem[jj * TILE_M * TILE_N + ii]; + } + + if constexpr (UseBias) { + output[mid * n + nid] = static_cast(val * alpha + (float)*(bias+nid)) ; + } else { + output[mid * n + nid] = static_cast(val * alpha); + } + } +} + +template +void cudaCoreGemmKernel(GemmParams const& params) { + dim3 block(BLOCK_SIZE); + dim3 grid(params.m / TILE_M, params.n / TILE_N); + // std::cout << "m" << params.m << " n" << params.n << " k " << params.k << std::endl; + + if (params.bias != nullptr) { + cudaCoreGemm + <<>>( + reinterpret_cast(params.act), + reinterpret_cast(params.weight), + reinterpret_cast(params.bias), + reinterpret_cast(params.output), + params.m, + params.n, + params.k, + params.alpha); + } else { + cudaCoreGemm + <<>>( + reinterpret_cast(params.act), + reinterpret_cast(params.weight), + reinterpret_cast(params.bias), + reinterpret_cast(params.output), + params.m, + params.n, + params.k, + params.alpha); + } +} + +template +bool cudaCoreGemmTemplateCaller(GemmParams const& params) { + constexpr int cudaCoreGemmTemplateMaxM = 16; + if (params.m == TILE_M) { + cudaCoreGemmKernel( + params); + return true; + } + if constexpr (TILE_M < cudaCoreGemmTemplateMaxM) { + return cudaCoreGemmTemplateCaller(params); + } + return false; +} + +template +bool cuda_core_gemm_launcher(GemmParams const& params) { + return cudaCoreGemmTemplateCaller(params); +} + +template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(GemmParams const&); +template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(GemmParams const&); +template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(GemmParams const&); +template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(GemmParams const&); \ No newline at end of file diff --git a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.h b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.h new file mode 100644 index 000000000000..31eab1943dac --- /dev/null +++ b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.h @@ -0,0 +1,37 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fp8_common.h" // NOLINT + +typedef struct { + void const* act; + void const* weight; + void const* bias; + void* output; + int32_t m, n, k; + float alpha; + cudaStream_t stream; +} GemmParams; + +inline bool enable_cuda_core_fp8_gemm() { + static const char* enable_cuda_core_fp8_env = std::getenv("FLAGS_cuda_core_fp8_gemm"); + static const bool enable_cuda_core_fp8_gemm = + enable_cuda_core_fp8_env != nullptr && std::string(enable_cuda_core_fp8_env) == "1"; + return enable_cuda_core_fp8_gemm; +} + +template +bool cuda_core_gemm_launcher(GemmParams const& params); diff --git a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu index 541d3689f1da..03b0c63204b0 100644 --- a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu +++ b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu @@ -16,6 +16,7 @@ #include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h" #include "fp8_common.h" // NOLINT +#include "fp8_fp8_half_cuda_core_gemm.h" std::vector cutlass_fp8_fp8_half_gemm( const paddle::Tensor& x, @@ -116,31 +117,59 @@ std::vector cutlass_fp8_fp8_half_gemm( } } - GemmEpilogueAllParams params = { - x_ptr, - y_ptr, - out_ptr, - scale, - M, - N, - K, - lda, - ldb, - ldd, - batch_count, - place, - stream, - sm_version, - 0.01, // for leaky_relu - bias_data, - bias_dims, - fuse_gemm_config}; - if (sm_version == 89){ - fp8_fp8_gemm_scale_bias_act(params); - }else{ - fp8_fp8_gemm_scale_bias_act_sm90(params); - } - + if (M <=4 && trans_y && !trans_x && act == "noact" && enable_cuda_core_fp8_gemm()) { + GemmParams params = { + x_ptr, + y_ptr, + bias_data, + out_ptr, + M, + N, + K, + scale, + stream, + }; + + if (x.dtype() == phi::DataType::FLOAT8_E4M3FN) + { + if(output_dtype == "bfloat16") { + cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params); + + } else { + cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params); + } + } else { + if(output_dtype == "bfloat16") { + cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(params); + } else { + cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(params); + } + } + } else { + GemmEpilogueAllParams params = {x_ptr, + y_ptr, + out_ptr, + scale, + M, + N, + K, + lda, + ldb, + ldd, + batch_count, + place, + stream, + sm_version, + 0.01, // for leaky_relu + bias_data, + bias_dims, + fuse_gemm_config}; + if (sm_version == 89){ + fp8_fp8_gemm_scale_bias_act(params); + }else{ + fp8_fp8_gemm_scale_bias_act_sm90(params); + } + } return {out}; } diff --git a/csrc/gpu/get_output.cc b/csrc/gpu/get_output.cc index 87535e0a6362..e1ee4be35c75 100644 --- a/csrc/gpu/get_output.cc +++ b/csrc/gpu/get_output.cc @@ -1,11 +1,11 @@ // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,53 +17,72 @@ #include #include #include + #include "paddle/extension.h" #define MAX_BSZ 512 +#define SPECULATE_MAX_BSZ 256 +#define MAX_DRAFT_TOKENS 6 -struct msgdata { - long mtype; - int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens +template +struct MsgData { + long mtype; + std::array mtext; }; -void GetOutput(const paddle::Tensor& x, - int64_t rank_id, - bool wait_flag) { +template +void GetOutputFunc(MsgData& msg_rcv, // NOLINT + const paddle::Tensor& x, + int64_t rank_id, + bool wait_flag) { if (rank_id > 0) return; - static struct msgdata msg_rcv; - static key_t key = ftok("./", 1); static int msgid = msgget(key, IPC_CREAT | 0666); - int64_t *out_data = const_cast(x.data()); int ret = -1; - if (!wait_flag) { - ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT); - } else { - ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0); - } - if(ret == -1) - { + ret = msgrcv( + msgid, &msg_rcv, SIZE * sizeof(int), 0, wait_flag ? 0 : IPC_NOWAIT); + + int64_t* out_data = const_cast(x.data()); + + if (ret == -1) { // read none out_data[0] = -2; out_data[1] = 0; - return; - } - - int bsz = msg_rcv.mtext[1]; + return; + } - for (int64_t i = 0; i < bsz + 2; i++) { + for (int64_t i = 0; i < SIZE; i++) { out_data[i] = (int64_t)msg_rcv.mtext[i]; } + return; } +void GetOutput(const paddle::Tensor& x, + int64_t rank_id, + bool wait_flag, + bool speculative_decoding) { + if (!speculative_decoding) { + constexpr int SIZE = MAX_BSZ + 2; // stop_flag, bsz, tokens... + static struct MsgData msg_rcv; + GetOutputFunc(msg_rcv, x, rank_id, wait_flag); + } else { + constexpr int SIZE = SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + + SPECULATE_MAX_BSZ + + 2; // stop_flag, bsz, accept_num*bsz, tokens... + static struct MsgData specu_msg_rcv; + GetOutputFunc(specu_msg_rcv, x, rank_id, wait_flag); + } +} + PD_BUILD_OP(get_output) .Inputs({"x"}) .Attrs({"rank_id: int64_t", - "wait_flag: bool"}) + "wait_flag: bool", + "speculative_decoding: bool"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) - .SetKernelFn(PD_KERNEL(GetOutput)); + .SetKernelFn(PD_KERNEL(GetOutput)); \ No newline at end of file diff --git a/csrc/gpu/get_padding_offset_v2.cu b/csrc/gpu/get_padding_offset_v2.cu index ab088e903b9a..46a118eeaad8 100644 --- a/csrc/gpu/get_padding_offset_v2.cu +++ b/csrc/gpu/get_padding_offset_v2.cu @@ -23,6 +23,9 @@ __global__ void GetPaddingOffsetV2Kernel(int *padding_offset, const int64_t *input_data, const int *cum_offsets, const int *seq_lens, + const int64_t *draft_tokens, + const int *seq_lens_encoder, + const int max_draft_tokens, const int max_seq_len) { // get padding offset of each batch const int bi = blockIdx.x; @@ -31,8 +34,18 @@ __global__ void GetPaddingOffsetV2Kernel(int *padding_offset, for (int i = ti; i < seq_lens[bi]; i += blockDim.x) { padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; const int tgt_seq_id = bi * max_seq_len - cum_offset + i; - const int src_seq_id = bi * max_seq_len + i; - output_data[tgt_seq_id] = input_data[src_seq_id]; + if (draft_tokens == nullptr) { + const int src_seq_id = bi * max_seq_len + i; + output_data[tgt_seq_id] = input_data[src_seq_id]; + } else { // speculative decoding + if (seq_lens_encoder[bi] > 0) { + const int src_seq_id = bi * max_seq_len + i; + output_data[tgt_seq_id] = input_data[src_seq_id]; + } else { + const int src_seq_id = bi * max_draft_tokens + i; + output_data[tgt_seq_id] = draft_tokens[src_seq_id]; + } + } } if (ti == 0) { if (bi == 0) { @@ -50,7 +63,9 @@ __global__ void GetPaddingOffsetV2Kernel(int *padding_offset, std::vector GetPaddingOffsetV2(const paddle::Tensor& input_ids, const paddle::Tensor& cum_offsets, const paddle::Tensor& token_num, - const paddle::Tensor& seq_len) { + const paddle::Tensor& seq_len, + const paddle::optional& draft_tokens, + const paddle::optional& seq_lens_encoder) { auto cu_stream = input_ids.stream(); std::vector input_ids_shape = input_ids.shape(); const int bsz = seq_len.shape()[0]; @@ -65,23 +80,46 @@ std::vector GetPaddingOffsetV2(const paddle::Tensor& input_ids, auto cu_seqlens_q = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_k = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place()); - GetPaddingOffsetV2Kernel<<>>( - padding_offset.data(), - cum_offsets_out.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - x_remove_padding.data(), - input_ids.data(), - cum_offsets.data(), - seq_len.data(), - seq_length); + int max_draft_tokens = 0; + if (draft_tokens) { // speculative decoding + max_draft_tokens = draft_tokens.get().shape()[1]; + GetPaddingOffsetV2Kernel<<>>( + padding_offset.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + x_remove_padding.data(), + input_ids.data(), + cum_offsets.data(), + seq_len.data(), + draft_tokens.get_ptr()->data(), + seq_lens_encoder.get_ptr()->data(), + max_draft_tokens, + seq_length); + } else { + GetPaddingOffsetV2Kernel<<>>( + padding_offset.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + x_remove_padding.data(), + input_ids.data(), + cum_offsets.data(), + seq_len.data(), + nullptr, + nullptr, + max_draft_tokens, + seq_length); + } return {x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num}; } std::vector> GetPaddingOffsetV2InferShape(const std::vector& input_ids_shape, const std::vector& cum_offsets_shape, const std::vector& token_num_shape, - const std::vector& seq_len_shape) { + const std::vector& seq_len_shape, + const std::vector& draft_tokens_shape, + const std::vector& seq_lens_encoder_shape) { int64_t bsz = seq_len_shape[0]; int64_t seq_len = input_ids_shape[1]; return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; @@ -90,12 +128,14 @@ std::vector> GetPaddingOffsetV2InferShape(const std::vector std::vector GetPaddingOffsetV2InferDtype(const paddle::DataType& input_ids_dtype, const paddle::DataType& cum_offsets_dtype, const paddle::DataType& token_num_dtype, - const paddle::DataType& seq_len_dtype) { + const paddle::DataType& seq_len_dtype, + const paddle::DataType& draft_tokens_dtype, + const paddle::DataType& seq_lens_encoder_dtype) { return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype}; } PD_BUILD_OP(get_padding_offset_v2) - .Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"}) + .Inputs({"input_ids", "cum_offsets", "token_num", "seq_len", paddle::Optional("draft_tokens"), paddle::Optional("seq_lens_encoder"),}) .Outputs({"x_remove_padding", "cum_offsets_out", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"}) .SetKernelFn(PD_KERNEL(GetPaddingOffsetV2)) .SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape)) diff --git a/csrc/gpu/helper.h b/csrc/gpu/helper.h index dfca82cf1e6c..4e8aa488141a 100644 --- a/csrc/gpu/helper.h +++ b/csrc/gpu/helper.h @@ -211,3 +211,13 @@ inline paddle::Tensor GetEmptyTensor(const common::DDim& dims, const paddle::Dat dense_tensor.AllocateFrom(allocator, dtype, dense_tensor.numel() * phi::SizeOf(dtype)); return paddle::Tensor(std::make_shared(dense_tensor)); } + +__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids, int length) { + bool flag = false; + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; + } + } + return flag; +} diff --git a/csrc/gpu/rebuild_padding_v2.cu b/csrc/gpu/rebuild_padding_v2.cu index 6a4b83c103d3..0a167bbdc4b0 100644 --- a/csrc/gpu/rebuild_padding_v2.cu +++ b/csrc/gpu/rebuild_padding_v2.cu @@ -45,11 +45,43 @@ __global__ void RebuildPaddingV2Kernel(T *output_data, } } +template +__global__ void RebuildAppendPaddingKernel(T *output_data, + const T *input_data, + const int *cum_offset, + const int *seq_len_decoder, + const int *seq_len_encoder, + const int *output_padding_offset, + const int max_seq_len, + const int dim_embed, + const int64_t output_elem_nums) { + AlignedVector src_vec; + const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = global_idx * VecSize; i < output_elem_nums; i += gridDim.x * blockDim.x * VecSize) { + const int out_token_id = i / dim_embed; + const int ori_token_id = out_token_id + output_padding_offset[out_token_id]; + const int bi = ori_token_id / max_seq_len; + int seq_id = 0; + + if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; + else if (seq_len_encoder[bi] != 0) { + seq_id = seq_len_encoder[bi] - 1; + } + + const int input_token_id = ori_token_id - cum_offset[bi] + seq_id; + const int bias_idx = i % dim_embed; + + Load(&input_data[input_token_id * dim_embed + bias_idx], &src_vec); + Store(src_vec, &output_data[i]); + } +} + template std::vector rebuild_padding_v2(const paddle::Tensor& tmp_out, // [token_num, dim_embed] const paddle::Tensor& cum_offsets, // [bsz, 1] const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_encoder, + const paddle::optional& output_padding_offset, int max_input_length) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; @@ -60,21 +92,49 @@ std::vector rebuild_padding_v2(const paddle::Tensor& tmp_out, // const int token_num = tmp_out_shape[0]; const int dim_embed = tmp_out_shape[1]; const int bsz = cum_offsets.shape()[0]; - auto out = paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place()); + + paddle::Tensor out; + if (output_padding_offset) { + int need_delete_token_num = 0; + auto seq_lens_encoder_cpu = seq_lens_encoder.copy_to(paddle::CPUPlace(), true); + for (int i = 0; i < bsz; ++i) { + if (seq_lens_encoder_cpu.data()[i] > 0) { + need_delete_token_num += seq_lens_encoder_cpu.data()[i] - 1; + } + } + out = paddle::full({token_num - need_delete_token_num, dim_embed}, 0, D, tmp_out.place()); + } else { + out = paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place()); + } + constexpr int PackSize = VEC_16B / sizeof(DataType_); int elem_nums = out.numel(); int pack_num = elem_nums / PackSize; const int blocksize = 128; const int grid_size = (pack_num + blocksize - 1) / blocksize; - RebuildPaddingV2Kernel<<>>( - reinterpret_cast(out.data()), - reinterpret_cast(const_cast(tmp_out.data())), - cum_offsets.data(), - seq_lens_decoder.data(), - seq_lens_encoder.data(), - max_input_length, - dim_embed, - elem_nums); + if (output_padding_offset) { + RebuildAppendPaddingKernel<<>>( + reinterpret_cast(out.data()), + reinterpret_cast(tmp_out.data()), + cum_offsets.data(), + seq_lens_decoder.data(), + seq_lens_encoder.data(), + output_padding_offset.get_ptr()->data(), + max_input_length, + dim_embed, + elem_nums); + } else { + RebuildPaddingV2Kernel<<>>( + reinterpret_cast(out.data()), + reinterpret_cast(const_cast(tmp_out.data())), + cum_offsets.data(), + seq_lens_decoder.data(), + seq_lens_encoder.data(), + max_input_length, + dim_embed, + elem_nums); + } + return {out}; } @@ -82,6 +142,7 @@ std::vector RebuildPaddingV2(const paddle::Tensor& tmp_out, // [ const paddle::Tensor& cum_offsets, // [bsz, 1] const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_encoder, + const paddle::optional& output_padding_offset, int max_input_length) { switch (tmp_out.type()) { case paddle::DataType::BFLOAT16: { @@ -90,6 +151,7 @@ std::vector RebuildPaddingV2(const paddle::Tensor& tmp_out, // [ cum_offsets, seq_lens_decoder, seq_lens_encoder, + output_padding_offset, max_input_length ); } @@ -99,6 +161,7 @@ std::vector RebuildPaddingV2(const paddle::Tensor& tmp_out, // [ cum_offsets, seq_lens_decoder, seq_lens_encoder, + output_padding_offset, max_input_length ); } @@ -108,6 +171,7 @@ std::vector RebuildPaddingV2(const paddle::Tensor& tmp_out, // [ cum_offsets, seq_lens_decoder, seq_lens_encoder, + output_padding_offset, max_input_length ); } @@ -123,21 +187,31 @@ std::vector RebuildPaddingV2(const paddle::Tensor& tmp_out, // [ std::vector> RebuildPaddingV2InferShape(const std::vector& tmp_out_shape, const std::vector& cum_offsets_shape, const std::vector& seq_lens_decoder_shape, - const std::vector& seq_lens_encoder_shape) { - int64_t bsz = cum_offsets_shape[0]; - int64_t dim_embed = tmp_out_shape[1]; - return {{bsz, dim_embed}}; + const std::vector& seq_lens_encoder_shape, + const paddle::optional>& output_padding_offset_shape) { + // whether speculative decoding + if (output_padding_offset_shape) { + int64_t dim_embed = tmp_out_shape[1]; + std::vector dynamic_shape = {-1, dim_embed}; + + return {dynamic_shape}; + } else { + int64_t bsz = cum_offsets_shape[0]; + int64_t dim_embed = tmp_out_shape[1]; + return {{bsz, dim_embed}}; + } } std::vector RebuildPaddingV2InferDtype(const paddle::DataType& tmp_out_dtype, const paddle::DataType& cum_offsets_dtype, const paddle::DataType& seq_lens_decoder_dtype, - const paddle::DataType& seq_lens_encoder_dtype) { + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::optional& output_padding_offset_dtype) { return {tmp_out_dtype}; } PD_BUILD_OP(rebuild_padding_v2) - .Inputs({"tmp_out", "cum_offsets", "seq_lens_decoder", "seq_lens_encoder"}) + .Inputs({"tmp_out", "cum_offsets", "seq_lens_decoder", "seq_lens_encoder", paddle::Optional("output_padding_offset")}) .Outputs({"out"}) .Attrs({"max_input_length: int"}) .SetKernelFn(PD_KERNEL(RebuildPaddingV2)) diff --git a/csrc/gpu/save_with_output_msg.cc b/csrc/gpu/save_with_output_msg.cc index 9123578e2568..eea9872a798e 100644 --- a/csrc/gpu/save_with_output_msg.cc +++ b/csrc/gpu/save_with_output_msg.cc @@ -1,11 +1,11 @@ // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,44 +17,97 @@ #include #include #include + #include "paddle/extension.h" #define MAX_BSZ 512 +#define SPECULATE_MAX_BSZ 256 +#define MAX_DRAFT_TOKENS 6 -struct msgdata { - long mtype; - int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens +template +struct MsgData { + long mtype; + std::array mtext; }; -void SaveOutMmsg(const paddle::Tensor& x, - const paddle::Tensor& not_need_stop, - int64_t rank_id) { - if (rank_id > 0) return; - auto x_cpu = x.copy_to(paddle::CPUPlace(), false); - int64_t *x_data = x_cpu.data(); - auto not_need_stop_cpu = not_need_stop.copy_to(paddle::CPUPlace(), false); - bool* not_need_stop_data = not_need_stop_cpu.data(); +template +void SaveOutMsgFunc(MsgData& msg_sed, // NOLINT + const paddle::Tensor& x, + const paddle::Tensor& not_need_stop, + const paddle::optional& accept_num, + int64_t rank_id) { + if (rank_id > 0) return; + auto x_cpu = x.copy_to(paddle::CPUPlace(), false); + int64_t* x_data = x_cpu.data(); + auto not_need_stop_cpu = not_need_stop.copy_to(paddle::CPUPlace(), false); + bool* not_need_stop_data = not_need_stop_cpu.data(); - static struct msgdata msg_sed; - static key_t key = ftok("./", 1); - static int msgid = msgget(key, IPC_CREAT | 0666); + static key_t key = ftok("./", 1); + static int msgid = msgget(key, IPC_CREAT | 0666); + int bsz = x.shape()[0]; + if (!accept_num) { msg_sed.mtype = 1; msg_sed.mtext[0] = not_need_stop_data[0] ? 1 : -1; - int bsz = x.shape()[0]; msg_sed.mtext[1] = bsz; for (int i = 2; i < bsz + 2; i++) { - msg_sed.mtext[i] = (int)x_data[i - 2]; + msg_sed.mtext[i] = (int)x_data[i - 2]; + } + if ((msgsnd(msgid, &msg_sed, SIZE * sizeof(int), 0)) == -1) { + // printf("full msg buffer\n"); } - if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) { - // printf("full msg buffer\n"); + } else { + auto accept_num_cpu = accept_num.get().copy_to(paddle::CPUPlace(), false); + int* accept_num_data = accept_num_cpu.data(); + + msg_sed.mtype = 1; + msg_sed.mtext[0] = not_need_stop_data[0] ? 1 : -1; + msg_sed.mtext[1] = bsz; + for (int i = 2; i < SPECULATE_MAX_BSZ + 2; i++) { + if (i - 2 >= bsz) { + msg_sed.mtext[i] = 0; + } else { + msg_sed.mtext[i] = (int)accept_num_data[i - 2]; + } } - return; + for (int i = SPECULATE_MAX_BSZ + 2; i < SIZE; i++) { + int token_id = i - SPECULATE_MAX_BSZ - 2; + int bid = token_id / MAX_DRAFT_TOKENS; + int local_token_id = token_id % MAX_DRAFT_TOKENS; + if (token_id / MAX_DRAFT_TOKENS >= bsz) { + msg_sed.mtext[i] = 0; + } else { + msg_sed.mtext[i] = x_data[bid * MAX_DRAFT_TOKENS + local_token_id]; + } + } + if ((msgsnd(msgid, &msg_sed, SIZE * sizeof(int), 0)) == -1) { + printf("full msg buffer\n"); + } + } + + return; +} + +void SaveOutMsg(const paddle::Tensor& x, + const paddle::Tensor& not_need_stop, + const paddle::optional& accept_num, + int64_t rank_id) { + if (!accept_num) { + constexpr int SIZE = MAX_BSZ + 2; // stop_flag, bsz, tokens... + static struct MsgData msg_sed; + SaveOutMsgFunc(msg_sed, x, not_need_stop, accept_num, rank_id); + } else { + constexpr int SIZE = SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + + SPECULATE_MAX_BSZ + + 2; // stop_flag, bsz, accept_num*bsz, tokens... + static struct MsgData specu_msg_sed; + SaveOutMsgFunc(specu_msg_sed, x, not_need_stop, accept_num, rank_id); + } } PD_BUILD_OP(save_output) - .Inputs({"x", "not_need_stop"}) + .Inputs({"x", "not_need_stop", paddle::Optional("accept_num")}) .Attrs({"rank_id: int64_t"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) - .SetKernelFn(PD_KERNEL(SaveOutMmsg)); + .SetKernelFn(PD_KERNEL(SaveOutMsg)); \ No newline at end of file diff --git a/csrc/gpu/speculate_decoding_kernels/ngram_match.cc b/csrc/gpu/speculate_decoding_kernels/ngram_match.cc new file mode 100644 index 000000000000..3c19064b2f66 --- /dev/null +++ b/csrc/gpu/speculate_decoding_kernels/ngram_match.cc @@ -0,0 +1,234 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +int sum(const int *value, int num) { + int sum_value = 0; + for (int i = 0; i <= num; i++) { + sum_value += value[i]; + } + return sum_value; +} + +void find_candidate_pred_tokens(const int64_t *input_ids, + const int64_t *input_ids_len, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *draft_token_num, + int64_t *draft_tokens, + int32_t *seq_lens_this_time, + int32_t *seq_lens_encoder, + int32_t *seq_lens_decoder, + int64_t input_ids_stride, + int64_t pre_ids_stride, + int64_t draft_tokens_stride, + const int real_batch_size, + int max_ngram_size = 3, + int max_draft_tokens = 10) { + int threshold = 128; + char *env_var = getenv("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"); + if (env_var) { + threshold = std::stoi(env_var); + } + bool is_insert = false; + for (int batch_idx = 0; batch_idx < real_batch_size; batch_idx++) { + if (seq_lens_encoder[batch_idx] > 0) { + is_insert = true; + } + } + for (int batch_idx = 0; batch_idx < real_batch_size; batch_idx++) { + max_draft_tokens = draft_token_num[batch_idx]; + // int local_draft_tokens = max_draft_tokens; + if (seq_lens_encoder[batch_idx] > 0) { + continue; + } else if (seq_lens_decoder[batch_idx] == 0) { + seq_lens_this_time[batch_idx] = 0; + continue; + } + const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; + int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride; + const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride; + const int64_t cur_step_idx = step_idx[batch_idx]; + const int64_t cur_input_ids_len = input_ids_len[batch_idx]; + seq_lens_this_time[batch_idx] = 1; + if (!is_insert) { + auto sum_token_num = sum(seq_lens_this_time, batch_idx); + int left_min_token_num = real_batch_size - batch_idx; + + if (sum_token_num + max_draft_tokens + left_min_token_num > threshold) { + int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num; + max_draft_tokens = tmp_max_draft_tokens < max_draft_tokens ? tmp_max_draft_tokens : max_draft_tokens; + } + + if (sum_token_num + left_min_token_num >= threshold - 1) { + continue; + } + } + + + for (int ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) { + // Extract the last n tokens as our search ngram + if (cur_step_idx < ngram_size) { + continue; + } + const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); +#ifdef _DEBUG + if (batch_idx == 0) { + for (int mm = 0; mm < ngram_size; mm++) { + printf("idx %d: %lld\n", mm, ngram[mm]); + } + } + printf("cur_input_ids_len %d\n", cur_input_ids_len); +#endif + // Iterate through sliding windows of size ngram_size + bool match_input = false; + for (int64_t i = 0; i <= cur_input_ids_len - ngram_size; ++i) { + // Check if the current window matches the ngram + bool match = true; + for (int j = 0; j < ngram_size; j++) { + if (ngram[j] != cur_input_ids[i + j]) { + match = false; + break; + } + } + if (match) { + int64_t start_idx = i + ngram_size; + int64_t end_idx = std::min(start_idx + max_draft_tokens, cur_input_ids_len); + if (start_idx >= end_idx) + continue; +#ifdef _DEBUG + printf("batch_idx:%d. ngram_size:%d. idx:%lld. \n", batch_idx, ngram_size, i); + printf("start:%d. end:%d.\n", start_idx, end_idx); +#endif + // Ensure we don't go beyond the length of input_ids and avoid self-match + // if (end_idx <= cur_input_ids_len && start_idx < cur_input_ids_len - ngram_size) { + // Return a pointer to the next num_pred_tokens + int64_t cur_draft_token_num = end_idx - start_idx; + + seq_lens_this_time[batch_idx] = cur_draft_token_num + 1; + memcpy(cur_draft_tokens + 1, cur_input_ids + start_idx, sizeof(int64_t) * cur_draft_token_num); + // To break the current batch_idx for-loop + ngram_size = 0; + match_input = true; + break; + // } + } + } + if (!match_input) { +#ifdef _DEBUG + printf("match_input is false so match output\n"); +#endif + for (int64_t i = 0; i <= cur_step_idx - ngram_size; ++i) { + // Check if the current window matches the ngram + bool match = true; +#ifdef _DEBUG + printf("search %d token in pre_ids\n", i); +#endif + for (int j = 0; j < ngram_size; j++) { + if (ngram[j] != cur_pre_ids[i + j]) { + match = false; + break; + } + } + + if (match) { +#ifdef _DEBUG + printf("%d token in pre_ids matched\n", i); +#endif + int64_t start_idx = i + ngram_size; + int64_t end_idx = std::min(start_idx + max_draft_tokens, cur_step_idx); + int64_t cur_draft_token_num = end_idx - start_idx; + if (start_idx >= end_idx) + continue; + +#ifdef _DEBUG + printf("cur_step_idx %d, start_idx %lld, end_idx %lld, cur_draft_token_num is %lld\n", + cur_step_idx, + start_idx, + end_idx, + cur_draft_token_num); +#endif + + seq_lens_this_time[batch_idx] = cur_draft_token_num + 1; + memcpy(cur_draft_tokens + 1, cur_pre_ids + start_idx, sizeof(int64_t) * cur_draft_token_num); + // To break the current batch_idx for-loop + ngram_size = 0; + break; + } + } + } + } + } +} + +void NgramMatch(const paddle::Tensor &input_ids, + const paddle::Tensor &input_ids_len, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &draft_token_num, + const paddle::Tensor &draft_tokens, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const int real_batch_size, + const int max_ngram_size, + const int max_draft_tokens) { + + auto input_ids_shape = input_ids.shape(); + const int64_t input_ids_stride = input_ids_shape[1]; + + auto pre_ids_shape = pre_ids.shape(); + const int64_t pre_ids_stride = pre_ids_shape[1]; + + auto draft_tokens_shape = draft_tokens.shape(); + const int64_t draft_tokens_stride = draft_tokens_shape[1]; + + find_candidate_pred_tokens(input_ids.data(), + input_ids_len.data(), + pre_ids.data(), + step_idx.data(), + draft_token_num.data(), + const_cast(draft_tokens.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + input_ids_stride, + pre_ids_stride, + draft_tokens_stride, + real_batch_size, + max_ngram_size, + max_draft_tokens); +} + +PD_BUILD_OP(ngram_match) + .Inputs({"input_ids", + "input_ids_len", + "pre_ids", + "step_idx", + "draft_token_num", + "draft_tokens", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder"}) + .Attrs({"real_batch_size: int", "max_ngram_size: int", "max_draft_tokens: int"}) + .Outputs({"draft_tokens_out", "seq_lens_this_time_out"}) + .SetKernelFn(PD_KERNEL(NgramMatch)) + .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"seq_lens_this_time", "seq_lens_this_time_out"}}); diff --git a/csrc/gpu/speculate_decoding_kernels/speculate_get_output_padding_offset.cu b/csrc/gpu/speculate_decoding_kernels/speculate_get_output_padding_offset.cu new file mode 100644 index 000000000000..9ae2befe6a25 --- /dev/null +++ b/csrc/gpu/speculate_decoding_kernels/speculate_get_output_padding_offset.cu @@ -0,0 +1,75 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +__global__ void SpeculateGetOutputPaddingOffsetKernel( + int* output_padding_offset, + int* output_cum_offsets, + const int *output_cum_offsets_tmp, + const int *seq_lens_output, + const int max_seq_len) { + // get padding offset of each batch + const int bi = blockIdx.x; + const int ti = threadIdx.x; + int cum_offset = bi == 0 ? 0 : output_cum_offsets_tmp[bi - 1]; + for (int i = ti; i < seq_lens_output[bi]; i += blockDim.x) { + output_padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; + } + if (ti == 0) { + output_cum_offsets[bi] = cum_offset; + } +} + +std::vector SpeculateGetOutputPaddingOffset(const paddle::Tensor& output_cum_offsets_tmp, + const paddle::Tensor& out_token_num, + const paddle::Tensor& seq_lens_output, + const int max_seq_len) { + auto cu_stream = output_cum_offsets_tmp.stream(); + std::vector output_cum_offsets_tmp_shape = output_cum_offsets_tmp.shape(); + const int bsz = output_cum_offsets_tmp_shape[0]; + auto cpu_out_token_num = out_token_num.copy_to(paddle::CPUPlace(), false); + + auto output_padding_offset = paddle::full({cpu_out_token_num}, 0, paddle::DataType::INT32, output_cum_offsets_tmp.place()); + auto output_cum_offsets = output_cum_offsets_tmp.copy_to(output_cum_offsets_tmp.place(), false); + + SpeculateGetOutputPaddingOffsetKernel<<>>(output_padding_offset.data(), + output_cum_offsets.data(), + output_cum_offsets_tmp.data(), + seq_lens_output.data(), + max_seq_len); + + return {output_padding_offset, output_cum_offsets}; +} + +std::vector> SpeculateGetOutputPaddingOffsetInferShape(const std::vector& output_cum_offsets_tmp_shape, + const std::vector& out_token_num_shape, + const std::vector& seq_lens_output_shape) { + int64_t bsz = output_cum_offsets_tmp_shape[0]; + return {{-1}, {bsz}}; +} + +std::vector SpeculateGetOutputPaddingOffsetInferDtype(const paddle::DataType& output_cum_offsets_tmp_dtype, + const paddle::DataType& out_token_num_dtype, + const paddle::DataType& seq_lens_output_dtype) { + return {output_cum_offsets_tmp_dtype, output_cum_offsets_tmp_dtype}; +} + +PD_BUILD_OP(speculate_get_output_padding_offset) + .Inputs({"output_cum_offsets_tmp", "out_token_num", "seq_lens_output"}) + .Outputs({"output_padding_offset", "output_cum_offsets"}) + .Attrs({"max_seq_len: int"}) + .SetKernelFn(PD_KERNEL(SpeculateGetOutputPaddingOffset)) + .SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetOutputPaddingOffsetInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetOutputPaddingOffsetInferDtype)); \ No newline at end of file diff --git a/csrc/gpu/speculate_decoding_kernels/speculate_get_seq_lens_output.cu b/csrc/gpu/speculate_decoding_kernels/speculate_get_seq_lens_output.cu new file mode 100644 index 000000000000..b370c2bd1982 --- /dev/null +++ b/csrc/gpu/speculate_decoding_kernels/speculate_get_seq_lens_output.cu @@ -0,0 +1,72 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +__global__ void SpeculateGetSeqLensOutputKernel( + int* seq_lens_output, + const int *seq_lens_this_time, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int real_bsz) { + for (int bid = threadIdx.x; bid < real_bsz; bid += blockDim.x) { + if (seq_lens_this_time[bid] == 0) { + continue; + } else if (seq_lens_this_time[bid] == 1) { + seq_lens_output[bid] = 1; + } else if (seq_lens_encoder[bid] != 0) { + seq_lens_output[bid] = 1; + } else { + seq_lens_output[bid] = seq_lens_this_time[bid]; + } + } +} + +std::vector SpeculateGetSeqLensOutput(const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder) { + auto cu_stream = seq_lens_this_time.stream(); + std::vector seq_lens_this_time_shape = seq_lens_this_time.shape(); + const int bsz = seq_lens_this_time_shape[0]; + + auto seq_lens_output = paddle::full({bsz}, 0, paddle::DataType::INT32, seq_lens_this_time.place()); + + SpeculateGetSeqLensOutputKernel<<<1, 256, 0, cu_stream>>>(seq_lens_output.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + bsz); + + return {seq_lens_output}; +} + +std::vector> SpeculateGetSeqLensOutputInferShape(const std::vector& seq_lens_this_time_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape) { + int64_t bsz = seq_lens_this_time_shape[0]; + return {{bsz}}; +} + +std::vector SpeculateGetSeqLensOutputInferDtype(const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype) { + return {seq_lens_this_time_dtype}; +} + +PD_BUILD_OP(speculate_get_seq_lens_output) + .Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"}) + .Outputs({"seq_lens_output"}) + .SetKernelFn(PD_KERNEL(SpeculateGetSeqLensOutput)) + .SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetSeqLensOutputInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetSeqLensOutputInferDtype)); \ No newline at end of file diff --git a/csrc/gpu/speculate_decoding_kernels/speculate_set_value_by_flags.cu b/csrc/gpu/speculate_decoding_kernels/speculate_set_value_by_flags.cu new file mode 100644 index 000000000000..d03ec99ae5b5 --- /dev/null +++ b/csrc/gpu/speculate_decoding_kernels/speculate_set_value_by_flags.cu @@ -0,0 +1,73 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +__global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all, + const int64_t *accept_tokens, + const int *accept_num, + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *step_idx, + int bs, + int length, + int max_draft_tokens) { + int tid = threadIdx.x; + if (tid < bs && !stop_flags[tid]) { + int64_t *pre_ids_all_now = pre_ids_all + tid * length; + const int64_t *accept_tokens_now = accept_tokens + tid * max_draft_tokens; + const int seq_len_dec = seq_lens_decoder[tid]; + const int seq_len_enc = seq_lens_encoder[tid]; + if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped + if (step_idx[tid] >= 0) { + for (int i = 0; i < accept_num[tid]; i++) { + pre_ids_all_now[step_idx[tid] - i] = accept_tokens_now[accept_num[tid] - 1 - i]; + } + } + } +} + +void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx ) { + auto cu_stream = stop_flags.stream(); + std::vector pre_ids_all_shape = pre_ids_all.shape(); + + int bs = seq_lens_this_time.shape()[0]; + int length = pre_ids_all_shape[1]; + int max_draft_tokens = accept_tokens.shape()[1]; + int block_size = (bs + 32 - 1) / 32 * 32; + speculate_set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(const_cast(pre_ids_all.data()), + accept_tokens.data(), + accept_num.data(), + stop_flags.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + step_idx.data(), + bs, + length, + max_draft_tokens); +} + +PD_BUILD_OP(speculate_set_value_by_flags_and_idx) + .Inputs({"pre_ids_all", "accept_tokens", "accept_num", "stop_flags", "seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder", "step_idx"}) + .Outputs({"pre_ids_all_out"}) + .SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateSetValueByFlagsAndIdx)); \ No newline at end of file diff --git a/csrc/gpu/speculate_decoding_kernels/speculate_step.cu b/csrc/gpu/speculate_decoding_kernels/speculate_step.cu new file mode 100644 index 000000000000..8ef2c477adcc --- /dev/null +++ b/csrc/gpu/speculate_decoding_kernels/speculate_step.cu @@ -0,0 +1,396 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + + +__global__ void speculate_free_and_dispatch_block(bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { + typedef cub::BlockReduce, 256> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + const int tid = threadIdx.x; + if (tid < bsz) { + int *block_table_now = block_tables + tid * block_num_per_seq; + if (stop_flags[tid] && !is_block_step[tid]) { + // 回收block块 + const int encoder_block_len = encoder_block_lens[tid]; + const int decoder_used_len = used_list_len[tid]; + if (decoder_used_len > 0) { + const int ori_free_list_len = atomicAdd(free_list_len, decoder_used_len); +#ifdef DEBUG_STEP + printf("free block seq_id: %d, free block num: %d, encoder_block_len: %d, ori_free_list_len: %d\n", + tid, + decoder_used_len, + encoder_block_len, + ori_free_list_len); +#endif + for (int i = 0; i < decoder_used_len; i++) { + free_list[ori_free_list_len + i] = block_table_now[encoder_block_len + i]; + block_table_now[encoder_block_len + i] = -1; + } + encoder_block_lens[tid] = 0; + used_list_len[tid] = 0; + } + } else if (seq_lens_this_time[tid] != 0 && + block_table_now[(seq_lens_decoder[tid] + max_draft_tokens + 1) / block_size] == -1) { + // 统计需要分配block的位置和总数 + const int ori_need_block_len = atomicAdd(need_block_len, 1); + need_block_list[ori_need_block_len] = tid; +#ifdef DEBUG_STEP + printf("seq_id: %d need block\n", tid); +#endif + } + } + __syncthreads(); + + while (need_block_len[0] > free_list_len[0]) { +#ifdef DEBUG_STEP + if (tid == 0) { + printf("need_block_len: %d, free_list_len: %d\n", need_block_len[0], free_list_len[0]); + } +#endif + // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len,已解码到最后一个block的query不参与调度(马上就结束) + const int used_block_num = + tid < bsz && !is_block_step[tid] + ? used_list_len[tid] + : 0; + cub::KeyValuePair kv_pair = {tid, used_block_num}; + kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, cub::ArgMax()); + + if (tid == 0) { + const int encoder_block_len = encoder_block_lens[kv_pair.key]; +#ifdef DEBUG_STEP + printf("max_id: %d, max_num: %d, encoder_block_len: %d\n", + kv_pair.key, + kv_pair.value, + encoder_block_len); +#endif + int *block_table_now = block_tables + kv_pair.key * block_num_per_seq; + for (int i = 0; i < kv_pair.value; i++) { + free_list[free_list_len[0] + i] = block_table_now[encoder_block_len + i]; + block_table_now[encoder_block_len + i] = -1; + } + step_block_list[step_len[0]] = kv_pair.key; + step_len[0] += 1; + free_list_len[0] += kv_pair.value; + stop_flags[kv_pair.key] = true; + is_block_step[kv_pair.key] = true; + seq_lens_this_time[kv_pair.key] = 0; + seq_lens_decoder[kv_pair.key] = 0; + } + __syncthreads(); + } + + // 为需要block的位置分配block,每个位置分配一个block + if (tid < need_block_len[0]) { + const int need_block_id = need_block_list[tid]; + if (!stop_flags[need_block_id]) { + // 如果需要的位置正好是上一步中被释放的位置,不做处理 + used_list_len[need_block_id] += 1; + const int ori_free_list_len = atomicSub(free_list_len, 1); + int *block_table_now = block_tables + need_block_id * block_num_per_seq; +#ifdef DEBUG_STEP + printf("need_block_id %d\n", need_block_id); + printf("ori_free_list_len %d\n", ori_free_list_len); + printf("max_draft_tokens %d\n", max_draft_tokens); + printf("seq_lens_decoder[need_block_id] %d\n", seq_lens_decoder[need_block_id]); + printf("free_list[ori_free_list_len - 1] %d\n", free_list[ori_free_list_len - 1]); +#endif + block_table_now[(seq_lens_decoder[need_block_id] + max_draft_tokens + 1) / block_size] = + free_list[ori_free_list_len - 1]; + } + need_block_list[tid] = -1; + } + __syncthreads(); + + // 计算可以复原的query id + if (tid == 0) { + int ori_free_list_len = free_list_len[0]; + int ori_step_len = step_len[0]; + if (ori_step_len > 0) { + int ori_step_block_id = step_block_list[ori_step_len - 1]; + int tmp_used_len = used_list_len[ori_step_block_id]; + // 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中) + int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len; + while (ori_step_len > 0 && ori_free_list_len >= used_len) { +#ifdef DEBUG_STEP + printf("recover seq_id: %d, free_list_len: %d, used_list_len: %d\n", + ori_step_block_id, ori_free_list_len, used_len); +#endif + recover_block_list[recover_len[0]] = ori_step_block_id; + is_block_step[ori_step_block_id] = false; + used_list_len[ori_step_block_id] = used_len; + ori_free_list_len -= used_len; + step_block_list[ori_step_len - 1] = -1; + step_len[0] -= 1; + recover_len[0] += 1; + ori_step_len = step_len[0]; + if (ori_step_len > 0) { + ori_step_block_id = step_block_list[ori_step_len - 1]; + tmp_used_len = used_list_len[ori_step_block_id]; + used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len; + } + } + } + need_block_len[0] = 0; + } +} + +// 根据上一步计算出的可以复原的query_id进行状态恢复 +__global__ void speculate_recover_block(int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + int *ori_seq_lens_encoder, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + int64_t *pre_ids, + int64_t *step_idx, + int *encoder_block_lens, + int *used_list_len, + const int64_t *next_tokens, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length, + const int64_t first_token_ids) { + const int bid = blockIdx.x; + const int tid = threadIdx.x; + __shared__ int ori_free_list_len; + if (bid < recover_len[0]) { + const int recover_id = recover_block_list[bid]; + const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id]; + const int step_idx_now = step_idx[recover_id]; + const int seq_len = ori_seq_len_encoder + step_idx_now; + const int encoder_block_len = encoder_block_lens[recover_id]; + const int decoder_used_len = used_list_len[recover_id]; + int *block_table_now = block_tables + recover_id * block_num_per_seq; + int64_t *input_ids_now = input_ids + recover_id * length; + int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length; + if (tid == 0) { + seq_lens_this_time[recover_id] = seq_len; + seq_lens_encoder[recover_id] = seq_len; + stop_flags[recover_id] = false; + input_ids_now[ori_seq_len_encoder + step_idx_now - 1] = next_tokens[recover_id]; // next tokens + input_ids_now[0] = first_token_ids; // set first prompt token + const int ori_free_list_len_tid0 = atomicSub(free_list_len, decoder_used_len); + ori_free_list_len = ori_free_list_len_tid0; +#ifdef DEBUG_STEP + printf("seq_id: %d, ori_seq_len_encoder: %d, step_idx_now: %d, seq_len: %d, ori_free_list_len_tid0: %d, " + "ori_free_list_len: %d\n", + recover_id, + ori_seq_len_encoder, + step_idx_now, + seq_len, + ori_free_list_len_tid0, + ori_free_list_len); +#endif + } + __syncthreads(); + // 恢复block table + for (int i = tid; i < decoder_used_len; i += blockDim.x) { + block_table_now[encoder_block_len + i] = free_list[ori_free_list_len - i - 1]; + } + // 恢复input_ids + for (int i = tid; i < step_idx_now - 1; i += blockDim.x) { + input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1]; + } + } + + if (bid == 0 && tid == 0) { + recover_len[0] = 0; + } +} + +void SpeculateStepPaddle(const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, + const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, + const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &next_tokens, + const int block_size, + const int encoder_decoder_block_num, + const int64_t first_token_ids, + const int max_draft_tokens) { + auto cu_stream = seq_lens_this_time.stream(); + const int bsz = seq_lens_this_time.shape()[0]; + const int block_num_per_seq = block_tables.shape()[1]; + const int length = input_ids.shape()[1]; + const int pre_id_length = pre_ids.shape()[1]; + constexpr int BlockSize = 256; // bsz <= 256 + const int max_decoder_block_num = pre_id_length / block_size; + // const int max_decoder_block_num = 2048 / block_size - encoder_decoder_block_num; +#ifdef DEBUG_STEP + printf("bsz: %d, block_num_per_seq: %d, length: %d, max_decoder_block_num: %d\n", + bsz, + block_num_per_seq, + length, + max_decoder_block_num); +#endif + speculate_free_and_dispatch_block<<<1, BlockSize, 0, cu_stream>>>( + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(encoder_block_lens.data()), + const_cast(is_block_step.data()), + const_cast(step_block_list.data()), + const_cast(step_lens.data()), + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(need_block_list.data()), + const_cast(need_block_len.data()), + const_cast(used_list_len.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); +#ifdef DEBUG_STEP + cudaDeviceSynchronize(); +#endif + auto cpu_recover_lens = recover_lens.copy_to(paddle::CPUPlace(), false); + const int grid_size = cpu_recover_lens.data()[0]; +#ifdef DEBUG_STEP + printf("grid_size2 %d\n", grid_size); +#endif + if (grid_size > 0) { + speculate_recover_block<<>>( + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(ori_seq_lens_encoder.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(input_ids.data()), + const_cast(pre_ids.data()), + const_cast(step_idx.data()), + const_cast(encoder_block_lens.data()), + const_cast(used_list_len.data()), + next_tokens.data(), + bsz, + block_num_per_seq, + length, + pre_id_length, + first_token_ids); +#ifdef DEBUG_STEP + cudaDeviceSynchronize(); +#endif + } +} + +PD_BUILD_OP(speculate_step_paddle) + .Inputs({"stop_flags", + "seq_lens_this_time", + "ori_seq_lens_encoder", + "seq_lens_encoder", + "seq_lens_decoder", + "block_tables", + "encoder_block_lens", + "is_block_step", + "step_block_list", + "step_lens", + "recover_block_list", + "recover_lens", + "need_block_list", + "need_block_len", + "used_list_len", + "free_list", + "free_list_len", + "input_ids", + "pre_ids", + "step_idx", + "next_tokens"}) + .Attrs({"block_size: int", + "encoder_decoder_block_num: int", + "first_token_id: int64_t", + "max_draft_tokens: int"}) + .Outputs({"stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "block_tables_out", + "encoder_block_lens_out", + "is_block_step_out", + "step_block_list_out", + "step_lens_out", + "recover_block_list_out", + "recover_lens_out", + "need_block_list_out", + "need_block_len_out", + "used_list_len_out", + "free_list_out", + "free_list_len_out", + "input_ids_out"}) + .SetInplaceMap({{"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"block_tables", "block_tables_out"}, + {"encoder_block_lens", "encoder_block_lens_out"}, + {"is_block_step", "is_block_step_out"}, + {"step_block_list", "step_block_list_out"}, + {"step_lens", "step_lens_out"}, + {"recover_block_list", "recover_block_list_out"}, + {"recover_lens", "recover_lens_out"}, + {"need_block_list", "need_block_list_out"}, + {"need_block_len", "need_block_len_out"}, + {"used_list_len", "used_list_len_out"}, + {"free_list", "free_list_out"}, + {"free_list_len", "free_list_len_out"}, + {"input_ids", "input_ids_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateStepPaddle)); \ No newline at end of file diff --git a/csrc/gpu/speculate_decoding_kernels/speculate_verify_and_update.cu b/csrc/gpu/speculate_decoding_kernels/speculate_verify_and_update.cu new file mode 100644 index 000000000000..aac0a0c9fdac --- /dev/null +++ b/csrc/gpu/speculate_decoding_kernels/speculate_verify_and_update.cu @@ -0,0 +1,451 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include +#include +#include + +__device__ bool is_in(const int64_t* candidates, const int64_t draft, const int candidate_len) { + for (int i = 0; i < candidate_len; i++) { + if (draft == candidates[i]) { + return true; + } + } + return false; +} + +static uint64_t seed = 0; +static uint64_t offset = 0; + +__device__ int64_t topp_sampling_kernel(const int64_t* candidate_ids, + const float* candidate_scores, + curandState_t* dev_curand_states, + const int candidate_len, + const float topp) { + + const int tid = threadIdx.x; + + float sum_scores = 0.0f; + float rand_top_p = curand_uniform(dev_curand_states + tid) * topp; + for (int i = 0; i < candidate_len; i++) { + sum_scores += candidate_scores[i]; + if (rand_top_p <= sum_scores) { + return candidate_ids[i]; + } + } + return candidate_ids[0]; +} + +__global__ void setup_kernel(curandState_t* state, + const uint64_t seed, + const uint64_t offset, + const int bs, + const bool need_batch_random) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { + if (need_batch_random) { + curand_init(seed, i, offset, &state[i]); + } else { + curand_init(seed, 0, offset, &state[i]); + } + } +} + +template +__global__ void speculate_verify_and_update_kernel(int64_t* accept_tokens, + int* accept_num, + int64_t* step_idx, + int* seq_lens_encoder, + int* seq_lens_decoder, + bool* stop_flags, + bool* not_need_stop, + int64_t* draft_tokens, + int* actual_draft_token_nums, + curandState_t* dev_curand_states, + const float* topp, + const int* seq_lens_this_time, + const int64_t* verify_tokens, + const float* verify_scores, + const int64_t* max_dec_len, + const int64_t* end_tokens, + const bool* is_block_step, + const int* output_cum_offsets, + const int* actual_candidate_len, + const int real_bsz, + const int max_draft_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window) { + const int bid = threadIdx.x; + // start token's id of bid batch + const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; + // verify and set stop flags + int accept_num_now = 1; + int stop_flag_now_int = 0; + + if (!(is_block_step[bid] || bid >= real_bsz)) { + + if (stop_flags[bid]) { + stop_flag_now_int = 1; + } else { // Here the prefill stage also goes in, but since the draft tokens are zero in prefill stage, it goes straight to the final sampling stage. + auto* verify_tokens_now = verify_tokens + start_token_id * max_candidate_len; + auto* draft_tokens_now = draft_tokens + bid * max_draft_tokens; + auto* actual_candidate_len_now = actual_candidate_len + start_token_id; + + int i = 0; + for (; i < seq_lens_this_time[bid] - 1; i++) { + if (USE_TOPK) { + if (verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1]) { + accept_num_now++; + step_idx[bid]++; + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + if (is_in_end(accept_token, end_tokens, end_length) || step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + break; + } + } else { + break; + } + } else { + auto actual_candidate_len_value = actual_candidate_len_now[i] > max_candidate_len + ? max_candidate_len + : actual_candidate_len_now[i]; + if (is_in(verify_tokens_now + i * max_candidate_len, + draft_tokens_now[i + 1], + actual_candidate_len_value)) { + // Top P verify + accept_num_now++; + step_idx[bid]++; + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + if (is_in_end(accept_token, end_tokens, end_length) || step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + break; + } + } else { + // TopK verify + int ii = i; + if (max_candidate_len >= 2 && + verify_tokens_now[ii * max_candidate_len + 1] == draft_tokens_now[ii + 1]) { // top-2 + int j = 0; + ii += 1; + for (; j < verify_window && ii < seq_lens_this_time[bid] - 1; j++, ii++) { + if (verify_tokens_now[ii * max_candidate_len] != draft_tokens_now[ii + 1]) { + break; + } + } + if (j >= verify_window) { // accept all + accept_num_now += verify_window + 1; + step_idx[bid] += verify_window + 1; + for (; i < ii; i++) { + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + break; + } + } + } + } + break; + } + } + } + + if (!stop_flag_now_int) { + int64_t accept_token; + const float* verify_scores_now = verify_scores + start_token_id * max_candidate_len; + if (ENABLE_TOPP) { + auto actual_candidate_len_value = actual_candidate_len_now[i] > max_candidate_len + ? max_candidate_len + : actual_candidate_len_now[i]; + accept_token = topp_sampling_kernel(verify_tokens_now + i * max_candidate_len, + verify_scores_now + i * max_candidate_len, + dev_curand_states, + actual_candidate_len_value, + topp[bid]); + } else { + accept_token = verify_tokens_now[i * max_candidate_len]; + } + accept_tokens[bid * max_draft_tokens + i] = accept_token; + if (is_in_end(accept_token, end_tokens, end_length) || step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + } + step_idx[bid]++; + } + + seq_lens_decoder[bid] += accept_num_now; + + // For append mode, determine whether to reduce the number of draft tokens depending on whether they are received or not. + if (seq_lens_this_time[bid] > 1 && seq_lens_encoder[bid] == 0) { + auto current_actual_draft_token_num = actual_draft_token_nums[bid]; + if (accept_num_now - 1 == current_actual_draft_token_num) { + if (current_actual_draft_token_num + 2 <= max_draft_tokens - 1) { + actual_draft_token_nums[bid] = current_actual_draft_token_num + 2; + } else if (current_actual_draft_token_num + 1 <= max_draft_tokens - 1) { + actual_draft_token_nums[bid] = current_actual_draft_token_num + 1; + } else { + actual_draft_token_nums[bid] = max_draft_tokens - 1; + } + } else { + actual_draft_token_nums[bid] = + actual_draft_token_nums[bid] - 1 >= 1 ? actual_draft_token_nums[bid] - 1 : 1; + } + } + + if (seq_lens_encoder[bid] != 0) { + seq_lens_decoder[bid] = seq_lens_encoder[bid]; + seq_lens_encoder[bid] = 0; + } + + accept_num[bid] = accept_num_now; + draft_tokens[bid * max_draft_tokens] = accept_tokens[bid * max_draft_tokens + accept_num_now - 1]; + } + } + if (stop_flag_now_int) { + seq_lens_decoder[bid] = 0; + } + + __syncthreads(); + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); + + if (threadIdx.x == 0) { + not_need_stop[0] = stop_sum < real_bsz; + } +} + +void SpeculateVerifyAndUpdate(const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& step_idx, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& verify_tokens, + const paddle::Tensor& verify_scores, + const paddle::Tensor& max_dec_len, + const paddle::Tensor& end_tokens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& actual_candidate_len, + const paddle::Tensor& actual_draft_token_nums, + const paddle::Tensor& topp, + int max_seq_len, + int verify_window, + bool enable_topp) { + auto bsz = accept_tokens.shape()[0]; + int real_bsz = seq_lens_this_time.shape()[0]; + auto max_draft_tokens = draft_tokens.shape()[1]; + auto end_length = end_tokens.shape()[0]; + auto max_candidate_len = verify_tokens.shape()[1]; + + constexpr int BlockSize = 512; + + curandState_t* dev_curand_states; + cudaMalloc(&dev_curand_states, sizeof(curandState_t) * bsz); + setup_kernel<<<1, BlockSize, 0, accept_tokens.stream()>>>(dev_curand_states, seed, offset, bsz, true); + seed++; + offset++; + + auto err = cudaDeviceSynchronize(); + if (err != 0) { + printf("err %d\n", err); + } + + err = cudaGetLastError(); + + if (err != 0) { + printf("err %d\n", err); + } + + bool use_topk = false; + char* env_var = getenv("SPECULATE_VERIFY_USE_TOPK"); + if (env_var) { + use_topk = (bool)std::stoi(env_var); + } + if (use_topk) { + if (enable_topp) { + speculate_verify_and_update_kernel + <<<1, BlockSize, 0, accept_tokens.stream()>>>(const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + const_cast(step_idx.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(stop_flags.data()), + const_cast(not_need_stop.data()), + const_cast(draft_tokens.data()), + const_cast(actual_draft_token_nums.data()), + dev_curand_states, + topp.data(), + seq_lens_this_time.data(), + verify_tokens.data(), + verify_scores.data(), + max_dec_len.data(), + end_tokens.data(), + is_block_step.data(), + output_cum_offsets.data(), + actual_candidate_len.data(), + real_bsz, + max_draft_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window); + } else { + speculate_verify_and_update_kernel + <<<1, BlockSize, 0, accept_tokens.stream()>>>(const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + const_cast(step_idx.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(stop_flags.data()), + const_cast(not_need_stop.data()), + const_cast(draft_tokens.data()), + const_cast(actual_draft_token_nums.data()), + dev_curand_states, + topp.data(), + seq_lens_this_time.data(), + verify_tokens.data(), + verify_scores.data(), + max_dec_len.data(), + end_tokens.data(), + is_block_step.data(), + output_cum_offsets.data(), + actual_candidate_len.data(), + real_bsz, + max_draft_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window); + } + } else { + if (enable_topp) { + speculate_verify_and_update_kernel + <<<1, BlockSize, 0, accept_tokens.stream()>>>(const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + const_cast(step_idx.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(stop_flags.data()), + const_cast(not_need_stop.data()), + const_cast(draft_tokens.data()), + const_cast(actual_draft_token_nums.data()), + dev_curand_states, + topp.data(), + seq_lens_this_time.data(), + verify_tokens.data(), + verify_scores.data(), + max_dec_len.data(), + end_tokens.data(), + is_block_step.data(), + output_cum_offsets.data(), + actual_candidate_len.data(), + real_bsz, + max_draft_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window); + } else { + speculate_verify_and_update_kernel + <<<1, BlockSize, 0, accept_tokens.stream()>>>(const_cast(accept_tokens.data()), + const_cast(accept_num.data()), + const_cast(step_idx.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(stop_flags.data()), + const_cast(not_need_stop.data()), + const_cast(draft_tokens.data()), + const_cast(actual_draft_token_nums.data()), + dev_curand_states, + topp.data(), + seq_lens_this_time.data(), + verify_tokens.data(), + verify_scores.data(), + max_dec_len.data(), + end_tokens.data(), + is_block_step.data(), + output_cum_offsets.data(), + actual_candidate_len.data(), + real_bsz, + max_draft_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window); + } + } + + cudaFree(dev_curand_states); +} + +PD_BUILD_OP(speculate_verify_and_update) + .Inputs({"accept_tokens", + "accept_num", + "step_idx", + "seq_lens_encoder", + "seq_lens_decoder", + "stop_flags", + "not_need_stop", + "draft_tokens", + "seq_lens_this_time", + "verify_tokens", + "verify_scores", + "max_dec_len", + "end_tokens", + "is_block_step", + "output_cum_offsets", + "actual_candidate_len", + "actual_draft_token_nums", + "topp"}) + .Outputs({"accept_tokens_out", + "accept_num_out", + "step_idx_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "stop_flags_out", + "not_need_stop_out", + "draft_tokens_out"}) + .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"}) + .SetInplaceMap({{"accept_tokens", "accept_tokens_out"}, + {"accept_num", "accept_num_out"}, + {"step_idx", "step_idx_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"stop_flags", "stop_flags_out"}, + {"not_need_stop", "not_need_stop_out"}, + {"draft_tokens", "draft_tokens_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateVerifyAndUpdate)); \ No newline at end of file diff --git a/csrc/gpu/speculate_decoding_kernels/top_p_candidates.cu b/csrc/gpu/speculate_decoding_kernels/top_p_candidates.cu new file mode 100644 index 000000000000..9976ccabe38e --- /dev/null +++ b/csrc/gpu/speculate_decoding_kernels/top_p_candidates.cu @@ -0,0 +1,619 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +#define WARP_SIZE 32 + +template +__forceinline__ __device__ T +CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) { + return __shfl_down_sync(mask, val, static_cast(delta), width); +} + +template <> +__forceinline__ __device__ phi::dtype::float16 CudaShuffleDownSync( + unsigned mask, phi::dtype::float16 val, int delta, int width) { + return paddle::float16(__shfl_down_sync( + mask, val.to_half(), static_cast(delta), width)); +} + +template <> +__forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync( + unsigned mask, phi::dtype::bfloat16 val, int delta, int width) { + return paddle::bfloat16(__shfl_down_sync( + mask, val.to_nv_bfloat16(), static_cast(delta), width)); +} + +struct BlockPrefixCallbackOp { + // Running prefix + float running_total; + // Constructor + __device__ BlockPrefixCallbackOp(float running_total) + : running_total(running_total) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide + // scan. + __device__ float operator()(float block_aggregate) { + float old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + +#define FINAL_MASK 0xFFFFFFFF + +#define FIXED_BLOCK_DIM_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kBlockDim = (dim); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) + + + #define FIXED_TOPK_BASE(topk, ...) \ + case (topk): { \ + constexpr auto kTopK = topk; \ + __VA_ARGS__; \ + } break + + #define FIXED_TOPK(...) \ + FIXED_TOPK_BASE(2, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(3, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(4, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(5, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(8, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(10, ##__VA_ARGS__) + + +struct SegmentOffsetIter { + explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} + + __host__ __device__ __forceinline__ int operator()(int idx) const { + return idx * num_cols_; + } + + int num_cols_; +}; + +inline int div_up(int a, int n) { return (a + n - 1) / n; } + +int GetBlockSize(int vocab_size) { + if (vocab_size > 512) { + return 1024; + } else if (vocab_size > 256) { + return 512; + } else if (vocab_size > 128) { + return 256; + } else if (vocab_size > 64) { + return 128; + } else { + return 64; + } +} + +template +__global__ void FillIndex(T* indices, T num_rows, T num_cols) { + int col_id = threadIdx.x; + int row_id = blockIdx.x; + + for (T j = row_id; j < num_rows; j += gridDim.x) { + for (T i = col_id; i < num_cols; i += blockDim.x) { + indices[j * num_cols + i] = i; + } + } +} + +__global__ void SetCountIter(int* count_iter, int num) { + int tid = threadIdx.x; + int bid = blockIdx.x; + int idx = bid * blockDim.x + tid; + for (int i = idx; i < num; i += gridDim.x * blockDim.x) { + count_iter[i] = i; + } +} + + +template +__global__ void top_p_candidates_kernel(T* sorted_probs, + int64_t* sorted_id, + T* out_val, + int64_t* out_id, + int* actual_candidates_lens, + const int vocab_size, + const float topp, + const int candidates_len) { + __shared__ int stop_shared; + __shared__ float rand_p; + const int tid = threadIdx.x; + const int bid = blockIdx.x; + constexpr int NUM_WARPS = BLOCK_SIZE / 32; + const int lane_id = tid % 32; + const int warp_id = tid / 32; + + typedef cub::BlockScan BlockScan; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockScan::TempStorage temp_storage; + __shared__ typename BlockReduce::TempStorage temp_storage_reduce; + __shared__ uint32_t selected_shared[NUM_WARPS]; + + if (lane_id == 0) { + selected_shared[warp_id] = 0; + } + + + // Initialize running total + BlockPrefixCallbackOp prefix_op(0); + + __syncthreads(); + + int offset = bid * vocab_size; + int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int i_activate = 0; + float thread_offset = 0; + for (int i = tid; i < end; i += BLOCK_SIZE) { + float thread_count = + (i < vocab_size) ? static_cast(sorted_probs[offset + i]) : 0.f; + + BlockScan(temp_storage) + .InclusiveSum(thread_count, thread_offset, prefix_op); + + if (i < candidates_len) { + out_id[bid * candidates_len + i] = sorted_id[offset + i]; + out_val[bid * candidates_len + i] = sorted_probs[offset + i]; + } + + uint32_t activate_mask = __ballot_sync(FINAL_MASK, topp <= thread_offset); + i_activate = i; + if (activate_mask != 0 || i >= candidates_len) { + if (lane_id == 0) { + atomicAdd(&stop_shared, 1); + selected_shared[warp_id] = activate_mask; + } + } + __syncthreads(); + if (stop_shared > 0) { + break; + } + } + __syncthreads(); + bool skip = (selected_shared[warp_id] > 0) ? false : true; + for (int i = 0; i < warp_id; i++) { + if (selected_shared[i] != 0) { + // If the previous has stopped, skip the current warp + skip = true; + } + } + if (!skip) { + int active_lane_id = + WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0 + if (lane_id == active_lane_id) { + actual_candidates_lens[bid] = i_activate + 1; + } + } + __syncthreads(); + if (tid == 0) { + // printf("actual_candidates_lens[%d] %d\n", bid, actual_candidates_lens[bid]); + if (actual_candidates_lens[bid] == 0) { + actual_candidates_lens[bid] = candidates_len; + } + } +} + + +template +struct Pair { + __device__ __forceinline__ Pair() {} + __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {} + + __device__ __forceinline__ void set(T value, int id) { + this->v = value; + this->id = id; + } + + __device__ __forceinline__ void operator=(const Pair& in) { + v = in.v; + id = in.id; + } + + __device__ __forceinline__ bool operator<(const T value) const { + return (static_cast(v) < static_cast(value)); + } + + __device__ __forceinline__ bool operator>(const T value) const { + return (static_cast(v) > static_cast(value)); + } + __device__ __forceinline__ bool operator<(const Pair& in) const { + return (static_cast(v) < static_cast(in.v)) || + ((static_cast(v) == static_cast(in.v)) && + (id > in.id)); + } + + __device__ __forceinline__ bool operator>(const Pair& in) const { + return (static_cast(v) > static_cast(in.v)) || + ((static_cast(v) == static_cast(in.v)) && + (id < in.id)); + } + + T v; + int id; +}; + +template +__device__ __forceinline__ void AddTo(Pair topk[], + const Pair& p, + int beam_size) { + for (int k = beam_size - 2; k >= 0; k--) { + if (topk[k] < p) { + topk[k + 1] = topk[k]; + } else { + topk[k + 1] = p; + return; + } + } + topk[0] = p; +} + + + +template +__device__ __forceinline__ void GetTopK( + Pair topk[], const T* src, int idx, int dim, int beam_size) { + while (idx < dim) { + if (topk[beam_size - 1] < src[idx]) { + Pair tmp(src[idx], idx); + AddTo(topk, tmp, beam_size); + } + idx += BlockSize; + } +} + +template +__device__ __forceinline__ void GetTopK(Pair topk[], + const T* src, + int idx, + int dim, + const Pair& max, + int beam_size) { + while (idx < dim) { + if (topk[beam_size - 1] < src[idx]) { + Pair tmp(src[idx], idx); + if (tmp < max) { + AddTo(topk, tmp, beam_size); + } + } + idx += BlockSize; + } +} + +template +__device__ __forceinline__ void ThreadGetTopK(Pair topk[], + int* beam, + int beam_size, + const T* src, + bool* firstStep, + bool* is_empty, + Pair* max, + int dim, + const int tid) { + if (*beam > 0) { + int length = (*beam) < beam_size ? *beam : beam_size; + if (*firstStep) { + *firstStep = false; + GetTopK(topk, src, tid, dim, length); + } else { + for (int k = 0; k < MaxLength; k++) { + if (k < MaxLength - (*beam)) { + topk[k] = topk[k + *beam]; + } else { + topk[k].set(std::numeric_limits::min(), -1); + } + } + if (!(*is_empty)) { + GetTopK( + topk + MaxLength - *beam, src, tid, dim, *max, length); + } + } + + *max = topk[MaxLength - 1]; + if ((*max).id == -1) *is_empty = true; + *beam = 0; + } +} + + +template +__forceinline__ __device__ Pair WarpReduce(Pair input) { +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + T tmp_val = + CudaShuffleDownSync(FINAL_MASK, input.v, offset); + int tmp_id = + CudaShuffleDownSync(FINAL_MASK, input.id, offset); + if (static_cast(input.v) < static_cast(tmp_val)) { + input.v = tmp_val; + input.id = tmp_id; + } + } + return input; +} + + + +template +__device__ __forceinline__ void BlockReduce(Pair shared_max[], + Pair topk[], + Pair beam_max[], + int* beam, + int* k, + int* count, + const int tid, + const int wid, + const int lane) { + while (true) { + __syncthreads(); + Pair input_now = topk[0]; + input_now = WarpReduce(input_now); + + if (lane == 0) { + shared_max[wid] = input_now; + } + __syncthreads(); + input_now = (tid < BlockSize / 32) + ? shared_max[lane] + : Pair(std::numeric_limits::min(), -1); + if (wid == 0) { + input_now = WarpReduce(input_now); + if (lane == 0) shared_max[0] = input_now; + } + __syncthreads(); + if (tid == 0) { + beam_max[*count] = shared_max[0]; + (*count)++; + } + int tid_max = shared_max[0].id % BlockSize; + if (tid == tid_max) { + (*beam)++; + } + if (--(*k) == 0) break; + __syncthreads(); + + if (tid == tid_max) { + if (*beam < MaxLength) { + topk[0] = topk[*beam]; + } + } + + if (MaxLength < 5) { + if (*beam >= MaxLength) break; + } else { + unsigned mask = 0u; + mask = __ballot_sync(FINAL_MASK, true); + if (tid_max / 32 == wid) { + if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) == MaxLength) + break; + } + } + } +} + +template +__global__ void KeMatrixTopPBeamTopKFt(const T* src, + const T* top_ps, + const int* output_padding_offset, + int64_t* out_id, // [max_cadidate_len, 1] + T* out_val, // [max_cadidate_len, 1] + int* actual_candidates_lens, + int vocab_size, + const int max_cadidate_len, + const int max_seq_len) { + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane = tid % 32; + const int token_id = blockIdx.x; + const int ori_token_id = token_id + output_padding_offset[token_id]; + const int bid = ori_token_id / max_seq_len; + + int top_num = TopPBeamTopK; + float top_p_value = static_cast(top_ps[bid]); + + __shared__ Pair shared_max[BlockSize / 32]; + __shared__ Pair beam_max[TopPBeamTopK]; + + Pair topk[MaxLength]; + int beam = MaxLength; + Pair max; + bool is_empty = false; + bool firststep = true; + __shared__ int count; + + if (tid == 0) { + count = 0; + } + + for (int j = 0; j < MaxLength; j++) { + topk[j].set(std::numeric_limits::min(), -1); + } + + while (top_num) { + ThreadGetTopK(topk, + &beam, + TopPBeamTopK, + src + token_id * vocab_size, + &firststep, + &is_empty, + &max, + vocab_size, + tid); + BlockReduce( + shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane); + } + if (tid == 0) { + float sum_prob = 0.0f; + bool flag = false; + for(int i = 0; i < TopPBeamTopK; i++) { + out_id[token_id * max_cadidate_len + i] = static_cast(beam_max[i].id); + out_val[token_id * max_cadidate_len + i] = beam_max[i].v; + float val = static_cast(beam_max[i].v); + sum_prob += val; + + if(sum_prob >= top_p_value) { + actual_candidates_lens[token_id] = i + 1; + break; + } + } + } +} + + +template +void DispatchTopK(const T* src, + const T* top_ps, + const int* output_padding_offset, + int64_t* out_id, // topk id + T* out_val, // topk val + int* actual_candidates_lens_data, + const int vocab_size, + const int token_num, + const int cadidate_len, + const int max_seq_len, + cudaStream_t& stream) { + int BlockSize = GetBlockSize(vocab_size); + switch (cadidate_len) { + FIXED_TOPK( + switch (BlockSize) { + FIXED_BLOCK_DIM( + KeMatrixTopPBeamTopKFt + <<>>( + src, + top_ps, + output_padding_offset, + out_id, + out_val, + actual_candidates_lens_data, + vocab_size, + cadidate_len, + max_seq_len) + ); + default: + PD_THROW("the input data shape has error in the topp_beam_topk kernel."); + } + ); + default: + PD_THROW("the input topk is not implemented."); + } +} + + + +template +std::vector LaunchTopPCandidates(const paddle::Tensor& probs, // [token_num, vocab_size] + const paddle::Tensor& top_p, // [token_num] + const paddle::Tensor& output_padding_offset, + const int candidates_len, + const int max_seq_len) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + std::vector input_shape = probs.shape(); + const int token_num = input_shape[0]; + const int vocab_size = input_shape[1]; + + auto verify_scores = paddle::full({token_num, candidates_len}, 0, D, probs.place()); + auto verify_tokens = paddle::full({token_num, candidates_len}, 0, paddle::DataType::INT64, probs.place()); + auto actual_candidate_lens = paddle::full({token_num}, 0, paddle::DataType::INT32, probs.place()); + + auto stream = probs.stream(); + + constexpr int TopKMaxLength = 2; + DispatchTopK( + reinterpret_cast(probs.data()), + reinterpret_cast(top_p.data()), + output_padding_offset.data(), + verify_tokens.data(), + reinterpret_cast(verify_scores.data()), + actual_candidate_lens.data(), + vocab_size, + token_num, + candidates_len, + max_seq_len, + stream + ); + + return {verify_scores, verify_tokens, actual_candidate_lens}; + +} + + +std::vector DispatchTopPCandidatesWithDtype(const paddle::Tensor& probs, + const paddle::Tensor& top_p, + const paddle::Tensor& output_padding_offset, + int candidates_len, + int max_seq_len) { + switch (probs.type()) { + case paddle::DataType::BFLOAT16: + return LaunchTopPCandidates(probs, top_p, output_padding_offset, candidates_len, max_seq_len); + break; + case paddle::DataType::FLOAT16: + return LaunchTopPCandidates(probs, top_p, output_padding_offset, candidates_len, max_seq_len); + break; + case paddle::DataType::FLOAT32: + return LaunchTopPCandidates(probs, top_p, output_padding_offset, candidates_len, max_seq_len); + break; + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16, float16 and float32 are supported. "); + break; + } +} + + +std::vector TopPCandidates(const paddle::Tensor& probs, + const paddle::Tensor& top_p, + const paddle::Tensor& output_padding_offset, + int candidates_len, + int max_seq_len) { + return DispatchTopPCandidatesWithDtype(probs, top_p, output_padding_offset, candidates_len, max_seq_len); +} + +std::vector> TopPCandidatesInferShape(const std::vector& probs_shape, + const std::vector& top_p_shape, + const std::vector& output_padding_offset_shape, + int max_candidates_len) { + int token_num = probs_shape[0]; + return {{token_num, max_candidates_len}, {token_num, max_candidates_len}, {token_num}}; +} + +std::vector TopPCandidatesInferDtype(const paddle::DataType& probs_dtype, + const paddle::DataType& top_p_dtype, + const paddle::DataType& output_padding_offset_dtype) { + return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32}; +} + +PD_BUILD_OP(top_p_candidates) + .Inputs({"probs","top_p", "output_padding_offset"}) + .Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"}) + .Attrs({"candidates_len: int", "max_seq_len: int"}) + .SetKernelFn(PD_KERNEL(TopPCandidates)) + .SetInferShapeFn(PD_INFER_SHAPE(TopPCandidatesInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(TopPCandidatesInferDtype)); \ No newline at end of file diff --git a/csrc/gpu/stop_generation_multi_ends.cu b/csrc/gpu/stop_generation_multi_ends.cu index 7be2c6cf3cd1..b74ac028f28e 100644 --- a/csrc/gpu/stop_generation_multi_ends.cu +++ b/csrc/gpu/stop_generation_multi_ends.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/extension.h" +#include "helper.h" #include #include #include @@ -32,16 +32,6 @@ void set_flags_multi_ends(char *str_flags, bool *res, int length) { } } -__device__ bool is_in_end(const int64_t id, const int64_t *end_ids, int length) { - bool flag = false; - for (int i = 0; i < length; i++) { - if (id == end_ids[i]) { - return true; - } - } - return flag; -} - __global__ void set_value_by_flags(const bool *stop_flags, const int64_t *end_ids, int64_t *topk_ids, bool *stop_flags_out, const int bs, int end_length) { int tid = threadIdx.x; if (tid < bs) { diff --git a/csrc/gpu/unittest/test_get_padding_offset_v2.py b/csrc/gpu/unittest/test_get_padding_offset_v2.py index 836ff8b0fb53..aabaa7e971b4 100644 --- a/csrc/gpu/unittest/test_get_padding_offset_v2.py +++ b/csrc/gpu/unittest/test_get_padding_offset_v2.py @@ -38,6 +38,8 @@ def test_get_padding_offset_v2(self): paddle.to_tensor(cum_offset), paddle.to_tensor(token_num), paddle.to_tensor(seq_lens), + None, # draft_tokens + None, # seq_lens_encoder ) print("input_ids is :\n", input_ids) diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index c05acaf371fb..38f585f12198 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -112,8 +112,10 @@ def get_gencode_flags(): "./gpu/sample_kernels/top_p_sampling_reject.cu", "./gpu/update_inputs_v2.cu", "./gpu/set_preids_token_penalty_multi_scores.cu", + "./gpu/speculate_decoding_kernels/ngram_match.cc", ] sources += find_end_files("./gpu/append_attn/template_instantiation", ".cu") +sources += find_end_files("./gpu/speculate_decoding_kernels", ".cu") nvcc_compile_args = gencode_flags update_git_submodule() @@ -151,6 +153,7 @@ def get_gencode_flags(): sources += find_end_files(fp8_auto_gen_directory, ".cu") sources += [ "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu", + "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu", "gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu", ] diff --git a/llm/README.md b/llm/README.md index 307deec0f8f6..7c297f752aea 100644 --- a/llm/README.md +++ b/llm/README.md @@ -228,16 +228,16 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo ```shell # PTQ 量化启动命令参考 -python run_finetune.py ./config/llama/ptq_argument.json +python run_quantization.py ./config/llama/ptq_argument.json # GPTQ 量化启动命令参考 -python run_finetune.py ./config/llama/ptq_argument.json +python run_quantization.py ./config/llama/gptq_argument.json # W8A8C8(INT)量化启动命令参考 -python run_finetune.py ./config/llama/ptq_c8_argument.json +python run_quantization.py ./config/llama/ptq_c8_argument.json # W8A8(FP8)量化启动命令参考 -python run_finetune.py ./config/llama/fp8_ptq_argument.json +python run_quantization.py ./config/llama/fp8_ptq_argument.json ``` 更多技术细节和模型量化使用详见[量化文档](./docs/quantization.md)。 diff --git a/llm/config/llama/reft_argument.json b/llm/config/llama/reft_argument.json new file mode 100644 index 000000000000..f1678a29a792 --- /dev/null +++ b/llm/config/llama/reft_argument.json @@ -0,0 +1,33 @@ +{ + "model_name_or_path": "meta-llama/Llama-2-7b", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/reft_ckpts", + "per_device_train_batch_size": 2, + "gradient_accumulation_steps": 2, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 2, + "learning_rate": 3e-04, + "warmup_ratio":0.01, + "logging_steps": 1, + "remove_unused_columns":false, + "evaluation_strategy": "no", + "metric_for_best_model": "no", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 512, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": false, + "disable_tqdm": true, + "load_best_model_at_end": false, + "eval_with_do_generation": false, + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "zero_padding": false, + "recompute": false, + "reft": true, + "intervention_type": "TinyIntervention" +} diff --git a/llm/config/qwen/AdvertiseGen/wfp8afp8_ptq_argument.json b/llm/config/qwen/AdvertiseGen/wfp8afp8_ptq_argument.json index ff261704b8cc..1bc207128856 100644 --- a/llm/config/qwen/AdvertiseGen/wfp8afp8_ptq_argument.json +++ b/llm/config/qwen/AdvertiseGen/wfp8afp8_ptq_argument.json @@ -17,6 +17,5 @@ "unified_checkpoint": false, "smooth": false, "weight_quant_method": "abs_max", - "act_quant_method": "abs_max", - "skip_list_names": ["down_proj"] + "act_quant_method": "abs_max" } \ No newline at end of file diff --git a/llm/config/qwen2moe/lora_argument.json b/llm/config/qwen2moe/lora_argument.json new file mode 100644 index 000000000000..47e7adb14ecd --- /dev/null +++ b/llm/config/qwen2moe/lora_argument.json @@ -0,0 +1,34 @@ +{ + "model_name_or_path": "Qwen/Qwen2-57B-A14B", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/lora_ckpts", + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-04, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 2048, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "lora": true, + "unified_checkpoint": true, + "zero_padding": false, + "use_flash_attention": true, + "pissa": false + } diff --git a/llm/config/qwen2moe/pretrain_argument.json b/llm/config/qwen2moe/pretrain_argument.json new file mode 100644 index 000000000000..f3115a64b648 --- /dev/null +++ b/llm/config/qwen2moe/pretrain_argument.json @@ -0,0 +1,40 @@ +{ + "model_name_or_path": "Qwen/Qwen2-57B-A14B", + "tokenizer_name_or_path": "Qwen/Qwen2-57B-A14B", + "input_dir": "./data", + "output_dir": "./checkpoints/pretrain_ckpts", + "per_device_train_batch_size": 2, + "gradient_accumulation_steps": 1, + "per_device_eval_batch_size": 2, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "sharding": "stage2", + "virtual_pp_degree": 1, + "sequence_parallel": 0, + "use_flash_attention": true, + "use_fused_rms_norm": true, + "max_seq_length": 4096, + "learning_rate": 3e-05, + "min_learning_rate": 3e-06, + "warmup_steps": 30, + "logging_steps": 1, + "max_steps": 10000, + "save_steps": 5000, + "eval_steps": 1000, + "weight_decay": 0.01, + "bf16": true, + "fp16_opt_level": "O2", + "warmup_ratio": 0.01, + "max_grad_norm": 1.0, + "dataloader_num_workers": 1, + "continue_training": 1, + "do_train": true, + "do_eval": true, + "do_predict": true, + "disable_tqdm": true, + "recompute": true, + "distributed_dataloader": 1, + "recompute_granularity": "full", + "unified_checkpoint": true, + "save_total_limit": 2 + } diff --git a/llm/config/qwen2moe/sft_argument.json b/llm/config/qwen2moe/sft_argument.json new file mode 100644 index 000000000000..c964137f2264 --- /dev/null +++ b/llm/config/qwen2moe/sft_argument.json @@ -0,0 +1,33 @@ +{ + "model_name_or_path": "Qwen/Qwen2-57B-A14B", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/sft_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-05, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 2048, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "sharding": "stage2", + "zero_padding": false, + "unified_checkpoint": true, + "use_flash_attention": true + } diff --git a/llm/docs/finetune.md b/llm/docs/finetune.md index 6a17e2031447..a7ab57c8e313 100644 --- a/llm/docs/finetune.md +++ b/llm/docs/finetune.md @@ -130,6 +130,11 @@ python merge_lora_params.py \ - `neftune_noise_alpha`: NEFT alpha 参数,默认为5.0。 - `vera`: 是否开启 VeRA 微调策略,默认为 False。 - `vera_rank`: VeRA 算法中 rank(秩)的值,默认为8。 +- `use_long_sequence_strategies`: 是否使用长序列扩展策略,默认为 False。 +- `strategy_type`: 长序列扩展策略的类型,默认为 None。 +- `strategy_name`: 长序列扩展策略的具体名称,默认为 None。 +- `rope_scaling_factor`: 应用 RoPE 扩展策略时的缩放因子。 +- `lora_use_mixer`: 是否开启 MosLoRA 策略。   数据参数(DataArgument)
@@ -140,6 +145,8 @@ python merge_lora_params.py \ - `src_length`: 模型输入上下文最大 token 长度,默认为1024。 - `max_length`:模型输入(上下文+生成内容)的最大 token 长度, 默认为2048。当`zero_padding`设为 True 的时候,同时也为 Zero Padding 数据流模型训练输入最大长度,通常建议设为模型允许输入最大长度,同时`per_device_train_batch_size`设为1,使用`gradient_accumulation_steps`控制 batch size。 - `lazy`:设置为 False 则使用`MapDataset`,设置为 True 则使用`IterDataset`,默认为 False。对于数据量较大的时候建议设为 True,`IterDataset`可以避免一次性将所有数据读入内存,注意需要设置`max_steps`并且`evaluation_strategy`和`save_strategy`设为`steps` +- `autoregressive`: 是否使用自回归生成,即训练数据为无监督数据,默认为 False。 +- `use_pose_convert`: 是否使用 PoSE 算法的数据处理,默认为 False。
@@ -176,3 +183,17 @@ python merge_lora_params.py \ - `sharding_parallel_degree`: 表示分组参数切片的数据并行大小. 默认值1, 表示不启用分组参数切片的数据并行。 - `sharding`:是否使用 Paddle 的 Sharding 数据并行功能,用户的参数。支持 sharding `stage1`, `stage2` or `stage3`。其中`stage2``stage3`可以和`offload`组合使用。 + + + +  表征微调(ReFT)参数(ReftArgument)
+ +- `model_name_or_path`: 预训练模型名称或者本地的模型路径,用于热启模型和分词器,默认为 None。每个模型**支持模型权重**详见各模型目录。 +- `layers`: 干预模型的那些层,默认为 all, 干预所有层。 +- `position`: 干预哪些位置的 token,默认为 f7, 干预前7个 token。 +- `intervention_type`: 干预网络的类型,默认为 LoReftIntervention。 +- `rank`: 干预网络的低秩,默认为 8。 +- `act_fn`: 干预网络中的激活函数,默认为 linear。 +- `add_bias`: 干预网络中是否添加偏置,默认为 False。 +- `dropout`: 干预网络中的 Dropout rate,默认为 0.00。 +
diff --git a/llm/docs/predict/best_practices.md b/llm/docs/predict/best_practices.md index fdedbaf2ea61..3c5c41ff5866 100644 --- a/llm/docs/predict/best_practices.md +++ b/llm/docs/predict/best_practices.md @@ -13,6 +13,10 @@ PaddleNLP 提供了多种环境变量,用于优化推理性能和资源使用 - `FLAGS_use_cutlass_device_best_config_path`: 在 `FLAGS_CUTLASS_FP8_GEMM` 设为 True 的前提下,使用该环境变量来指定离线调优出的 fp8 gemm 配置文件。配置文件可以通过`PaddleNLP/csrc/utils/tune_cutlass_fp8_*.py`产出,该脚本会自动搜索当前输入大小下提供的最优 gemm 配置并将结果记录下来,默认产出文件为`fp8_fuse_gemm_config.json`。不同 NVIDIA GPU 和 CUDA 版本需要分别调优,SM89架构 GPU 增加 dual_gemm 调优,具体可参考`dual_gemm.py`。可选值:`tune`,开启调优;空值或`default`,使用默认配置;任意值,优先使用配置文件中的参数,若无则使用默认配置。 +- `FLAGS_cuda_core_int8_gemm`:是否开启小 Batch Int8 Gemm 优化,默认值不开启。设为1可开启,推理 A8W8模型时,平均性能会加速约40%-55%,适用于 SM>=70的显卡。 + +- `FLAGS_cuda_core_fp8_gemm`:是否开启小 Batch FP8 Gemm 优化,默认值不开启。设为1可开启,推理 FP8模型时,平均性能会加速约30%左右,适用于 SM>=89的显卡。 + **GQA 优化** - `FLAGS_use_xqa_optim`:gpa 是否开启 xqa 优化,默认值为0,表示不开启。gqa 模型(如 llama3/3.1、qwen2)设为1性能会更好。 diff --git a/llm/docs/quantization.md b/llm/docs/quantization.md index 9a507c6d1815..92353b8f3fba 100644 --- a/llm/docs/quantization.md +++ b/llm/docs/quantization.md @@ -67,31 +67,31 @@ python prepare_data_for_ptq.py ### 2.3 PTQ 量化 ```shell -python run_finetune.py ./config/llama/ptq_argument.json +python run_quantization.py ./config/llama/ptq_argument.json ``` ### 2.4 GPTQ 量化 ```shell -python run_finetune.py ./config/llama/gptq_argument.json +python run_quantization.py ./config/llama/gptq_argument.json ``` ### 2.5 AWQ 量化 ```shell -python run_finetune.py ./config/llama/awq_argument.json +python run_quantization.py ./config/llama/awq_argument.json ``` ### 2.6 W8A8C8(INT8)量化 ```shell -python run_finetune.py ./config/llama/ptq_c8_argument.json +python run_quantization.py ./config/llama/ptq_c8_argument.json ``` ### 2.7 W8A8(FP8)量化 ```shell -python run_finetune.py ./config/llama/fp8_ptq_argument.json +python run_quantization.py ./config/llama/fp8_ptq_argument.json ``` ### 2.8 量化参数介绍 diff --git a/llm/experimental/layers/cache_kv.py b/llm/experimental/layers/cache_kv.py deleted file mode 100644 index e159ae8f5096..000000000000 --- a/llm/experimental/layers/cache_kv.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import paddle -from paddle import ParamAttr -from paddle.nn import Layer -from paddle.nn.initializer import Constant -from paddle.nn.quant.format import ConvertibleQuantedLayer - - -class CacheKVMatMul(Layer): - def __init__(self): - super().__init__() - - def forward(self, x, y, transpose_x=False, transpose_y=False, name=None): - return paddle.matmul(x, y, transpose_x, transpose_y, name) - - -class QuantizedCacheKVMatMul(ConvertibleQuantedLayer): - def __init__(self, layer: Layer, q_config): - super().__init__() - # For FakeQuant - self.activation_quanter = None - self.weight_quanter = None - if q_config.activation is not None: - self.activation_quanter = q_config.activation._instance(layer) - - def forward(self, x, y, transpose_x=False, transpose_y=False, name=None): - # qdq - if self.activation_quanter is not None: - y = self.activation_quanter(y) - return paddle.matmul(x, y, transpose_x, transpose_y, name) - - def weights_to_quanters(self): - return [("weight", "weight_quanter")] - - def activation_quanters(self): - return ["activation_quanter"] - - -class ShiftSmoothCacheKVMatMul(Layer): - """ - The computational logic of ShiftSmoothCacheKVMatMul is the same as CacheKVMatMul. - The only difference is that its inputs are shift. - """ - - def __init__(self): - super().__init__() - self.sequence_parallel = False - self.dtype = None - - def forward( - self, - x, - y, - transpose_x=False, - transpose_y=False, - perm_x=None, - perm_y=None, - use_smooth_x=False, - use_smooth_out=False, - name=None, - sequence_parallel=False, - ): - self.sequence_parallel = sequence_parallel - # smooth - smooth_x, smooth_y = self._smooth(x, y, use_smooth_x) - # transpose - if perm_x is not None: - smooth_x = paddle.transpose(smooth_x, perm=perm_x) - if perm_y is not None: - smooth_y = paddle.transpose(smooth_y, perm=perm_y) - # matmul output - out = paddle.matmul(smooth_x, smooth_y, transpose_x, transpose_y, name) - if not use_smooth_out: - return out - else: - # combine heads - if self.sequence_parallel: - out = paddle.transpose(out, perm=[2, 0, 1, 3]) - else: - out = paddle.transpose(out, perm=[0, 2, 1, 3]) - return paddle.multiply(out, self.smooth_weight) - - def _smooth(self, x, y, use_smooth_x): - # For ShiftSmooth - smooth_shape = [1] - self.dtype = y.dtype - if not hasattr(self, "smooth_weight"): - self.smooth_weight = self.create_parameter( - shape=smooth_shape, attr=ParamAttr(initializer=Constant(value=1.0)), dtype=self.dtype - ) - smooth_y = y - smooth_y = paddle.divide(smooth_y, self.smooth_weight) - - if use_smooth_x: - smooth_x = x - x = paddle.multiply(smooth_x, self.smooth_weight) - return x, smooth_y - - def convert_weight(self, smooth_weight=None): - if smooth_weight is not None: - self.smooth_weight.set_value(smooth_weight.squeeze().cast(self.dtype)) - - -class QuantizedShiftSmoothCacheKVMatMul(ConvertibleQuantedLayer): - """ - The computational logic of QuantizedShiftSmoothCacheKVMatMul is the same as RowParallelLinear. - The only difference is that its inputs are shift. - """ - - def __init__(self, layer: Layer, q_config): - super().__init__() - - # For FakeQuant - self.weight_quanter = None - self.activation_quanter = None - self.smooth_weight = layer.smooth_weight - if q_config.activation is not None: - self.activation_quanter = q_config.activation._instance(layer) - - def forward( - self, - x, - y, - transpose_x=False, - transpose_y=False, - perm_x=None, - perm_y=None, - use_smooth_x=False, - use_smooth_out=False, - name=None, - sequence_parallel=False, - ): - # smooth - smooth_x, smooth_y = self._smooth(x, y, use_smooth_x) - # qdq - if self.activation_quanter is not None: - smooth_y = self.activation_quanter(smooth_y) - # transpose - if perm_x is not None: - smooth_x = paddle.transpose(smooth_x, perm=perm_x) - if perm_y is not None: - smooth_y = paddle.transpose(smooth_y, perm=perm_y) - # matmul output - out = paddle.matmul(smooth_x, smooth_y, transpose_x, transpose_y, name) - if not use_smooth_out: - return out - else: - # combine heads - if sequence_parallel: - out = paddle.transpose(out, perm=[2, 0, 1, 3]) - else: - out = paddle.transpose(out, perm=[0, 2, 1, 3]) - return paddle.multiply(out, self.smooth_weight) - - def _smooth(self, x, y, use_smooth_x): - # For ShiftSmooth - self.dtype = y.dtype - smooth_y = y - smooth_y = paddle.divide(smooth_y, self.smooth_weight) - - if use_smooth_x: - smooth_x = x - x = paddle.multiply(smooth_x, self.smooth_weight) - return x, smooth_y - - def weights_to_quanters(self): - return [("weight", "weight_quanter")] - - def activation_quanters(self): - return ["activation_quanter"] diff --git a/llm/experimental/layers/custom_attention.py b/llm/experimental/layers/custom_attention.py deleted file mode 100644 index c40c815b3346..000000000000 --- a/llm/experimental/layers/custom_attention.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Custome Attention Layer for quantization. -""" -# import paddle -import paddle.tensor as tensor -from paddle.nn import Layer -from paddle.nn.quant.format import ConvertibleQuantedLayer - - -class QuantizedCustomAttentionLayer(ConvertibleQuantedLayer): - """ - Quantized Custom Attention Layer. - """ - - def __init__(self, layer: Layer, q_config=None): - """ - Initialize the QuantizeWrapper class. - - Args: - layer (Layer): The layer to be quantized. - q_config (QuantConfig, optional): The quantization configuration. Defaults to None. - """ - super().__init__() - # hard code: get activation quanter from weight - self.activation_quanter_k = q_config.weight._instance(layer) - self.activation_quanter_v = q_config.activation._instance(layer) - self.layer = layer - self.enable_fake_quant = False - self.quant_info = None - layer_name = self.layer.full_name() - self.layer_id = int(layer_name.split("_")[-1]) - self.kv_losses = {} - - def forward( - self, - q, - config, - k, - v, - attention_mask, - output_attentions, - # alibi, - # attn_mask_startend_row_indices, - # sequence_parallel, - **kwargs - ): - """forward""" - if self.enable_fake_quant: - self.collect_kv_quant_policy(q, k, v, **kwargs) - perm = [0, 2, 1, 3] # [1, 2, 0, 3] if self.sequence_parallel else [0, 2, 1, 3] - tmp_k = tensor.transpose(x=k, perm=perm) - tmp_v = tensor.transpose(x=v, perm=perm) - if self.activation_quanter_k is not None: - tmp_k = self.activation_quanter_k(tmp_k) - if self.activation_quanter_v is not None: - tmp_v = self.activation_quanter_v(tmp_v) - k = tensor.transpose(x=tmp_k, perm=perm) - v = tensor.transpose(x=tmp_v, perm=perm) - return self.layer( - q, - config, - k, - v, - attention_mask, - output_attentions, - # alibi, - # attn_mask_startend_row_indices, - # sequence_parallel, - **kwargs, - ) - - def weights_to_quanters(self): - """weights to quanters""" - return [] - - def activation_quanters(self): - """activation to quanters""" - return ["activation_quanter_k", "activation_quanter_v"] diff --git a/llm/experimental/observer/abs_max.py b/llm/experimental/observer/abs_max.py deleted file mode 100644 index 9d30db49cba3..000000000000 --- a/llm/experimental/observer/abs_max.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -from paddle.quantization.factory import ObserverFactory - -from .uniform import UniformObserver - - -class AbsmaxObserver(ObserverFactory): - r""" - It collects maximum absolute values of target tensor. - Args: - bit_length(int, optional): Number of bits to represent an quantized integer in binary. - dtype(str, optional): The data type of input tensor. - name (str, optional): This parameter is used by developers to print debugging information. \ - For details, please refer to :ref:`api_guide_Name`. Default is None. - Examples: - .. code-block:: python - from paddle.quantization import QuantConfig - from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver - quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99) - q_config = QuantConfig(activation=quanter, weight=quanter) - """ - - def __init__(self, quant_bits=8): - super(AbsmaxObserver, self).__init__(quant_bits=quant_bits) - - def _get_class(self): - return AbsmaxObserverLayer - - -class AbsmaxObserverLayer(UniformObserver): - def __init__( - self, - layer, - quant_bits=8, - ): - super(AbsmaxObserverLayer, self).__init__(quant_bits=quant_bits) - self._quant_bits = quant_bits - self._layer = layer - self._scale = None - self._zero_point = None - self._min = None - self._max = paddle.to_tensor(1e-7, dtype="float32") - self.step = 0 - - def forward(self, inputs): - """Calculate forward pass.""" - self._min, self._max = self.cal_min_max(inputs) - return inputs - - def cal_min_max(self, inputs): - abs_max_val = paddle.max(paddle.abs(inputs.cast("float32"))) - abs_max_val = paddle.maximum(abs_max_val, self._max) - return 0, abs_max_val - - def cal_thresholds(self): - """Compute thresholds for MAX function.""" - if self._scale is not None: - self._zero_point = 0 - return - self._scale, self._zero_point = self.cal_scales_zero_points() - - def min_value(self) -> float: - return self._min - - def max_value(self) -> float: - return self._max - - def bit_length(self): - """Return the bit length of quantized data.""" - return self._quant_bits - - def quant_axis(self): - """Return quantization axis.""" - return -1 - - def scales(self): - """Return output scales.""" - if self._scale is None: - self.cal_thresholds() - return self._scale - - def zero_points(self): - """Return output zero points.""" - if self._zero_point is None: - self.cal_thresholds() - return self._zero_point diff --git a/llm/experimental/observer/abs_max_headwise.py b/llm/experimental/observer/abs_max_headwise.py deleted file mode 100644 index 500fbfa1ff55..000000000000 --- a/llm/experimental/observer/abs_max_headwise.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import paddle -from experimental.observer.channel_wise import ChannelWiseObserver -from paddle.quantization.factory import ObserverFactory - - -class AbsMaxHeadwiseObserver(ObserverFactory): - r""" - It collects channel-wise maximum absolute values of target weights. - Args: - bit_length(int, optional): Number of bits to represent an quantized integer in binary. - dtype(str, optional): The data type of input tensor. - name (str, optional): This parameter is used by developers to print debugging information. \ - For details, please refer to :ref:`api_guide_Name`. Default is None. - Examples: - .. code-block:: python - from paddle.quantization import QuantConfig - from paddle.quantization.quanters import AbsMaxHeadwiseObserver - quanter = AbsMaxHeadwiseObserver() - q_config = QuantConfig(activation=None, weight=quanter) - """ - - def __init__(self, quant_bits=8, quant_axis=None): - super(AbsMaxHeadwiseObserver, self).__init__(quant_bits=quant_bits, quant_axis=quant_axis) - - def _get_class(self): - return AbsMaxHeadwiseObserverLayer - - -class AbsMaxHeadwiseObserverLayer(ChannelWiseObserver): - def __init__(self, layer, quant_bits=8, quant_axis=None): - super(AbsMaxHeadwiseObserverLayer, self).__init__( - layer, quant_bits=quant_bits, sign=True, symmetric=True, quant_axis=quant_axis - ) - self.quant_bits = quant_bits - self.calibration_loss = float("inf") - self.qmin, self.qmax = self.qmin_qmax - self._layer = layer - self._max = None - self._scale = None - self._zero_point = None - - def forward(self, inputs): - self._max = self._cal_abs_max(inputs) - return inputs - - def _cal_abs_max(self, inputs): - reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != self.quant_axis()]) - abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis).cast("float32") - abs_max_values = paddle.where(abs_max_values == np.float32(0.0), np.float32(1e-8), abs_max_values) - - if self._max is not None: - abs_max_values = paddle.maximum(abs_max_values, self._max) - - return abs_max_values - - def min_value(self) -> float: - return 0.0 - - def max_value(self) -> float: - return self._max - - def cal_thresholds(self): - """Compute thresholds for MAX function.""" - self._scale = self._max - self._zero_point = paddle.zeros_like(self._scale) - - def scales(self): - """Return output scales.""" - if self._scale is None: - self.cal_thresholds() - return self._scale - - def zero_points(self): - """Return output zero points.""" - if self._zero_point is None: - self.cal_thresholds() - return self._zero_point diff --git a/llm/experimental/observer/avg.py b/llm/experimental/observer/avg.py deleted file mode 100644 index c38b3ec45c78..000000000000 --- a/llm/experimental/observer/avg.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -from paddle.quantization.factory import ObserverFactory - -from .uniform import UniformObserver - - -class AVGObserver(ObserverFactory): - r""" - It collects maximum absolute values of target tensor. - Args: - bit_length(int, optional): Number of bits to represent an quantized integer in binary. - dtype(str, optional): The data type of input tensor. - name (str, optional): This parameter is used by developers to print debugging information. \ - For details, please refer to :ref:`api_guide_Name`. Default is None. - Examples: - .. code-block:: python - from paddle.quantization import QuantConfig - from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver - quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99) - q_config = QuantConfig(activation=quanter, weight=quanter) - """ - - def __init__(self, quant_bits=8): - super(AVGObserver, self).__init__(quant_bits=quant_bits) - - def _get_class(self): - return AVGObserverLayer - - -class AVGObserverLayer(UniformObserver): - def __init__( - self, - layer, - quant_bits=8, - ): - super(AVGObserverLayer, self).__init__(quant_bits=quant_bits) - self._quant_bits = quant_bits - self._avg_list = [] - - def forward(self, inputs): - """Calculate forward pass.""" - self._scale = None - self._zero_point = None - self._min = None - self._max = None - self._avg_min, self._avg_max = self.cal_min_max(inputs) - self._avg_list.append(self._avg_max) - - return inputs - - def cal_min_max(self, inputs): - abs_avg_value = paddle.abs(inputs.reshape((inputs.shape[0], -1))) - abs_avg_value = float(paddle.mean(paddle.max(abs_avg_value, axis=(1)))) - return 0, abs_avg_value - - def cal_thresholds(self): - """Compute thresholds for MAX function.""" - if self._scale is not None: - self._zero_point = 0 - return - self._min, self._max = self._avg_min, paddle.mean(paddle.to_tensor(self._avg_list)) - self._scale, self._zero_point = self.cal_scales_zero_points() - - def min_value(self) -> float: - return self._min - - def max_value(self) -> float: - return self._max - - def bit_length(self): - """Return the bit length of quantized data.""" - return self._quant_bits - - def quant_axis(self): - """Return quantization axis.""" - return -1 - - def scales(self): - """Return output scales.""" - if self._scale is None: - self.cal_thresholds() - return self._scale - - def zero_points(self): - """Return output zero points.""" - if self._zero_point is None: - self.cal_thresholds() - return self._zero_point diff --git a/llm/experimental/observer/avg_headwise.py b/llm/experimental/observer/avg_headwise.py deleted file mode 100644 index a25fbd770019..000000000000 --- a/llm/experimental/observer/avg_headwise.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import paddle -from paddle.quantization.factory import ObserverFactory - -from .abs_max_headwise import AbsMaxHeadwiseObserverLayer - - -class AvgHeadwiseObserver(ObserverFactory): - r""" - It collects channel-wise maximum absolute values of target weights. - Args: - bit_length(int, optional): Number of bits to represent an quantized integer in binary. - dtype(str, optional): The data type of input tensor. - name (str, optional): This parameter is used by developers to print debugging information. \ - For details, please refer to :ref:`api_guide_Name`. Default is None. - Examples: - .. code-block:: python - from paddle.quantization import QuantConfig - from paddle.quantization.quanters import AbsMaxHeadwiseObserver - quanter = AbsMaxHeadwiseObserver() - q_config = QuantConfig(activation=None, weight=quanter) - """ - - def __init__(self, quant_bits=8, quant_axis=None, moving_avg=False): - super(AvgHeadwiseObserver, self).__init__(quant_bits=quant_bits, quant_axis=quant_axis, moving_avg=moving_avg) - - def _get_class(self): - return AvgHeadwiseObserverLayer - - -class AvgHeadwiseObserverLayer(AbsMaxHeadwiseObserverLayer): - def __init__(self, layer, quant_bits=8, quant_axis=None, moving_avg=True): - super(AvgHeadwiseObserverLayer, self).__init__(layer, quant_bits=quant_bits, quant_axis=quant_axis) - self.quant_bits = quant_bits - self._qmin, self._qmax = self.qmin_qmax - self._max = None - self._scale = None - self._zero_point = None - if quant_axis is not None: - self._channel_axis = quant_axis - self._current_iters = 0 - self._range_update_factor_min = 0.001 - self._moving_avg = moving_avg - self.observer_enabled = True - - def forward(self, inputs, quant_axis=None): - if self.observer_enabled: - if quant_axis is not None: - self._channel_axis = quant_axis - self._max = self._cal_abs_max(inputs) - return inputs - - def _cal_abs_max(self, inputs): - self._current_iters += 1 - reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != self.quant_axis()]) - abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis).cast("float32") - abs_max_values = paddle.where(abs_max_values == np.float32(0.0), np.float32(1e-8), abs_max_values) - if self._max is not None: - if self._moving_avg: - # exponential moving average update - update_factor = 1.0 / self._current_iters - update_factor = max(update_factor, self._range_update_factor_min) - abs_max_values = self._max * (1 - update_factor) + abs_max_values * update_factor - else: - # normal average - abs_max_values = (self._max * (self._current_iters - 1) + abs_max_values) / self._current_iters - return abs_max_values - - def min_value(self) -> float: - return 0.0 - - def max_value(self) -> float: - return self._max - - def cal_thresholds(self): - """Compute thresholds for MAX function.""" - if self._scale is not None: - self._zero_point = paddle.zeros_like(self._scale) - return - self._scale = self._max - self._zero_point = paddle.zeros_like(self._scale) - - def scales(self): - """Return output scales.""" - self.cal_thresholds() - return self._scale - - def zero_points(self): - """Return output zero points.""" - self.cal_thresholds() - return self._zero_point diff --git a/llm/experimental/observer/channel_wise.py b/llm/experimental/observer/channel_wise.py deleted file mode 100644 index 883a74a8f9b0..000000000000 --- a/llm/experimental/observer/channel_wise.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict - -import paddle -from experimental.layers.cache_kv import CacheKVMatMul -from paddleslim.quant.observers.uniform import UniformObserver - -CHANNEL_AXIS: Dict[type, int] = { - paddle.nn.Conv2D: 0, - paddle.nn.Linear: 1, - paddle.distributed.fleet.meta_parallel.ColumnParallelLinear: 1, - paddle.distributed.fleet.meta_parallel.RowParallelLinear: 1, - CacheKVMatMul: 1, -} - - -class ChannelWiseObserver(UniformObserver): - def __init__(self, layer, quant_bits=8, sign=True, symmetric=True, quant_axis=None): - super(ChannelWiseObserver, self).__init__( - quant_bits=quant_bits, - sign=sign, - symmetric=symmetric, - ) - if quant_axis is not None: - self._channel_axis = quant_axis - else: - assert type(layer) in CHANNEL_AXIS, "Unsupported layer type: {}".format(type(layer)) - self._channel_axis = CHANNEL_AXIS[type(layer)] - self._quant_bits = quant_bits - - def quant_axis(self): - """Return quantization axis.""" - return self._channel_axis - - def bit_length(self): - """Return the bit length of quantized data.""" - return self._quant_bits diff --git a/llm/experimental/observer/uniform.py b/llm/experimental/observer/uniform.py deleted file mode 100644 index 6c8882f5142f..000000000000 --- a/llm/experimental/observer/uniform.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -from typing import Tuple - -import numpy as np -from paddle.quantization.base_observer import BaseObserver - - -class UniformObserver(BaseObserver): - """This is the base class for a uniform quantization observer, which provides - common functions for calculating the scale and zero-point used in uniform quantization. - Uniform quantization maps floating point values to integers, where the scale determines - the step size of the quantizer and the floating point zero is mapped to the zero-point, - an integer value ensuring that zero is quantized without error. - - Args: - quant_bits (int): The number of bits for quantization. - sign (bool): Whether the quantized integer includes a sign. - symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric. - In symmetric quantization, the range of floating point values is relaxed to be symmetric - around zero and the zero-point is always 0. - - """ - - def __init__( - self, - quant_bits=8, - sign=True, - symmetric=True, - ): - super(UniformObserver, self).__init__() - self._quant_bits = quant_bits - self._sign = sign - self._symmetric = symmetric - - self._min = None - self._max = None - self._qmin = None - self._qmax = None - - self._scale = None - self._zero_point = None - - @property - def qmin_qmax(self): - """Calculate the range of the quantized integer based on the specified - quant_bits, sign, and symmetric properties.""" - if isinstance(self._quant_bits, tuple): - if self._quant_bits[0] == 4 and self._quant_bits[1] == 3 and len(self._quant_bits) == 2: - self._qmin = -448.0 - self._qmax = 448.0 - elif self._quant_bits[0] == 5 and self._quant_bits[1] == 2 and len(self._quant_bits) == 2: - self._qmin = -57344.0 - self._qmax = 57344.0 - else: - raise NotImplementedError( - "Currently, only float8_e4m3 and float8_e5m2 formats are supported. Please set quant_bits to (4,3) or (5,2) for the corresponding format." - ) - else: - if self._sign: - self._qmin = -(2 ** (self.bit_length() - 1)) - self._qmax = 2 ** (self.bit_length() - 1) - 1 - else: - self._qmin = 0 - self._qmax = 2 ** self.bit_length() - return self._qmin, self._qmax - - @abc.abstractmethod - def min_value(self) -> float: - """The minimum value of floating-point numbers.""" - raise NotImplementedError( - "Please implement the abstract method to get the The minimum value of floating-point numbers." - ) - - @abc.abstractmethod - def max_value(self) -> float: - """The maximum value of floating-point numbers.""" - raise NotImplementedError( - "Please implement the abstract method to get the the maximum value value of floating-point numbers." - ) - - def cal_scales_zero_points(self) -> Tuple[float, float]: - """Calculate the scales and zero points based on the min_value and max_value.""" - assert self.min_value() is not None and self.max_value() is not None - _qmin, _qmax = self.qmin_qmax - # For one-sided distributions, the range (_min , _max ) is relaxed to include zero. - # It is important to ensure that common operations like zero padding do not cause quantization errors. - _min = min(self.min_value(), 0.0) - _max = max(self.max_value(), 0.0) - - if self._symmetric: - self._scale = max(-_min, _max) - if self._sign: - self._zero_point = 0 - else: - self._zero_point = (_qmax + _qmin) / 2 - else: - self._scale = (_max - _min) / float(_qmax - _qmin) - self._zero_point = _qmin - round(_min / self._scale) - self._zero_point = np.clip(self._zero_point, _qmin, _qmax) - return self._scale, self._zero_point diff --git a/llm/predict/export_model.py b/llm/predict/export_model.py index cb09d65174b0..774b69091c7a 100644 --- a/llm/predict/export_model.py +++ b/llm/predict/export_model.py @@ -18,12 +18,11 @@ import paddle from paddle.distributed import fleet +from predict.predictor import ModelArgument, PredictorArgument, create_predictor from paddlenlp.trainer import PdArgumentParser from paddlenlp.trl import llm_utils -from .predictor import ModelArgument, PredictorArgument, create_predictor - @dataclass class ExportArgument: @@ -68,6 +67,7 @@ def main(): "dtype": predictor_args.dtype, "export_precache": predictor_args.export_precache, "cachekv_int8_type": predictor_args.cachekv_int8_type, + "speculate_method": predictor_args.speculate_method, }, ) add_inference_args_to_config(predictor.model.config, predictor_args) diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 4294b7b60ab3..8e8770f138d8 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -27,6 +27,7 @@ from paddle.base.framework import in_cinn_mode, in_pir_executor_mode, use_pir_api from paddle.distributed import fleet +from paddlenlp.experimental.transformers import InferenceWithReferenceProposer from paddlenlp.generation import GenerationConfig, TextIteratorStreamer from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM from paddlenlp.taskflow.utils import static_mode_guard @@ -44,12 +45,10 @@ PretrainedTokenizer, ) from paddlenlp.trl import llm_utils +from paddlenlp.utils.env import MAX_BSZ, MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ from paddlenlp.utils.import_utils import is_paddlenlp_ops_available from paddlenlp.utils.log import logger -# Note(@RochardWooSJTU): MAX_BSZ must be the same as definition in get_output / save_output -MAX_BSZ = 512 - @dataclass class PredictorArgument: @@ -138,8 +137,25 @@ class PredictorArgument: total_max_length: int = field( default=4096, metadata={"help": "Super parameter. Maximum sequence length(encoder+decoder)."} ) + speculate_method: str = field( + default=None, + metadata={ + "help": "speculate method, it should be one of ['None', 'autoregressive', 'inference_with_reference']" + }, + ) + speculate_max_draft_token_num: int = field( + default=1, + metadata={"help": "the max length of draft tokens for speculate method."}, + ) + speculate_max_ngram_size: int = field(default=1, metadata={"help": "the max ngram size of speculate method."}) + speculate_verify_window: int = field( + default=2, metadata={"help": "the max length of verify window for speculate method."} + ) + speculate_max_candidate_len: int = field(default=5, metadata={"help": "the max length of candidate tokens."}) def __post_init__(self): + if self.speculate_method is not None: + self.append_attn = True if self.append_attn: self.block_attn = True assert ( @@ -946,6 +962,29 @@ def _preprocess(self, input_text: list[str]): ) self.model_inputs["next_tokens"] = paddle.full(shape=[self.config.batch_size, 1], fill_value=-1, dtype="int64") + # speculative decoding related parameters + if self.config.speculate_method is not None: + self.model_inputs["accept_tokens"] = paddle.full( + shape=[self.config.batch_size, self.config.speculate_max_draft_token_num + 1], + fill_value=0, + dtype="int64", + ) + self.model_inputs["accept_num"] = paddle.full(shape=[self.config.batch_size], fill_value=0, dtype="int32") + self.model_inputs["draft_tokens"] = paddle.full( + shape=[self.config.batch_size, self.config.speculate_max_draft_token_num + 1], + fill_value=0, + dtype="int64", + ) + self.model_inputs["actual_draft_token_num"] = paddle.full( + shape=[self.config.batch_size], fill_value=self.config.speculate_max_draft_token_num, dtype="int32" + ) + + self.proposer.input_ids_cpu = self.model_inputs["input_ids"].to("cpu", blocking=False) + for bid in range(self.config.batch_size): + self.model_inputs["pre_ids"][bid, 0] = self.model_inputs["input_ids"][bid][ + seq_lens[bid] - 1 + ] # get the last token before padding of this batch + if self.config.mode == "static": for k, v in self.model_inputs.items(): v.name = k @@ -977,6 +1016,17 @@ def __init__( self.model_inputs["cache_kvs"] = self.cache_kvs + # init speculate components + if config.speculate_method == "inference_with_reference": + self.proposer = InferenceWithReferenceProposer( + config.speculate_max_draft_token_num, + config.speculate_max_ngram_size, + config.batch_size, + config.max_length, + ) + else: + self.proposer = None + @paddle.no_grad() def _infer(self, inputs: dict[str, paddle.Tensor]): self.model.generate( @@ -990,18 +1040,35 @@ def predict(self, input_texts: list[str], return_tokens=False): result_queue = mp.Queue() tensor_queue = mp.Queue() done_event = mp.Event() + + # whether speculative decoding + if self.proposer is None: + read_res_func = llm_utils.read_res + output_tensor_shape = [MAX_BSZ + 2, 1] + else: + read_res_func = llm_utils.speculate_read_res + output_tensor_shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1] + read_res_process = mp.Process( - target=llm_utils.read_res, args=[self.model_name_or_path, tensor_queue, result_queue, done_event] + target=read_res_func, args=[self.model_name_or_path, tensor_queue, result_queue, done_event] ) if self.tensor_parallel_rank == 0: read_res_process.start() - output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64").cpu() + output_tensor = paddle.full(shape=output_tensor_shape, fill_value=2, dtype="int64").cpu() + tensor_queue.put(output_tensor) if self.tensor_parallel_rank == 0: done_event.wait() s_time = time.time() while self.model_inputs["not_need_stop"]: + # whether speculative decoding + if self.proposer is not None: + self.proposer.run( + self.model_inputs, + real_batch_size=self.batch_size, + seq_lens_this_time=self.model_inputs["seq_lens_this_time"], + ) self._infer(self.model_inputs) logger.info(f"running spend {time.time() - s_time}") @@ -1055,6 +1122,17 @@ def __init__( self.model_inputs["k_dequant_scales_" + str(i)] = self.k_dequant_scales[i] self.model_inputs["v_dequant_scales_" + str(i)] = self.v_dequant_scales[i] + # init speculate components + if config.speculate_method == "inference_with_reference": + self.proposer = InferenceWithReferenceProposer( + config.speculate_max_draft_token_num, + config.speculate_max_ngram_size, + config.batch_size, + config.max_length, + ) + else: + self.proposer = None + def _create_predictor(self, predictor_args: PredictorArgument): if not is_paddlenlp_ops_available(): raise ValueError( @@ -1120,18 +1198,34 @@ def predict(self, input_texts: list[str], return_tokens=False): tensor_queue = mp.Queue() done_event = mp.Event() + # whether speculative decoding + if self.proposer is None: + read_res_func = llm_utils.read_res + output_tensor_shape = [MAX_BSZ + 2, 1] + else: + read_res_func = llm_utils.speculate_read_res + output_tensor_shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1] + read_res_process = mp.Process( - target=llm_utils.read_res, args=[self.model_name_or_path, tensor_queue, result_queue, done_event] + target=read_res_func, args=[self.model_name_or_path, tensor_queue, result_queue, done_event] ) - if self.tensor_parallel_rank == 0: read_res_process.start() - output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64").cpu() + + output_tensor = paddle.full(shape=output_tensor_shape, fill_value=2, dtype="int64").cpu() + tensor_queue.put(output_tensor) if self.tensor_parallel_rank == 0: done_event.wait() s_time = time.time() while self.model_inputs["not_need_stop"]: + # whether speculative decoding + if self.proposer is not None: + self.proposer.run( + self.model_inputs, + real_batch_size=self.batch_size, + seq_lens_this_time=self.model_inputs["seq_lens_this_time"], + ) self.predictor.run(list(self.model_inputs.values())) logger.info(f"running spend {time.time() - s_time}") diff --git a/llm/predict/reft_predictor.py b/llm/predict/reft_predictor.py new file mode 100644 index 000000000000..398b96b7b48e --- /dev/null +++ b/llm/predict/reft_predictor.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +from functools import partial +from types import SimpleNamespace + +import paddle +from utils.data import convert_example_for_reft + +from paddlenlp.datasets import load_dataset +from paddlenlp.peft.reft import ReFTModel, do_predict +from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer + +device = "gpu" if paddle.is_compiled_with_cuda() else "cpu" + + +def get_intervention_info(reft_config_file): + with open(os.path.join(reft_config_file, "config.json"), "r") as f: + intervention_info = json.load(f) + intervention_info["num_interventions"] = len(intervention_info["representations"]) + return intervention_info + + +def reft_predict(predictor_args): + intervention_info = get_intervention_info(predictor_args.reft_path) + tokenizer = AutoTokenizer.from_pretrained( + predictor_args.model_name_or_path, + padding_side="right", + ) + tokenizer.pad_token_id = tokenizer.eos_token_id + dev_ds = load_dataset( + "json", + data_files=os.path.join(predictor_args.dataset_name_or_path, "dev.json"), + )[0] + trans_func = partial( + convert_example_for_reft, + tokenizer=tokenizer, + data_args=SimpleNamespace( + **{ + "max_length": predictor_args.max_length, + "src_length": predictor_args.src_length, + "autoregressive": False, + } + ), + positions=intervention_info["position"], + num_interventions=intervention_info["num_interventions"], + ) + + dev_ds = dev_ds.map(partial(trans_func, is_test=True, zero_padding=False, flash_mask=False)) + + model = AutoModelForCausalLM.from_pretrained(predictor_args.model_name_or_path, dtype=paddle.bfloat16) + reft_model = ReFTModel.from_pretrained(predictor_args.reft_path, model) + do_predict( + intervenable=reft_model, + tokenizer=tokenizer, + eval_dataset=dev_ds, + batch_size=predictor_args.batch_size, + predict_path=predictor_args.output_file, + num_beams=predictor_args.num_beams, + max_length=predictor_args.max_length, + ) + + +def get_pred_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name_or_path", type=str, help="The base model name or path") + parser.add_argument("--reft_path", type=str, help="The reft model path") + parser.add_argument("--output_file", type=str, help="The output file path") + parser.add_argument("--batch_size", type=int, help="The batch size in prediction") + parser.add_argument("--dataset_name_or_path", type=str, help="The dataset name or path") + parser.add_argument("--max_length", type=int, default=1024, help="The maximum length of input sequences") + parser.add_argument("--src_length", type=int, default=512, help="The source sequence length") + parser.add_argument("--num_beams", type=int, default=4, help="The maximum length of input sequences") + return parser.parse_args() + + +def main(): + predictor_args = get_pred_parser() + reft_predict(predictor_args) + + +if __name__ == "__main__": + main() diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 6711a328854d..3eb70dbab6ae 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# import inspect import json +import logging import os import sys from functools import partial @@ -21,10 +23,10 @@ DataArgument, GenerateArgument, ModelArgument, - QuantArgument, + ReftArgument, TrainingArguments, ) -from utils.data import get_convert_example +from utils.data import convert_example_for_reft, get_convert_example from paddlenlp.data import DataCollatorForSeq2Seq from paddlenlp.datasets import ( @@ -41,7 +43,13 @@ VeRAConfig, VeRAModel, ) -from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint +from paddlenlp.peft.reft import ( + ReFTConfig, + ReftDataCollator, + ReFTModel, + intervention_mapping, +) +from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed from paddlenlp.trainer.trainer_callback import TrainerState from paddlenlp.transformers import ( AutoConfig, @@ -75,24 +83,19 @@ def main(): - parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments)) + parser = PdArgumentParser((GenerateArgument, ModelArgument, ReftArgument, DataArgument, TrainingArguments)) if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): - gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() + gen_args, model_args, reft_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() else: - gen_args, quant_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses() + gen_args, model_args, reft_args, data_args, training_args = parser.parse_args_into_dataclasses() training_args.print_config(model_args, "Model") training_args.print_config(data_args, "Data") - training_args.print_config(quant_args, "Quant") training_args.print_config(gen_args, "Generation") - if sum([quant_args.do_ptq, quant_args.do_qat, quant_args.do_gptq, training_args.do_train]) > 1: - raise ValueError( - "--do_train, --do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time" - ) - # Setup GPU & distributed training paddle.set_device(training_args.device) + set_seed(seed=training_args.seed) logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" @@ -160,6 +163,22 @@ def main(): model_config.seq_length = data_args.max_length + # Config for model useing long sequence strategy + if model_args.use_long_sequence_strategies: + data_args.scaled_max_length = int(data_args.max_length * model_args.rope_scaling_factor) + model_config.use_long_sequence_strategies = True + model_config.long_sequence_strategy_type = model_args.strategy_type + model_config.long_sequence_strategy_name = model_args.strategy_name + model_config.rope_scaling_factor = model_args.rope_scaling_factor + model_config.long_sequence_init_args = { + "dim": int(model_config.hidden_size / model_config.num_attention_heads), + "max_position_embeddings": data_args.scaled_max_length, # extended context window + "base": model_config.rope_theta, + "scaling_factor": model_args.rope_scaling_factor, + } + if model_args.strategy_name == "YaRNScalingRotaryEmbedding": + model_config.long_sequence_init_args["original_max_position_embeddings"] = data_args.max_length + logger.info(f"Final model config: {model_config}") model_class = AutoModelForCausalLM @@ -210,6 +229,15 @@ def neft_post_hook(module, input, output): ) # Load tokenizer & dataset tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio) + if model_args.reft: + # reft requires padding side right + tokenizer.padding_side = "right" + layers = reft_args.layers + if reft_args.layers != "all": + layers = [int(l) for l in layers.split(";")] + else: + layers = [l for l in range(model_config.num_hidden_layers)] + logging.info("Using ReFT with layers: ", layers) # init chat_template for tokenizer init_chat_template(tokenizer, model_args.model_name_or_path, data_args.chat_template) @@ -222,12 +250,10 @@ def neft_post_hook(module, input, output): if data_args.dataset_name_or_path is None: raise ValueError(f"Please specific dataset name or path (got {data_args.dataset_name_or_path})") - elif ( - os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json")) - or os.path.exists(os.path.join(data_args.dataset_name_or_path, "dev.json")) - or os.path.exists(os.path.join(data_args.dataset_name_or_path, "quant.json")) + elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json")) or os.path.exists( + os.path.join(data_args.dataset_name_or_path, "dev.json") ): - if training_args.do_train or quant_args.do_qat: + if training_args.do_train: train_ds = load_dataset( "json", data_files=os.path.join(data_args.dataset_name_or_path, "train.json"), @@ -243,36 +269,13 @@ def neft_post_hook(module, input, output): )[0] else: dev_ds = None - if quant_args.do_ptq or quant_args.do_gptq or quant_args.load_quant_model: - if os.path.exists(os.path.join(data_args.dataset_name_or_path, "quant.json")): - ptq_ds = load_dataset( - "json", - data_files=os.path.join(data_args.dataset_name_or_path, "quant.json"), - lazy=data_args.lazy, - )[0] - elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json")): - ptq_ds = load_dataset( - "json", - data_files=os.path.join(data_args.dataset_name_or_path, "train.json"), - lazy=data_args.lazy, - )[0] - logger.info( - f"Not found quant.json in {data_args.dataset_name_or_path}. Set train dataset as PTQ calibration dataset." - ) - else: - raise ValueError( - f"Quant strategy requires quant.json or train.json in {data_args.dataset_name_or_path}" - ) - else: - ptq_ds = None - elif ( - os.path.exists(os.path.join(data_args.dataset_name_or_path, "train")) - or os.path.exists(os.path.join(data_args.dataset_name_or_path, "dev")) - or os.path.exists(os.path.join(data_args.dataset_name_or_path, "quant")) + + elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train")) or os.path.exists( + os.path.join(data_args.dataset_name_or_path, "dev") ): import glob - if training_args.do_train or quant_args.do_qat: + if training_args.do_train: train_ds = load_dataset( "json", data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "train", "*.json")), @@ -288,28 +291,9 @@ def neft_post_hook(module, input, output): )[0] else: dev_ds = None - if quant_args.do_ptq or quant_args.do_gptq or quant_args.load_quant_model: - if os.path.exists(os.path.join(data_args.dataset_name_or_path, "quant")): - ptq_ds = load_dataset( - "json", - data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "quant", "*.json")), - lazy=data_args.lazy, - )[0] - elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train")): - ptq_ds = load_dataset( - "json", - data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "train", "*.json")), - lazy=data_args.lazy, - )[0] - logger.info( - f"Not found quant.json in {data_args.dataset_name_or_path}. Set train dataset as PTQ calibration dataset." - ) - else: - raise ValueError(f"Quant strategy requires quant or train folder in {data_args.dataset_name_or_path}") - else: - ptq_ds = None + else: - if training_args.do_train or quant_args.do_qat: + if training_args.do_train: train_ds = load_dataset(data_args.dataset_name_or_path, splits=["train"])[0] else: train_ds = None @@ -317,11 +301,7 @@ def neft_post_hook(module, input, output): dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["dev"])[0] else: dev_ds = None - if quant_args.do_ptq or quant_args.do_gptq or quant_args.load_quant_model: - ptq_ds = load_dataset(data_args.dataset_name_or_path, splits=["train"])[0] - logger.info("Set train dataset as PTQ calibration dataset.") - else: - ptq_ds = None + # TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later. if training_args.resume_from_checkpoint is not None and data_args.lazy: logger.info( @@ -347,6 +327,14 @@ def neft_post_hook(module, input, output): from utils.data import convert_example_common trans_func = partial(convert_example_common, tokenizer=tokenizer, data_args=data_args) + elif model_args.reft: + trans_func = partial( + convert_example_for_reft, + tokenizer=tokenizer, + data_args=data_args, + positions=reft_args.position, + num_interventions=len(layers), + ) else: trans_func = partial(get_convert_example(model), tokenizer=tokenizer, data_args=data_args) @@ -357,13 +345,7 @@ def neft_post_hook(module, input, output): if train_ds is not None else None ) - ptq_ds = ( - ptq_ds.map( - partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask) - ) - if ptq_ds is not None - else None - ) + eval_zero_padding = data_args.zero_padding if data_args.zero_padding and data_args.eval_with_do_generation: logger.warning( @@ -398,16 +380,6 @@ def neft_post_hook(module, input, output): if train_ds is not None else None ) - ptq_ds = ( - intoken_dataset( - ptq_ds, - tokenizer=tokenizer, - max_length=data_args.max_length, - greedy_zero_padding=data_args.greedy_zero_padding, - ) - if ptq_ds is not None - else None - ) if eval_zero_padding: dev_ds = ( @@ -464,9 +436,9 @@ def neft_post_hook(module, input, output): merge_weights=False, tensor_parallel_degree=training_args.tensor_parallel_degree, dtype=dtype, - do_qat=quant_args.do_qat, base_model_name_or_path=model_args.model_name_or_path, use_quick_lora=model_args.use_quick_lora, + lora_use_mixer=model_args.lora_use_mixer, ) model = LoRAModel(model, lora_config) else: @@ -474,6 +446,35 @@ def neft_post_hook(module, input, output): model.print_trainable_parameters() + if model_args.reft: + intervention_dtype = dtype + intervention_params = { + "embed_dim": model_config.hidden_size, + "low_rank_dimension": reft_args.rank, + "dropout": reft_args.dropout, + "dtype": intervention_dtype, + "act_fn": reft_args.act_fn, + "device": "gpu", + "add_bias": reft_args.add_bias, + } + representations = [ + { + "layer": l, + "component": "block_output", + "low_rank_dimension": reft_args.rank, + "intervention": intervention_mapping[reft_args.intervention_type](**intervention_params), + } + for l in layers + ] + reft_config = ReFTConfig( + representations=representations, intervention_params=intervention_params, position=reft_args.position + ) + # get reft model + model = ReFTModel(reft_config, model) + # disable origianl model gradients + model.disable_model_gradients() + model.print_trainable_parameters() + def compute_metrics_do_generation(eval_preds): rouge1 = Rouge1() rouge2 = Rouge2() @@ -541,6 +542,15 @@ def compute_metrics_do_generation(eval_preds): else: metrics = compute_metrics + data_collator_fn = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + max_length=max_length, + padding=padding, + max_label_length=max_length, + return_tensors="np", + return_attention_mask=not model_args.flash_mask, + pad_to_multiple_of=data_args.pad_to_multiple_of, + ) trainer = SFTTrainer( model=model, args=training_args, @@ -548,15 +558,7 @@ def compute_metrics_do_generation(eval_preds): eval_dataset=dev_ds, tokenizer=tokenizer, compute_metrics=metrics, - data_collator=DataCollatorForSeq2Seq( - tokenizer=tokenizer, - max_length=max_length, - padding=padding, - max_label_length=max_length, - return_tensors="np", - return_attention_mask=not model_args.flash_mask, - pad_to_multiple_of=data_args.pad_to_multiple_of, - ), + data_collator=data_collator_fn if not model_args.reft else ReftDataCollator(data_collator=data_collator_fn), do_generation=data_args.eval_with_do_generation, callbacks=[ZeroPaddingIterDatasetCallback()] if isinstance(train_ds, ZeroPaddingIterableDataset) else None, gen_args=gen_args, @@ -613,65 +615,6 @@ def compute_metrics_do_generation(eval_preds): trainer.save_metrics("train", train_result.metrics) trainer.save_state() - # QAT - if quant_args.do_qat: - from utils.quant import create_qat_model - - trainer.model = create_qat_model(quant_args, trainer.model, dtype) - train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) - trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) - trainer.log_metrics("qat", train_result.metrics) - trainer.save_metrics("qat", train_result.metrics) - trainer.save_state() - - # PTQ - if quant_args.do_ptq: - if isinstance(model, LoRAModel): - raise NotImplementedError( - "PTQ strategy not supported for LoRA model. Please merge lora parameters to pretrain model first." - ) - from utils.quant import ( - apply_autoclip, - apply_ptq, - apply_shift, - apply_smooth, - get_ptq_model_config, - ) - - trainer.model.eval() - trainer.model.config.quantization_config.quant_type = quant_args.quant_type - trainer.model.config.quantization_config.smooth = quant_args.smooth - trainer.model.config.quantization_config.shift = quant_args.shift - trainer.model.config.quantization_config.shift_smooth_all_linears = ( - quant_args.smooth_all_linears or quant_args.shift_all_linears - ) - ptq_dataloader = trainer.get_ptq_dataloader(ptq_ds) - if quant_args.shift or quant_args.smooth: - ptq_model_config = get_ptq_model_config(trainer.model) - - if quant_args.shift: - apply_shift(quant_args, trainer, ptq_dataloader, ptq_model_config) - - if quant_args.smooth: - apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config) - - if quant_args.auto_clip: - apply_autoclip(quant_args, trainer, ptq_dataloader) - - apply_ptq(quant_args, trainer, ptq_dataloader) - trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) - - if quant_args.do_gptq: - if isinstance(model, LoRAModel): - raise NotImplementedError( - "PTQ strategy not supported for LoRA model. Please merge lora parameters to pretrain model first." - ) - from utils.quant import apply_gptq - - ptq_dataloader = trainer.get_ptq_dataloader(ptq_ds) - apply_gptq(quant_args, trainer, ptq_dataloader) - trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) - # Evaluation test set if training_args.do_predict: test_ds = load_dataset( @@ -690,43 +633,9 @@ def compute_metrics_do_generation(eval_preds): eval_result = trainer.predict(test_ds).metrics trainer.log_metrics("test", eval_result) - if quant_args.load_quant_model and not quant_args.do_ptq: - if isinstance(model, LoRAModel): - raise NotImplementedError( - "PTQ strategy not supported for LoRA model. Please merge lora parameters to pretrain model first." - ) - from utils.quant import ( - apply_autoclip, - apply_ptq, - apply_shift, - apply_smooth, - get_ptq_model_config, - load_quant_model, - ) - - trainer.model.eval() - trainer.model.config.quantization_config.quant_type = quant_args.quant_type - trainer.model.config.quantization_config.smooth = quant_args.smooth - trainer.model.config.quantization_config.shift = quant_args.shift - trainer.model.config.quantization_config.shift_smooth_all_linears = ( - quant_args.smooth_all_linears or quant_args.shift_all_linears - ) - ptq_dataloader = trainer.get_ptq_dataloader(ptq_ds) - if quant_args.shift or quant_args.smooth: - ptq_model_config = get_ptq_model_config(trainer.model) - - if quant_args.shift: - apply_shift(quant_args, trainer, ptq_dataloader, ptq_model_config) - - if quant_args.smooth: - apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config) - - load_quant_model(trainer.model, quant_args, training_args.output_dir) - # Evaluation dev set if training_args.do_eval: - - logger.info("*** Evaluate result after train/ptq/qat/ etc.***") + logger.info("*** Evaluate result after train ***") eval_result = trainer.evaluate(dev_ds) trainer.log_metrics("eval", eval_result) diff --git a/llm/run_quantization.py b/llm/run_quantization.py new file mode 100644 index 000000000000..e4f36d11fb88 --- /dev/null +++ b/llm/run_quantization.py @@ -0,0 +1,579 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import sys +from functools import partial + +import paddle +from utils.argument import ( + DataArgument, + GenerateArgument, + ModelArgument, + QuantArgument, + TrainingArguments, +) +from utils.data import get_convert_example + +from paddlenlp.data import DataCollatorForSeq2Seq +from paddlenlp.datasets import ( + ZeroPaddingIterableDataset, + ZeroPaddingMapDataset, + load_dataset, +) +from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL +from paddlenlp.peft import LoRAModel +from paddlenlp.trainer import PdArgumentParser +from paddlenlp.trainer.trainer_callback import TrainerState +from paddlenlp.transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForCausalLMPipe, + AutoTokenizer, + Llama3Tokenizer, + LlamaForCausalLM, + LlamaForCausalLMPipe, + LlamaTokenizer, + Qwen2ForCausalLM, + Qwen2ForCausalLMPipe, + register_sequence_parallel_allreduce_hooks, +) +from paddlenlp.transformers.configuration_utils import LlmMetaConfig +from paddlenlp.trl import SFTTrainer +from paddlenlp.trl.llm_utils import ( + ZeroPaddingIterDatasetCallback, + compute_metrics, + init_chat_template, +) +from paddlenlp.utils.log import logger +from paddlenlp.utils.tools import get_env_device + +# Fine-tune Environment Variables to support sharding stage1 overlap optimization. +os.environ["USE_CASUAL_MASK"] = "False" + +flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe, Qwen2ForCausalLM, Qwen2ForCausalLMPipe] + + +def main(): + parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments)) + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): + gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() + else: + gen_args, quant_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + training_args.print_config(quant_args, "Quant") + training_args.print_config(gen_args, "Generation") + + if sum([quant_args.do_ptq, quant_args.do_qat, quant_args.do_gptq]) > 1: + raise ValueError( + "--do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time" + ) + + # Setup GPU & distributed training + paddle.set_device(training_args.device) + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: + try: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + except ImportError: + # It's OK, not use accumulate_steps optimization + pass + + # Load model + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + elif training_args.bf16: + dtype = "bfloat16" + else: + raise ValueError("Please specific dtype: --fp16 or --bf16") + else: + dtype = "float32" + quantization_config = dict( + weight_quantize_algo=model_args.weight_quantize_algo, + weight_blocksize=model_args.weight_blocksize, + weight_double_quant=model_args.weight_double_quant, + weight_double_quant_block_size=model_args.weight_double_quant_block_size, + ) + + model_config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + dtype=dtype, + from_aistudio=model_args.from_aistudio, + quantization_config=quantization_config, + ) + + LlmMetaConfig.set_llm_config(model_config, training_args) + model_config.use_fast_layer_norm = model_args.use_fast_layer_norm + + # Config for model using dropout, such as GPT. + if hasattr(model_config, "hidden_dropout_prob"): + model_config.hidden_dropout_prob = model_args.hidden_dropout_prob + if hasattr(model_config, "attention_probs_dropout_prob"): + model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob + if hasattr(model_config, "ignore_index"): + model_config.ignore_index = -100 + + if model_args.fuse_attention_qkv is not None: + model_config.fuse_attention_qkv = model_args.fuse_attention_qkv + if model_args.fuse_attention_ffn is not None: + model_config.fuse_attention_ffn = model_args.fuse_attention_ffn + + model_config.seq_length = data_args.max_length + + logger.info(f"Final model config: {model_config}") + + model_class = AutoModelForCausalLM + if training_args.pipeline_parallel_degree > 1: + if data_args.eval_with_do_generation and training_args.do_eval: + raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.") + + model_class = AutoModelForCausalLMPipe + + if model_args.continue_training and not training_args.autotuner_benchmark: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=model_config, + from_aistudio=model_args.from_aistudio, + ) + else: + # NOTE(gongenlei): new add autotuner_benchmark + model = model_class.from_config(model_config, dtype=dtype) + + if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention): + logger.warning("`flash_mask` must use with zero padding and flash attention.") + data_args.zero_padding = True + model.config.use_flash_attention = True + + if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list): + raise NotImplementedError(f"{model.__class__} not support flash mask.") + + if training_args.sequence_parallel: + register_sequence_parallel_allreduce_hooks( + model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce + ) + # Load tokenizer & dataset + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio) + # init chat_template for tokenizer + init_chat_template(tokenizer, model_args.model_name_or_path, data_args.chat_template) + + # if using chat_template, data_args.eval_with_do_generation must be false + if tokenizer.chat_template is not None: + data_args.eval_with_do_generation = False + + if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, Llama3Tokenizer): + tokenizer.pad_token_id = tokenizer.eos_token_id + + if data_args.dataset_name_or_path is None: + raise ValueError(f"Please specific dataset name or path (got {data_args.dataset_name_or_path})") + elif ( + os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json")) + or os.path.exists(os.path.join(data_args.dataset_name_or_path, "dev.json")) + or os.path.exists(os.path.join(data_args.dataset_name_or_path, "quant.json")) + ): + if quant_args.do_qat: + train_ds = load_dataset( + "json", + data_files=os.path.join(data_args.dataset_name_or_path, "train.json"), + lazy=data_args.lazy, + )[0] + else: + train_ds = None + if training_args.do_eval: + dev_ds = load_dataset( + "json", + data_files=os.path.join(data_args.dataset_name_or_path, "dev.json"), + lazy=data_args.lazy, + )[0] + else: + dev_ds = None + + if os.path.exists(os.path.join(data_args.dataset_name_or_path, "quant.json")): + ptq_ds = load_dataset( + "json", + data_files=os.path.join(data_args.dataset_name_or_path, "quant.json"), + lazy=data_args.lazy, + )[0] + elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json")): + ptq_ds = load_dataset( + "json", + data_files=os.path.join(data_args.dataset_name_or_path, "train.json"), + lazy=data_args.lazy, + )[0] + logger.info( + f"Not found quant.json in {data_args.dataset_name_or_path}. Set train dataset as PTQ calibration dataset." + ) + else: + raise ValueError(f"Quant strategy requires quant.json or train.json in {data_args.dataset_name_or_path}") + + elif ( + os.path.exists(os.path.join(data_args.dataset_name_or_path, "train")) + or os.path.exists(os.path.join(data_args.dataset_name_or_path, "dev")) + or os.path.exists(os.path.join(data_args.dataset_name_or_path, "quant")) + ): + import glob + + if quant_args.do_qat: + train_ds = load_dataset( + "json", + data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "train", "*.json")), + lazy=data_args.lazy, + )[0] + else: + train_ds = None + if training_args.do_eval: + dev_ds = load_dataset( + "json", + data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "dev", "*.json")), + lazy=data_args.lazy, + )[0] + else: + dev_ds = None + + if os.path.exists(os.path.join(data_args.dataset_name_or_path, "quant")): + ptq_ds = load_dataset( + "json", + data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "quant", "*.json")), + lazy=data_args.lazy, + )[0] + elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train")): + ptq_ds = load_dataset( + "json", + data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "train", "*.json")), + lazy=data_args.lazy, + )[0] + logger.info( + f"Not found quant.json in {data_args.dataset_name_or_path}. Set train dataset as PTQ calibration dataset." + ) + else: + raise ValueError(f"Quant strategy requires quant or train folder in {data_args.dataset_name_or_path}") + + else: + if quant_args.do_qat: + train_ds = load_dataset(data_args.dataset_name_or_path, splits=["train"])[0] + else: + train_ds = None + if training_args.do_eval: + dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["dev"])[0] + else: + dev_ds = None + if quant_args.do_ptq or quant_args.do_gptq or quant_args.load_quant_model: + ptq_ds = load_dataset(data_args.dataset_name_or_path, splits=["train"])[0] + logger.info("Set train dataset as PTQ calibration dataset.") + else: + ptq_ds = None + # TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later. + if training_args.resume_from_checkpoint is not None and data_args.lazy: + logger.info( + f"Loading from '{training_args.resume_from_checkpoint}' with `lazy=True`, manually skipping dataset and setting `ignore_data_skip` to True." + ) + training_args.ignore_data_skip = True + state = TrainerState.load_from_json(os.path.join(training_args.resume_from_checkpoint, "trainer_state.json")) + if state.trial_params is not None and "zero_padding_global_step" in state.trial_params: + consumed_samples = state.trial_params["zero_padding_global_step"] + else: + consumed_samples = ( + state.global_step + * training_args.per_device_train_batch_size + * training_args.gradient_accumulation_steps + * training_args.dataset_world_size + ) + logger.info( + f"Skipping the first {consumed_samples} samples to warmup the dataset from checkpoint '{training_args.resume_from_checkpoint}'." + ) + train_ds = train_ds.skip(consumed_samples) + + if training_args.pipeline_parallel_degree > 1: + from utils.data import convert_example_common + + trans_func = partial(convert_example_common, tokenizer=tokenizer, data_args=data_args) + else: + trans_func = partial(get_convert_example(model), tokenizer=tokenizer, data_args=data_args) + + train_ds = ( + train_ds.map( + partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask) + ) + if train_ds is not None + else None + ) + ptq_ds = ( + ptq_ds.map( + partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask) + ) + if ptq_ds is not None + else None + ) + eval_zero_padding = data_args.zero_padding + if data_args.zero_padding and data_args.eval_with_do_generation: + logger.warning( + "`zero_padding` conflicts with `eval_with_do_generation`. Setting zero_padding to False for the eval_dataset." + ) + eval_zero_padding = False + dev_ds = ( + dev_ds.map( + partial( + trans_func, + is_test=data_args.eval_with_do_generation, + zero_padding=eval_zero_padding, + flash_mask=model_args.flash_mask, + ) + ) + if dev_ds is not None + else None + ) + if data_args.zero_padding: + if data_args.lazy: + intoken_dataset = ZeroPaddingIterableDataset + else: + intoken_dataset = ZeroPaddingMapDataset + logger.info("Creating Zero Padding Data Stream. This may take a few minutes.") + train_ds = ( + intoken_dataset( + train_ds, + tokenizer=tokenizer, + max_length=data_args.max_length, + greedy_zero_padding=data_args.greedy_zero_padding, + ) + if train_ds is not None + else None + ) + ptq_ds = ( + intoken_dataset( + ptq_ds, + tokenizer=tokenizer, + max_length=data_args.max_length, + greedy_zero_padding=data_args.greedy_zero_padding, + ) + if ptq_ds is not None + else None + ) + + if eval_zero_padding: + dev_ds = ( + intoken_dataset( + dev_ds, + tokenizer=tokenizer, + max_length=data_args.max_length, + ) + if dev_ds is not None + else None + ) + + def compute_metrics_do_generation(eval_preds): + rouge1 = Rouge1() + rouge2 = Rouge2() + rougel = RougeL() + bleu4 = BLEU(n_size=4) + + predictions = [x[x != -100].tolist() for x in eval_preds.predictions] + references = [x[x != -100].tolist() for x in eval_preds.label_ids] + + predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True, clean_up_tokenization_spaces=False) + references = tokenizer.batch_decode(references, skip_special_tokens=True, clean_up_tokenization_spaces=False) + if data_args.save_generation_output: + with open(os.path.join(training_args.output_dir, "generated_output.json"), "w", encoding="utf-8") as f: + for pred, ref in zip(predictions, references): + out = {"output": pred, "tgt": ref} + f.write(json.dumps(out, ensure_ascii=False) + "\n") + + # for pred in predictions: + rouge1_score = rouge1.score(predictions, references) + rouge2_score = rouge2.score(predictions, references) + for pred, ref in zip(predictions, references): + rougel.add_inst(pred, [ref]) + bleu4.add_inst(pred, [ref]) + return { + "rouge1": rouge1_score, + "rouge2": rouge2_score, + "rougel": rougel.score(), + "bleu4": bleu4.score(), + } + + # Create trainer + + if ( + training_args.pipeline_parallel_degree > 1 + or training_args.sequence_parallel + or training_args.autotuner_benchmark + or data_args.zero_padding + or data_args.pad_to_max_length + ): + # NOTE(gongenlei): new add autotuner_benchmark + max_length = data_args.max_length + padding = "max_length" + else: + max_length = None + padding = True + + if training_args.pipeline_parallel_degree > 1: + metrics = None + elif data_args.eval_with_do_generation: + metrics = compute_metrics_do_generation + else: + metrics = compute_metrics + + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=train_ds, + eval_dataset=dev_ds, + tokenizer=tokenizer, + compute_metrics=metrics, + data_collator=DataCollatorForSeq2Seq( + tokenizer=tokenizer, + max_length=max_length, + padding=padding, + max_label_length=max_length, + return_tensors="np", + return_attention_mask=not model_args.flash_mask, + pad_to_multiple_of=data_args.pad_to_multiple_of, + ), + do_generation=data_args.eval_with_do_generation, + callbacks=[ZeroPaddingIterDatasetCallback()] if isinstance(train_ds, ZeroPaddingIterableDataset) else None, + gen_args=gen_args, + data_args=data_args, + ) + trainable_parameters = [p for p in model.parameters() if not p.stop_gradient] + trainer.set_optimizer_grouped_parameters(trainable_parameters) + + # QAT + if quant_args.do_qat: + from utils.quant import create_qat_model + + trainer.model = create_qat_model(quant_args, trainer.model, dtype) + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) + trainer.log_metrics("qat", train_result.metrics) + trainer.save_metrics("qat", train_result.metrics) + trainer.save_state() + + # PTQ + if quant_args.do_ptq: + if isinstance(model, LoRAModel): + raise NotImplementedError( + "PTQ strategy not supported for LoRA model. Please merge lora parameters to pretrain model first." + ) + from utils.quant import ( + apply_autoclip, + apply_ptq, + apply_shift, + apply_smooth, + get_ptq_model_config, + ) + + trainer.model.eval() + trainer.model.config.quantization_config.quant_type = quant_args.quant_type + trainer.model.config.quantization_config.smooth = quant_args.smooth + trainer.model.config.quantization_config.shift = quant_args.shift + trainer.model.config.quantization_config.shift_smooth_all_linears = ( + quant_args.smooth_all_linears or quant_args.shift_all_linears + ) + ptq_dataloader = trainer.get_ptq_dataloader(ptq_ds) + if quant_args.shift or quant_args.smooth: + ptq_model_config = get_ptq_model_config(trainer.model) + + if quant_args.shift: + apply_shift(quant_args, trainer, ptq_dataloader, ptq_model_config) + + if quant_args.smooth: + apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config) + + if quant_args.auto_clip: + apply_autoclip(quant_args, trainer, ptq_dataloader) + + apply_ptq(quant_args, trainer, ptq_dataloader) + trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) + + if quant_args.do_gptq: + if isinstance(model, LoRAModel): + raise NotImplementedError( + "PTQ strategy not supported for LoRA model. Please merge lora parameters to pretrain model first." + ) + from utils.quant import apply_gptq + + ptq_dataloader = trainer.get_ptq_dataloader(ptq_ds) + apply_gptq(quant_args, trainer, ptq_dataloader) + trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) + + # Evaluation test set + if training_args.do_predict: + test_ds = load_dataset( + "json", + data_files=os.path.join(data_args.dataset_name_or_path, "test.json"), + lazy=data_args.lazy, + )[0] + + test_ds = test_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation)) + if eval_zero_padding: + test_ds = intoken_dataset( + test_ds, + tokenizer=tokenizer, + max_length=data_args.max_length, + ) + eval_result = trainer.predict(test_ds).metrics + trainer.log_metrics("test", eval_result) + + if quant_args.load_quant_model and not quant_args.do_ptq: + if isinstance(model, LoRAModel): + raise NotImplementedError( + "PTQ strategy not supported for LoRA model. Please merge lora parameters to pretrain model first." + ) + from utils.quant import ( + apply_autoclip, + apply_ptq, + apply_shift, + apply_smooth, + get_ptq_model_config, + load_quant_model, + ) + + trainer.model.eval() + trainer.model.config.quantization_config.quant_type = quant_args.quant_type + trainer.model.config.quantization_config.smooth = quant_args.smooth + trainer.model.config.quantization_config.shift = quant_args.shift + trainer.model.config.quantization_config.shift_smooth_all_linears = ( + quant_args.smooth_all_linears or quant_args.shift_all_linears + ) + ptq_dataloader = trainer.get_ptq_dataloader(ptq_ds) + if quant_args.shift or quant_args.smooth: + ptq_model_config = get_ptq_model_config(trainer.model) + + if quant_args.shift: + apply_shift(quant_args, trainer, ptq_dataloader, ptq_model_config) + + if quant_args.smooth: + apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config) + + load_quant_model(trainer.model, quant_args, training_args.output_dir) + + # Evaluation dev set + if training_args.do_eval: + + logger.info("*** Evaluate result after ptq/qat/ etc.***") + eval_result = trainer.evaluate(dev_ds) + trainer.log_metrics("eval", eval_result) + + +if __name__ == "__main__": + main() diff --git a/llm/tools/merge_lora_params.py b/llm/tools/merge_lora_params.py index 136a7cee34dc..5496150daaf8 100644 --- a/llm/tools/merge_lora_params.py +++ b/llm/tools/merge_lora_params.py @@ -86,16 +86,23 @@ def lora_process(name, lora_config, state_dict, device, lora_state_dict=None): return weight = state_dict.pop(name + ".weight") + lora_use_mixer = lora_config.lora_use_mixer if lora_state_dict is None: lora_A = state_dict.pop(name + ".lora_A") lora_B = state_dict.pop(name + ".lora_B") + if lora_use_mixer: + lora_AB = state_dict.pop(name + ".lora_AB") else: lora_A = lora_state_dict.pop(name + ".lora_A") lora_B = lora_state_dict.pop(name + ".lora_B") + if lora_use_mixer: + lora_AB = lora_state_dict.pop(name + ".lora_AB") if device != "cpu": weight = weight.to(target_device) lora_A = lora_A.to(target_device) lora_B = lora_B.to(target_device) + if lora_use_mixer: + lora_AB = lora_AB.to(target_device) if not lora_config.rslora: scaling = lora_config.lora_alpha / lora_config.r else: @@ -105,9 +112,16 @@ def lora_process(name, lora_config, state_dict, device, lora_state_dict=None): weight = weight.astype("float32") lora_A = lora_A.astype("float32") lora_B = lora_B.astype("float32") - out = (weight + lora_A @ lora_B * scaling).astype("bfloat16") + if lora_use_mixer: + lora_AB = lora_AB.astype(lora_config.dtype) + out = (weight + lora_A @ lora_AB @ lora_B * scaling).astype(lora_config.dtype) + else: + out = (weight + lora_A @ lora_B * scaling).astype(lora_config.dtype) else: - out = (weight + lora_A @ lora_B * scaling).cpu() + if lora_use_mixer: + out = (weight + lora_A @ lora_AB @ lora_B * scaling).cpu() + else: + out = (weight + lora_A @ lora_B * scaling).cpu() state_dict[name + ".weight"] = out diff --git a/llm/utils/argument.py b/llm/utils/argument.py index 006d21aaec90..58bead70df63 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -136,6 +136,12 @@ class DataArgument: default=False, metadata={"help": "Pad the input sequence to `max_length`."}, ) + autoregressive: bool = field( + default=False, + metadata={"help": "Whether to use autoregressive mode."}, + ) + # Pose ralated parameters + use_pose_convert: bool = field(default=False, metadata={"help": "Whether to use PoSE data conversion function"}) def __post_init__(self): if self.task_name_or_path is not None: @@ -209,6 +215,9 @@ class ModelArgument: rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"}) pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"}) + lora_use_mixer: bool = field( + default=False, metadata={"help": "Whether to use MosLoRA: https://arxiv.org/pdf/2406.11909"} + ) # vera related parameters vera: bool = field(default=False, metadata={"help": "Whether to use vera technique"}) @@ -219,6 +228,9 @@ class ModelArgument: prefix_path: str = field(default=None, metadata={"help": "Initialize prefix state dict."}) num_prefix_tokens: int = field(default=128, metadata={"help": "Number of prefix tokens"}) + # reft related parameter + reft: bool = field(default=False, metadata={"help": "Whether using reft method"}) + from_aistudio: bool = field(default=False, metadata={"help": "Whether to load model from aistudio"}) save_to_aistudio: bool = field(default=False, metadata={"help": "Whether to save model to aistudio"}) aistudio_repo_id: str = field(default=None, metadata={"help": "The id of aistudio repo"}) @@ -229,6 +241,25 @@ class ModelArgument: neftune_noise_alpha: float = field(default=5.0, metadata={"help": "NEFT noise alpha"}) flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash_mask in flash attention."}) + # long sequence strategy + use_long_sequence_strategies: bool = field( + default=False, metadata={"help": "Whether to use long sequence strategy"} + ) + rope_scaling_factor: float = field(default=1.0, metadata={"help": "Rope extension scaling factor"}) + strategy_type: str = field(default=None, metadata={"help": "Long sequence strategy type"}) + strategy_name: str = field(default=None, metadata={"help": "Long sequence strategy name"}) + + +@dataclass +class ReftArgument: + layers: str = field(default="all", metadata={"help": "Layer configuration for the model."}) + position: str = field(default="f7+l7", metadata={"help": "Position parameter for model."}) + intervention_type: str = field(default="LoreftIntervention", metadata={"help": "Type of intervention."}) + rank: int = field(default=8, metadata={"help": "Rank parameter for model."}) + act_fn: str = field(default="linear", metadata={"help": "Activation function."}) + add_bias: bool = field(default=False, metadata={"help": "Flag indicating whether to add bias."}) + dropout: float = field(default=0.0, metadata={"help": "Dropout rate."}) + @dataclass class QuantArgument: @@ -300,7 +331,7 @@ class QuantArgument: ) shift_step: int = field(default=32, metadata={"help": "Sample steps when shift"}) - # Pre-quant methos Smooth related parameters + # Pre-quant methods Smooth related parameters smooth: bool = field(default=False, metadata={"help": "Whether to use Smooth"}) smooth_all_linears: bool = field(default=False, metadata={"help": "Whether to smooth all linears"}) smooth_sampler: str = field( diff --git a/llm/utils/data.py b/llm/utils/data.py index 4c76ed2e7cb1..db9d417743d0 100644 --- a/llm/utils/data.py +++ b/llm/utils/data.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random + import numpy as np from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM @@ -69,6 +71,27 @@ class DataFormatError(ValueError): pass +def tokenize_unsupervised_example(tokenizer, example, data_args, is_test=True, zero_padding=False, flash_mask=False): + if "src" in example: + source = example["src"][0] if isinstance(example["src"], list) else example["src"] + else: + raise DataFormatError( + f"Example format is wrong, please check: {example} or rewrite tokenize_example in data.py " + ) + tokenized_source = tokenizer( + source, + truncation=False, + padding=True, + max_length=data_args.scaled_max_length, + add_special_tokens=True, + ) + + if data_args.use_pose_convert: + tokenized_source = get_example_pose(tokenized_source, tokenizer, data_args) + + return tokenized_source + + def tokenize_example(tokenizer, example, data_args): if "src" in example and "tgt" in example: source = example["src"][0] if isinstance(example["src"], list) else example["src"] @@ -177,33 +200,114 @@ def tokenize_rounds_example(tokenizer, example, data_args, **kwargs): def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False): - if tokenizer.chat_template is not None: - return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding, flash_mask) - - tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args) - - if is_test: - return { - **tokenized_source, - "labels": tokenized_target_input_ids, - } - else: - input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids - source_length = len(tokenized_source["input_ids"]) - labels = [-100] * source_length + input_ids[source_length:] - # shift input_ids and labels - input_ids, labels = input_ids[:-1], labels[1:] - seq_length = len(input_ids) + if data_args.autoregressive: + tokenized_source = tokenize_unsupervised_example( + tokenizer, example, data_args, is_test=True, zero_padding=False, flash_mask=False + ) + input_ids = tokenized_source["input_ids"] + if "labels" in tokenized_source: + labels = tokenized_source["labels"] + else: + labels = input_ids + input_ids = input_ids[:-1] + [tokenizer.eos_token_id] + labels = labels[1:] + [-100] features = {"input_ids": input_ids, "labels": labels} if "position_ids" in tokenized_source: - features["position_ids"] = list(range(seq_length)) - if zero_padding: - if flash_mask: - features["attn_mask_startend_row_indices"] = [seq_length] * seq_length + features["position_ids"] = tokenized_source["position_ids"] + else: + if tokenizer.chat_template is not None: + return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding, flash_mask) + else: + tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args) + + if is_test: + return { + **tokenized_source, + "labels": tokenized_target_input_ids, + } else: - features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) + input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids + source_length = len(tokenized_source["input_ids"]) + labels = [-100] * source_length + input_ids[source_length:] + # shift input_ids and labels + input_ids, labels = input_ids[:-1], labels[1:] + seq_length = len(input_ids) + features = {"input_ids": input_ids, "labels": labels} + if "position_ids" in tokenized_source: + features["position_ids"] = list(range(seq_length)) + # maybe change here to suit flash_mask with longlora + if zero_padding: + if flash_mask: + features["attn_mask_startend_row_indices"] = [seq_length] * seq_length + else: + features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) + return features - return features + +def parse_positions(positions: str): + # parse position + first_n, last_n = 0, 0 + if "+" in positions: + first_n = int(positions.split("+")[0].strip("f")) + last_n = int(positions.split("+")[1].strip("l")) + else: + if "f" in positions: + first_n = int(positions.strip("f")) + elif "l" in positions: + last_n = int(positions.strip("l")) + return first_n, last_n + + +# layers * intervention tokens +def get_intervention_locations(positions, last_position, num_interventions): + """ + This function generates the intervention locations. + """ + _first_n, _last_n = parse_positions(positions) + + first_n = min(last_position // 2, _first_n) + last_n = min(last_position // 2, _last_n) + + pad_amount = (_first_n - first_n) + (_last_n - last_n) + pad_position = -1 + + position_list = ( + [i for i in range(first_n)] + + [i for i in range(last_position - last_n, last_position)] + + [pad_position for _ in range(pad_amount)] + ) + intervention_locations = [position_list] * num_interventions + + return intervention_locations + + +def get_src_last_position(labels): + for i in range(len(labels) - 1, -1, -1): + if labels[i] == -100: + return i + 2 + + +# reft +def convert_example_for_reft( + example, + tokenizer, + data_args, + is_test=True, + zero_padding=False, + flash_mask=False, + positions="f7+l7", + num_interventions=32, +): + features = convert_example_common(example, tokenizer, data_args, is_test, zero_padding, flash_mask) + # src的最后一个位置 + if not is_test: + last_position = get_src_last_position(features["labels"]) + else: + last_position = len(features["input_ids"]) + # add positons + intervention_locations = get_intervention_locations(positions, last_position, num_interventions) + features["intervention_locations"] = intervention_locations + return features def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False): @@ -289,3 +393,28 @@ def convert_example_chatglm(example, tokenizer, data_args, is_test=True, zero_pa features["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) return features + + +def get_example_pose(tokenized_source, tokenizer, data_args): + + ids = tokenized_source["input_ids"] + len_chunk = min(len(ids), data_args.max_length) + if len(tokenized_source["input_ids"]) <= data_args.max_length: + tokenized_source["input_ids"] += [tokenizer.eos_token_id] + + len_input = len(ids) + + lt1 = 0 # chunk1 start pos + rt1 = random.randint(1, (len_chunk) // 2) # chunk1 end pos + + rt2 = random.randint(lt1 + len_chunk, len_input - 1) # chunk2 end pos + lt2 = rt2 - (len_chunk - (rt1 - lt1)) # chunk2 start pos + chunked_ids = ids[lt1:rt1] + ids[lt2:rt2] + labels = ids[lt1 + 1 : rt1 + 1] + ids[lt2 + 1 : rt2 + 1] + + pos_ids = range(len(chunked_ids)) + pos_ids = [x + lt1 if i < rt1 - lt1 else x + (lt2 - (rt1 - lt1)) for i, x in enumerate(pos_ids)] + + features = {"input_ids": chunked_ids, "labels": labels, "position_ids": pos_ids} + + return features diff --git a/llm/utils/quant.py b/llm/utils/quant.py index 8d5cee2f4d74..c2459994ffb3 100644 --- a/llm/utils/quant.py +++ b/llm/utils/quant.py @@ -15,12 +15,6 @@ import os import paddle -from experimental.layers.custom_attention import QuantizedCustomAttentionLayer -from experimental.observer.abs_max import AbsmaxObserver -from experimental.observer.abs_max_headwise import AbsMaxHeadwiseObserver -from experimental.observer.avg import AVGObserver -from experimental.observer.avg_headwise import AvgHeadwiseObserver -from experimental.observer.channel_wise import ChannelWiseObserver from paddle import nn from paddle.distributed.fleet.meta_parallel import ( ColumnParallelLinear, @@ -44,10 +38,16 @@ QuantizedColumnParallelLinear, QuantizedRowParallelLinear, ) +from paddleslim.quant.layers.custom_attention import QuantizedCustomAttentionLayer from paddleslim.quant.observers import ( AbsMaxChannelWiseWeightObserver, GroupWiseWeightObserver, ) +from paddleslim.quant.observers.abs_max import AbsmaxObserver +from paddleslim.quant.observers.abs_max_headwise import AbsMaxHeadwiseObserver +from paddleslim.quant.observers.avg import AVGObserver +from paddleslim.quant.observers.avg_headwise import AvgHeadwiseObserver +from paddleslim.quant.observers.channel_wise import ChannelWiseObserver from paddlenlp.peft import PrefixModelForCausalLM from paddlenlp.peft.lora import ( diff --git a/paddlenlp/experimental/transformers/__init__.py b/paddlenlp/experimental/transformers/__init__.py index 3bd15a024e0b..090a934e3769 100644 --- a/paddlenlp/experimental/transformers/__init__.py +++ b/paddlenlp/experimental/transformers/__init__.py @@ -20,6 +20,7 @@ from .llama import * from .mixtral import * from .opt import * +from .proposers import * from .qwen import * from .qwen2 import * from .qwen2_moe import * diff --git a/paddlenlp/experimental/transformers/bloom/modeling.py b/paddlenlp/experimental/transformers/bloom/modeling.py index ecdae7b000e8..2d1218802449 100644 --- a/paddlenlp/experimental/transformers/bloom/modeling.py +++ b/paddlenlp/experimental/transformers/bloom/modeling.py @@ -594,13 +594,13 @@ def set_transformer_block(self, transformer_config): else: self.transformer_block = FusedBlockMultiTransformer(transformer_config) - def remove_padding(self, input_ids, seq_lens_this_time): + def remove_padding(self, input_ids, seq_lens_this_time, draft_tokens=None, seq_lens_encoder=None): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) from paddlenlp_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( - input_ids, cum_offsets_now, token_num, seq_lens_this_time + input_ids, cum_offsets_now, token_num, seq_lens_this_time, draft_tokens, seq_lens_encoder ) return ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k diff --git a/paddlenlp/experimental/transformers/chatglm_v2/modeling.py b/paddlenlp/experimental/transformers/chatglm_v2/modeling.py index f268742a4a00..23f9d3dddb24 100644 --- a/paddlenlp/experimental/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/experimental/transformers/chatglm_v2/modeling.py @@ -387,13 +387,13 @@ def set_transformer_block(self, transformer_config): else: self.transformer_block = FusedBlockMultiTransformer(transformer_config) - def remove_padding(self, input_ids, seq_lens_this_time): + def remove_padding(self, input_ids, seq_lens_this_time, draft_tokens=None, seq_lens_encoder=None): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) from paddlenlp_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( - input_ids, cum_offsets_now, token_num, seq_lens_this_time + input_ids, cum_offsets_now, token_num, seq_lens_this_time, draft_tokens, seq_lens_encoder ) return ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index bac51b6ffbf0..33b7282bebcc 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -144,6 +144,12 @@ class AvxConfig: cache_dtype: str = "fp16" +@dataclass +class SpeculateConfig: + speculate_max_draft_token_num: int = (1,) + speculate_method: str = None + + class FusedMultiTransformerConfig: def __init__( self, @@ -208,6 +214,7 @@ def __init__( append_attn=False, moe_config=MoeConfig(), avx_config=AvxConfig(), + speculate_config=SpeculateConfig(), ): self.embed_dim = embed_dim self.num_heads = num_heads @@ -285,6 +292,7 @@ def __init__( self.moe_config = moe_config self.avx_config = avx_config + self.speculate_config = speculate_config class FusedMultiTransformerBase(Layer): @@ -1036,7 +1044,6 @@ def forward( kwargs["decoder_block_shape_q"] = 16 kwargs["max_partition_size"] = 32768 kwargs["encoder_max_partition_size"] = 32768 - kwargs["speculate_max_draft_token_num"] = 5 from paddlenlp_ops import get_block_shape_and_split_kv_block @@ -1061,7 +1068,7 @@ def forward( kwargs.get("decoder_block_shape_q", 16), self.num_heads // self.kv_num_heads, kwargs.get("block_size", 64), - kwargs["speculate_max_draft_token_num"], + self.config.speculate_config.speculate_max_draft_token_num, ) residual_input = src @@ -2202,9 +2209,9 @@ def compute_attn( kwargs.get("decoder_block_shape_q", 16), kwargs.get("max_partition_size", 32768), kwargs.get("encoder_max_partition_size", 32768), - kwargs["speculate_max_draft_token_num"], # speculate_max_draft_token_num + self.config.speculate_config.speculate_max_draft_token_num, True, # causal - False, # speculate_decoder + self.config.speculate_config.speculate_method is not None, # speculate_decoder )[0] else: if paddle.is_compiled_with_xpu(): @@ -2299,8 +2306,15 @@ def post_process(self, **kwargs): seq_lens_encoder = kwargs.get("seq_lens_encoder", None) seq_lens_decoder = kwargs.get("seq_lens_decoder", None) max_input_length = kwargs.get("max_input_length", -1) - - out = rebuild_padding_v2(multi_block_output, cum_offsets, seq_lens_decoder, seq_lens_encoder, max_input_length) + output_padding_offset = kwargs.get("output_padding_offset", None) # only used in speculative decoding + out = rebuild_padding_v2( + multi_block_output, + cum_offsets, + seq_lens_decoder, + seq_lens_encoder, + output_padding_offset, + max_input_length, + ) return out @@ -2393,9 +2407,9 @@ def compute_attn( kwargs.get("decoder_block_shape_q", 16), kwargs.get("max_partition_size", 32768), kwargs.get("encoder_max_partition_size", 32768), - kwargs["speculate_max_draft_token_num"], # speculate_max_draft_token_num + self.config.speculate_config.speculate_max_draft_token_num, True, # causal - False, # speculate_decoder + self.config.speculate_config.speculate_method is not None, # speculate_decoder )[0] else: fmha_out = paddle.incubate.nn.functional.block_multihead_attention( @@ -2444,7 +2458,6 @@ def compute_attn( class FusedBlockMultiTransformerFP8(FusedBlockMultiTransformer): def __init__(self, config: FusedMultiTransformerConfig): - """""" super().__init__(config) self.act_scales = None self.weight_scales = None @@ -2759,7 +2772,7 @@ def compute_attn( kwargs.get("decoder_block_shape_q", 16), kwargs.get("max_partition_size", 32768), kwargs.get("encoder_max_partition_size", 32768), - kwargs["speculate_max_draft_token_num"], # speculate_max_draft_token_num + self.config.speculate_config.speculate_max_draft_token_num, True, # causal False, # speculate_decoder )[0] diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index db428b6cd80a..6dc2f1e20d0e 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -530,6 +530,15 @@ def to_static(self, output_path: str, config: dict): cache_v_dequant_scales, tgt_mask_spec, ] + if config.get("speculate_method", None) is not None: + speculate_spec = [ + paddle.static.InputSpec(shape=[None, None], dtype="int64", name="draft_tokens"), + paddle.static.InputSpec(shape=[None, None], dtype="int64", name="accept_tokens"), + paddle.static.InputSpec(shape=[None], dtype="int32", name="accept_num"), + paddle.static.InputSpec(shape=[None], dtype="int32", name="actual_draft_token_num"), + ] + input_spec.extend(speculate_spec) + model = paddle.jit.to_static(self.generate, input_spec=input_spec) paddle.jit.save( model, output_path, skip_prune_program=True @@ -579,6 +588,10 @@ def generate( k_dequant_scales=None, v_dequant_scales=None, tgt_mask=None, + draft_tokens=None, + accept_tokens=None, + accept_num=None, + actual_draft_token_num=None, **model_kwargs, ): @@ -609,6 +622,11 @@ def generate( model_kwargs["is_block_step"] = is_block_step model_kwargs["src_mask"] = src_mask model_kwargs["tgt_mask"] = tgt_mask + # speculate decoding related parameters + model_kwargs["draft_tokens"] = draft_tokens + model_kwargs["accept_tokens"] = accept_tokens + model_kwargs["accept_num"] = accept_num + model_kwargs["actual_draft_token_num"] = actual_draft_token_num ret = self.sample( eos_token_id, @@ -700,7 +718,12 @@ def _post_process_( ) from paddlenlp_ops import save_output - save_output(next_tokens, model_kwargs["not_need_stop"], self.config.tensor_parallel_rank) + save_output( + next_tokens, + model_kwargs["not_need_stop"], + model_kwargs.get("accept_tokens", None), # only initialized in speculative decoding + self.config.tensor_parallel_rank, + ) return next_tokens # encoder diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index d706948cf221..e72081499e8d 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -19,9 +19,18 @@ import numpy as np import paddle +import paddle.nn.functional as F from paddle import nn from paddle.distributed import fleet from paddle.nn.quant import weight_quantize +from paddlenlp_ops import ( + save_output, + speculate_get_output_padding_offset, + speculate_get_seq_lens_output, + speculate_set_value_by_flags_and_idx, + speculate_verify_and_update, + top_p_candidates, +) from paddlenlp.experimental.model_utils import ( ActScalesLoader, @@ -40,6 +49,7 @@ FusedMultiTransformerBase, FusedMultiTransformerConfig, FusedMultiTransformerWeightOnly, + SpeculateConfig, ) from paddlenlp.experimental.transformers.generation_utils import ( GenerationAvxInferenceModel, @@ -71,6 +81,7 @@ "LlamaForCausalLMInferenceModel", "LlamaForCausalLMAvxInferenceModel", "LlamaForCausalLMBlockInferenceModel", + "LlamaForCausalLMSpeculateInferenceModel", "LlamaForMiniGPT4InferenceModel", ] @@ -608,6 +619,12 @@ def __init__(self, config: LlamaConfig): paddle.ParamAttr(name="fusellama.{}.cache_v_out_scale".format(i)) for i in range(self.num_layers) ] + speculate_config = SpeculateConfig( + speculate_method=config.speculate_method if hasattr(config, "speculate_method") else None, + speculate_max_draft_token_num=config.speculate_max_draft_token_num + if hasattr(config, "speculate_max_draft_token_num") + else 1, + ) transformer_config = FusedMultiTransformerConfig( embed_dim=self.hidden_size, num_heads=self.num_attention_heads, @@ -657,6 +674,7 @@ def __init__(self, config: LlamaConfig): rank_id=config.tensor_parallel_rank, trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True), append_attn=config.append_attn, + speculate_config=speculate_config, ) self.set_transformer_block(transformer_config) @@ -1382,13 +1400,13 @@ def set_transformer_block(self, transformer_config): else: self.transformer_block = FusedBlockMultiTransformer(transformer_config) - def remove_padding(self, input_ids, seq_lens_this_time): + def remove_padding(self, input_ids, seq_lens_this_time, draft_tokens=None, seq_lens_encoder=None): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) from paddlenlp_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( - input_ids, cum_offsets_now, token_num, seq_lens_this_time + input_ids, cum_offsets_now, token_num, seq_lens_this_time, draft_tokens, seq_lens_encoder ) return ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k @@ -1438,6 +1456,54 @@ def forward( ) +@register_base_model +class LlamaSpeculateInferenceModel(LlamaBlockInferenceModel): + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + caches=None, + pre_caches=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + draft_tokens=None, + **kwargs, + ): + seq_lens_this_time = kwargs.get("seq_lens_this_time", None) + seq_lens_encoder = kwargs.get("seq_lens_encoder", None) + rope_emb = kwargs.get("rope_emb", None) + ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding( + input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder + ) + kwargs["cu_seqlens_q"] = cu_seqlens_q + kwargs["cu_seqlens_k"] = cu_seqlens_k + kwargs["padding_offsets"] = padding_offset + kwargs["max_input_length"] = self.max_seq_len + + inputs_embeds = self.embed_tokens(ids_remove_padding) + with dy2st_nocheck_guard_context(): + hidden_states, _ = self.transformer_block( + input_ids=input_ids, + src=inputs_embeds, + cum_offsets=cum_offsets, + attn_mask=attention_mask, + caches=caches, + pre_caches=pre_caches, + rotary_embs=rope_emb, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + class LlamaForCausalLMAvxInferenceModel(GenerationAvxInferenceModel, LlamaPretrainedModel): _keys_to_ignore_on_load_missing = [r"lm_head.weight"] @@ -1911,6 +1977,186 @@ def set_state_dict(self, state_dict): self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) +class LlamaForCausalLMSpeculateInferenceModel(LlamaForCausalLMBlockInferenceModel): + def __init__(self, config): + LlamaPretrainedModel.__init__(self, config) + self.max_seq_len = config.max_seq_len + self.max_candidate_len = config.speculate_max_candidate_len + self.verify_window = config.speculate_verify_window + self.llama = LlamaSpeculateInferenceModel(config) + self.lm_head = LlamaLMHead(config) + + def prepare_inputs_for_generation(self, **kwargs): + model_inputs = super().prepare_inputs_for_generation(**kwargs) + draft_tokens = kwargs["draft_tokens"] + model_inputs["draft_tokens"] = draft_tokens + output_padding_offset = kwargs["output_padding_offset"] + model_inputs["output_padding_offset"] = output_padding_offset + + return model_inputs + + def get_output_padding_offset(self, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder): + """ + In the senerio of speculate decoding, the length of output token after rebuild_padding is no longer bsz. + So we need to calculate the output_padding_offset after rebuild_padding. + """ + seq_lens_output = speculate_get_seq_lens_output(seq_lens_this_time, seq_lens_encoder, seq_lens_decoder) + out_token_num = paddle.sum(seq_lens_output) + output_cum_offsets_tmp = paddle.cumsum(self.max_seq_len - seq_lens_output) + output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset( + output_cum_offsets_tmp, out_token_num, seq_lens_output, self.max_seq_len + ) + return output_padding_offset, output_cum_offsets + + def forward( + self, + input_ids, + src_mask=None, + pre_caches=None, + caches=None, + seq_lens_this_time=None, + seq_lens_encoder=None, + seq_lens_decoder=None, + rope_emb=None, + block_tables=None, + k_quant_scales=None, + v_quant_scales=None, + k_dequant_scales=None, + v_dequant_scales=None, + draft_tokens=None, + output_padding_offset=None, + ): + outputs = self.llama( + input_ids, + src_mask=src_mask, + caches=caches, + rope_emb=rope_emb, + block_tables=block_tables, + pre_caches=pre_caches, + seq_lens_this_time=seq_lens_this_time, + seq_lens_encoder=seq_lens_encoder, + seq_lens_decoder=seq_lens_decoder, + k_quant_scales=k_quant_scales, + v_quant_scales=v_quant_scales, + k_dequant_scales=k_dequant_scales, + v_dequant_scales=v_dequant_scales, + draft_tokens=draft_tokens, + output_padding_offset=output_padding_offset, + ) + + hidden_states = outputs[0] + logits = self.lm_head( + hidden_states, + tensor_parallel_output=False, + ) + + return logits + + def sample( + self, + eos_token_id, + top_k, + top_p, + penalty_score, + frequency_score, + presence_score, + temperature=None, + min_tokens_to_keep=1, + **model_kwargs + ): + def _forward_(**args): + model_inputs = self.prepare_inputs_for_generation(**args) + return self(**model_inputs) + + def _post_process_( + outputs, + top_k, + top_p, + penalty_score, + frequency_score, + presence_score, + temperature, + model_kwargs, + ): + logits = paddle.cast(outputs, paddle.float32) + + # TODO(Wanglongzhi2001): get_token_penalty_multi_scores_v2 don't support seqlen > 1 + + # sample + probs = F.softmax(logits) + verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( + probs, top_p, model_kwargs["output_padding_offset"], self.max_candidate_len, self.max_seq_len + ) # [token_num, max_candidate_len] + + # Speculate Verify And Update + speculate_verify_and_update( + model_kwargs["accept_tokens"], + model_kwargs["accept_num"], + model_kwargs["step_idx"], + model_kwargs["seq_lens_encoder"], + model_kwargs["seq_lens_decoder"], + model_kwargs["stop_flags"], + model_kwargs["not_need_stop"], + model_kwargs[ + "draft_tokens" + ], # Both input and output, need to write the last 1 token accepted to position 0. + model_kwargs["seq_lens_this_time"], + verify_tokens, + verify_scores, + model_kwargs["max_dec_len"], + eos_token_id, + model_kwargs["is_block_step"], + model_kwargs["output_cum_offsets"], + actual_candidate_len, + model_kwargs["actual_draft_token_num"], + top_p, + self.max_seq_len, + self.verify_window, + True, # enable_topp + ) + + # Since the output token length is not bsz anymore, we need to change the token_num + # in the msg queue and write accept_num tokens into the msg queue. + save_output( + model_kwargs["accept_tokens"], + model_kwargs["not_need_stop"], + model_kwargs["accept_num"], + self.config.tensor_parallel_rank, + ) + + # If seq_lens_decoder is 0 (means stop), accept_num should be set to 0 + model_kwargs["accept_num"][model_kwargs["seq_lens_decoder"] == 0] = 0 + + # Update pre_ids through accept tokens + speculate_set_value_by_flags_and_idx( + model_kwargs["pre_ids"], + model_kwargs["accept_tokens"], + model_kwargs["accept_num"], + model_kwargs["stop_flags"], + model_kwargs["seq_lens_this_time"], + model_kwargs["seq_lens_encoder"], + model_kwargs["seq_lens_decoder"], + model_kwargs["step_idx"], + ) + + # # Prepare output padding offset + output_padding_offset, output_cum_offsets = self.get_output_padding_offset( + model_kwargs["seq_lens_this_time"], model_kwargs["seq_lens_encoder"], model_kwargs["seq_lens_decoder"] + ) + model_kwargs["output_padding_offset"] = output_padding_offset + model_kwargs["output_cum_offsets"] = output_cum_offsets + + # LLM + outputs = _forward_(**model_kwargs) + + # Post-process + _post_process_( + outputs, top_k, top_p, penalty_score, frequency_score, presence_score, temperature, model_kwargs + ) + + return top_p + + class LlamaForMiniGPT4InferenceModel(LlamaForCausalLMInferenceModel): """ This class is 99% like LlamaForCausalLMInferenceModel. diff --git a/paddlenlp/experimental/transformers/mixtral/modeling.py b/paddlenlp/experimental/transformers/mixtral/modeling.py index b7c40b761108..7b79c9fc1b5a 100644 --- a/paddlenlp/experimental/transformers/mixtral/modeling.py +++ b/paddlenlp/experimental/transformers/mixtral/modeling.py @@ -1086,13 +1086,13 @@ def set_transformer_block(self, transformer_config): else: self.transformer_block = FusedBlockMultiTransformer(transformer_config) - def remove_padding(self, input_ids, seq_lens_this_time): + def remove_padding(self, input_ids, seq_lens_this_time, draft_tokens=None, seq_lens_encoder=None): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) from paddlenlp_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( - input_ids, cum_offsets_now, token_num, seq_lens_this_time + input_ids, cum_offsets_now, token_num, seq_lens_this_time, draft_tokens, seq_lens_encoder ) return ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k diff --git a/paddlenlp/experimental/transformers/proposers.py b/paddlenlp/experimental/transformers/proposers.py new file mode 100644 index 000000000000..75362f0cc037 --- /dev/null +++ b/paddlenlp/experimental/transformers/proposers.py @@ -0,0 +1,92 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from abc import ABC, abstractmethod + +import paddle +from paddlenlp_ops import ngram_match + + +class Proposer(ABC): + """ + Abstract base class for all proposers that can be used in the speculative decoding framework. + The subclasses of this class must implement the run method to get the draft tokens that are + generated by the proposer. + """ + + def __init__(self, **kwargs): + pass + + @abstractmethod + def run(self, model_inputs: dict[str, paddle.Tensor], **kargs): + """ + Get the draft tokens that are generated by the proposer. + """ + raise NotImplementedError() + + +class InferenceWithReferenceProposer(Proposer): + """ + InferenceWithReference(https://arxiv.org/pdf/2304.04487) is one of the speculative decoding method. + It match tokens in the input and output as draft tokens. + """ + + def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size: int, max_seq_len: int, **kwargs): + """ + Args: + max_draft_token_num (int): + Maximum number of tokens a proposer can generate at one time. + The hyperparameter of k in the paper. + max_ngram_size (int): + The maximum size of the window used to match inputs and outputs. + The hyperparameter of n in the paper. + max_batch_size (int): + The maximum batch size. + max_seq_len (int): + The maximum sequence length. + """ + super().__init__() + self.max_ngram_size = max_ngram_size + self.input_ids_len = paddle.zeros(shape=[max_batch_size, 1], dtype="int64").cpu() + self.input_ids_cpu = paddle.zeros(shape=[max_batch_size, max_seq_len], dtype="int64").cpu() + self.max_batch_size = max_batch_size + self.max_draft_token_num = max_draft_token_num + + def run(self, model_inputs: dict[str, paddle.Tensor], **kargs): + """ + Use ngram_match to get draft tokens from the input and output. + """ + draft_tokens = model_inputs["draft_tokens"].cpu() + seq_lens_this_time = kargs["seq_lens_this_time"].cpu() + seq_lens_encoder = model_inputs["seq_lens_encoder"].cpu() + seq_lens_decoder = model_inputs["seq_lens_decoder"].cpu() + ngram_match( + self.input_ids_cpu, + self.input_ids_len.cpu(), + model_inputs["pre_ids"].cpu(), + model_inputs["step_idx"].cpu(), + model_inputs["actual_draft_token_num"].cpu(), + draft_tokens, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + kargs["real_batch_size"], + self.max_ngram_size, + self.max_draft_token_num, + ) + + model_inputs["draft_tokens"][:] = draft_tokens.cuda() + model_inputs["seq_lens_encoder"][:] = seq_lens_encoder.cuda() + kargs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index e4cb60515921..a0ae895633ae 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -1247,13 +1247,13 @@ def set_transformer_block(self, transformer_config): else: self.transformer_block = FusedBlockMultiTransformer(transformer_config) - def remove_padding(self, input_ids, seq_lens_this_time): + def remove_padding(self, input_ids, seq_lens_this_time, draft_tokens=None, seq_lens_encoder=None): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) from paddlenlp_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( - input_ids, cum_offsets_now, token_num, seq_lens_this_time + input_ids, cum_offsets_now, token_num, seq_lens_this_time, draft_tokens, seq_lens_encoder ) return ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k diff --git a/paddlenlp/experimental/transformers/qwen2_moe/modeling.py b/paddlenlp/experimental/transformers/qwen2_moe/modeling.py index baafc0d41b5c..b5519d36e8f2 100644 --- a/paddlenlp/experimental/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2_moe/modeling.py @@ -774,13 +774,13 @@ def set_transformer_block(self, transformer_config): else: self.transformer_block = FusedBlockMultiTransformer(transformer_config) - def remove_padding(self, input_ids, seq_lens_this_time): + def remove_padding(self, input_ids, seq_lens_this_time, draft_tokens=None, seq_lens_encoder=None): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) from paddlenlp_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( - input_ids, cum_offsets_now, token_num, seq_lens_this_time + input_ids, cum_offsets_now, token_num, seq_lens_this_time, draft_tokens, seq_lens_encoder ) return ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k diff --git a/paddlenlp/ops/distributed/parallel.py b/paddlenlp/ops/distributed/parallel.py index a0d93359efb6..6abd91717b7c 100644 --- a/paddlenlp/ops/distributed/parallel.py +++ b/paddlenlp/ops/distributed/parallel.py @@ -17,6 +17,7 @@ try: from paddle.distributed.fleet import fleet + from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker except Exception: import warnings @@ -88,8 +89,16 @@ def __init__(self, num_embeddings, embedding_dim, rank, world_size, weight_attr= self._weight_attr = weight_attr self._name = name - self.weight = self.create_parameter(attr=self._weight_attr, shape=self._size, dtype=self._dtype, is_bias=False) - self.weight.is_distributed = True + if self.is_mp and paddle.in_dynamic_mode(): + with get_rng_state_tracker().rng_state(): + self.weight = self.create_parameter( + attr=self._weight_attr, shape=self._size, dtype=self._dtype, is_bias=False + ) + else: + self.weight = self.create_parameter( + attr=self._weight_attr, shape=self._size, dtype=self._dtype, is_bias=False + ) + self.weight.is_distributed = True if self.is_mp else False startup_block = paddle.static.default_startup_program().global_block() main_block = paddle.static.default_main_program().global_block() diff --git a/paddlenlp/peft/__init__.py b/paddlenlp/peft/__init__.py index bf290397ec2e..434d68ddaa93 100644 --- a/paddlenlp/peft/__init__.py +++ b/paddlenlp/peft/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + from .lora import LoRAConfig, LoRAModel from .prefix import PrefixConfig, PrefixModelForCausalLM +from .reft import ReFTModel from .vera import VeRAConfig, VeRAModel diff --git a/paddlenlp/peft/lora/lora_config.py b/paddlenlp/peft/lora/lora_config.py index 40b59e5c1a17..e670b1c818b6 100644 --- a/paddlenlp/peft/lora/lora_config.py +++ b/paddlenlp/peft/lora/lora_config.py @@ -86,6 +86,10 @@ class LoRAConfig: "help": "Whether to use quick lora, The use of Quick LoRa will only take effect when lora_dropout is set to 0." }, ) + lora_use_mixer: bool = field( + default=False, + metadata={"help": "Whether to use mos lora."}, + ) def __post_init__(self): if self.use_quick_lora and self.lora_dropout > 0: diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index afb8cb744766..60be3faea7a8 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -64,6 +64,7 @@ def __init__( rslora: bool = False, lora_plus_scale: float = 1.0, pissa: bool = False, + lora_use_mixer: bool = False, **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) @@ -79,6 +80,7 @@ def __init__( # Mark the weight as unmerged self.merged = False self.pissa = pissa + self.lora_use_mixer = lora_use_mixer # Actual trainable parameters self.lora_A = self.create_parameter( @@ -87,6 +89,15 @@ def __init__( is_bias=False, default_initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu"), ) + if self.lora_use_mixer: + self.lora_AB = self.create_parameter( + shape=[r, r], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform( + negative_slope=math.sqrt(5), nonlinearity="leaky_relu" + ), + ) self.lora_B = self.create_parameter( shape=[r, out_features], dtype=self._dtype, @@ -135,13 +146,19 @@ def pissa_init(self, rank): def merge(self): if not self.merged: - new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling + if self.lora_use_mixer: + new_weight = self.weight + self.lora_A @ self.lora_AB @ self.lora_B * self.scaling + else: + new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = True def unmerge(self): if self.merged: - new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling + if self.lora_use_mixer: + new_weight = self.weight - self.lora_A @ self.lora_AB @ self.lora_B * self.scaling + else: + new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = False @@ -156,7 +173,10 @@ def forward(self, input: paddle.Tensor, *args, **kwargs): result = quick_lora(input, self.lora_A, self.lora_B, self.weight, self.bias, self.scaling) else: result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name) - result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling + if self.lora_use_mixer: + result += (self.lora_dropout(input) @ self.lora_A @ self.lora_AB @ self.lora_B) * self.scaling + else: + result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling return result def extra_repr(self): diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 4f619307f5cd..3f0453b7bc35 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -154,11 +154,14 @@ def __init__(self, model, lora_config: LoRAConfig) -> None: if issubclass(type(self.model), PipelineLayer): self.is_pipelinemodel = True self.model._single_to_pp_mapping = None + if (self.lora_config.tensor_parallel_degree > 1 or self.is_pipelinemodel) and self.lora_config.lora_use_mixer: + raise NotImplementedError("lora_use_mixer is not supported in tensor parallel mode.") if self.lora_config.tensor_parallel_degree != self.model.config.tensor_parallel_degree: self.lora_config.tensor_parallel_degree = self.model.config.tensor_parallel_degree logger.warning( f"Reset tensor_parallel_degree of lora_config to {self.model.config.tensor_parallel_degree}." ) + self.forward = self.model.forward logger.info("Mark only lora and trainable_module as trainable.") @@ -262,7 +265,9 @@ def from_pretrained(cls, model, lora_path, **kwargs): pre_tensor_parallel_split = True tp_actions = lora_model._get_tensor_parallel_convert_actions(loaded_keys, is_split=True) state_dict = load_state_dict( - shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys + shard_file, + tp_actions if pre_tensor_parallel_split else None, + expected_keys, ) error_msgs += _load_state_dict_into_model(lora_model.model, state_dict, "") del state_dict @@ -443,6 +448,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) pissa=lora_config.pissa, bias_attr=False if module.bias is None else None, use_quick_lora=lora_config.use_quick_lora, + lora_use_mixer=lora_config.lora_use_mixer, ) if isinstance(module, nn.Conv2D): lora_module = LoRAConv2D( diff --git a/paddlenlp/peft/prefix/prefix_model.py b/paddlenlp/peft/prefix/prefix_model.py index 29a34442280c..25d25a354b47 100644 --- a/paddlenlp/peft/prefix/prefix_model.py +++ b/paddlenlp/peft/prefix/prefix_model.py @@ -333,7 +333,9 @@ def from_pretrained( pre_tensor_parallel_split = True tp_actions = prefix_model._get_tensor_parallel_convert_actions(is_split=True) state_dict = load_state_dict( - shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys + shard_file, + tp_actions if pre_tensor_parallel_split else None, + expected_keys, ) error_msgs += _load_state_dict_into_model(prefix_model.prefix_encoder, state_dict, "") del state_dict diff --git a/paddlenlp/peft/reft/__init__.py b/paddlenlp/peft/reft/__init__.py new file mode 100644 index 000000000000..0f8b4f51d4b9 --- /dev/null +++ b/paddlenlp/peft/reft/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .interventions import ( + LoreftIntervention, + LowRankRotateLayer, + TinyIntervention, + intervention_mapping, +) +from .modeling_utils import ReftDataCollator +from .predict import do_predict +from .reft_config import ReFTConfig +from .reft_model import ReFTModel diff --git a/paddlenlp/peft/reft/interventions.py b/paddlenlp/peft/reft/interventions.py new file mode 100644 index 000000000000..030a90cd00d1 --- /dev/null +++ b/paddlenlp/peft/reft/interventions.py @@ -0,0 +1,148 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import paddle +import paddle.nn as nn +from paddle import ParamAttr + + +def linear_act(x): + return x + + +ACT2FN = { + "linear": linear_act, + "relu": nn.ReLU(), +} + + +# A linear transformation with orthogonal initialization. +class LowRankRotateLayer(nn.Layer): + def __init__(self, n, m): + super().__init__() + self.weight = self.create_parameter( + shape=[n, m], + attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Orthogonal()), + is_bias=False, + ) + + def forward(self, x): + return paddle.matmul(x.astype(self.weight.dtype), self.weight) + + +# existing methods LoReFT(h) = h + R^T(Wh + b − Rh) +class LoreftIntervention(nn.Layer): + def __init__(self, **kwargs): + super(LoreftIntervention, self).__init__() + rotate_layer = LowRankRotateLayer(kwargs["embed_dim"], kwargs["low_rank_dimension"]) + self.rotate_layer = rotate_layer + self.learned_source = nn.Linear( + kwargs["embed_dim"], + kwargs["low_rank_dimension"], + weight_attr=ParamAttr(initializer=nn.initializer.Orthogonal()), + ) + self.data_type = kwargs["dtype"] + self.learned_source = self.learned_source.astype(self.data_type) + self.dropout = nn.Dropout(kwargs["dropout"] if "dropout" in kwargs else 0.0) + self.act_fn = ( + ACT2FN["linear"] if "act_fn" not in kwargs or kwargs["act_fn"] is None else ACT2FN[kwargs["act_fn"]] + ) + + def forward( + self, + base, + ): + rotated_base = self.rotate_layer(base) + output = base + paddle.matmul( + ( + self.act_fn( + self.learned_source( + base, + ) + ) + - rotated_base + ), + self.rotate_layer.weight.T, + ) + return self.dropout(output.astype(base.dtype)) + + def load_state_dict(self, state_dict, *args, **kwargs): + self.learned_source.weight.data = state_dict["learned_source.weight"].astype(self.data_type) + self.learned_source.bias.data = state_dict["learned_source.bias"].astype(self.data_type) + overload_w = state_dict["rotate_layer.weight"].astype(self.data_type) + overload_w_width = overload_w.shape[-1] + with paddle.no_grad(): + self.rotate_layer.weight[:, :overload_w_width] = paddle.to_tensor(overload_w) + return + + +# our proposed method +class TinyIntervention(nn.Layer): + def __init__(self, **kwargs): + super(TinyIntervention, self).__init__() + self.rank = kwargs["low_rank_dimension"] + self.hidden_size = kwargs["embed_dim"] + dropout = 0.0 + if dropout > 0.0: + self.dropout = nn.Dropout(p=dropout) + else: + self.dropout = lambda x: x + self.scaling = 1 + # Actual trainable parameters + self.param_A = self.create_parameter( + shape=[self.hidden_size, self.rank], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu"), + ) + self.param_B = self.create_parameter( + shape=[self.rank, self.hidden_size], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.Constant(value=0.0), + ) + self.param_a = self.create_parameter( + shape=[self.rank], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.Constant(value=1), + ) + self.param_b = self.create_parameter( + shape=[self.hidden_size], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.Constant(value=1), + ) + self.param_A.stop_gradient = False + self.param_B.stop_gradient = False + + def forward( + self, + base, + ): + diag_b = paddle.diag(self.param_b) + diag_a = paddle.diag(self.param_a) + result = (self.dropout(base) @ self.param_A @ diag_a @ self.param_B @ diag_b) * self.scaling + return self.dropout(base + result.astype(base.dtype)) + + def load_state_dict(self, state_dict): + self.param_A.set_value(state_dict["param_A"]) + self.param_B.set_value(state_dict["param_B"]) + self.param_a.set_value(state_dict["param_a"]) + self.param_b.set_value(state_dict["param_b"]) + + +intervention_mapping = {"LoreftIntervention": LoreftIntervention, "TinyIntervention": TinyIntervention} diff --git a/paddlenlp/peft/reft/modeling_utils.py b/paddlenlp/peft/reft/modeling_utils.py new file mode 100644 index 000000000000..8cc71f657295 --- /dev/null +++ b/paddlenlp/peft/reft/modeling_utils.py @@ -0,0 +1,175 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import logging +import os +import random +from dataclasses import dataclass +from typing import Dict, Sequence + +import numpy as np +import paddle +from paddle import nn + + +def getattr_for_paddle_module(model, parameter_name): + """Recursively fetch the model based on the name.""" + current_module = model + for param in parameter_name.split("."): + if "[" in param: + current_module = getattr(current_module, param.split("[")[0])[int(param.split("[")[-1].strip("]"))] + else: + current_module = getattr(current_module, param) + return current_module + + +def get_module_hook(model, representation) -> nn.Layer: + """Render the intervening module with a hook.""" + hook_type = "register_forward_post_hook" + parameter_name = f'llama.layers[{representation["layer"]}]' + module = getattr_for_paddle_module(model, parameter_name) + module_hook = getattr(module, hook_type) + return module_hook + + +class HandlerList: + """General class to set hooks and set off hooks.""" + + def __init__(self, handlers): + self.handlers = handlers + + def __len__(self): + return len(self.handlers) + + def remove(self): + for handler in self.handlers: + handler.remove() + + def extend(self, new_handlers): + self.handlers.extend(new_handlers.handlers) + return self + + +# gather hidden states on intervention locations +def gather_neurons(tensor_input, unit_locations_as_list): + unit_locations = paddle.to_tensor(unit_locations_as_list, place=tensor_input.place) + tensor_output = paddle.take_along_axis( + tensor_input, + axis=1, + indices=unit_locations.reshape([*unit_locations.shape, *(1,) * (len(tensor_input.shape) - 2)]).expand( + [-1, -1, *tensor_input.shape[2:]] + ), + ) + return tensor_output + + +# Replace selected neurons in `tensor_input` by `replacing_tensor_input`. +def scatter_neurons( + tensor_input, + replacing_tensor_input, + unit_locations_as_list, +): + unit_locations = paddle.to_tensor( + unit_locations_as_list, + place=tensor_input.place, + ) + + # [1,1,4096] + meta_component = paddle.arange(tensor_input.shape[-1]).unsqueeze(axis=0).unsqueeze(axis=0) + + start_index, end_index = ( + meta_component.min().tolist(), + meta_component.max().tolist() + 1, + ) + # 4096 + # last_dim = meta_component.shape[-1] + # 0, 1, 2, ..., batch_size-1 + _batch_idx = paddle.arange(tensor_input.shape[0]).unsqueeze(1) + tensor_input[_batch_idx, unit_locations, start_index:end_index] = replacing_tensor_input + return tensor_input + + +# do intervention +def do_intervention( + base_representation, + intervention, +): + """Do the actual intervention.""" + # base_representation: 从隐藏状态抽取出的对应token的隐藏状态 f7+l7: batch_size, 14, hidden_size + # intervention: 干预的模型 + # flatten + # original_base_shape = base_representation.shape + # if len(original_base_shape) == 2 or intervention.keep_last_dim: + # base_representation_f = base_representation + # intervened_representation = intervention( + # base_representation_f, + # ) + intervened_representation = intervention( + base_representation, + ) + return intervened_representation + + +# Introducing corresponding classes based on strings +def get_type_from_string(type_str): + """Help function to convert string to type""" + # Remove from the string + type_str = type_str.replace("", "") + + # Split the string into module and class name + module_name, class_name = type_str.rsplit(".", 1) + + # Import the module + if not module_name.startswith("paddlenlp"): + module_name = f"paddlenlp.peft.reft.{module_name}" + module = importlib.import_module(module_name) + + # Get the class + cls = getattr(module, class_name) + + return cls + + +def create_directory(path): + """Create directory if not exist""" + if not os.path.exists(path): + os.makedirs(path) + logging.info(f"Directory '{path}' created successfully.") + else: + logging.info(f"Directory '{path}' already exists.") + + +def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + +def count_parameters(model): + """Count parameters of a model that require gradients""" + return int(sum(p.numel() for p in model.parameters() if not p.stop_gradient)) + + +@dataclass +class ReftDataCollator(object): + """Collate examples for ReFT.""" + + def __init__(self, data_collator): + self.data_collator = data_collator + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, paddle.Tensor]: + batch_inputs = self.data_collator(instances) + max_seq_length = batch_inputs["input_ids"].shape[-1] + batch_inputs["intervention_locations"] = batch_inputs["intervention_locations"][..., :max_seq_length] + return batch_inputs diff --git a/paddlenlp/peft/reft/predict.py b/paddlenlp/peft/reft/predict.py new file mode 100644 index 000000000000..62118eb013e5 --- /dev/null +++ b/paddlenlp/peft/reft/predict.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging + +import paddle +from paddle.io import DataLoader, Dataset +from tqdm import tqdm + +from paddlenlp.data import DataCollatorForSeq2Seq +from paddlenlp.transformers import AutoTokenizer + +from .modeling_utils import ReftDataCollator + +device = "gpu" if paddle.is_compiled_with_cuda() else "cpu" + + +def make_data_collator(tokenizer, model, max_length): + data_collator_fn = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + model=model, + label_pad_token_id=-100, + padding="longest", + max_length=max_length, + ) + return ReftDataCollator(data_collator=data_collator_fn) + + +def make_dataloader( + dataset: Dataset, batch_size: int, collate_fn: DataCollatorForSeq2Seq, shuffle: bool +) -> DataLoader: + return DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, collate_fn=collate_fn) + + +def do_predict( + intervenable, + tokenizer: AutoTokenizer, + eval_dataset: Dataset, + batch_size: int = 4, + data_collator=None, + greedy_decoding=True, + temperature=None, + top_p=None, + top_k=None, + max_new_tokens=32, + do_sample=False, + predict_path=None, + num_beams=4, + max_length=2048, +): + # switch the tokenizer mode first for generation tasks + tokenizer.padding_side = "left" # switch padding side for collator + if greedy_decoding: + num_beams = 1 + data_collator = make_data_collator(tokenizer, intervenable.model, max_length) + eval_dataloader = make_dataloader(eval_dataset, batch_size, data_collator, shuffle=False) + generations = [] + eval_iterator = tqdm(eval_dataloader, position=0, leave=True) + with paddle.no_grad(): + for step, inputs in enumerate(eval_iterator): + for k, v in inputs.items(): + if v is not None and isinstance(v, paddle.Tensor): + inputs[k] = v.to(device) + + # [layers, batch_size, positions] + intervention_locations = paddle.transpose(inputs["intervention_locations"], perm=[1, 0, 2]) + # get left padding count, [batch_size], and add to locations + left_padding = (inputs["input_ids"] == tokenizer.bos_token_id).nonzero(as_tuple=True)[1] + + if left_padding.numel() > 0: + left_padding = left_padding.reshape([1, -1, 1]).to(device) # [1, batch_size, 1] + intervention_locations += left_padding + # intervention_locations -= 1 # offset for the sink padding + else: + logging.info("Warning: No BOS token found, skipping left padding adjustment.") + + # repeat each batch by num_beams times in intervention locations + # -> [layers, batch_size * num_beams, positions] + intervention_locations = intervention_locations.repeat_interleave(num_beams, axis=1).tolist() + + # set generation args depending on task + generation_args = { + "base": { + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + }, + "unit_locations": intervention_locations, + "intervene_on_prompt": True, + "eos_token_id": tokenizer.eos_token_id, + "early_stopping": True, + "max_new_tokens": max_new_tokens, + "do_sample": do_sample, + } + # override generation args if necessary + if temperature is not None: + generation_args["temperature"] = temperature + if top_p is not None: + generation_args["top_p"] = top_p + if top_k is not None: + generation_args["top_k"] = top_k + + # generate with intervention on prompt + _, steered_response = intervenable.generate(**generation_args) + # detokenize in batch + actual_preds = tokenizer.batch_decode(steered_response[0], skip_special_tokens=True) + + for inputs_id, label, pred in zip(inputs["input_ids"], inputs["labels"], actual_preds): + filtered_labels = label[label != -100] + generations += [ + { + "src": tokenizer.decode(inputs_id, skip_special_tokens=True), + "trg": tokenizer.decode(filtered_labels, skip_special_tokens=True), + "pred": pred, + } + ] + + if predict_path is not None: + with open(predict_path, "w") as json_file: + json.dump(generations, json_file, indent=4, ensure_ascii=False) + + return generations diff --git a/paddlenlp/peft/reft/reft_config.py b/paddlenlp/peft/reft/reft_config.py new file mode 100644 index 000000000000..fda6d092c7cb --- /dev/null +++ b/paddlenlp/peft/reft/reft_config.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +from .modeling_utils import get_type_from_string + + +class ReFTConfig: + def __init__( + self, + representations, + intervention_params=None, + position=None, + intervention_types=None, + sorted_keys=None, + intervention_dimensions=None, + **kwargs, + ): + if not isinstance(representations, list): + representations = [representations] + + self.representations = representations + self.intervention_types = intervention_types + overwrite_intervention_types = [] + for reprs in self.representations: + if reprs["intervention"] is not None: + overwrite_intervention_types += [type(reprs["intervention"])] + + self.intervention_types = overwrite_intervention_types + self.sorted_keys = sorted_keys + self.intervention_dimensions = intervention_dimensions + self.intervention_params = intervention_params + self.position = position + + def to_dict(self): + return { + "representations": self.representations, + "intervention_types": self.intervention_types, + "sorted_keys": self.sorted_keys, + } + + @staticmethod + def from_pretrained(load_directory): + saved_config = json.load(open(os.path.join(load_directory, "config.json"), "r")) + for representation, intervention_type in zip( + saved_config["representations"], saved_config["intervention_types"] + ): + representation["intervention"] = get_type_from_string(intervention_type)( + **saved_config["intervention_params"] + ) + reft_config = ReFTConfig( + representations=saved_config["representations"], + intervention_params=saved_config["intervention_params"], + ) + return reft_config + + def save_pretrained(self, save_directory): + config_dict = {} + config_dict["representations"] = [ + { + "layer": repr["layer"], + "component": repr["component"], + "low_rank_dimension": repr["low_rank_dimension"], + } + for repr in self.representations + ] + + config_dict["intervention_params"] = self.intervention_params + config_dict["intervention_types"] = [repr(intervention_type) for intervention_type in self.intervention_types] + config_dict["position"] = self.position + with open(os.path.join(save_directory, "config.json"), "w") as f: + json.dump(config_dict, f, indent=4) diff --git a/paddlenlp/peft/reft/reft_model.py b/paddlenlp/peft/reft/reft_model.py new file mode 100644 index 000000000000..df866e8287a6 --- /dev/null +++ b/paddlenlp/peft/reft/reft_model.py @@ -0,0 +1,365 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import logging +import os +import types +from typing import List, Optional + +import paddle +from paddle import nn + +from .modeling_utils import ( + HandlerList, + count_parameters, + create_directory, + do_intervention, + gather_neurons, + get_module_hook, + scatter_neurons, +) +from .reft_config import ReFTConfig + + +class ReFTModel(nn.Layer): + """ + config: ReFTConfig + """ + + def __init__(self, config, model, **kwargs): + super().__init__() + self.config = config + self.intervention_types = config.intervention_types + self.representations = {} + self.interventions = {} + _original_key_order = [] + # for generate + self._key_setter_call_counter = {} + for i, representation in enumerate(config.representations): + _key = f'layer.{representation["layer"]}' + if representation["intervention"] is not None: + intervention = representation["intervention"] + + module_hook = get_module_hook(model, representation) + self.representations[_key] = representation + self.interventions[_key] = (intervention, module_hook) + _original_key_order += [_key] + + # usually, it's a one time call per + # hook unless model generates. + self._key_setter_call_counter[_key] = 0 + + self.sorted_keys = _original_key_order + self.model = model + self.model_config = model.config + self.disable_model_gradients() + self.trainable_model_parameters = {} + + def forward( + self, + **base, + ): + unit_locations = base["intervention_locations"].transpose([1, 0, 2]).tolist() + self._reset_hook_count() + try: + # intervene, register hook after decoder block + set_handlers_to_remove = self._wait_for_forward_with_intervention(unit_locations) + # run intervened forward + del base["intervention_locations"] + counterfactual_outputs = self.model(**base) + set_handlers_to_remove.remove() + except Exception as e: + raise e + self._reset_hook_count() + return counterfactual_outputs + + def generate( + self, + base, + unit_locations: Optional[List] = None, + intervene_on_prompt: bool = False, + output_original_output: Optional[bool] = False, + **kwargs, + ): + self._reset_hook_count() + self._intervene_on_prompt = intervene_on_prompt + base_outputs = None + if output_original_output or True: + # returning un-intervened output + base_outputs = self.model.generate(**base, **kwargs) + set_handlers_to_remove = None + try: + # intervene, register hook after decoder block + set_handlers_to_remove = self._wait_for_forward_with_intervention( + unit_locations, + ) + # run intervened generate + counterfactual_outputs = self.model.generate(**base, **kwargs) + set_handlers_to_remove.remove() + except Exception as e: + raise e + self._reset_hook_count() + return base_outputs, counterfactual_outputs + + def _wait_for_forward_with_intervention( + self, + unit_locations, + ): + all_set_handlers = HandlerList([]) + for key_id, key in enumerate(self.sorted_keys): + set_handlers = self._intervention_setter(key, unit_locations[key_id]) + all_set_handlers.extend(set_handlers) + return all_set_handlers + + def _intervention_setter( + self, + key, + unit_locations_base, + ) -> HandlerList: + """ + Create a list of setter handlers that will set activations + """ + handlers = [] + intervention, module_hook = self.interventions[key] + + def hook_callback( + model, + inputs, + outputs, + ): + is_prompt = self._key_setter_call_counter[key] == 0 + if is_prompt: + self._key_setter_call_counter[key] += 1 + if not is_prompt: + return + + selected_output = self._gather_intervention_output(outputs, key, unit_locations_base) + + if not isinstance(self.interventions[key][0], types.FunctionType): + intervened_representation = do_intervention( + selected_output, + intervention, + ) + if intervened_representation is None: + return + + if isinstance(outputs, tuple): + _ = self._scatter_intervention_output( + outputs[0], + intervened_representation, + key, + unit_locations_base, + ) + else: + _ = self._scatter_intervention_output( + outputs, + intervened_representation, + key, + unit_locations_base, + ) + + handlers.append( + module_hook( + hook_callback, + ) + ) + + return HandlerList(handlers) + + def _gather_intervention_output(self, output, representations_key, unit_locations) -> paddle.Tensor: + """ + Gather intervening activations from the output based on indices + """ + if isinstance(output, tuple): + original_output = output[0].clone() + else: + original_output = output.clone() + if unit_locations is None: + return original_output + + # gather based on intervention locations + selected_output = gather_neurons( + original_output, + unit_locations, + ) + return selected_output + + def _scatter_intervention_output( + self, + output, + intervened_representation, + representations_key, + unit_locations, + ) -> paddle.Tensor: + """ + Scatter in the intervened activations in the output + """ + # data structure casting + if isinstance(output, tuple): + original_output = output[0] + else: + original_output = output + # for non-sequence-based models, we simply replace + # all the activations. + if unit_locations is None: + original_output[:] = intervened_representation[:] + return original_output + + # component = self.representations[representations_key].component + # unit = self.representations[representations_key].unit + + # scatter in-place + _ = scatter_neurons( + original_output, + intervened_representation, + unit_locations, + ) + + return original_output + + def save_pretrained(self, save_directory, **kwargs): + create_directory(save_directory) + saving_config = copy.deepcopy(self.config) + saving_config.sorted_keys = self.sorted_keys + saving_config.intervention_types = [] + saving_config.intervention_dimensions = [] + + for k, v in self.interventions.items(): + intervention = v[0] + saving_config.intervention_types += [(type(intervention))] + binary_filename = f"intkey_{k}.bin" + # save intervention binary file + logging.info(f"Saving trainable intervention to {binary_filename}.") + paddle.save( + intervention.state_dict(), + os.path.join(save_directory, binary_filename), + ) + + saving_config.save_pretrained(save_directory) + + @staticmethod + def from_pretrained( + load_directory, + model, + ): + """ + Load interventions from disk + """ + reft_config = ReFTConfig.from_pretrained( + load_directory=load_directory, + ) + intervenable = ReFTModel(reft_config, model) + intervenable.disable_model_gradients() + + # load binary files + for i, (k, v) in enumerate(intervenable.interventions.items()): + intervention = v[0] + binary_filename = f"intkey_{k}.bin" + saved_state_dict = paddle.load(os.path.join(load_directory, binary_filename)) + intervention.load_state_dict(saved_state_dict) + return intervenable + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def count_parameters(self, include_model=False): + total_parameters = 0 + for k, v in self.interventions.items(): + total_parameters += count_parameters(v[0]) + if include_model: + total_parameters += sum(p.numel() for p in self.model.parameters() if p.requires_grad) + return total_parameters + + def print_trainable_parameters(self): + trainable_intervention_parameters = 0 + for k, v in self.interventions.items(): + trainable_intervention_parameters += count_parameters(v[0]) + + trainable_model_parameters = int(sum(p.numel() for p in self.model.parameters() if not p.stop_gradient)) + + all_model_parameters = int(sum(p.numel() for p in self.model.parameters())) + + total_trainable_parameters = trainable_intervention_parameters + trainable_model_parameters + + logging.info("trainable_intervention_parameters:", trainable_intervention_parameters) + logging.info("trainable_model_parameters:", trainable_model_parameters) + logging.info("all_model_parameters:", all_model_parameters) + logging.info("total_trainable_parameters:", total_trainable_parameters) + logging.info( + f"trainable intervention params: {trainable_intervention_parameters:,d} || trainable model params: {trainable_model_parameters:,d}\n" + f"model params: {all_model_parameters:,d} || trainable%: {100 * total_trainable_parameters / all_model_parameters}" + ) + + def _reset_hook_count(self): + """ + Reset the hook count before any generate call + """ + self._key_setter_call_counter = dict.fromkeys(self._key_setter_call_counter, 0) + + def __str__(self): + attr_dict = { + "model_type": str(self.model_type), + "intervention_types": str(self.intervention_types), + "alignabls": self.sorted_keys, + } + return json.dumps(attr_dict, indent=4) + + def get_trainable_parameters(self): + """ + Return trainable params as key value pairs + """ + ret_params = [] + for k, v in self.interventions.items(): + ret_params += [p for p in v[0].parameters()] + for p in self.model.parameters(): + if p.requires_grad: + ret_params += [p] + return ret_params + + def named_parameters(self, recurse=True, include_sublayers=True): + """ + The above, but for HuggingFace. + """ + ret_params = [] + for k, v in self.interventions.items(): + ret_params += [(k + "." + n, p) for n, p in v[0].named_parameters()] + for n, p in self.model.named_parameters(): + if not p.stop_gradient: + ret_params += [("model." + n, p)] + return ret_params + + def enable_model_gradients(self): + """ + Enable gradient in the model + """ + # Unfreeze all model weights + self.model.train() + for param in self.model.parameters(): + param.stop_gradient = False + self.model_has_grad = True + + def disable_model_gradients(self): + """ + Disable gradient in the model + """ + # Freeze all model weights + self.model.eval() + for param in self.model.parameters(): + param.stop_gradient = True + self.model_has_grad = False diff --git a/paddlenlp/quantization/checkpoint_quantization_utils.py b/paddlenlp/quantization/checkpoint_quantization_utils.py new file mode 100644 index 000000000000..8541107427df --- /dev/null +++ b/paddlenlp/quantization/checkpoint_quantization_utils.py @@ -0,0 +1,364 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import paddle + + +def cal_ratio(m, v, eps=1e-8): + """ + cal part adam update ratio. + Args: + m (`paddle.Tensor`): + moment in Adam optimizer. + v (`paddle.Tensor`): + variance in Adam optimizer. + eps (`int`): + epsilon in Adam optimizer. + """ + return 1 / (np.sqrt(v) + eps) + + +def group_wise_quant_dequant( + inputs, + mins=None, + maxs=None, + quant_bits=4, + group_size=32, + quant=True, + tp_rank=-1, + tp_degree=1, + use_pd=False, + symmetry=False, +): + """ + group-wise quantization (support symmetry, asymmetry). + Args: + inputs (`paddle.Tensor`): + The tensor to quantize. + mins (`paddle.Tensor`): + Min scales tensor in asymmetry quantization. + maxs (`paddle.Tensor`): + Max scales tensor in asymmetry quantization, or Abs max tensor in symmetry quantization. + quant_bits (`int`): + Quantization bits. + group_size (`int`): + Group size of group-wise quantization. + quant (`bool`): + True when quantization, False in dequantization. + tp_rank (`int`): + Tensor parallel rank. + tp_degree (`int`): + Tensor parallel world size. + use_pd (`bool`): + Whether to use paddle caculation. If False will use numpy. + symmetry (`bool`): + Whether to use symmetry quantization. + """ + + qmax = (1 << (quant_bits)) - 1 + qmin = 0 + shape = inputs.shape + + if quant: + inputs_processed = inputs.reshape([shape[0] // group_size, group_size, shape[1]]) + if symmetry: + bnt = (1 << (quant_bits - 1)) - 1 + scales = np.max(np.abs(inputs_processed), axis=1) + new_scales = np.repeat(scales, repeats=group_size, axis=0) + quant_tensor = np.clip(np.round(inputs / new_scales * bnt), -bnt - 1, bnt) + return quant_tensor.astype("int8"), scales + + # scales: [shape[0] // group_size, shape[1]] + maxs = np.max(inputs_processed, axis=1) + mins = np.min(inputs_processed, axis=1) + scales = maxs - mins + # new_scales: [shape[0], shape[1]] + new_scales = np.repeat(scales, repeats=group_size, axis=0) + new_mins = np.repeat(mins, repeats=group_size, axis=0) + # add eps to avoid devide zero + quant_tensor = np.clip(np.round((inputs - new_mins) / (new_scales) * qmax), qmin, qmax) + quant_tensor = np.nan_to_num(quant_tensor) + return quant_tensor.astype("uint8"), mins, maxs + else: + if symmetry: + scales = mins + bnt = (1 << (quant_bits - 1)) - 1 + if use_pd: + new_scales = paddle.repeat_interleave(scales, group_size, 0) + else: + new_scales = np.repeat(scales, repeats=group_size, axis=0) + + if tp_rank == -1: + dequant_tensor = inputs.astype("float32") * new_scales / bnt + elif len(new_scales.shape) == 0 or inputs.shape[-1] == new_scales.shape[-1]: + # input tensor was row parallel in tp. + dequant_tensor = ( + inputs.astype("float32") + * new_scales[ + tp_rank * new_scales.shape[0] // tp_degree : (tp_rank + 1) * new_scales.shape[0] // tp_degree + ] + / bnt + ) + else: + # input tensor was column parallel in tp. + dequant_tensor = ( + inputs.astype("float32") + * new_scales[ + :, + tp_rank + * new_scales.shape[-1] + // tp_degree : (tp_rank + 1) + * new_scales.shape[-1] + // tp_degree, + ] + / bnt + ) + return dequant_tensor + + scales = maxs - mins + if use_pd: + new_scales = paddle.repeat_interleave(scales, group_size, 0) + new_mins = paddle.repeat_interleave(mins, group_size, 0) + else: + new_scales = np.repeat(scales, repeats=group_size, axis=0) + new_mins = np.repeat(mins, repeats=group_size, axis=0) + + if tp_rank == -1: + dequant_tensor = (inputs.astype("float32") / qmax * new_scales) + new_mins + elif len(new_scales.shape) == 0 or inputs.shape[-1] == new_scales.shape[-1]: + # input tensor was row parallel in tp. + dequant_tensor = ( + inputs.astype("float32") + / qmax + * new_scales[ + tp_rank * new_scales.shape[0] // tp_degree : (tp_rank + 1) * new_scales.shape[0] // tp_degree + ] + ) + new_mins[tp_rank * new_mins.shape[0] // tp_degree : (tp_rank + 1) * new_mins.shape[0] // tp_degree] + else: + # input tensor was column parallel in tp. + dequant_tensor = ( + inputs.astype("float32") + / qmax + * new_scales[ + :, tp_rank * new_scales.shape[-1] // tp_degree : (tp_rank + 1) * new_scales.shape[-1] // tp_degree + ] + ) + new_mins[ + :, tp_rank * new_mins.shape[-1] // tp_degree : (tp_rank + 1) * new_mins.shape[-1] // tp_degree + ] + return dequant_tensor + + +def merge_int4(x, y): + """ + merge 2 signed int4 to 1 int8 + Args: + x (`numpy.array`): + 4bits signed int x. + y (`numpy.array`): + 4bits signed int y. + """ + int4_high = x << 4 + int4_low = y & 0x0F + final = int4_high | int4_low + return final.astype("int8") + + +def split_int8(final): + """ + split an int8 to 2 int4 elems + Args: + final (`numpy.array`): + 8bits signed int. + """ + int4_high = final >> 4 + int4_low = final & 0x0F + + int4_high = np.where(int4_high > 8, int4_high - 16, int4_high) + + high_tensor = paddle.Tensor(int4_high) + low_tensor = paddle.Tensor(int4_low) + + return high_tensor, low_tensor + + +def cal_abs_min_max_channel(inputs, quant_axis=1): + """ + channel-wise min max scales calculation + Args: + inputs (`numpy.array`): + input tensor for quantization. + quant_axis (`int`): + dimension where calulating inputs' abs min and max scales on. + """ + eps = 1e-8 + reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != quant_axis]) + abs_max_values = np.max(inputs, axis=reduce_axis) + abs_min_values = np.min(inputs, axis=reduce_axis) + abs_max_values = np.where( + abs_max_values == np.array(0, dtype=inputs.dtype), np.array(eps, dtype=inputs.dtype), abs_max_values + ) + abs_min_values = np.where( + abs_min_values == np.array(0, dtype=inputs.dtype), np.array(eps, dtype=inputs.dtype), abs_min_values + ) + return abs_max_values, abs_min_values + + +def asymmetry_qdq_weight( + x, quant_bit=8, quant_axis=-1, mins=None, maxs=None, dequant=False, tp_rank=-1, tp_degree=1, use_pd=False +): + """ + channel-wise asymmetry quantization + Args: + x (`paddle.Tensor`): + The tensor to quantize. + quant_bits (`int`): + Quantization bits. + quant_axis (`int`): + Scales caculation axis. + mins (`paddle.Tensor`): + Min scales tensor in asymmetry quantization. + maxs (`paddle.Tensor`): + Max scales tensor in asymmetry quantization. + dequant (`bool`): + True when dequantization, False in quantization. + tp_rank (`int`): + Model parallel rank. + tp_degree (`int`): + Model parallel world size. + use_pd (`bool`): + Whether to use paddle caculation. If False will use numpy. + """ + + if mins is None: + maxs, mins = cal_abs_min_max_channel(x) + bnt = (1 << (quant_bit)) - 1 + scales = maxs - mins + if not dequant: + # quant + quant_x = np.clip(np.round((x - mins) / scales * bnt), 0, bnt) + return quant_x.astype(np.uint8), mins, maxs + else: + quant_x = x + # dequant + if not use_pd: + if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]: + # input tensor was row parallel in tp. + qdq_x = (quant_x / bnt * scales) + mins + else: + # input tensor was column parallel in tp. + qdq_x = ( + quant_x + / bnt + * scales[tp_rank * scales.shape[0] // tp_degree : (tp_rank + 1) * scales.shape[0] // tp_degree] + ) + mins[tp_rank * mins.shape[0] // tp_degree : (tp_rank + 1) * mins.shape[0] // tp_degree] + return qdq_x.astype(np.float32), scales + else: + if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]: + # input tensor was row parallel in tp. + qdq_x = (quant_x / bnt * scales.unsqueeze(0).expand(quant_x.shape)) + mins + else: + # input tensor was column parallel in tp. + qdq_x = ( + quant_x + / bnt + * scales[tp_rank * scales.shape[0] // tp_degree : (tp_rank + 1) * scales.shape[0] // tp_degree] + .unsqueeze(0) + .expand(quant_x.shape) + ) + mins[tp_rank * mins.shape[0] // tp_degree : (tp_rank + 1) * mins.shape[0] // tp_degree] + return qdq_x.astype(paddle.float32), scales + + +def cal_abs_max_channel(inputs, quant_axis=1): + """ + channel-wise abs max calculation + Args: + inputs (`numpy.array`): + input tensor for quantization. + quant_axis (`int`): + dimension where calulating inputs' abs max scales on. + """ + epsilon = 1e-8 + reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != quant_axis]) + abs_max_values = np.max(np.abs(inputs), axis=reduce_axis) + # maybe all elements are zero in one group, + # so set the scales from those group to an actual number + # from divide 0. + abs_max_values = np.where( + abs_max_values == np.array(0, dtype=inputs.dtype), np.array(epsilon, dtype=inputs.dtype), abs_max_values + ) + return abs_max_values + + +def qdq_weight(x, quant_bit=8, quant_axis=-1, scales=None, dequant=False, tp_rank=-1, tp_degree=1, use_pd=False): + """ + channel-wise symmetry quantization + Args: + x (`paddle.Tensor`): + The tensor to quantize. + quant_bits (`int`): + Quantization bits. + quant_axis (`int`): + Scales caculation axis. + scales (`paddle.Tensor`): + Abs max scales tensor in symmetry quantization. + dequant (`bool`): + True when dequantization, False in quantization. + tp_rank (`int`): + Model parallel rank. + tp_degree (`int`): + Model parallel world size. + use_pd (`bool`): + Whether to use paddle caculation. If False will use numpy. + """ + + if scales is None: + scales = cal_abs_max_channel(x) + bnt = (1 << (quant_bit - 1)) - 1 + if not dequant: + # quant + quant_x = np.clip(np.round(x / scales * bnt), -bnt - 1, bnt) + return quant_x.astype(np.int8), scales + else: + quant_x = x + # dequant + if not use_pd: + if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]: + # input tensor was row parallel in tp. + qdq_x = quant_x / bnt * scales + else: + # input tensor was column parallel in tp. + qdq_x = ( + quant_x + / bnt + * scales[tp_rank * scales.shape[0] // tp_degree : (tp_rank + 1) * scales.shape[0] // tp_degree] + ) + # fp32 , int8, int, fp32 or fp64 + return qdq_x.astype(np.float32), scales + else: + if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]: + # input tensor was row parallel in tp. + qdq_x = quant_x / bnt * scales.unsqueeze(0).expand(quant_x.shape) + else: + # input tensor was column parallel in tp. + qdq_x = ( + quant_x + / bnt + * scales[tp_rank * scales.shape[0] // tp_degree : (tp_rank + 1) * scales.shape[0] // tp_degree] + .unsqueeze(0) + .expand(quant_x.shape) + ) + # fp32 , int8, int, fp32 or fp64 + return qdq_x.astype(paddle.float32), scales diff --git a/paddlenlp/quantization/unified_checkpoint_quantization.py b/paddlenlp/quantization/unified_checkpoint_quantization.py new file mode 100644 index 000000000000..1f1c3ad0c8a1 --- /dev/null +++ b/paddlenlp/quantization/unified_checkpoint_quantization.py @@ -0,0 +1,209 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.distributed import fleet + +from paddlenlp.quantization.checkpoint_quantization_utils import ( + asymmetry_qdq_weight, + cal_ratio, + group_wise_quant_dequant, + merge_int4, + qdq_weight, + split_int8, +) +from paddlenlp.utils.env import ( + ASYMMETRY_QUANT_SCALE_MAX, + ASYMMETRY_QUANT_SCALE_MIN, + MOMENT1_KEYNAME, + MOMENT2_KEYNAME, + SYMMETRY_QUANT_SCALE, +) +from paddlenlp.utils.log import logger + + +def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict): + """ + dequantize unified optimizer state dict. + Args: + state_dict (`dict`): + unified checkpoint optimizer state dict. + ckpt_quant_stage (`str`): + checkpoint quantization stage, chosen in ["O0", "O1", "O2"]. + scale_dict (`int`): + compression checkpoint scale dict. + """ + tp_rank, tp_degree = -1, 1 + if paddle.distributed.get_world_size() > 1: + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + tp_rank, tp_degree = tp_group.rank, tp_group.nranks + + if ckpt_quant_stage == "O1": + # set eps + eps = 1e-8 + for quant_key in state_dict.keys(): + is_moment1 = MOMENT1_KEYNAME in quant_key + is_moment2 = MOMENT2_KEYNAME in quant_key + if is_moment1: + # dequant m1 + scale_key = quant_key + SYMMETRY_QUANT_SCALE + weight = state_dict[quant_key] + scales = scale_dict[scale_key] + weight, _ = qdq_weight( + weight, + scales=scales, + quant_bit=8, + dequant=True, + tp_rank=tp_rank, + tp_degree=tp_degree, + use_pd=True, + ) + state_dict[quant_key] = weight + elif is_moment2: + # dequant ratio + weight = state_dict[quant_key] + min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN + max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX + mins, maxs = scale_dict[min_scale_key], scale_dict[max_scale_key] + weight, _ = asymmetry_qdq_weight( + weight, + mins=mins, + maxs=maxs, + quant_bit=8, + dequant=True, + tp_rank=tp_rank, + tp_degree=tp_degree, + use_pd=True, + ) + # cal m2 + weight = paddle.square(1.0 / weight - eps) + state_dict[quant_key] = weight + elif ckpt_quant_stage == "O2": + # set eps + eps = 1e-8 + m1_state_dict = {} + for quant_key in state_dict.keys(): + # not all optimizer weights in O2 stage were quantized to int8, + # the norm-like weights were still remain in float32. + if state_dict[quant_key].dtype != paddle.int8: + logger.info(f"{quant_key} skip.") + continue + # split int8 + weight = state_dict[quant_key] + m1_quant, ratio_quant = split_int8(weight.numpy()) + # dequant ratio + ratio_min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN + ratio_max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX + m1_scale_key = quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME + SYMMETRY_QUANT_SCALE + m1_scales = scale_dict[m1_scale_key] + ratio_mins, ratio_maxs = scale_dict[ratio_min_scale_key], scale_dict[ratio_max_scale_key] + m1_weight = group_wise_quant_dequant( + m1_quant, + mins=m1_scales, + maxs=None, + quant_bits=4, + quant=False, + tp_rank=tp_rank, + tp_degree=tp_degree, + use_pd=True, + symmetry=True, + ) + ratio_weight = group_wise_quant_dequant( + ratio_quant, + mins=ratio_mins, + maxs=ratio_maxs, + quant_bits=4, + quant=False, + tp_rank=tp_rank, + tp_degree=tp_degree, + use_pd=True, + ) + + ratio_weight = paddle.square(1.0 / ratio_weight - eps) + state_dict[quant_key] = ratio_weight + m1_state_dict[quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME] = m1_weight + state_dict.update(m1_state_dict) + + return state_dict + + +def quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage, async_save=False): + """ + quantize unified optimizer state dict. + Args: + state_dict (`dict`): + unified checkpoint optimizer state dict. + state_dict_type (`str`): + state_dict type, chosen in ["model_weight", "master_weight", "optimizer_weight"]. + ckpt_quant_stage (`str`): + checkpoint quantization stage, chosen in ["O0", "O1", "O2"]. + async_save (`bool`): + whether use async_save. + """ + quant = False + if ckpt_quant_stage != "O0": + quant = True + del_key = [] + if quant and state_dict_type == "optimizer_weight": + scales_dict = {} + opt_keys = state_dict.keys() + for k in opt_keys: + momentum1 = k.endswith(MOMENT1_KEYNAME) + momentum2 = k.endswith(MOMENT2_KEYNAME) + + quant_weight = None + + if ckpt_quant_stage == "O1": + # m1: wint8, 1/(sqrt(m2)+eps): wint8 + if momentum2: + # m1: m1_quant_weight, m2: ratio + m1_key = k.split("/")[0] + "/" + MOMENT1_KEYNAME + ratio = cal_ratio(state_dict[m1_key], state_dict[k]) + m1_quant, scales = qdq_weight(state_dict[m1_key], quant_bit=8) + quant_weight, mins, maxs = asymmetry_qdq_weight(ratio, quant_bit=8) + state_dict[m1_key] = m1_quant + scales_dict[m1_key + SYMMETRY_QUANT_SCALE] = scales + scales_dict[k + ASYMMETRY_QUANT_SCALE_MIN] = mins + scales_dict[k + ASYMMETRY_QUANT_SCALE_MAX] = maxs + elif not momentum1: + quant_weight = state_dict[k] + elif ckpt_quant_stage == "O2": + # m1: bw-wint4, 1/(sqrt(m2)+eps): bw-wint4 + if momentum2: + # skip norm-like parameters + if len(state_dict[k].shape) < 2: + continue + # m1: m1_quant_weight, m2: ratio + m1_key = k.split("/")[0] + "/" + MOMENT1_KEYNAME + ratio = cal_ratio(state_dict[m1_key], state_dict[k]) + m1_quant, m1_scales = group_wise_quant_dequant(state_dict[m1_key], quant_bits=4, symmetry=True) + quant_weight, r_mins, r_maxs = group_wise_quant_dequant(ratio, quant_bits=4) + quant_weight = merge_int4(m1_quant, quant_weight) + scales_dict[m1_key + SYMMETRY_QUANT_SCALE] = m1_scales + scales_dict[k + ASYMMETRY_QUANT_SCALE_MIN] = r_mins + scales_dict[k + ASYMMETRY_QUANT_SCALE_MAX] = r_maxs + del_key.append(m1_key) + elif not momentum1: + quant_weight = state_dict[k] + + if quant_weight is not None: + state_dict[k] = quant_weight + + for k in del_key: + state_dict.pop(k, None) + + state_dict.update(scales_dict) + + return state_dict diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 8b027168a21d..2ebaea1d5699 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -117,18 +117,25 @@ def _wrap_for_dist_loader(self, train_dataloader): def _wrap_for_auto(self, model, train_dataloader): logger.info("Wrapping model for auto paralle") dist_loader = self._wrap_for_dist_loader(train_dataloader) + sharding_parallel_mesh_dimension = self.args.sharding_parallel_mesh_dimension if ShardingOption.SHARD_OP in self.args.sharding: self.optimizer = dist.shard_optimizer( - self.optimizer, dist.ShardingStage1(), self.args.gradient_accumulation_steps + self.optimizer, + dist.ShardingStage1(sharding_mesh_dim=sharding_parallel_mesh_dimension), + self.args.gradient_accumulation_steps, ) elif ShardingOption.SHARD_GRAD_OP in self.args.sharding: self.optimizer = dist.shard_optimizer( - self.optimizer, dist.ShardingStage2(), self.args.gradient_accumulation_steps + self.optimizer, + dist.ShardingStage2(sharding_mesh_dim=sharding_parallel_mesh_dimension), + self.args.gradient_accumulation_steps, ) elif ShardingOption.FULL_SHARD in self.args.sharding: self.optimizer = dist.shard_optimizer( - self.optimizer, dist.ShardingStage3(), self.args.gradient_accumulation_steps + self.optimizer, + dist.ShardingStage3(sharding_mesh_dim=sharding_parallel_mesh_dimension), + self.args.gradient_accumulation_steps, ) else: self.optimizer = dist.shard_optimizer(self.optimizer, None, self.args.gradient_accumulation_steps) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 82926e945dce..76b5e7ce9ee0 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -81,7 +81,7 @@ default_data_collator, init_dataloader_comm_group, ) -from ..peft import LoRAModel, PrefixModelForCausalLM, VeRAModel +from ..peft import LoRAModel, PrefixModelForCausalLM, ReFTModel, VeRAModel try: from ..quantization.quantization_linear import QuantizationLinear @@ -418,6 +418,7 @@ def _save_ckpt_func(state_dict, path, signal_path=None): isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM) or isinstance(self.model, VeRAModel) + or isinstance(self.model, ReFTModel) ): if self.args.unified_checkpoint and "skip_save_model_weight" in self.args.unified_checkpoint_config: self.args.unified_checkpoint_config.remove("skip_save_model_weight") @@ -563,6 +564,9 @@ def _load_from_peft_checkpoint(self, resume_from_checkpoint=None): convert_tp = True elif isinstance(self.model, VeRAModel): weights_file = os.path.join(resume_from_checkpoint, VERA_WEIGHTS_NAME) + elif isinstance(self.model, ReFTModel): + self.model.from_pretrained(resume_from_checkpoint, self.model.model) + return if self.args.dataset_rank == 0: logger.info(f"Loading model from {resume_from_checkpoint} .") @@ -621,6 +625,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM) or isinstance(self.model, VeRAModel) + or isinstance(self.model, ReFTModel) ): self._load_from_peft_checkpoint(resume_from_checkpoint) self.runtime_timer.stop() @@ -1588,20 +1593,13 @@ def _get_eval_sampler(self, eval_dataset: Dataset): drop_last=False, ) else: - drop_last = False - if self.args.pipeline_parallel_degree > 1: - drop_last = True - logger.warning( - "In parallel mode, the batch_size is strictly checked. set DistributedBatchSampler drop_last=True." - ) - return DistributedBatchSampler( eval_dataset, num_replicas=self.args.dataset_world_size, rank=self.args.dataset_rank, batch_size=self.args.per_device_eval_batch_size, shuffle=False, - drop_last=drop_last, + drop_last=False, ) def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: @@ -2077,6 +2075,7 @@ def get_expected_keys(inputs, keys): if self.args.amp_master_grad: mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use model = fleet.distributed_model(model) + if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) @@ -2702,6 +2701,7 @@ def _save( "world_size": world_size, "ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim, "skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config, + "remove_master_weight": "remove_master_weight" in self.args.unified_checkpoint_config, } if os.path.exists( os.path.join(self.args.output_signal_dir, "async_save_info.json") @@ -2736,6 +2736,7 @@ def _save( isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM) or isinstance(self.model, VeRAModel) + or isinstance(self.model, ReFTModel) ): self.model.save_pretrained( output_dir, diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index 30488e960f14..0fc54d52f74d 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -240,13 +240,15 @@ class TrainOutput(NamedTuple): _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") -def _check_checkpoint_files(folder_path, world_size, ignore_save_lr_and_optim, skip_save_model_weight): +def _check_checkpoint_files( + folder_path, world_size, ignore_save_lr_and_optim, skip_save_model_weight, remove_master_weight +): files = os.listdir(folder_path) model_weight_files = [f for f in files if f.startswith(".model_weight")] a = len(model_weight_files) == world_size if not ignore_save_lr_and_optim: b = True - if not skip_save_model_weight: + if not skip_save_model_weight or not remove_master_weight: master_weight_file = [f for f in files if f.startswith(".master_weight")] b = len(master_weight_file) == world_size optimizer_file = [f for f in files if f.startswith(".optimizer_weight")] @@ -282,8 +284,13 @@ def get_last_checkpoint(folder, signal_folder=None, uc_async_save=False): pre_world_size = saving_info.get("world_size", 1) ignore_save_lr_and_optim = saving_info.get("ignore_save_lr_and_optim", False) skip_save_model_weight = saving_info.get("skip_save_model_weight", False) + remove_master_weight = saving_info.get("remove_master_weight", False) if _check_checkpoint_files( - current_signal_path, pre_world_size, ignore_save_lr_and_optim, skip_save_model_weight + current_signal_path, + pre_world_size, + ignore_save_lr_and_optim, + skip_save_model_weight, + remove_master_weight, ): return current_path return diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 25cf62309983..f6ca63065947 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -227,6 +227,9 @@ class TrainingArguments: Sharding parameter in certain cards group. For example, aussume we use 2 machines each with 8 cards, then set sharding_parallel_degree=8, sharding will only communication inside machine. default -1 means sharding parameters between all workers. + sharding_parallel_mesh_dimension (`str`, *optional*, defaults to `dp`) + Specifies the name of the dimension in a multi-dimensional parallelism mesh that is responsible for sharding. + default `dp` for default parallelism mesh. tensor_parallel_degree (`int`, *optional*, defaults to `-1`) Tensor parallelism is parallel technique proposed in (https://arxiv.org/pdf/2104.04473.pdf see 2.3 Tensor Model Parallelism). This technique splits one transformer layer into multi-cards (For examples, tensor_parallel_degree=4, will split a layer to 4-parts) @@ -562,6 +565,15 @@ class TrainingArguments: ) }, ) + sharding_parallel_mesh_dimension: str = field( + default="dp", + metadata={ + "help": ( + "Specifies the name of the dimension in a multi-dimensional parallelism mesh that is responsible for sharding. " + "default `dp` for default parallelism mesh. " + ) + }, + ) sharding_comm_buffer_size_MB: int = field( default=-1, metadata={ @@ -858,11 +870,16 @@ class TrainingArguments: "- skip_save_model_weight: do not save model weights when the masters weight exist\n" "- master_weight_compatible: 1. if the master weights exist, only load when needed\n" " 2. if master weights does not exist, convert model weights to master weights when needed\n" + "- remove_master_weight: same with `master_weight_compatible`, use in checkpoint quantization.\n" "- async_save: enable asynchronous saving checkpoints to disk\n" "- enable_all_options: enable all optimization configurations\n" ) }, ) + ckpt_quant_stage: str = field( + default="O0", + metadata={"help": "checkpoint quantization stage."}, + ) ignore_load_lr_and_optim: Optional[bool] = field( default=False, metadata={"help": "whether to ignore load optimizer and scheduler."}, @@ -883,6 +900,14 @@ class TrainingArguments: default=False, metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"}, ) + expert_max_capacity: Optional[int] = field( + default=pow(2, 32), + metadata={"help": "Enable MoE (Mixture of Experts) expert max token capacity"}, + ) + expert_min_capacity: Optional[int] = field( + default=1, + metadata={"help": "Enable MoE (Mixture of Experts) expert min token capacity"}, + ) release_grads: Optional[bool] = field( default=False, metadata={"help": "Whether to release gradients during training. Default is `False`."} ) @@ -1560,6 +1585,8 @@ def is_segment_parallel_supported(): sharding.stage = 2 elif ShardingOption.FULL_SHARD in self.sharding: sharding.stage = 3 + if self.sharding_comm_buffer_size_MB > 0: + sharding.comm_buffer_size_MB = int(self.sharding_comm_buffer_size_MB) sharding_parallel_config = split_parallel_config(self.sharding_parallel_config) for x in sharding_parallel_config: @@ -1568,6 +1595,7 @@ def is_segment_parallel_supported(): "enable_stage1_tensor_fusion", "enable_stage1_overlap", "enable_stage2_overlap", + "enable_release_grads", ]: raise ValueError( f"Found unknown pipeline mode config {x}, " f"accpet config is reduce_overlap." @@ -1582,6 +1610,9 @@ def is_segment_parallel_supported(): if "enable_stage1_tensor_fusion" in sharding_parallel_config: sharding.grad_bucket_size_numel = 210355872 + if "enable_release_grads" in sharding_parallel_config: + sharding.release_gradients = True + if self.bf16 or self.fp16: amp = strategy.amp amp.enable = True @@ -1660,6 +1691,7 @@ def is_segment_parallel_supported(): if x not in [ "skip_save_model_weight", "master_weight_compatible", + "remove_master_weight", "async_save", "enable_all_options", "ignore_merge_optimizer", diff --git a/paddlenlp/trainer/unified_checkpoint/async_handler.py b/paddlenlp/trainer/unified_checkpoint/async_handler.py index 942ea41508bf..ffe098808c2f 100644 --- a/paddlenlp/trainer/unified_checkpoint/async_handler.py +++ b/paddlenlp/trainer/unified_checkpoint/async_handler.py @@ -27,6 +27,10 @@ if is_safetensors_available(): from safetensors.numpy import save_file as safe_save_file +from paddlenlp.quantization.unified_checkpoint_quantization import ( + quant_unified_optimizer, +) + from .shared_memory_utils import ( _read_state_dict_from_shm, _traverse_copy_to_shm, @@ -69,12 +73,14 @@ def __init__(self, args): self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) def _file_save_async_or_sync( - self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight" + self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight", ckpt_quant_stage="O0" ): if is_sync: for k in list(state_dict.keys()): if isinstance(state_dict[k], paddle.Tensor): state_dict[k] = state_dict.pop(k).cpu().numpy() + + state_dict = quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage) safe_save_file(state_dict, path, metadata={"format": "np"}) else: if len(state_dict.keys()) == 0: @@ -155,6 +161,7 @@ def _file_save_async_or_sync( self._lock, state_dict_type, self.global_rank, + ckpt_quant_stage, ), ) self._process_optimizer_weight.start() @@ -185,6 +192,7 @@ def _save_file_async_in_process( lock, state_dict_type, global_rank, + ckpt_quant_stage="O0", ): shm = shared_memory.SharedMemory(name=shm_name) while True: @@ -198,6 +206,9 @@ def _save_file_async_in_process( signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00") logger.info(f"Start to async save {path}") state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array + state_dict = quant_unified_optimizer( + state_dict, state_dict_type, ckpt_quant_stage, async_save=True + ) # ckpt quantization safe_save_file(state_dict, path, {"format": "np"}) del state_dict saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}") diff --git a/paddlenlp/trainer/unified_checkpoint/load_local.py b/paddlenlp/trainer/unified_checkpoint/load_local.py index 459eff7185d1..d1565c7dd933 100644 --- a/paddlenlp/trainer/unified_checkpoint/load_local.py +++ b/paddlenlp/trainer/unified_checkpoint/load_local.py @@ -14,6 +14,7 @@ """Unfied checkpoint locally loading functions.""" import gc +import json import os import paddle @@ -183,6 +184,13 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin if len(resolved_archive_file) > 1: resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards") + with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f: + index = json.loads(f.read()) + + ckpt_quant_stage = "O0" + if "ckpt_quant_stage" in index: + ckpt_quant_stage = index["ckpt_quant_stage"] + # update has_master_weights and index_filename_master_weights # 1. if the master weight exists, only has_master_weights is set True and loaded when needed # 2. if master weight does not exist, convert model weight to master weight when needed @@ -204,7 +212,9 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin if len(resolved_archive_file_mw) > 1: resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards") - def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False): + def load_resolved_archive_file( + resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False, ckpt_quant_stage="O0" + ): returned_state_dict = {} # load optimizer for shard_file in resolved_archive_file: @@ -227,10 +237,22 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys) # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors - state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected") + state_dict = load_state_dict( + shard_file, + tp_actions, + expected_keys, + device="expected", + ckpt_quant_stage=ckpt_quant_stage, + ) else: # for pipeline model, we don't need to use tp_actions - state_dict = load_state_dict(shard_file, None, expected_keys, device="expected") + state_dict = load_state_dict( + shard_file, + None, + expected_keys, + device="expected", + ckpt_quant_stage=ckpt_quant_stage, + ) returned_state_dict.update(state_dict) # force memory release @@ -238,7 +260,9 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected gc.collect() return returned_state_dict - state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys) + state_dict_optim = load_resolved_archive_file( + resolved_archive_file, sharded_metadata, expected_keys, ckpt_quant_stage=ckpt_quant_stage + ) if has_master_weights: state_dict_master_weight = load_resolved_archive_file( resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True @@ -246,9 +270,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected # rename optimizer param for key in list(state_dict_optim.keys()): key_name = key.split("/") - static_name = struct2static_name_mappings[key_name[0]] + model_weight_key = key_name[0] + static_name = struct2static_name_mappings[model_weight_key] if has_master_weights: - if model_state_dict[key_name[0]].dtype != paddle.float32: + if model_state_dict[model_weight_key].dtype != paddle.float32: key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) @@ -257,6 +282,12 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected returned_optim_state_dict[key_name] = state_dict_optim.pop(key) returned_optim_state_dict[key_name].name = key_name + # master weight cast (only in remove_master_weight) + if has_master_weights and state_dict_master_weight[model_weight_key].dtype != paddle.float32: + state_dict_master_weight[model_weight_key] = paddle.cast( + state_dict_master_weight[model_weight_key], dtype=paddle.float32 + ) + if has_master_weights: for key in list(state_dict_master_weight.keys()): static_name = struct2static_name_mappings[key] diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 0190529a84e3..72543e038e6a 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -164,6 +164,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None) "world_size": world_size, "ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim, "skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config, + "remove_master_weight": "remove_master_weight" in self.args.unified_checkpoint_config, } paddle.save(save_info, os.path.join(save_directory, ".saving_info")) @@ -210,6 +211,7 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp static_name, type_name = generate_base_static_name(key) new_name = static2struct_name_mappings[static_name] + "/" + type_name optim_state_dict[new_name] = optim_state_dict.pop(key) + if master_weights is not None: for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) @@ -237,6 +239,15 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) + # save opt index json if checkpoint quantization is on. + if self.args.ckpt_quant_stage != "O0": + sharded_optim_index = {"ckpt_quant_stage": self.args.ckpt_quant_stage} + optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME + path = os.path.join(output_dir, optimizer_index_name) + if self.args.should_save: + with open(path, "w") as f: + json.dump(sharded_optim_index, f, indent=4) + is_sync_save = True if "async_save" in self.args.unified_checkpoint_config: is_sync_save = False @@ -246,16 +257,18 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="optimizer_weight", + ckpt_quant_stage=self.args.ckpt_quant_stage, ) - self.async_handler._file_save_async_or_sync( - master_weights, - path=os.path.join(output_dir, master_weights_name), - signal_path=signal_dir, - is_sync=is_sync_save, - state_dict_type="master_weight", - ) + if master_weights is not None: + self.async_handler._file_save_async_or_sync( + master_weights, + path=os.path.join(output_dir, master_weights_name), + signal_path=signal_dir, + is_sync=is_sync_save, + state_dict_type="master_weight", + ) - def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): + def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"): # init and get optimizer LR_Scheduler returned_optim_state_dict = nested_copy(optimizer.state_dict()) @@ -263,19 +276,25 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) optimizer_path = os.path.join(resume_from_checkpoint, optimizer_name) master_weights_path = os.path.join(resume_from_checkpoint, master_weights_name) - has_master_weights = True if os.path.isfile(master_weights_path) else False + # no quantization & no master weight represent O1 AMP strategy. + is_amp_o1 = True if not os.path.isfile(master_weights_path) and ckpt_quant_stage == "O0" else False model_state_dict = get_expected_state_dict(model) struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings - optimizer_state_dict = load_state_dict(optimizer_path, None, None, device="expected") - if has_master_weights: + optimizer_state_dict = load_state_dict( + optimizer_path, None, None, device="expected", ckpt_quant_stage=ckpt_quant_stage + ) + master_weights = {} + # normal AMP O2 + if not is_amp_o1 and os.path.isfile(master_weights_path): master_weights = load_state_dict(master_weights_path, None, None, device="expected") # rename and move to paddle.Tensor for key in list(optimizer_state_dict.keys()): key_name = key.split("/") + model_weight_key = key_name[0] static_name = struct2static_name_mappings[key_name[0]] - if has_master_weights: + if not is_amp_o1: if model_state_dict[key_name[0]].dtype != paddle.float32: key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) else: @@ -285,7 +304,13 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): returned_optim_state_dict[key_name] = optimizer_state_dict.pop(key) returned_optim_state_dict[key_name].name = key_name - if has_master_weights: + # master weight cast (only in AMP O2 + remove_master_weight) + if not is_amp_o1 and not os.path.isfile(master_weights_path): + master_weights[model_weight_key] = paddle.cast( + model_state_dict[model_weight_key], dtype=paddle.float32 + ) + + if not is_amp_o1: returned_optim_state_dict["master_weights"] = {} for key in list(master_weights.keys()): static_name = struct2static_name_mappings[key] @@ -320,6 +345,10 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): if "LR_Scheduler" in optim_state_dict.keys(): optim_state_dict.pop("LR_Scheduler") + if UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value in self.args.unified_checkpoint_config: + logger.info("Skip master weight saving.") + master_weights = None + if "ignore_merge_optimizer" in self.args.unified_checkpoint_config: self.save_non_merge_optimizer(model, optim_state_dict, master_weights, output_dir, signal_dir) return @@ -350,6 +379,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="optimizer_weight", + ckpt_quant_stage=self.args.ckpt_quant_stage, ) if master_weight_state_dict is not None: self.async_handler._file_save_async_or_sync( @@ -391,16 +421,26 @@ def load_unified_optimizer(self, model, optimizer, resume_from_checkpoint): optim_state_dict = load_single_card_optimizer(model, optimizer, resume_from_checkpoint) return optim_state_dict + index = {} has_merge_optimizer_safetensors = distributed_isfile( os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME) ) + if has_merge_optimizer_safetensors: + with open(os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME), "r") as f: + index = json.loads(f.read()) + + ckpt_quant_stage = "O0" + if "ckpt_quant_stage" in index: + ckpt_quant_stage = index["ckpt_quant_stage"] + # If not having merge optimizer, then load non-merge optimizer. - if not has_merge_optimizer_safetensors: + if "weight_map" not in index: if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: returned_optim_state_dict = self.load_non_merge_optimizer( model, optimizer, resume_from_checkpoint, + ckpt_quant_stage=ckpt_quant_stage, ) return returned_optim_state_dict else: @@ -445,7 +485,7 @@ def unified_checkpoint_into_shards( assert hasattr(model_to_save, "config") state_dict = get_expected_state_dict(model_to_save) - all_filter_keys = filter_params(model_to_save, state_dict) + all_filter_keys = filter_params(model_to_save, state_dict, args) config_to_save = copy.deepcopy(model_to_save.config) @@ -534,6 +574,7 @@ def unified_optimizer_into_shards( static_name, type_name = generate_base_static_name(key) new_name = static2struct_name_mappings[static_name] + "/" + type_name optim_state_dict[new_name] = optim_state_dict.pop(key) + if master_weights is not None: for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) @@ -541,8 +582,8 @@ def unified_optimizer_into_shards( # filter optimizer param if master_weights is not None: - filter_master_keys = filter_params(model, master_weights, is_optimizer=True) - filter_optim_keys = filter_params(model, optim_state_dict, is_optimizer=True) + filter_master_keys = filter_params(model, master_weights, args, is_optimizer=True) + filter_optim_keys = filter_params(model, optim_state_dict, args, is_optimizer=True) tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() tp_size = tp_group.nranks @@ -605,6 +646,10 @@ def unified_optimizer_into_shards( use_expert_parallel=args.use_expert_parallel, ) sharded_optim_index = get_sharded_index(index_optimizer_filelist, total_optim_size_list) + + if args.should_save and args.ckpt_quant_stage in ["O1", "O2"]: + sharded_optim_index["ckpt_quant_stage"] = args.ckpt_quant_stage + if master_weights is not None: index_master_weight_filelist, total_master_weight_size_list = gather_sharded_object( index_master_weight_file, diff --git a/paddlenlp/trainer/unified_checkpoint/utils.py b/paddlenlp/trainer/unified_checkpoint/utils.py index 9bd9fdcc65b7..58e425ca987d 100644 --- a/paddlenlp/trainer/unified_checkpoint/utils.py +++ b/paddlenlp/trainer/unified_checkpoint/utils.py @@ -32,6 +32,10 @@ from paddlenlp.transformers.utils import dtype_byte_size from paddlenlp.utils.distributed import distributed_allgather, distributed_gather from paddlenlp.utils.env import ( + BETA1_KEYNAME, + BETA2_KEYNAME, + MOMENT1_KEYNAME, + MOMENT2_KEYNAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME, PADDLE_PEFT_WEIGHTS_INDEX_NAME, PADDLE_WEIGHTS_INDEX_NAME, @@ -72,6 +76,7 @@ class UnifiedCheckpointOption(ExplicitEnum): SKIP_SAVE_MODEL_WEIGHT = "skip_save_model_weight" MASTER_WEIGHT_COMPATIBLE = "master_weight_compatible" + REMOVE_MASTER_WEIGHT = "remove_master_weight" ASYNC_SAVE = "async_save" IGNORE_MERGE_OPTIMIZER = "ignore_merge_optimizer" @@ -96,7 +101,10 @@ def is_need_master_weight(optimizer, is_fp16_or_bp16): def update_master_weight_status(args, optimizer, has_master_weight, safe_serialization): if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)): if not has_master_weight: - if UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value in args.unified_checkpoint_config: + if ( + UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value in args.unified_checkpoint_config + or UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value in args.unified_checkpoint_config + ): index_filename_master_weights = ( PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME ) @@ -108,7 +116,8 @@ def update_master_weight_status(args, optimizer, has_master_weight, safe_seriali else: raise ValueError( "Can't find a valid unified master weight checkpoint," - f"add '{UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value}' into 'unified_checkpoint_config' to " + f"add '{UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value}'" + f" or '{UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value}' into 'unified_checkpoint_config' to " "load model checkpoint as master weight" ) else: @@ -463,7 +472,7 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, return state_dict_to_save -def filter_params(model_to_save, state_dict, is_optimizer=False): +def filter_params(model_to_save, state_dict, args, is_optimizer=False): """ Group according to the size of the tensor, aiming to make the weight size stored on each device as equal as possible. @@ -479,16 +488,34 @@ def filter_params(model_to_save, state_dict, is_optimizer=False): return [list(state_dict.keys())] filter_tensor_list = [[] for _ in range(tp_size)] + is_master_weights = False if tp_rank == 0: + quant = False + if args.ckpt_quant_stage != "O0": + quant = True tensor_bytes_dict = {} model_state_dict = get_expected_state_dict(model_to_save) for (k, v) in state_dict.items(): - model_v = model_state_dict[k.split("/")[0]] if is_optimizer else v - if hasattr(model_v, "is_distributed") and model_v.is_distributed: - tensor_bytes_dict[k] = v.numel().item() * tp_size * dtype_byte_size(v.dtype) + # master weight has same key as model weight + if not is_master_weights and k in model_state_dict: + is_master_weights = True + + weight_key = k.split("/")[0] + model_v = model_state_dict[weight_key] if is_optimizer else v + if not quant or not is_optimizer: + if hasattr(model_v, "is_distributed") and model_v.is_distributed: + tensor_bytes_dict[k] = v.numel().item() * tp_size * dtype_byte_size(v.dtype) + else: + tensor_bytes_dict[k] = v.numel().item() * dtype_byte_size(v.dtype) else: - tensor_bytes_dict[k] = v.numel().item() * dtype_byte_size(v.dtype) + if weight_key not in tensor_bytes_dict: + tensor_bytes_dict[weight_key] = 0 + + if hasattr(model_v, "is_distributed") and model_v.is_distributed: + tensor_bytes_dict[weight_key] += v.numel().item() * tp_size * dtype_byte_size(v.dtype) + else: + tensor_bytes_dict[weight_key] += v.numel().item() * dtype_byte_size(v.dtype) filter_tensor_list = [] current_block = [] @@ -509,7 +536,14 @@ def filter_params(model_to_save, state_dict, is_optimizer=False): current_block = [] current_block_size = 0 - current_block.append(key) + if not quant or not is_optimizer or is_master_weights: + current_block.append(key) + else: + current_block.append(key + "/" + MOMENT1_KEYNAME) + current_block.append(key + "/" + MOMENT2_KEYNAME) + current_block.append(key + "/" + BETA1_KEYNAME) + current_block.append(key + "/" + BETA2_KEYNAME) + current_block_size += weight_size total_size += weight_size diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index ab7510e0897e..32de7553992f 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -32,6 +32,8 @@ from .attention_utils import create_bigbird_rand_mask_idx_list from .sequence_parallel_utils import AllGatherVarlenOp, sequence_parallel_sparse_mask_labels from .tensor_parallel_utils import parallel_matmul, parallel_linear, fused_head_and_loss_fn +from .moe_gate import * +from .moe_layer import * try: from paddle.distributed.fleet.utils.sequence_parallel_utils import ( diff --git a/paddlenlp/transformers/auto/modeling.py b/paddlenlp/transformers/auto/modeling.py index c080e44793e8..3649a09199b9 100644 --- a/paddlenlp/transformers/auto/modeling.py +++ b/paddlenlp/transformers/auto/modeling.py @@ -799,8 +799,6 @@ class AutoInferenceModelForCausalLM(_BaseAutoModelClass): AutoInferenceModelForCausalLM. """ - _name_mapping = get_name_mapping("ForCausalLM") - @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ @@ -832,13 +830,16 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): ) else: # Check whether the model use block attention - attn_type = "Block" if predictor_args.block_attn else "" + if predictor_args.block_attn and predictor_args.speculate_method is None: + attn_type = "Block" + elif predictor_args.speculate_method is not None: + attn_type = "Speculate" + else: + attn_type = "" model_name = f"{config.architectures[0]}{attn_type}" # Import the InferenceModel - import_class = importlib.import_module( - f"paddlenlp.experimental.transformers.{cls._name_mapping[config.architectures[0]]}.modeling" - ) + import_class = importlib.import_module(f"paddlenlp.experimental.transformers.{config.model_type}.modeling") model_class_name = f"{model_name}InferenceModel" model_class = getattr(import_class, model_class_name) diff --git a/paddlenlp/transformers/auto/tokenizer.py b/paddlenlp/transformers/auto/tokenizer.py index 6fd8b5fcf2b0..315f40341aad 100644 --- a/paddlenlp/transformers/auto/tokenizer.py +++ b/paddlenlp/transformers/auto/tokenizer.py @@ -178,7 +178,12 @@ def tokenizer_class_from_name(class_name: str): return getattr(module, class_name) except AttributeError: - raise ValueError(f"Tokenizer class {class_name} is not currently imported.") + try: + module = importlib.import_module(f".{module_name}.tokenizer_fast", "paddlenlp.transformers") + + return getattr(module, class_name) + except AttributeError: + raise ValueError(f"Tokenizer class {class_name} is not currently imported.") for config, tokenizers in TOKENIZER_MAPPING._extra_content.items(): for tokenizer in tokenizers: diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index 2d2af39c2b30..d8eb469e119b 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -98,7 +98,7 @@ class RotaryEmbedding(nn.Layer): def __init__(self, dim, original_impl=False): super().__init__() self.default_dtype = paddle.get_default_dtype() - inv_freq = 1.0 / (10000 ** (paddle.arange(0, dim, 2, dtype="float32") / dim)) + inv_freq = 1.0 / (10000 ** (paddle.arange(0, dim, 2, dtype=self.default_dtype) / dim)) self.register_buffer("inv_freq", inv_freq) self.dim = dim self.original_impl = original_impl @@ -113,16 +113,16 @@ def forward_impl(self, seq_len: int, n_elem: int, base: int = 10000): theta = 1.0 / (base ** (paddle.arange(0, n_elem, 2, dtype="float32") / n_elem)) # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = paddle.arange(0, seq_len, dtype=theta.dtype) + seq_idx = paddle.arange(0, seq_len, dtype="float32") # Calculate the product of position index and $\theta_i$ - idx_theta = paddle.outer(seq_idx, theta).astype(self.default_dtype) + idx_theta = paddle.outer(seq_idx, theta).astype("float32") cache = paddle.stack([paddle.cos(idx_theta), paddle.sin(idx_theta)], axis=-1) # this is to mimic the behaviour of complex32, else we will get different results - if self.default_dtype in (paddle.float16, paddle.bfloat16, paddle.int8): - cache = cache.astype(self.default_dtype) + if self.default_dtype in ("float16", "bfloat16", "int8"): + cache = cache.astype("bfloat16") if self.default_dtype == "bfloat16" else cache.astype("float16") # cache = cache.bfloat16() if dtype == paddle.bfloat16 else cache.astype("float16") return cache diff --git a/paddlenlp/transformers/gemma/tokenizer.py b/paddlenlp/transformers/gemma/tokenizer.py index 54a6413d4f2a..4cee34e6187c 100644 --- a/paddlenlp/transformers/gemma/tokenizer.py +++ b/paddlenlp/transformers/gemma/tokenizer.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import re from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple @@ -310,3 +311,33 @@ def create_token_type_ids_from_sequences( output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) return output + + def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]): + regex_pattern = "|".join(map(re.escape, split_s)) + rendered_messages = self.chat_template.render( + messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map + ) + pattern = re.compile(r"(?:%s)" % regex_pattern) + split_positions = [match.span() for match in pattern.finditer(rendered_messages)] + + filtered_positions = [] + for start, end in split_positions: + # Find the last occurrence of '' before the split index + last_start = rendered_messages.rfind("", 0, start) + if last_start == -1: + continue # Skip if '' is not found + model_start = last_start + len("") + + # Get the text following 'model_start' and check if it starts with 'model' + following_text = rendered_messages[model_start:].lstrip() + if following_text.startswith("model"): + filtered_positions.append((start, end)) + non_learnable_parts = [] + last_end = 0 + for start, end in filtered_positions: + non_learnable_parts.append(rendered_messages[last_end:start]) + last_end = end + remaining_part = rendered_messages[last_end:] + if remaining_part: + non_learnable_parts.append(remaining_part) + return non_learnable_parts diff --git a/paddlenlp/transformers/gpt/modeling_auto.py b/paddlenlp/transformers/gpt/modeling_auto.py index 719d4ca4a37b..8583d22f0728 100644 --- a/paddlenlp/transformers/gpt/modeling_auto.py +++ b/paddlenlp/transformers/gpt/modeling_auto.py @@ -30,10 +30,10 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.utils import recompute +from paddle.utils import try_import try: from paddle.distributed.fleet.utils.sequence_parallel_utils import ( - ScatterOp, mark_as_sequence_parallel_parameter, ) except: @@ -41,7 +41,10 @@ from ...utils.converter import StateDictNameMapping from .. import PretrainedModel, register_base_model -from ..model_outputs import BaseModelOutputWithPastAndCrossAttentions +from ..model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) from .configuration import GPT_PRETRAINED_INIT_CONFIGURATION, GPTConfig try: @@ -61,6 +64,7 @@ "GPTForCausalLMAuto", "GPTEmbeddingsAuto", "GPTDecoderLayerAuto", + "GPTLayerNorm", ] @@ -92,6 +96,29 @@ def seed_guard_context(name=None): return contextlib.nullcontext() +def fast_layer_norm(input, weight, bias, eps): + fast_ln_lib = try_import("fast_ln") + return fast_ln_lib.fast_ln(input, weight, bias, eps)[0] + + +class GPTLayerNorm(nn.LayerNorm): + def __init__(self, config, normalized_shape, epsilon=1e-05, weight_attr=None, bias_attr=None, name=None): + super().__init__( + normalized_shape=normalized_shape, epsilon=epsilon, weight_attr=weight_attr, bias_attr=bias_attr + ) + self.config = config + self._check_normalized_shape(self._normalized_shape) + + def _check_normalized_shape(self, normalized_shape): + if isinstance(normalized_shape, (list, tuple)): + assert len(normalized_shape) == 1 + + def forward(self, input): + if self.config.use_fast_layer_norm: + return fast_layer_norm(input, self.weight, self.bias, self._epsilon) + return super().forward(input) + + def _make_causal_mask(input_ids_shape, past_key_values_length): """ Make causal mask used for self-attention @@ -152,6 +179,12 @@ def __init__(self, config, ipp=None): if self.config.fuse_attention_qkv: self.qkv_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias_attr=True) + self.qkv_proj.weight = dist.shard_tensor( + self.qkv_proj.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)] + ) + self.qkv_proj.bias = dist.shard_tensor( + self.qkv_proj.bias, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)] + ) else: self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True) self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True) @@ -261,8 +294,8 @@ def _core_attention(self, q, k, v, attention_mask=None, output_attentions=False) # softmax_mask_fuse_upper_triangle is not supported sif paddle is not compiled with cuda/rocm if not paddle.is_compiled_with_cuda(): attention_mask = get_triangle_upper_mask(product, attention_mask) - if attention_mask is not None: + attention_mask = dist.reshard(attention_mask, get_mesh(self.ipp), [dist.Replicate(), dist.Replicate()]) product = product + attention_mask.astype(product.dtype) weights = F.softmax(product) else: @@ -351,8 +384,8 @@ def __init__(self, config, decoder_layers, norm=None, hidden_size=None): self.config = config self.layers = decoder_layers - self.norm = nn.LayerNorm(config.hidden_size, epsilon=1e-5, bias_attr=True) + self.norm = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5) if config.sequence_parallel: mark_as_sequence_parallel_parameter(self.norm.weight) mark_as_sequence_parallel_parameter(self.norm.bias) @@ -418,8 +451,11 @@ def forward( for i, decoder_layer in enumerate(self.layers): if decoder_layer.ipp is not None and pre_ipp != decoder_layer.ipp: output = dist.reshard(output, get_mesh(decoder_layer.ipp), [dist.Shard(0), dist.Replicate()]) + attention_mask = dist.reshard( + attention_mask, get_mesh(decoder_layer.ipp), [dist.Replicate(), dist.Replicate()] + ) has_gradient = not output.stop_gradient - if self.enable_recompute and has_gradient and self.config.recompute_granularity == "full_attn": + if self.enable_recompute and has_gradient and self.config.recompute_granularity == "full": outputs = self.recompute_training( layer_module=decoder_layer, hidden_states=output, @@ -489,17 +525,17 @@ def __init__(self, config: GPTConfig, ipp=None): self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias_attr=True) self.linear1.weight = dist.shard_tensor(self.linear1.weight, get_mesh(ipp), [dist.Replicate(), dist.Shard(1)]) + self.linear1.bias = dist.shard_tensor(self.linear1.bias, get_mesh(ipp), [dist.Replicate(), dist.Shard(0)]) self.linear2.weight = dist.shard_tensor(self.linear2.weight, get_mesh(ipp), [dist.Replicate(), dist.Shard(0)]) - - self.norm1 = nn.LayerNorm(config.hidden_size, epsilon=1e-5, bias_attr=True) - self.norm2 = nn.LayerNorm(config.hidden_size, epsilon=1e-5, bias_attr=True) + # fix : change nn.LayerNorm(config.hidden_size, epsilon=1e-5, bias_attr=True) to GPTLayerNorm() + self.norm1 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5, bias_attr=True) + self.norm2 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5, bias_attr=True) if config.sequence_parallel: mark_as_sequence_parallel_parameter(self.norm1.weight) mark_as_sequence_parallel_parameter(self.norm1.bias) mark_as_sequence_parallel_parameter(self.norm2.weight) mark_as_sequence_parallel_parameter(self.norm2.bias) - if config.use_fused_dropout_add: self.fused_dropout_add1 = FusedDropoutAdd(config.attention_probs_dropout_prob, mode="upscale_in_train") self.fused_dropout_add2 = FusedDropoutAdd(config.hidden_dropout_prob, mode="upscale_in_train") @@ -571,7 +607,8 @@ def forward( # hidden_states => [bs * seq_len / n, embed_dim] with seed_guard_context(current_seed): if not self.config.use_fused_dropout_add: - act = self.activation(self.linear1(hidden_states), approximate=True) + l_1 = self.linear1(hidden_states) + act = self.activation(l_1, approximate=True) l_2 = self.linear2(act) hidden_states = residual + self.dropout2(l_2) else: @@ -629,31 +666,29 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None): if position_ids is not None and inputs_embeddings is not None: raise ValueError("You cannot specify both `inputs_embeddings` and `position_ids`)") - # if input_ids is not None: - # input_shape = input_ids.shape - # inputs_embeddings = self.word_embeddings(input_ids) - - if input_ids is not None: - input_shape = input_ids.shape - inputs_embeddings = self.word_embeddings(input_ids) - else: - input_shape = inputs_embeddings.shape[:-1] - - if position_ids is None: - ones = paddle.ones(input_shape, dtype="int64") - seq_length = paddle.cumsum(ones, axis=-1) - position_ids = seq_length - ones + with paddle.amp.auto_cast(False): + if input_ids is not None: + input_shape = input_ids.shape + inputs_embeddings = self.word_embeddings(input_ids) + else: + input_shape = inputs_embeddings.shape[:-1] - position_embeddings = self.position_embeddings(position_ids) + if position_ids is None: + ones = paddle.ones(input_shape, dtype="int64") + seq_length = paddle.cumsum(ones, axis=-1) + position_ids = seq_length - ones + position_embeddings = self.position_embeddings(position_ids) embeddings = inputs_embeddings + position_embeddings + # exit() if self.config.sequence_parallel: + # embeddings = dist.shard_tensor(embeddings,get_mesh(),[dist.Replicate(),dist.Replicate()]) bs, seq_len, hidden_size = embeddings.shape # [bs, seq_len, dim] -> [bs * seq_len, dim] embeddings = paddle.reshape_(embeddings, [bs * seq_len, hidden_size]) # [bs * seq_len / n, dim] (n is mp parallelism) - embeddings = ScatterOp.apply(embeddings) - + # embeddings = ScatterOp.apply(embeddings) + embeddings = dist.reshard(embeddings, get_mesh(), [dist.Replicate(), dist.Shard(0)]) # Use a ternary operator for a more concise assignment of current_seed current_seed = "local_seed" if self.config.sequence_parallel else "global_seed" # The 'with' block ensures the correct seed context is used @@ -785,7 +820,6 @@ def _get_name_mappings(cls, config: GPTConfig) -> list[StateDictNameMapping]: ] model_mappings.extend(layer_mappings) - # downstream mappings if "GPT2Model" not in config.architectures: for mapping in model_mappings: @@ -876,7 +910,7 @@ def __init__(self, config: GPTConfig): self.bias = paddle.tril( paddle.ones([1, 1, config.max_position_embeddings, config.max_position_embeddings], dtype="int64") ) - + self.bias = dist.shard_tensor(self.bias, get_mesh(), [dist.Replicate(), dist.Replicate()]) self.embeddings = GPTEmbeddingsAuto(config) decoder_layers = nn.LayerList() @@ -1029,7 +1063,6 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") # input_shape => bs, seq_len - if past_key_values is None: past_key_values = tuple([None] * len(self.decoder.layers)) @@ -1061,7 +1094,6 @@ def forward( attention_mask = (1.0 - (attention_mask & causal_mask)) * -1e4 else: attention_mask = (1.0 - causal_mask) * -1e4 - # The tensor returned by triu not in static graph. attention_mask.stop_gradient = True @@ -1116,6 +1148,8 @@ def forward(self, prediction_scores, masked_lm_labels, loss_mask=None): """ with paddle.amp.auto_cast(False): + if len(prediction_scores.shape) < len(masked_lm_labels.unsqueeze(2).shape): + prediction_scores = paddle.unsqueeze_(prediction_scores, 0) masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") loss = paddle.mean(masked_lm_loss) @@ -1155,6 +1189,11 @@ def __init__(self, config: GPTConfig, embedding_weights=None, ipp=None): self.weight.split_axis = 0 def forward(self, hidden_states, tensor_parallel_output=None): + + if self.config.sequence_parallel: + hidden_states = dist.reshard(hidden_states, get_mesh(self.ipp), [dist.Replicate(), dist.Replicate()]) + hidden_states = paddle.reshape(hidden_states, [-1, self.config.seq_length, self.config.hidden_size]) + if tensor_parallel_output is None: tensor_parallel_output = self.config.tensor_parallel_output @@ -1173,6 +1212,8 @@ class GPTForCausalLMAuto(GPTPretrainedModelAuto): """ + _tied_weights_keys = ["lm_head.weight", "lm_head.decoder.weight"] + def __init__(self, config: GPTConfig): super(GPTForCausalLMAuto, self).__init__(config) self.gpt = GPTModelAuto(config) @@ -1259,25 +1300,23 @@ def forward( else: hidden_states = outputs[0] logits = self.lm_head(hidden_states) - return logits + loss = None + if labels is not None: + loss = self.criterion(logits, labels) - # NOTE: The following code failed to run from dynamic to static mode - # loss = None - # if labels is not None: - # loss = self.criterion(logits, labels) - # if not return_dict: - # if isinstance(outputs, input_type): - # return (loss, logits) if loss is not None else logits - # outputs = (logits,) + outputs[1:] - # return ((loss,) + outputs) if loss is not None else outputs - # return CausalLMOutputWithCrossAttentions( - # loss=loss, - # logits=logits, - # past_key_values=outputs.past_key_values, - # hidden_states=outputs.hidden_states, - # attentions=outputs.attentions, - # cross_attentions=outputs.cross_attentions, - # ) + if not return_dict: + if isinstance(outputs, input_type): + return (loss, logits) if loss is not None else logits + outputs = (logits,) + outputs[1:] + return ((loss,) + outputs) if loss is not None else outputs + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) def prepare_fast_entry(self, kwargs): from paddlenlp.ops import FasterGPT diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 7611fd961ab6..17a7517e6f05 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -21,6 +21,7 @@ from functools import partial from typing import Optional, Tuple +import numpy as np import paddle import paddle.distributed.fleet.meta_parallel as mpu import paddle.nn.functional as F @@ -100,14 +101,14 @@ def swiglu(x, y=None): def _get_interleave(n): def _get_interleave_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) + start = 2 ** (-(2 ** -(np.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] - if math.log2(n).is_integer(): + if np.log2(n).is_integer(): return _get_interleave_power_of_2(n) else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) + closest_power_of_2 = int(2 ** np.floor(np.log2(n))) return ( _get_interleave_power_of_2(closest_power_of_2) + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] @@ -1545,8 +1546,9 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values expanded_attn_mask = expanded_attn_mask.astype("float32") expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) elif get_env_device() in ["xpu", "gcu"]: + min_val = paddle.finfo(dtype).min if get_env_device() == "gcu" else -1e37 # mask value for xpu x = paddle.to_tensor(0.0, dtype=dtype) - y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype) + y = paddle.to_tensor(min_val, dtype=dtype) expanded_attn_mask = expanded_attn_mask.astype(dtype) expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) else: diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index c194906178d0..de2ae508a397 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -142,18 +142,17 @@ def scaled_dot_product_attention( ) else: if alibi is not None: - alibi = alibi.reshape([bsz, num_heads, 1, -1]) attention_mask = attention_mask.cast(alibi.dtype) + alibi attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, - is_causal=attention_mask is None, + is_causal=attention_mask is None and query_states.shape[1] != 1, ) attn_weights = None - attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + attn_output = attn_output.reshape([bsz, q_len, head_dim * query_states.shape[-2]]) return (attn_output, attn_weights) if output_attentions else attn_output else: # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] @@ -161,14 +160,11 @@ def scaled_dot_product_attention( # merge with the next tranpose key_states = paddle.transpose(key_states, [0, 2, 1, 3]) value_states = paddle.transpose(value_states, [0, 2, 1, 3]) - # matmul and devide by sqrt(head_dim) attn_weights = paddle.matmul(query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2])) # then add alibi bias if alibi is not None: - alibi = alibi.reshape([bsz, num_heads, 1, -1]) attn_weights = attn_weights + alibi - if list(attn_weights.shape) != [bsz, num_heads, q_len, kv_seq_len]: raise ValueError( f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" @@ -199,7 +195,7 @@ def scaled_dot_product_attention( class LlamaRMSNormAuto(nn.Layer): - def __init__(self, config): + def __init__(self, config, ipp): super().__init__() self.hidden_size = config.hidden_size self.weight = paddle.create_parameter( @@ -207,6 +203,12 @@ def __init__(self, config): dtype=paddle.get_default_dtype(), default_initializer=nn.initializer.Constant(1.0), ) + self.ipp = ipp + self.weight = dist.shard_tensor( + self.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) self.variance_epsilon = config.rms_norm_eps self.config = config @@ -516,7 +518,12 @@ def forward( if (paddle_version != 0.0) and (paddle_version <= 2.6): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - + attention_mask = ( + dist.reshard(attention_mask, get_mesh(self.ipp), [dist.Shard(0), dist.Replicate()]) + if attention_mask is not None + else None + ) + alibi = dist.reshard(alibi, get_mesh(self.ipp), [dist.Shard(0), dist.Shard(1)]) if alibi is not None else None has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) if ( self.enable_recompute @@ -588,8 +595,8 @@ def __init__(self, config, layerwise_recompute: bool = False, ipp: Optional[int] self.hidden_size = config.hidden_size self.self_attn = LlamaAttentionAuto(config, layerwise_recompute, ipp) self.mlp = LlamaMLPAuto(config, ipp) - self.input_layernorm = LlamaRMSNormAuto(config) - self.post_attention_layernorm = LlamaRMSNormAuto(config) + self.input_layernorm = LlamaRMSNormAuto(config, ipp) + self.post_attention_layernorm = LlamaRMSNormAuto(config, ipp) # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True # Enable_recompute defaults to False and is controlled by Trainer self.enable_recompute = False @@ -620,7 +627,6 @@ def forward( (see `cache`). cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states """ - # [bs, seq_len, embed_dim] or [seq_len / n, bs, embed_dim] (if sequence_parallel) residual = hidden_states @@ -852,7 +858,6 @@ def __init__(self, config: LlamaConfig): self.hidden_size = config.hidden_size self.recompute_granularity = config.recompute_granularity self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] - # Recompute defaults to False and is controlled by Trainer self.enable_recompute = False self.embed_tokens = nn.Embedding( @@ -860,10 +865,15 @@ def __init__(self, config: LlamaConfig): self.hidden_size, ) + embedding_placements = ( + [dist.Replicate(), dist.Shard(1)] + if self.config.tensor_parallel_degree > 1 + else [dist.Replicate(), dist.Replicate()] + ) self.embed_tokens.weight = dist.shard_tensor( self.embed_tokens.weight, get_mesh(), - [dist.Replicate(), dist.Shard(1)], + embedding_placements, ) def get_layer_pp_info(layer_index): @@ -885,7 +895,7 @@ def get_layer_pp_info(layer_index): self.next_pp_stage_indexes.append(i) self.layers = nn.LayerList(decoder_layers) - self.norm = LlamaRMSNormAuto(config) + self.norm = LlamaRMSNormAuto(config, pp_stage_id) self.gradient_checkpointing = False @@ -964,7 +974,8 @@ def forward( seq_length_with_past += cache_length if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + with paddle.amp.auto_cast(False): + inputs_embeds = self.embed_tokens(input_ids) if self.config.sequence_parallel: # [B, S, H] -> [S, B, H] @@ -979,20 +990,22 @@ def forward( global_mesh, [dist.Replicate() for _ in range(len(global_mesh._shape))], ) - # embed positions if not self.config.use_flash_attention and attention_mask is None: # [bs, seq_len] attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) if self.config.alibi: + if attention_mask is None: + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + alibi_place = [dist.Replicate() for _ in range(len(global_mesh._shape))] alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) - alibi = alibi.reshape([batch_size * self.config.num_attention_heads, 1, seq_length_with_past]) + alibi = dist.shard_tensor(alibi, global_mesh, alibi_place) else: alibi = None - - if self.config.use_flash_attention: + if self.config.use_flash_attention and not self.config.alibi: # attention_mask in flash_attn is always None for pretrain + # atttenton_mask is used in scaled_dot_product_attention with alibi_tensor attention_mask = None else: attention_mask = self._prepare_decoder_attention_mask( @@ -1003,7 +1016,6 @@ def forward( global_mesh, [dist.Replicate() for _ in range(len(global_mesh._shape))], ) - hidden_states = inputs_embeds hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements) @@ -1011,7 +1023,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None - for idx, (decoder_layer) in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1040,7 +1051,14 @@ def forward( if attention_mask is not None else None ) - + if alibi is not None: + pp_mesh = get_mesh(ipp) + alibi_place = [dist.Replicate() for _ in range(len(pp_mesh._shape))] + alibi = dist.reshard( + alibi, + pp_mesh, + alibi_place, + ) if idx in self.next_pp_stage_indexes: hidden_states = dist.reshard( hidden_states, diff --git a/paddlenlp/transformers/llama/tokenizer.py b/paddlenlp/transformers/llama/tokenizer.py index be688206e2ad..8260e2b1239f 100644 --- a/paddlenlp/transformers/llama/tokenizer.py +++ b/paddlenlp/transformers/llama/tokenizer.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import os from shutil import copyfile @@ -247,7 +248,7 @@ def create_token_type_ids_from_sequences( import base64 import unicodedata -from typing import Collection, List, Optional, Set, Tuple +from typing import Collection, Set from ...utils.import_utils import is_tiktoken_available from .. import PretrainedTokenizer @@ -289,9 +290,10 @@ def __init__( vocab_file, errors="replace", padding_side="left", + add_bos_token=True, + add_eos_token=False, **kwargs, ): - super().__init__(**kwargs) if not is_tiktoken_available(): raise ValueError("tiktoken is not installed, please install it use: pip install tiktoken") @@ -320,6 +322,9 @@ def __init__( self.tokenizer = enc # type: tiktoken.Encoding + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.bod_id = self.special_tokens[BEGINOFTEXT] self.eod_id = self.special_tokens[ENDOFTEXT] self.start_header_id = self.special_tokens[IMSTART] @@ -331,11 +336,19 @@ def __init__( if "eos_token_id" in kwargs: self.eos_token_id = kwargs["eos_token_id"] + self.bos_token = BEGINOFTEXT + self.eos_token = ENDOFTEXT + self.bos_token_id = self.bod_id + self.eos_token_id = self.eod_id + self.pad_token = self.convert_ids_to_tokens(self.eos_token_id) + + super().__init__(pad_token=self.pad_token, **kwargs) + def __len__(self) -> int: return self.tokenizer.n_vocab def get_vocab(self) -> Dict[bytes, int]: - return self.mergeable_ranks + return {**self.mergeable_ranks, **self.special_tokens} def convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str]]]) -> List[int]: ids = [] @@ -351,13 +364,44 @@ def convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str] ids.append(self.mergeable_ranks.get(token)) return ids + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + if isinstance(ids, int): + return self.decoder[ids] + tokens = [] + for index in ids: + index = int(index) + if skip_special_tokens and index >= len(self.mergeable_ranks): + continue + if index in self.decoder: + tokens.append(self.decoder[index]) + return tokens + def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: if not special_tokens and new_tokens: raise ValueError("Adding regular tokens is not supported") for token in new_tokens: surface_form = token.content if isinstance(token, AddedToken) else token if surface_form not in SPECIAL_TOKENS: - raise ValueError("Adding unknown special tokens is not supported") + logger.info(f"adding a special token '{surface_form}'.") + token_id = len(self.mergeable_ranks) + len(self.special_tokens) + self.special_tokens[surface_form] = token_id + self.decoder[token_id] = surface_form + + import tiktoken as tk + + tiktoken = tk + enc = tiktoken.Encoding( + "Llama3", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + assert ( + len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab + ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" + + self.tokenizer = enc # type: tiktoken.Encoding + return 0 def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: @@ -432,28 +476,16 @@ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: def vocab_size(self): return self.tokenizer.n_vocab - def _convert_id_to_token(self, index: int) -> Union[bytes, str]: - """Converts an id to a token, special tokens included""" - if index in self.decoder: - return self.decoder[index] - raise ValueError("unknown ids") - - def _convert_token_to_id(self, token: Union[bytes, str]) -> int: - """Converts a token to an id using the vocab, special tokens included""" - if token in self.special_tokens: - return self.special_tokens[token] - if token in self.mergeable_ranks: - return self.mergeable_ranks[token] - raise ValueError("unknown token") - - def _tokenize(self, text: str, **kwargs): - """ - Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based - vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bod_id] if self.add_bos_token else [] + eos_token_id = [self.eod_id] if self.add_eos_token else [] - Do NOT take care of added tokens. - """ - raise NotImplementedError + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output def _decode( self, @@ -465,5 +497,5 @@ def _decode( if isinstance(token_ids, int): token_ids = [token_ids] if skip_special_tokens: - token_ids = [i for i in token_ids if i < self.eod_id] + token_ids = [i for i in token_ids if i <= len(self.mergeable_ranks)] return self.tokenizer.decode(token_ids, errors=errors or self.errors) diff --git a/paddlenlp/transformers/long_sequence_strategies/embedding_strategies.py b/paddlenlp/transformers/long_sequence_strategies/embedding_strategies.py index 6e9291e0d951..675500332144 100755 --- a/paddlenlp/transformers/long_sequence_strategies/embedding_strategies.py +++ b/paddlenlp/transformers/long_sequence_strategies/embedding_strategies.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + import paddle from paddle import nn @@ -20,6 +22,7 @@ "LinearScalingRotaryEmbedding", "NTKScalingRotaryEmbedding", "DynamicNTKScalingRotaryEmbedding", + "YaRNScalingRotaryEmbedding", ] @@ -120,3 +123,101 @@ def forward(self, seq_len=None, ntk_alpha=None): self._scale_cos_sin(seq_len=seq_len, ntk_alpha=ntk_alpha) return self.cos_cached[:, :], self.sin_cached[:, :] + + +class YaRNScalingRotaryEmbedding(nn.Layer): + """RotaryEmbedding extended with YaRN scaling.""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, + beta_slow=1, + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.scaling_factor = scaling_factor # scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + + self.yarn() + + self._set_cos_sin_cache(seq_len=self.max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + # [seq_len] + t = paddle.arange(seq_len, dtype=paddle.float32) + # [seq_len, dim/2] + with paddle.amp.auto_cast(enable=False): + freqs = paddle.outer(t.astype(self.inv_freq.dtype), self.inv_freq) + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + self.cos_cached = emb.cos()[:, :] * self.mscale + self.sin_cached = emb.sin()[:, :] * self.mscale + + def _scale_cos_sin(self, seq_len): + self.max_seq_len_cached = seq_len + + t = paddle.arange(self.max_seq_len_cached, dtype=self.inv_freq.dtype) + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) + emb = paddle.concat((freqs, freqs), axis=-1) + + self.cos_cached = emb.cos()[:, :] * self.mscale + self.sin_cached = emb.sin()[:, :] * self.mscale + + def forward(self, seq_len=None, ntk_alpha=None): + if seq_len > self.max_seq_len_cached: + self._scale_cos_sin(seq_len=seq_len) + + return self.cos_cached[:, :], self.sin_cached[:, :] + + def yarn(self): + inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype=paddle.float32) / self.dim)) + + low, high = self._yarn_find_correction_range( + self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings + ) + inv_freq_mask = ( + 1 - paddle.cast(self._yarn_linear_ramp_mask(low, high, self.dim // 2), dtype=paddle.float32) + ) * self.extrapolation_factor + + inv_freq = inv_freq / ((1 - inv_freq_mask) * self.scaling_factor + inv_freq_mask) + self.register_buffer("inv_freq", inv_freq) + self.mscale = self._yarn_get_mscale(self.scaling_factor) * self.attn_factor + + @classmethod + def _yarn_find_correction_dim(cls, num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + @classmethod + def _yarn_find_correction_range(cls, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(cls._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(cls._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + @classmethod + def _yarn_linear_ramp_mask(cls, low, high, dim): + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (paddle.arange(dim, dtype=paddle.float32) - low) / (high - low) + ramp_func = paddle.clip(linear_func, 0, 1) + return ramp_func + + @classmethod + def _yarn_get_mscale(cls, scaling_factor=1): + if scaling_factor <= 1: + return 1.0 + return 0.1 * math.log(scaling_factor) + 1.0 diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index c4e7d2786307..fcc207a2e0bc 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -54,6 +54,8 @@ from tqdm.auto import tqdm from paddlenlp.utils.env import ( + ASYMMETRY_QUANT_SCALE_MAX, + ASYMMETRY_QUANT_SCALE_MIN, CONFIG_NAME, LEGACY_CONFIG_NAME, PADDLE_WEIGHTS_INDEX_NAME, @@ -64,10 +66,12 @@ SAFE_PEFT_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, + SYMMETRY_QUANT_SCALE, ) from paddlenlp.utils.log import logger from ..generation import GenerationConfig, GenerationMixin +from ..quantization.unified_checkpoint_quantization import dequant_unified_optimizer from ..utils import device_guard from ..utils.download import resolve_file_path from .configuration_utils import PretrainedConfig @@ -362,10 +366,21 @@ def _load_part_state_dict( """ part_state_dict = {} + scale_dict = {} with safe_open(checkpoint_file, framework="np") as f: for key in keys: + # 1. non-merge ckpt loading dont have filter key. + # 2. merge ckpt will skip quant scale by `fliter_dict_keys` + if ( + key.endswith(SYMMETRY_QUANT_SCALE) + or key.endswith(ASYMMETRY_QUANT_SCALE_MIN) + or key.endswith(ASYMMETRY_QUANT_SCALE_MAX) + ): + continue + if fliter_dict_keys is not None and key not in fliter_dict_keys: continue + py_safe_slice_ = f.get_slice(key) if key in tensor_parallel_split_mapping: weight = tensor_parallel_split_mapping[key](py_safe_slice_) @@ -376,15 +391,31 @@ def _load_part_state_dict( weight = paddle.Tensor(weight, zero_copy=True) weight = weight._copy_to(paddle.framework._current_expected_place(), False) part_state_dict[key] = weight - return part_state_dict + for key in keys: + if ( + key.endswith(SYMMETRY_QUANT_SCALE) + or key.endswith(ASYMMETRY_QUANT_SCALE_MIN) + or key.endswith(ASYMMETRY_QUANT_SCALE_MAX) + ): + scale = f.get_tensor(key) + with device_guard(): + scale = paddle.Tensor(scale, zero_copy=True) + scale = scale._copy_to(paddle.framework._current_expected_place(), False) + scale_dict[key] = scale + return part_state_dict, scale_dict def load_state_dict( - checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None, device="cpu" + checkpoint_file: Union[str, os.PathLike], + tensor_parallel_split_mapping=None, + fliter_dict_keys=None, + device="cpu", + ckpt_quant_stage="O0", ): """ Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise. """ + if tensor_parallel_split_mapping is None: tensor_parallel_split_mapping = {} @@ -404,10 +435,9 @@ def load_state_dict( raise ValueError("Currently unsupport paddle weights file, use numpy instead.") if metadata.get("format", "np") == "np": thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1")) - state_dict = {} if thread_num <= 1: with safe_open(checkpoint_file, framework="np") as f: - state_dict = _load_part_state_dict( + state_dict, scale_dict = _load_part_state_dict( list(f.keys()), checkpoint_file, tensor_parallel_split_mapping, @@ -431,14 +461,20 @@ def load_state_dict( for keys in keys_groups } for future in concurrent.futures.as_completed(future_to_key): - result = future.result() - state_dict.update(result) + state_dict, scale_dict = future.result() + state_dict.update(state_dict) + scale_dict.update(scale_dict) if device == "cpu": for k in list(state_dict.keys()): with device_guard(): state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True) + if len(scale_dict) != 0: + if ckpt_quant_stage == "O0": + raise ValueError('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"') + state_dict = dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict) + return state_dict state_dict = paddlenlp_load(checkpoint_file, map_location="cpu") @@ -1156,6 +1192,13 @@ def set_inference_config(cls, config, predictor_args, **kwargs): config.block_size = predictor_args.block_size config.max_seq_len = predictor_args.total_max_length + if predictor_args.speculate_method is not None: + config.speculate_method = predictor_args.speculate_method + config.speculate_max_draft_token_num = predictor_args.speculate_max_draft_token_num + config.speculate_max_ngram_size = predictor_args.speculate_max_ngram_size + config.speculate_verify_window = predictor_args.speculate_verify_window + config.speculate_max_candidate_len = predictor_args.speculate_max_candidate_len + @classmethod def confirm_inference_model(cls, predictor_args, **kwargs): """ @@ -2102,7 +2145,9 @@ def _fuse_or_split_keys( if config.quantization_config.is_weight_quantize(): filter_dict_keys = None state_dict = load_state_dict( - shard_file, tp_actions if pre_tensor_parallel_split else None, filter_dict_keys + shard_file, + tp_actions if pre_tensor_parallel_split else None, + filter_dict_keys, ) # convert for fusing or splitting weights diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py new file mode 100644 index 000000000000..8118ba60f7ac --- /dev/null +++ b/paddlenlp/transformers/moe_gate.py @@ -0,0 +1,474 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn as nn +import paddle.nn.functional as F + +from ..utils.log import logger + + +class MoEGateMixin: + def gate_score_func(self, logits: paddle.Tensor) -> paddle.Tensor: + # [..., hidden_dim] -> [..., num_experts] + with paddle.amp.auto_cast(False): + scoring_func = getattr(self, "scoring_func", None) + if scoring_func == "softmax": + scores = F.softmax(logits.cast("float32"), axis=-1) + elif scoring_func == "sigmoid": + scores = F.sigmoid(logits.cast("float32")) + elif scoring_func == "tanh": + scores = F.tanh(logits.cast("float32")) + elif scoring_func == "relu": + scores = F.relu(logits.cast("float32")) + elif scoring_func == "gelu": + scores = F.gelu(logits.cast("float32")) + elif scoring_func == "leaky_relu": + scores = F.leaky_relu(logits.cast("float32")) + else: + logger.warning_once( + f"insupportable scoring function for MoE gating: {scoring_func}, use softmax instead" + ) + scores = F.softmax(logits.cast("float32"), axis=-1) + return scores + + def gumbel_rsample(self, logits: paddle.Tensor) -> paddle.Tensor: + gumbel = paddle.distribution.gumbel.Gumbel(0, 1) + return gumbel.rsample(logits.shape) + + def uniform_sample(self, logits: paddle.Tensor) -> paddle.Tensor: + uniform = paddle.distribution.uniform.Uniform(0, 1) + return uniform.sample(logits.shape) + + @paddle.no_grad() + def _one_hot_to_float(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.get_default_dtype()) + + @paddle.no_grad() + def _one_hot_to_int64(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.int64) + + @paddle.no_grad() + def _capacity( + self, gates: paddle.Tensor, capacity_factor: float, max_capacity: int, min_capacity: int + ) -> paddle.Tensor: + """Calculate the capacity for each expert based on the gates and capacity factor. + + Args: + gates (paddle.Tensor): A tensor of shape [num_tokens, num_experts] representing the probability distribution + over experts for each token. + capacity_factor (float): A scalar float value representing the capacity factor for each expert. + min_capacity (int): A scalar integer value representing the minimum capacity for each expert. + + Returns: + int: A tensor value representing the calculated capacity for each expert. + """ + assert gates.ndim == 2, f"gates should be 2D, but got {gates.ndim}, {gates.shape}" + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + capacity = int((num_tokens // num_experts) * capacity_factor) + if capacity < min_capacity: + capacity = min_capacity + if capacity > max_capacity: + capacity = max_capacity + assert capacity > 0, f"requires capacity > 0, capacity_factor: {capacity_factor}, input_shape: {gates.shape}" + + return capacity + + def _cal_aux_loss(self, gates, mask): + """ + Calculate auxiliary loss + + Args: + gates (paddle.Tensor): Represents the output probability of each expert. The shape is [batch_size, num_experts] + mask (paddle.Tensor): Represents whether each sample belongs to a certain expert. The shape is [batch_size, num_experts] + + Returns: + paddle.Tensor: The value of auxiliary loss. + + """ + me = paddle.mean(gates, axis=0) + ce = paddle.mean(mask.cast("float32"), axis=0) + if self.global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=self.group) + dist.all_gather(ce_list, ce, group=self.group) + + me_list[self.rank] = me + ce_list[self.rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + aux_loss = paddle.sum(me * ce) * float(self.num_experts) + return aux_loss + + def _cal_z_loss(self, logits) -> paddle.Tensor: + """ + Calculate the z loss. + + Args: + logits (paddle.Tensor): Model output. The shape is [batch_size, num_experts]. + + Returns: + paddle.Tensor: The z loss value. + """ + l_zloss = logits.exp().sum(1).log().square().mean() + return l_zloss + + def _cal_orthogonal_loss(self) -> paddle.Tensor: + """Gate weight orthogonal loss. + + Returns: + Paddle.Tensor: orthogonal loss + """ + weight = F.normalize(self.weight, axis=0) + orthogonal_loss = paddle.mean(paddle.square(paddle.matmul(weight.T, weight) - paddle.eye(self.num_experts))) + return orthogonal_loss + + +class PretrainedMoEGate(nn.Layer, MoEGateMixin): + def __init__(self, config, num_experts, expert_hidden_size, **kwargs): + super(PretrainedMoEGate, self).__init__() + + self.config = config + + self.num_experts = num_experts + self.expert_hidden_size = expert_hidden_size + + # force keep in float32 when using amp + self._cast_to_low_precision = False + + self.capacity_factor = kwargs.pop("capacity_factor", 1.0) + self.eval_capacity_factor = kwargs.pop("eval_capacity_factor", 1.0) + self.min_capacity = kwargs.pop("min_capacity", 1.0) + self.max_capacity = kwargs.pop("max_capacity", pow(2, 32)) + + self.group = kwargs.pop("group", None) + self.global_aux_loss = kwargs.pop("global_aux_loss", False) + if self.global_aux_loss: + assert self.group is not None, "group is required when global_aux_loss is True" + self.rank = dist.get_rank(self.group) + + self.expert_drop = kwargs.pop("expert_drop", False) + self.noisy_gate_policy = kwargs.pop("noisy_gate_policy", None) + self.drop_tokens = kwargs.pop("drop_tokens", True) + self.use_rts = kwargs.pop("use_rts", True) + self.top2_2nd_expert_sampling = kwargs.pop("top2_2nd_expert_sampling", True) + + self.drop_policy = kwargs.pop("drop_policy", "probs") + self.top_k = kwargs.pop("top_k", 2) + self.norm_topk_prob = kwargs.pop("norm_topk_prob", False) + + def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor: + """_summary_ + The priority is the cumulative sum of the expert indices. + + This method is used in hunyuan model + Args: + topk_idx (paddle.Tensor): [batch_size * seq_len, topk] + + Returns: + paddle.Tensor: cumsum locations + """ + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = paddle.transpose(topk_idx, [1, 0]) # [topk, B*S] + # Shape: [num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape([-1]) + + # Create mask out of indices. + # Shape: [tokens_per_group * num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_index, self.num_experts).cast(paddle.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [tokens_per_group * num_selected_experts, num_experts]. + token_priority = paddle.cumsum(expert_mask, axis=0) * expert_mask - 1 + # Shape: [num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((self.top_k, -1, self.num_experts)) + # Shape: [tokens_per_group, num_selected_experts, num_experts]. + token_priority = paddle.transpose(token_priority, [1, 0, 2]) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [tokens_per_group, num_experts]. + token_priority = paddle.max(token_priority, axis=1) + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [tokens_per_group, num_experts, expert_capacity]. + valid_mask = paddle.logical_and(token_priority >= 0, token_priority < capacity) + token_priority = paddle.masked_fill(token_priority, ~valid_mask, 0) + dispatch_mask = F.one_hot(token_priority, capacity).cast(paddle.int32) + valid_mask = valid_mask.unsqueeze(-1).expand(valid_mask.shape + [capacity]) + dispatch_mask = paddle.masked_fill(dispatch_mask, ~valid_mask, 0) + + return dispatch_mask + + def topk_naive(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + """ + topk_weight, topk_idx = paddle.topk(scores, k=k, axis=-1, sorted=False) + return topk_weight, topk_idx + + def topk_group( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + group_scores = scores.reshape([0, n_group, -1]).max(axis=-1) # [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=False)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=False) + + return topk_weight, topk_idx + + def top1gating( + self, + logits: paddle.Tensor, + used_token: paddle.Tensor = None, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements Top1Gating on logits.""" + if self.noisy_gate_policy == "RSample": + logits += self.gumbel_rsample(logits.shape) + + gates = self.gate_score_func(logits=logits) + capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) + + # Create a mask for 1st's expert per token + # noisy gating + # Only save the position of the maximum value + indices1_s = paddle.argmax(logits if self.noisy_gate_policy == "RSample" else gates, axis=1) + # Convert the position of the maximum value to a one-hot vector [s, e] + mask1 = self._one_hot_to_float(indices1_s, num_classes=self.num_experts) + + # mask only used tokens + if used_token is not None: + mask1 = paddle.einsum( + "s,se->se", used_token, mask1 + ) # Element-wise multiply used_token with mask1 to obtain a new mask1 + + # gating decisions + exp_counts = paddle.sum(mask1, axis=0) # Calculate the number of tokens for each expert + + # if we don't want to drop any tokens + if not self.drop_tokens: + new_capacity = paddle.max(exp_counts) # Calculate the number of tokens for each expert + # Communicate across expert processes to pick the maximum capacity. + if self.group is not None: + dist.all_reduce( + new_capacity, op=dist.ReduceOp.MAX, group=self.group + ) # Calculate the maximum value among expert processes + # Make sure the capacity value does not exceed the number of tokens. + capacity = int(min(new_capacity, paddle.tensor(mask1.size(0)))) + + l_aux = self._cal_aux_loss(gates, mask1) + l_zloss = self._cal_z_loss(logits) + + # Random Token Selection + if self.use_rts: + mask1_rand = mask1 * self.uniform_sample(mask1) + else: + mask1_rand = mask1 + + assert ( + logits.shape[0] >= self.min_capacity + ), "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." + + _, top_idx = paddle.topk(mask1_rand, k=capacity, axis=0) # Select top_capacity tokens + + new_mask1 = mask1 * paddle.zeros_like(mask1).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=0) + mask1 = new_mask1 + + # Compute locations in capacity buffer + locations1 = paddle.cumsum(mask1, axis=0) - 1 # Compute the position of each token in mask1 + + # Store the capacity location for each token + locations1_s = paddle.sum(locations1 * mask1, axis=1).cast(paddle.int64) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + gates = gates / gates * mask1_float + + locations1_sc = self._one_hot_to_float(locations1_s, capacity) + combine_weights = paddle.einsum("se,sc->sec", gates, locations1_sc) + dispatch_mask = combine_weights.cast(paddle.bool).detach() + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def top2gating( + self, + logits: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + # everything is in fp32 in this function + gates = self.gate_score_func(logits=logits) + + # Create a mask for 1st's expert per token. + indices1_s = paddle.argmax(gates, axis=1) # [S, 1] + mask1 = self._one_hot_to_int64(indices1_s, self.num_experts) # [S, E] + + if self.top2_2nd_expert_sampling: + # Create a mask for 2nd's expert per token using Gumbel-max trick. + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + logits += self.gumbel_rsample(logits) + + # Replace top-expert with min value + logits_except1 = logits.masked_fill(mask1.cast(paddle.bool), float("-inf")) # [S, E] + indices2_s = paddle.argmax(logits_except1, axis=1) # [S, 1] + mask2 = self._one_hot_to_int64(indices2_s, self.num_experts) # [S, E] + + # Note: mask1 and mask2 can be combined to form a single mask. + # mask = paddle.concat([mask1, mask2], axis=0) + # locations = paddle.cumsum(mask, axis=0) - 1 + # locations1, locations2 = locations.split(2, axis=0) + # Compute locations in capacity buffer. + locations1 = paddle.cumsum(mask1, axis=0) - 1 # [S, E] + locations2 = paddle.cumsum(mask2, axis=0) - 1 # [S, E] + # Update 2nd's location by accounting for locations of 1st. + locations2 += paddle.sum(mask1, axis=0, keepdim=True) + + l_aux = self._cal_aux_loss(gates, mask1) + l_zloss = self._cal_z_loss(logits) + + # gating decisions + exp_counts = paddle.sum(mask1 + mask2, axis=0) + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) + # Remove locations outside capacity from mask. + mask1 *= (locations1 < capacity).cast(paddle.int64) + mask2 *= (locations2 < capacity).cast(paddle.int64) + else: + # Do not drop tokens - set capacity according to current expert assignments + new_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(new_capacity) + + # Store the capacity location for each token. + locations1_s = paddle.sum(locations1 * mask1, axis=1) + locations2_s = paddle.sum(locations2 * mask2, axis=1) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + mask2_float = mask2.cast(paddle.float32) + gates1_s = paddle.einsum("se,se->s", gates, mask1_float) + gates2_s = paddle.einsum("se,se->s", gates, mask2_float) + denom_s = gates1_s + gates2_s + # Avoid divide-by-zero + denom_s = paddle.clip(denom_s, min=paddle.finfo(denom_s.dtype).eps) + gates1_s /= denom_s + gates2_s /= denom_s + + # Calculate combine_weights and dispatch_mask + gates1 = paddle.einsum("s,se->se", gates1_s, mask1_float) + gates2 = paddle.einsum("s,se->se", gates2_s, mask2_float) + locations1_sc = self._one_hot_to_float(locations1_s, capacity) + locations2_sc = self._one_hot_to_float(locations2_s, capacity) + combine1_sec = paddle.einsum("se,sc->sec", gates1, locations1_sc) + combine2_sec = paddle.einsum("se,sc->sec", gates2, locations2_sc) + combine_weights = combine1_sec + combine2_sec + dispatch_mask = combine_weights.cast(paddle.bool) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def topkgating( + self, + gates: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements TopKGating on logits.""" + l_zloss = self._cal_z_loss(gates) + + # get topk gates + top_gate, top_idx = paddle.topk(gates, k=self.top_k, axis=1) + # get topk mask + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + l_aux = self._cal_aux_loss(gates, mask) + + exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) + + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity(gates, self.capacity_factor * self.top_k, self.max_capacity, self.min_capacity) + + # update mask and locations by capacity + if self.drop_policy == "probs": + topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + capacity_probs, capacity_indices = paddle.topk(topk_masked_gates, k=capacity, axis=0, sorted=False) + token_priority = self._priority(capacity_indices, capacity) + + elif self.drop_policy == "position": + token_priority = self._priority(top_idx, capacity) + else: + raise ValueError(f"Invalid drop_policy: {self.drop_policy}") + else: + # Do not drop tokens - set capacity according to current expert assignments + local_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(local_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(local_capacity) + token_priority = self._priority(top_idx, capacity) + + # normalize gates + gates_masked = gates * mask + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + + combine_weights = paddle.einsum("se,sec->sec", gates_masked, token_priority.cast(paddle.get_default_dtype())) + dispatch_mask = combine_weights.cast(paddle.bool) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py new file mode 100644 index 000000000000..56369c6c3b92 --- /dev/null +++ b/paddlenlp/transformers/moe_layer.py @@ -0,0 +1,270 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Tuple + +import paddle +import paddle.distributed as dist +from paddle import Tensor, nn +from paddle.distributed.communication import stream +from paddle.distributed.communication.group import Group + +from .moe_gate import PretrainedMoEGate + + +def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): + """ + Rearranges the input tensor `x` based on gate results, truncates it according to the specified capacity, and performs padding. + + Args: + x (Tensor)[Seq, Dim]: The input tensor. + dispatch_mask (List[Tensor[Seq, 1], Tensor[Seq, 1]]): A list of dispatch masks. + scatter_index (Union[List[Tensor[Seq,], Tensor[Seq]], Tensor[Seq, 2]]): A list or tensor representing scatter indices. + num_experts (int): The number of experts. + capacity (int): The capacity size. + + Returns: + Tensor [Expert*Capacity, Dim]: The output tensor after dispatching. + """ + output = None + orig_dtype = x.dtype + if isinstance(scatter_index, paddle.Tensor): + scatter_index = scatter_index.unbind(1) + for i_scatter_index, i_dispatch_mask in zip(scatter_index, dispatch_mask): + init_output = paddle.zeros([num_experts * capacity, x.shape[-1]], dtype="float32") + updates = x * i_dispatch_mask.cast(x.dtype) + if output is None: + output = paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + else: + output = output + paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + if output.dtype != orig_dtype: + output = output.cast(orig_dtype) + return output + + +def combining(x, combine_weights, scatter_index): + """ + Performs combination and aggregation operations on the input matrix. + + Args: + x: Tensor[num_experts * capacity, dim] - The input matrix to be processed, where the last dimension represents the number of features. + combine_weights: Union[List[Tensor[seq, 1], Tensor[seq, 1]], Tensor[seq, 2, 1]] - A list or tensor containing combination weights for each feature. + scatter_index: Union[List[Tensor[seq], Tensor[seq]], Tensor[seq, 2]] - A tuple of indices indicating which elements are to be aggregated, where the first element is the row index and the second element is the column index. + + Returns: + Tensor: The output matrix after combination and aggregation, with a shape of [n, dim * num_features], where n is the number of samples in the input matrix. + """ + + dim = x.shape[-1] + if isinstance(scatter_index, (list, tuple)): + scatter_index = paddle.concat([i.unsqueeze([-1]) for i in scatter_index], -1) + scatter_index = scatter_index.reshape([-1]) + num_k = len(combine_weights) if isinstance(combine_weights, (list, tuple)) else combine_weights.shape[-1] + x = paddle.gather(x, scatter_index).reshape([-1, num_k, dim]) # [seq,2,dim] + if isinstance(combine_weights, (list, tuple)): + combine_weights = paddle.concat(combine_weights, -1).unsqueeze([1]) + return paddle.matmul(combine_weights, x).squeeze(1) # [seq,1,2] @ [seq,2,dim] -> [seq,1,dim] + + +class _AllToAll(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx: Any, + input: Tensor, + group: Group, + ) -> Tensor: # type: ignore + """ + All-to-all communication in the group. + + Args: + ctx (Any): Context object. + input (Tensor): Input tensor. + group (Group): The group object. + + Returns: + Tensor: Output tensor. + """ + + ctx.group = group + # return input + if dist.get_world_size(group) <= 1: + return input + output = paddle.empty_like(input) + stream.alltoall_single(output, input, None, None, group, True, True) + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor]: + """ + Aggregates gradient information from all input tensors into a single tensor. + + Args: + ctx (Any): The context object used to store information that needs to be passed. + *grad_output (Tensor): A list of input tensors whose gradients are to be aggregated. + + Returns: + Tuple[Tensor]: A tuple containing a tensor that holds the gradients of all input tensors. + + """ + # return grad_output + return _AllToAll.apply(*grad_output, ctx.group) + + +class MoELayer(nn.Layer): + def __init__( + self, + config, + moe_num_experts: int, + expert_class: nn.Layer, + expert_kwargs: dict, + gate: PretrainedMoEGate, + capacity: int = 1.0, + moe_group: str = "data", + all_to_all_dropout=0.0, + ): + super().__init__() + + self.config = config + + self.moe_num_experts = moe_num_experts + self.capacity = capacity + + if dist.get_world_size() > 1 and moe_group == "data": + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + self.moe_rank = dist.get_rank(self.moe_group) + self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank + self.expert_parallel_degree = dist.get_world_size(self.moe_group) + self.expert_parallel_degree = 1 if self.expert_parallel_degree < 0 else self.expert_parallel_degree + self.moe_num_experts_per_device = self._parse_moe_expert_parallel( + self.moe_num_experts, self.expert_parallel_degree + ) + else: + # when moe_group is dummy, we don't need to use all_to_all + self.moe_group = None + self.moe_rank = 0 + self.expert_parallel_degree = 1 + self.moe_num_experts_per_device = self.moe_num_experts + + self.all_to_all_dropout = all_to_all_dropout + self.enable_recompute = False + + self.experts = nn.LayerList([]) + for i in range(self.moe_num_experts): + if i // self.moe_num_experts_per_device == self.moe_rank: + self.experts.append(expert_class(expert_kwargs)) + else: + self.experts.append(None) + + self.gate = gate + self.gate.group = self.moe_group + + def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): + assert ( + moe_num_experts >= expert_parallel_degree + ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={expert_parallel_degree}" + assert ( + moe_num_experts % expert_parallel_degree == 0 + ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={expert_parallel_degree} == 0" + moe_num_experts_per_device = moe_num_experts // expert_parallel_degree + return moe_num_experts_per_device + + def _post_init(self): + for p in self.gate.parameters(): + p.is_gate = True + + for k in self.experts: + if k is not None: + for p in k.parameters(): + p.expert = not self.is_dummy_moe + p.no_sync = not self.is_dummy_moe + # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") + + def expert_forward(self, dispatched_input): + true_experts = self.experts[ + self.moe_rank * self.moe_num_experts_per_device : (self.moe_rank + 1) * self.moe_num_experts_per_device + ] + expert_outputs = [] + chunks = dispatched_input.unbind(1) + assert len(chunks) == len(true_experts), (len(chunks), len(true_experts)) + for chunk, expert in zip(chunks, true_experts): + chunk = chunk.contiguous() + expert_outputs += [expert(chunk)] + expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] + return expert_output + + def forward( + self, + hidden_state: paddle.Tensor, + used_token: paddle.Tensor = None, + ): + """_summary_ + + Args: + input (_type_): _description_ + used_token + + Returns: + _type_: _description_ + """ + # Implement Algorithm 2 from GShard paper. + batch_size, seq_len, d_model = hidden_state.shape + + # Initial implementation -> Reshape into S tokens by dropping sequence dimension. + # Reshape into G groups so that each group can distribute tokens equally + # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 + reshaped_input = hidden_state.reshape([-1, d_model]) + + capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate(reshaped_input) + + # self.l_aux : + # combine_weights : sec + # dispatch_mask : sec + # self.exp_counts : + dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input) + + if self.expert_parallel_degree > 1: + dispatched_input = _AllToAll.apply(dispatched_input, self.moe_group) + + # Re-shape after all-to-all: ecm -> gecm + dispatched_input = dispatched_input.reshape( + [self.expert_parallel_degree, self.moe_num_experts_per_device, -1, d_model] + ) + expert_output = self.expert_forward(dispatched_input) + # Re-shape before drop_tokens: gecm -> ecm + expert_output = expert_output.reshape( + [self.expert_parallel_degree * self.moe_num_experts_per_device, -1, d_model] + ) + + if self.expert_parallel_degree > 1: + expert_output = _AllToAll.apply(expert_output, self.moe_group) + + # combine withe expert weights + combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output) + + a = combined_output.reshape(hidden_state.shape) + + return a, l_aux, l_zloss diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index 0f9fb994539c..ced5c6ed7052 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -1175,8 +1175,16 @@ def forward(self, prediction_scores, masked_lm_labels): masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) # skip ignore_index which loss == 0 - masked_lm_loss = masked_lm_loss[masked_lm_loss > 0] - loss = paddle.mean(masked_lm_loss) + # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0] + # loss = paddle.mean(masked_lm_loss) + binary_sequence = paddle.where( + masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) + ) + count = paddle.sum(binary_sequence) + if count == 0: + loss = paddle.sum(masked_lm_loss * binary_sequence) + else: + loss = paddle.sum(masked_lm_loss * binary_sequence) / count return loss diff --git a/paddlenlp/transformers/qwen2/tokenizer.py b/paddlenlp/transformers/qwen2/tokenizer.py index 0e489bced151..a670b1a5d129 100644 --- a/paddlenlp/transformers/qwen2/tokenizer.py +++ b/paddlenlp/transformers/qwen2/tokenizer.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for Qwen2.""" +from __future__ import annotations import json import os import unicodedata from functools import lru_cache -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import regex as re @@ -338,3 +339,110 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = def prepare_for_tokenization(self, text, **kwargs): text = unicodedata.normalize("NFC", text) return (text, kwargs) + + def _encode_chat_inputs( + self, + conversations: List[List[str, str]], + context_data: Dict[str, Any] = {}, + system: str = None, + add_generation_prompt=True, + ): + result = {} + + # Some template do not support system msg, so we need to check it first. + if system: + try: + self.chat_template.render(messages={"role": "system", "content": system}) + except Exception as e: + raise ValueError("System is not supported in this tokenizer.", e) + + # convert list msg to role dict msg + conversation_dict = [] + origin_msg = [] + for round in conversations: + round_role = [ + {"role": "user", "content": round[0]}, + {"role": "assistant", "content": round[1]}, + ] + origin_msg.extend(round_role) + conversation_dict.append(round_role) + + # Get system string in ChatTemplate + # ChatTemplate contains three parts: system, user, and assistant. + # However, the system string cannot be obtained directly with the chat_template.render() function. + # Thus, three steps are needed to extract the system string. + # Step 1: Obtain the combined system and user string in the first round. + # Step 2: Obtain the special system string. + # Step 3: Obtain the special combined system and user string in the first round. + # Then, user string = (special system and user string) - (special system string) + # And, system string = (initial system and user string) - (user string) + + assert len(conversation_dict) > 0, "conversations is empty" + + def replace_first_occurrence(original_string, to_find, to_replace): + index = original_string.find(to_find) + if index == -1: # to_find not found in original_string + return original_string + else: + return original_string[:index] + to_replace + original_string[index + len(to_find) :] + + if system: + system_str = self.chat_template.render([system]) + else: + # get system and user str + round0_str = self.chat_template.render( + messages=conversation_dict[0][:1], add_generation_prompt=False, **self.special_tokens_map + ) + # get special system str + round0_only_system_str = self.chat_template.render( + messages=[{"role": "system", "content": ""}], add_generation_prompt=False, **self.special_tokens_map + ) + # get special system and user str + round0_system_user_str = self.chat_template.render( + messages=[{"role": "system", "content": ""}] + conversation_dict[0][:1], + add_generation_prompt=False, + **self.special_tokens_map, + ) + + # get user str = {special system and user str} - {special system str} + user_str = replace_first_occurrence(round0_system_user_str, round0_only_system_str, "") + # get system str = { system and user str} - {user str} + system_str = round0_str.replace(user_str, "") + + no_ans = [] + ans = [] + for conv in conversation_dict: + roundi = [system] + conv if system else conv + roundi_str = self.chat_template.render( + messages=roundi, add_generation_prompt=False, **self.special_tokens_map + ) + + roundi_no_ans = [system] + [conv[0]] if system else [conv[0]] + roundi_no_ans_str = self.chat_template.render( + messages=roundi_no_ans, add_generation_prompt=add_generation_prompt, **self.special_tokens_map + ) + + roundi_ans_str = roundi_str[len(roundi_no_ans_str) :] + ans.append(roundi_ans_str) + + roundi_no_ans_no_system_str = replace_first_occurrence(roundi_no_ans_str, system_str, "") + assert ( + roundi_no_ans_str == system_str + roundi_no_ans_no_system_str + ), f"the src string contains system str: {system_str}" + no_ans.append(roundi_no_ans_no_system_str) + + # the first round is special, we need to add system_str + no_ans[0] = system_str + no_ans[0] + + conversation_ids = [] + for i in range(len(no_ans)): + conversation_ids.append( + self.batch_encode( + [no_ans[i], ans[i]], + add_special_tokens=False, + padding=False, + )["input_ids"] + ) + + result["conversations"] = conversation_ids + return result diff --git a/paddlenlp/transformers/qwen2_moe/modeling.py b/paddlenlp/transformers/qwen2_moe/modeling.py index 46a3bb885a60..18507c1d5dc7 100644 --- a/paddlenlp/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/transformers/qwen2_moe/modeling.py @@ -34,6 +34,8 @@ from ..conversion_utils import StateDictNameMapping, init_name_mappings from ..model_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPast from ..model_utils import PretrainedModel, register_base_model +from ..moe_gate import PretrainedMoEGate +from ..moe_layer import MoELayer from .configuration import Qwen2MoeConfig try: @@ -52,7 +54,7 @@ try: from paddle.nn.functional.flash_attention import flash_attention -except: +except ImportError: flash_attention = None __all__ = [ @@ -683,68 +685,69 @@ def forward( return outputs -class Qwen2MoeSparseMoEBlock(nn.Layer): - def __init__(self, config: Qwen2MoeConfig): - super().__init__() - self.num_experts = config.num_experts - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob - - self.gate = nn.Linear(config.hidden_size, self.num_experts, bias_attr=False) - self.experts = nn.LayerList([Qwen2MoeMLP(config) for _ in range(self.num_experts)]) - - self.shared_expert = Qwen2MoeMLP(config, is_shared=True) - self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias_attr=False) +class Qwen2MoeGate(PretrainedMoEGate): + def __init__(self, config, num_experts, expert_hidden_size, **kwargs): + super().__init__(config, num_experts, expert_hidden_size, **kwargs) + # [hidden_size, n_expert] + self.weight = paddle.create_parameter( + shape=[expert_hidden_size, num_experts], + dtype=paddle.get_default_dtype(), + is_bias=False, + default_initializer=nn.initializer.Constant(1.0), + ) def forward(self, hidden_states): - batch_size, seq_len, hidden_dim = hidden_states.shape - hidden_states = hidden_states.reshape([-1, hidden_dim]) - # router_logits: [batch_size * seq_len, num_experts] - router_logits = self.gate(hidden_states) + """ + Args: + hidden_states (_type_): [batch_size * seq_len, hidden_size] + """ + _, h_dim = hidden_states.shape + + # compute gating score + logits = F.linear(hidden_states, self.weight, None) with paddle.amp.auto_cast(False): - routing_weights = F.softmax(router_logits.astype("float32"), axis=1) - routing_weights, selected_experts = paddle.topk(routing_weights, self.top_k, axis=-1) - if self.norm_topk_prob: # Note: Mixtral is set norm as default, Qwen2Moe is set to no norm - routing_weights /= routing_weights.sum(axis=-1, keepdim=True) - # we cast back to input dtype - routing_weights = routing_weights.astype(hidden_states.dtype) - - final_hidden_states = paddle.zeros( - [batch_size * seq_len, hidden_dim], - dtype=hidden_states.dtype, - ) + scores = self.gate_score_func(logits=logits) + scores = scores.cast(paddle.get_default_dtype()) + + capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores) - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated. - # shape: [num_experts, top_k, batch_size * seq_len] - expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).transpose([2, 1, 0]) + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss - # Loop over all available experts in the model and perform the computation on each expert. - for expert_id in range(self.num_experts): - expert_layer = self.experts[expert_id] - idx, top_x = paddle.where(expert_mask[expert_id]) - if top_x.shape[0] == 0: - continue +class Qwen2MoeSparseMoEBlock(MoELayer): + def __init__(self, config: Qwen2MoeConfig): + gate = Qwen2MoeGate( + config, + config.num_experts, + config.hidden_size, + top_k=config.num_experts_per_tok, + drop_tokens=False, + ) - current_state = paddle.gather(hidden_states, top_x.squeeze()) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx] + super().__init__( + config, + moe_num_experts=config.num_experts, + expert_class=Qwen2MoeMLP, + expert_kwargs=config, + gate=gate, + capacity=2.0, + ) - top_x = top_x.squeeze() - if top_x.shape == []: - top_x = paddle.to_tensor([top_x.item()]) - final_hidden_states = paddle.index_add_( - final_hidden_states, top_x, 0, current_hidden_states.astype(hidden_states.dtype) - ) + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + self.shared_expert = Qwen2MoeMLP(config, is_shared=True) + self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias_attr=False) + + def forward(self, hidden_states): + final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output - final_hidden_states = final_hidden_states + shared_expert_output - final_hidden_states = final_hidden_states.reshape([batch_size, seq_len, hidden_dim]) - return final_hidden_states, router_logits + return final_hidden_states, l_aux class Qwen2MoeDecoderLayer(nn.Layer): diff --git a/paddlenlp/transformers/yuan/tokenizer.py b/paddlenlp/transformers/yuan/tokenizer.py index 03472368afc6..b56e6aa39ac5 100644 --- a/paddlenlp/transformers/yuan/tokenizer.py +++ b/paddlenlp/transformers/yuan/tokenizer.py @@ -16,8 +16,9 @@ """Tokenization class for Yuan2.0 model""" import os +import re from shutil import copyfile -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import sentencepiece as spm @@ -200,3 +201,82 @@ def create_token_type_ids_from_sequences( if token_ids_1 is None: return len(token_ids_0 + eos) * [0] return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + def _encode_chat_inputs( + self, + conversations: List[Tuple[str, str]], + context_data: Dict[str, Any] = {}, + system: str = None, + add_generation_prompt=True, + ): + result = {} + + # Some template do not support system msg, so we need to check it first. + if system: + try: + self.chat_template.render(messages={"role": "system", "content": system}) + except Exception as e: + raise ValueError("System is not supported in this tokenizer.", e) + + # convert list msg to role dict msg + conversation_dict = [] + origin_msg = [] + for round in conversations: + round_role = [ + {"role": "user", "content": round[0]}, + {"role": "assistant", "content": round[1]}, + ] + origin_msg.extend(round_role) + conversation_dict.append(round_role) + ans = [] + + # get answer in single round, then compile the chat entirely and split by single round ans + # attention: answer should include end token! + for conv in conversation_dict: + roundi = [system] + conv if system else conv + roundi_str = self.chat_template.render( + messages=roundi, add_generation_prompt=False, **self.special_tokens_map + ) + roundi_no_ans = [system] + [conv[0]] if system else [conv[0]] + roundi_no_ans_str = self.chat_template.render( + messages=roundi_no_ans, add_generation_prompt=add_generation_prompt, **self.special_tokens_map + ) + + ans_roundi = roundi_str[len(roundi_no_ans_str) - len("") + len("") : -len("")] + ans.append(ans_roundi) + for idx, _ in enumerate(ans): + ans[idx] += "" if idx != len(ans) - 1 else "" + + non_learnable_parts = self._extract_non_learnable_parts(origin_msg, ans) + assert len(non_learnable_parts) == len(ans) + + conversation_ids = [] + for i in range(len(non_learnable_parts)): + conversation_ids.append( + self.batch_encode( + [non_learnable_parts[i], ans[i]], + add_special_tokens=False, + padding=False, + )["input_ids"] + ) + + result["conversations"] = conversation_ids + return result + + def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]): + """Split the entire chat by specified words. Extract the non-learnable parts.""" + # distingish and replace the special words in original string to an uncompiled form: Like | -> \| + split_s_with_front_token = split_s.copy() + for idx, _ in enumerate(split_s): + split_s_with_front_token[idx] = "" + split_s_with_front_token[idx] + regex_pattern = "|".join(map(re.escape, split_s_with_front_token)) + # splited by replaced specified words + non_learnable_parts = re.split( + r"(?:%s)" % regex_pattern, + self.chat_template.render(messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map), + ) + if non_learnable_parts[-1] == "": + non_learnable_parts.pop() + for idx, _ in enumerate(non_learnable_parts): + non_learnable_parts[idx] = non_learnable_parts[idx] + "" + return non_learnable_parts diff --git a/paddlenlp/trl/llm_utils.py b/paddlenlp/trl/llm_utils.py index 4ee23f2822ee..cbe94f2f15ed 100644 --- a/paddlenlp/trl/llm_utils.py +++ b/paddlenlp/trl/llm_utils.py @@ -633,7 +633,7 @@ def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Q from paddlenlp_ops import get_output while True: - get_output(output_tensor, 0, True) + get_output(output_tensor, 0, True, False) # wait_flag # speculative_decoding if int(output_tensor[0, 0]) == -2: # read none continue bsz = int(output_tensor[1, 0]) @@ -650,6 +650,50 @@ def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Q logger.info("Finish read result message") +def speculate_read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Queue, done_event: mp.Event): + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + paddle.device.set_device("cpu") + paddle.disable_static() + outputs = [] + from paddlenlp.utils.env import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ + + for _ in range(SPECULATE_MAX_BSZ): + outputs.append([]) + output_tensor = tensor_queue.get(timeout=1) + done_event.set() + logger.info("Start speculate read result message") + logger.info(f"Current path is {os.getcwd()}") + + from paddlenlp_ops import get_output + + while True: + get_output(output_tensor, 0, True, True) # wait_flag # speculative_decoding + if int(output_tensor[0, 0]) == -2: # read none + continue + bsz = int(output_tensor[1]) + accept_num = output_tensor[2 : bsz + 2].numpy() + for bi in range(bsz): + output_numpy = output_tensor[ + 2 + + SPECULATE_MAX_BSZ + + bi * MAX_DRAFT_TOKENS : 2 + + SPECULATE_MAX_BSZ + + bi * MAX_DRAFT_TOKENS + + int(accept_num[bi]), + 0, + ].numpy() + output_numpy[output_numpy == -1] = tokenizer.eos_token_id + outputs[bi].extend(output_numpy.tolist()) + if int(output_tensor[0, 0]) == -1: + break + + seqs = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False) + for i, (out, seq) in enumerate(zip(outputs, seqs)): + result_queue.put([i, out, seq]) + + logger.info("Finish read result message") + + def get_rotary_position_embedding(position_ids, head_dim, rope_theta=10000.0, rope_scaling: dict = None): """ Pre-calculate rotary position embedding for position_ids. diff --git a/paddlenlp/utils/env.py b/paddlenlp/utils/env.py index d1fbbb1a60ba..249989616497 100644 --- a/paddlenlp/utils/env.py +++ b/paddlenlp/utils/env.py @@ -111,3 +111,18 @@ def _get_bool_env(env_key: str, default_value: str) -> bool: SAFE_PEFT_WEIGHTS_NAME = "peft_model.safetensors" SAFE_PEFT_WEIGHTS_INDEX_NAME = "peft_model.safetensors.index.json" + +# Checkpoint quantization +MOMENT1_KEYNAME = "moment1_0" +MOMENT2_KEYNAME = "moment2_0" +BETA1_KEYNAME = "beta1_pow_acc_0" +BETA2_KEYNAME = "beta2_pow_acc_0" +SYMMETRY_QUANT_SCALE = "@scales" +ASYMMETRY_QUANT_SCALE_MIN = "@min_scales" +ASYMMETRY_QUANT_SCALE_MAX = "@max_scales" + +# LLM Inference related environment variables +# Note(@Wanglongzhi2001): MAX_BSZ, SPECULATE_MAX_BSZ, MAX_DRAFT_TOKENS must be the same as definition in get_output / save_output +MAX_BSZ = 512 +SPECULATE_MAX_BSZ = 256 +MAX_DRAFT_TOKENS = 6 diff --git a/pyproject.toml b/pyproject.toml index c6ac5bc3ba58..54d319a8c85f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ testpaths = [ "tests/generation", "tests/layers", "tests/metrics", + "tests/pose", "tests/ops", "tests/trainer", "tests/transformers", diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 4f634470511c..7c11fbad457e 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -68,57 +68,104 @@ function track_case_status() { return 0 } +function restore_func() { + fun_list=$1 + cd ${log_path} || { echo "Failed to enter log_path: $log_path"; return 1; } + if [ -e "functions.txt" ]; then + rm "functions.txt" + echo "Deleted existing functions.txt" + fi + for function in ${fun_list[@]};do + echo "$function" >> functions.txt + done +} + + + # NOTE: Please place the new tests as much as possible after the existing tests function llama_case_list_auto() { - # The test name must have "llama_" as a prefix, which will - # be used for tracking the execution status of the case. - llama_dygraph_auto_bs8_fp32_DP2 - llama_dygraph_auto_bs8_fp32_DP2-MP2 - llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2 - llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2 - llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw - llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2 - - llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1 - llama_pir_auto_fuse_ffn_attention_qkv_MP2 - llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1 - llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP - llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP - llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1 - llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4 - llama_align_dygraph_dy2st_pir_auto_pp_bs2_bf16_DP1-MP1-PP4 - - track_case_status $FUNCNAME "llama_" + fun_list=( + # The test name must have "llama_" as a prefix, which will + # be used for tracking the execution status of the case. + llama_dygraph_auto_bs8_fp32_DP2 + llama_dygraph_auto_bs8_fp32_DP2-MP2 + llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2 + llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2 + llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw + llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2 + llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1 + llama_pir_auto_fuse_ffn_attention_qkv_MP2 + llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1 + llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP + llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP + llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1 + llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4 + llama_align_dygraph_dy2st_pir_auto_pp_bs2_bf16_DP1-MP1-PP4 + ) + if [ $1 = "prepare_case" ]; then + restore_func $fun_list + elif [ $1 = "exec_case" ]; then + for fun in "${fun_list[@]}"; do + eval "$fun" + done + track_case_status $FUNCNAME "llama_" + else + echo -e "\033[31m ---- Invalid status $1 \033[0m" + return 1 + fi } + function llm_gpt_case_list_auto() { - # The test name must have "llm_gpt_dygraph_auto_" as a prefix, - # which will be used for tracking the execution status of the case. - llm_gpt_dygraph_auto_bs8_fp32_DP2 - llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2 - llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2 - llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2 - - track_case_status $FUNCNAME "llm_gpt_dygraph_auto_" + fun_list=( + # The test name must have "llm_gpt_dygraph_auto_" as a prefix, + # which will be used for tracking the execution status of the case. + llm_gpt_dygraph_auto_bs8_fp32_DP2 + llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2 + llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2 + llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2 + llm_gpt_pir_auto_bs4_TP2 + llm_gpt_pir_auto_bs4_TP2_PP2 + ) + if [ $1 = "prepare_case" ]; then + restore_func $fun_list + elif [ $1 = "exec_case" ]; then + for fun in "${fun_list[@]}"; do + eval "$fun" + done + track_case_status $FUNCNAME "llm_gpt" + else + echo -e "\033[31m ---- Invalid status $1 \033[0m" + return 1 + fi } function llm_qwen_case_list_auto() { - # The test name must have "llm_qwen_dygraph_auto_" as a prefix, - # which will be used for tracking the execution status of the case. - llm_qwen_dygraph_auto_bs1_fp32_DP2 - llm_qwen_dygraph_auto_bs1_fp32_DP2-MP2 - llm_qwen_dygraph_auto_bs1_fp32_DP2-MP2-PP2 - llm_qwen_dygraph_auto_bs1_bf16_DP2-MP2-PP2 - - track_case_status $FUNCNAME "llm_qwen_dygraph_auto_" + fun_list=( + # The test name must have "llm_qwen_dygraph_auto_" as a prefix, + # which will be used for tracking the execution status of the case. + llm_qwen_dygraph_auto_bs1_fp32_DP2 + llm_qwen_dygraph_auto_bs1_fp32_DP2-MP2 + llm_qwen_dygraph_auto_bs1_fp32_DP2-MP2-PP2 + llm_qwen_dygraph_auto_bs1_bf16_DP2-MP2-PP2 + ) + if [ $1 = "prepare_case" ]; then + restore_func $fun_list + elif [ $1 = "exec_case" ]; then + for fun in "${fun_list[@]}"; do + eval "$fun" + done + track_case_status $FUNCNAME "llm_qwen_dygraph_auto_" + else + echo -e "\033[31m ---- Invalid status $1 \033[0m" + return 1 + fi } ############ case start ############ function llama_dygraph_auto_bs8_fp32_DP2() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -177,7 +224,7 @@ function llama_dygraph_auto_bs8_fp32_DP2() { ips=-1 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" - loss_base=9.51876831 + loss_base=9.4992733 if [ $IS_A100 -ne 0 ];then loss_base=9.53084087 fi @@ -189,8 +236,6 @@ function llama_dygraph_auto_bs8_fp32_DP2() { function llama_dygraph_auto_bs8_fp32_DP2-MP2() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -261,8 +306,6 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() { function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -333,8 +376,6 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -409,8 +450,6 @@ function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2() { echo IS_A100 is $IS_A100 if [ $IS_A100 -ne 0 ]; then echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -510,8 +549,6 @@ function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw() { echo IS_A100 is $IS_A100 if [ $IS_A100 -ne 0 ]; then echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -609,8 +646,6 @@ function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw() { function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export PYTHONPATH=/paddle/Paddle/build_gpu/python/:$PYTHONPATH export FLAGS_call_stack_level=3 @@ -705,8 +740,6 @@ function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP() { function llama_pir_auto_fuse_ffn_attention_qkv_MP2() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export FLAGS_max_inplace_grad_add=100 @@ -803,8 +836,6 @@ function llama_pir_auto_fuse_ffn_attention_qkv_MP2() { function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export PYTHONPATH=/paddle/Paddle/build_gpu/python/:$PYTHONPATH export FLAGS_call_stack_level=3 @@ -900,8 +931,6 @@ function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP() { function llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -981,7 +1010,7 @@ function llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1() { ips=-1 mem=-1 echo "result: to_static=$to_static loss=$loss ips=$ips mem=$mem" - loss_base=9.97198105 + loss_base=9.99302673 if [ $IS_A100 -ne 0 ];then loss_base=10.18783569 fi @@ -994,8 +1023,6 @@ function llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1() { function llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -1095,119 +1122,126 @@ function llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1() { function llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH - # Only A100 support this case. - if [ $IS_A100 -ne 0 ]; then - export FLAGS_call_stack_level=3 - export NVIDIA_TF32_OVERRIDE=0 - export FLAGS_max_inplace_grad_add=3 + export FLAGS_call_stack_level=3 + export NVIDIA_TF32_OVERRIDE=0 + export FLAGS_max_inplace_grad_add=3 - task_name="llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1_MP1_PP4" - case_out_dir="output/$task_name" - case_log_dir="output/$task_name""_log" - loss1=0 - loss2=0 - use_pir=1 - - max_step=10 - to_static=1 - - for pp_mode in "1F1B" "VPP"; do - export FLAGS_enable_pir_api=${use_pir} - export FLAGS_enable_pir_in_executor=${use_pir} - rm -rf $case_out_dir - rm -rf $case_log_dir - rm -rf ${log_path}/$FUNCNAME - if [ "$pp_mode" == "FThenB" ]; then - vpp_degree=1 - else - vpp_degree=2 - fi + task_name="llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1_MP1_PP4" + case_out_dir="output/$task_name" + case_log_dir="output/$task_name""_log" + loss1=0 + loss2=0 + use_pir=1 + + max_step=10 + to_static=1 + loss1_array=() + loss2_array=() - python -u -m paddle.distributed.launch \ - --gpus "0,1,2,3" \ - --log_dir $case_log_dir \ - run_pretrain_auto.py \ - --model_type "llama" \ - --model_name_or_path "facebook/llama-7b" \ - --tokenizer_name_or_path "facebook/llama-7b" \ - --input_dir "./data" \ - --output_dir $case_out_dir \ - --split 949,50,1 \ - --weight_decay 0.01 \ - --warmup_ratio 0.01 \ - --warmup_steps 30 \ - --max_grad_norm 0.0 \ - --learning_rate 3e-05 \ - --min_learning_rate 3e-06 \ - --max_steps $max_step \ - --logging_steps 1 \ - --eval_steps 1000 \ - --save_steps 50000 \ - --continue_training 0 \ - --do_train true \ - --do_eval false \ - --do_predict false \ - --disable_tqdm true \ - --skip_profile_timer true \ - --save_total_limit 2 \ - --device gpu \ - --disable_tqdm true \ - --dataloader_num_workers 1 \ - --distributed_dataloader 0 \ - --enable_auto_parallel 1 \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 4 \ - --per_device_eval_batch_size 2 \ - --recompute false \ - --recompute_use_reentrant true \ - --recompute_granularity full \ - --fp16 0 \ - --fp16_opt_level "O2" \ - --fuse_attention_ffn true \ - --fuse_attention_qkv true \ - --fuse_sequence_parallel_allreduce false \ - --use_flash_attention 0 \ - --use_fused_rope false \ - --use_fused_rms_norm 0 \ - --max_seq_length 2048 \ - --hidden_size 1024 \ - --sep_parallel_degree 1 \ - --sequence_parallel false \ - --pipeline_parallel_degree 4 \ - --sharding_parallel_degree 1 \ - --tensor_parallel_degree 1 \ - --sharding "" \ - --to_static ${to_static} \ - --num_hidden_layers 8 \ - --data_parallel_config "gradient_sync_after_accumulate" \ - --pipeline_schedule_mode $pp_mode \ - --virtual_pp_degree $vpp_degree \ - >>${log_path}/$FUNCNAME 2>&1 - - loss=$(grep "global_step: 10," "$case_log_dir/workerlog.0" | grep -oP '(?<=loss: )\d+(\.\d+)?' | awk -F ',' '{print $1}') - if [ "$pp_mode" == "1F1B" ]; then - loss1=($loss) + for pp_mode in "FThenB" "VPP"; do + export FLAGS_enable_pir_api=${use_pir} + export FLAGS_enable_pir_in_executor=${use_pir} + rm -rf $case_out_dir + rm -rf $case_log_dir + rm -rf ${log_path}/$FUNCNAME + if [ "$pp_mode" == "FThenB" ]; then + vpp_degree=1 + else + vpp_degree=2 + fi + + python -u -m paddle.distributed.launch \ + --gpus "0,1,2,3" \ + --log_dir $case_log_dir \ + run_pretrain_auto.py \ + --model_type "llama" \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir $case_out_dir \ + --split 949,50,1 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --warmup_steps 30 \ + --max_grad_norm 0.0 \ + --learning_rate 3e-05 \ + --min_learning_rate 3e-06 \ + --max_steps $max_step \ + --logging_steps 1 \ + --eval_steps 1000 \ + --save_steps 50000 \ + --continue_training 0 \ + --do_train true \ + --do_eval false \ + --do_predict false \ + --disable_tqdm true \ + --skip_profile_timer true \ + --save_total_limit 2 \ + --device gpu \ + --disable_tqdm true \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --enable_auto_parallel 1 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --per_device_eval_batch_size 2 \ + --recompute false \ + --recompute_use_reentrant true \ + --recompute_granularity full \ + --fp16 0 \ + --fp16_opt_level "O2" \ + --fuse_attention_ffn true \ + --fuse_attention_qkv true \ + --fuse_sequence_parallel_allreduce false \ + --use_flash_attention 0 \ + --use_fused_rope false \ + --use_fused_rms_norm 0 \ + --max_seq_length 2048 \ + --hidden_size 1024 \ + --sep_parallel_degree 1 \ + --sequence_parallel false \ + --pipeline_parallel_degree 4 \ + --sharding_parallel_degree 1 \ + --tensor_parallel_degree 1 \ + --sharding "" \ + --to_static ${to_static} \ + --num_hidden_layers 8 \ + --data_parallel_config "gradient_sync_after_accumulate" \ + --pipeline_schedule_mode $pp_mode \ + --virtual_pp_degree $vpp_degree \ + >>${log_path}/$FUNCNAME 2>&1 + + for step in $(seq 1 $max_step); do + loss=$(grep "global_step: $step," "$case_log_dir/workerlog.0" | grep -oP '(?<=loss: )\d+(\.\d+)?' | awk -F ',' '{print $1}') + if [ "$pp_mode" == "FThenB" ]; then + loss1_array+=($loss) else - loss2=($loss) + loss2_array+=($loss) fi - echo "result: $pp_mode loss=$loss" done - ips=-1 - mem=-1 - ips_base=-1 - mem_base=-1 - check_result $FUNCNAME ${loss1} ${loss2} ${ips_base} ${ips} ${mem_base} ${mem} - fi + + loss=$(grep "global_step: 10," "$case_log_dir/workerlog.0" | grep -oP '(?<=loss: )\d+(\.\d+)?' | awk -F ',' '{print $1}') + if [ "$pp_mode" == "FThenB" ]; then + loss1=($loss) + else + loss2=($loss) + fi + echo "result: $pp_mode loss=$loss" + done + ips=-1 + mem=-1 + ips_base=-1 + mem_base=-1 + for step in $(seq 1 $max_step); do + echo "step=$step fthenb loss: ${loss1_array[$step-1]}, vpp loss: ${loss2_array[$step-1]}" + done + check_result $FUNCNAME ${loss1} ${loss2} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } function llama_align_dygraph_dy2st_pir_auto_pp_bs2_bf16_DP1-MP1-PP4() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -1315,8 +1349,6 @@ function llama_align_dygraph_dy2st_pir_auto_pp_bs2_bf16_DP1-MP1-PP4() { function llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${llama_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -1476,8 +1508,6 @@ function llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1() { function llm_gpt_dygraph_auto_bs8_fp32_DP2() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${gpt_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -1548,8 +1578,6 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2() { function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${gpt_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -1622,8 +1650,6 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2() { function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${gpt_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -1697,8 +1723,6 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { function llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { echo "=========== $FUNCNAME run begin ===========" - export_env - cd ${gpt_case_path} export PYTHONPATH=$root_path/:$PYTHONPATH export FLAGS_call_stack_level=3 export NVIDIA_TF32_OVERRIDE=0 @@ -1760,7 +1784,7 @@ function llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" # loss_base=10.58456802 # note: need to debug - loss_base=10.59941483 + loss_base=10.59941673 ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then @@ -1771,9 +1795,129 @@ function llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { echo "=========== $FUNCNAME run end ===========" } +function llm_gpt_pir_auto_bs4_TP2(){ + echo "=========== $FUNCNAME run begin ===========" + export PYTHONPATH=$root_path/:$PYTHONPATH + export FLAGS_call_stack_level=3 + export NVIDIA_TF32_OVERRIDE=0 + + cd ${llm_gpt_case_path} + + task_name="gpt3_auto_bs4_tp2" + case_out_dir="output/$task_name" + case_log_dir="output/$task_name""_log" + rm -rf $case_out_dir + rm -rf $case_log_dir + + python -u -m paddle.distributed.launch --gpus "0,1" \ + --log_dir $case_log_dir \ + run_pretrain_auto.py \ + --model_name_or_path gpt3-13B-en \ + --tokenizer_name_or_path gpt3-13B-en \ + --input_dir "$gpt_data_path/data" \ + --output_dir "output/$task_name" \ + --split 949,50,1 \ + --max_seq_length 1024 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --sharding "" \ + --tensor_parallel_degree 2 \ + --pipeline_parallel_degree 1 \ + --sequence_parallel 0 \ + --fuse_attention_qkv 0 \ + --use_flash_attention 0 \ + --scale_loss 1024 \ + --learning_rate 0.00001 \ + --min_learning_rate 0.000005 \ + --max_steps 10 \ + --save_steps 50000 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --max_grad_norm 1.0 \ + --logging_steps 1\ + --continue_training 0\ + --dataloader_num_workers 1 \ + --eval_steps 100000 \ + --report_to "visualdl" \ + --disable_tqdm true \ + --recompute 0 \ + --gradient_accumulation_steps 4 \ + --do_train \ + --do_eval \ + --device "gpu" \ + --model_type "gpt" \ + --enable_auto_parallel 1 \ + --to_static 1 \ + --fp16 0 \ + --fp16_opt_level "O2" \ + --num_hidden_layers 4 \ + --intermediate_size 1024 \ + >>${log_path}/$FUNCNAME 2>&1 + echo "=========== $FUNCNAME run end ===========" +} + +function llm_gpt_pir_auto_bs4_TP2_PP2(){ + echo "=========== $FUNCNAME run begin ===========" + export PYTHONPATH=$root_path/:$PYTHONPATH + export FLAGS_call_stack_level=3 + export NVIDIA_TF32_OVERRIDE=0 + + cd ${llm_gpt_case_path} + + task_name="gpt3_auto_bs4_tp2_pp2" + case_out_dir="output/$task_name" + case_log_dir="output/$task_name""_log" + rm -rf $case_out_dir + rm -rf $case_log_dir + + python -u -m paddle.distributed.launch --gpus "0,1,2,3" \ + --log_dir $case_log_dir \ + run_pretrain_auto.py \ + --model_name_or_path gpt3-13B-en \ + --tokenizer_name_or_path gpt3-13B-en \ + --input_dir "$gpt_data_path/data" \ + --output_dir "output/$task_name" \ + --split 949,50,1 \ + --max_seq_length 1024 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --sharding "" \ + --tensor_parallel_degree 2 \ + --pipeline_parallel_degree 2 \ + --sequence_parallel 0 \ + --fuse_attention_qkv 0 \ + --use_flash_attention 0 \ + --scale_loss 1024 \ + --learning_rate 0.00001 \ + --min_learning_rate 0.000005 \ + --max_steps 10 \ + --save_steps 50000 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --max_grad_norm 1.0 \ + --logging_steps 1\ + --continue_training 0\ + --dataloader_num_workers 1 \ + --eval_steps 100000 \ + --report_to "visualdl" \ + --disable_tqdm true \ + --recompute 0 \ + --gradient_accumulation_steps 4 \ + --do_train \ + --do_eval \ + --device "gpu" \ + --model_type "gpt" \ + --enable_auto_parallel 1 \ + --to_static 1 \ + --fp16 0 \ + --fp16_opt_level "O2" \ + --num_hidden_layers 4 \ + --intermediate_size 1024 \ + >>${log_path}/$FUNCNAME 2>&1 + echo "=========== $FUNCNAME run end ===========" +} + function llm_qwen_dygraph_auto_bs1_fp32_DP2() { - export_env - cd ${gpt_case_path} set -x config_json="pretrain_argument_for_ci_auto_dp2.json" @@ -1867,8 +2011,6 @@ EOF } function llm_qwen_dygraph_auto_bs1_fp32_DP2-MP2() { - export_env - cd ${gpt_case_path} set -x config_json="pretrain_argument_for_ci_auto_dp2_mp2.json" @@ -1962,8 +2104,6 @@ EOF } function llm_qwen_dygraph_auto_bs1_fp32_DP2-MP2-PP2() { - export_env - cd ${gpt_case_path} set -x config_json="pretrain_argument_for_ci_auto_dp2_mp2_pp2.json" @@ -2057,8 +2197,6 @@ EOF } function llm_qwen_dygraph_auto_bs1_bf16_DP2-MP2-PP2() { - export_env - cd ${gpt_case_path} set -x config_json="pretrain_argument_for_ci_auto_dp2_mp2_pp2.json" @@ -2306,48 +2444,6 @@ function before_hook_for_llama() { fi } -function restore_func() { - fun_list=$1 - cd ${log_path} || { echo "Failed to enter log_path: $log_path"; return 1; } - if [ -e "functions.txt" ]; then - rm "functions.txt" - echo "Deleted existing functions.txt" - fi - for function in ${fun_list[@]};do - echo "$function" >> functions.txt - done -} - -function restore_llama_case_list_auto_func() { - fun_list=( - llama_dygraph_auto_bs8_fp32_DP2 - llama_dygraph_auto_bs8_fp32_DP2-MP2 - llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2 - llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2 - llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw - llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2 - llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1 - llama_pir_auto_fuse_ffn_attention_qkv_MP2 - llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1 - llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP - llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP - llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1 - llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4 - llama_align_dygraph_dy2st_pir_auto_pp_bs2_bf16_DP1-MP1-PP4 - ) - restore_func $fun_list -} - -function restore_llm_gpt_case_list_auto_func() { - fun_list=( - llm_gpt_dygraph_auto_bs8_fp32_DP2 - llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2 - llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2 - llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2 - ) - restore_func $fun_list -} - @@ -2356,20 +2452,26 @@ if [[ $status = "prepare_case" ]];then export FLAGS_install_deps=$3 export FLAGS_download_data=$4 if [[ $2 = "llama_case_list_auto" ]];then - before_hook_for_llama - restore_llama_case_list_auto_func + before_hook_for_llama + llama_case_list_auto prepare_case elif [[ $2 = "llm_gpt_case_list_auto" ]];then before_hook_for_gpt - restore_llm_gpt_case_list_auto_func + llm_gpt_case_list_auto prepare_case else echo -e "\033[31m ---- Invalid exec_case $2 \033[0m" fi elif [[ $status = "exec_case" ]];then export FLAGS_install_deps=$3 export FLAGS_download_data=$4 + export_env + if [[ $2 =~ "gpt" ]];then + cd ${gpt_case_path} + elif [[ $2 =~ "llama" ]];then + cd ${llama_case_path} + fi $2 else - echo -e "\033[31m ---- Invalid status $status \033[0m" + echo -e "\033[31m ---- Start executing $status \033[0m" export exec_case=$1 export FLAGS_install_deps=$2 export FLAGS_download_data=$3 @@ -2382,5 +2484,5 @@ else else echo -e "\033[31m ---- Invalid exec_case $exec_case \033[0m" fi - $1 + $1 exec_case fi diff --git a/scripts/distribute/ci_case_dy.sh b/scripts/distribute/ci_case_dy.sh index 2d9eaddd5758..13329a52b32c 100644 --- a/scripts/distribute/ci_case_dy.sh +++ b/scripts/distribute/ci_case_dy.sh @@ -57,47 +57,80 @@ function track_case_status() { return 0 } -function gpt_case_list_dygraph(){ - # The test name must have "gpt_" as a prefix, which will - # be used for tracking the execution status of the case. - gpt_preprocess_data - gpt_345M_single - gpt_1.3B_dp - gpt_6.7B_stage2_dp2_sharding4 - gpt_6.7B_stage3_dp2_sharding4 - gpt_6.7B_stage2_sharding8 - gpt_175B_DP1_MP4_PP2 - gpt_175B_DP1_MP4_PP2_sp - gpt_175B_DP1_MP8_PP1 - gpt_175B_DP1_MP8_PP1_sp - gpt_175B_DP1_MP1_PP8 - gpt_generation_345M_single - gpt_generation_345M_hybrid - gpt_345M_mp8_qat - # gpt_export_345M_mp1 - # gpt_export_345M_mp2 - # gpt_export_qat_345M - # gpt_inference_345M_single - # gpt_inference_345M_dp8 - gpt_345M_single_finetune - gpt_eval_WikiText - gpt_eval_LAMBADA - - track_case_status $FUNCNAME "gpt_" +function restore_func() { + fun_list=$1 + cd ${log_path} || { echo "Failed to enter log_path: $log_path"; return 1; } + if [ -e "functions.txt" ]; then + rm "functions.txt" + echo "Deleted existing functions.txt" + fi + for function in ${fun_list[@]};do + echo "$function" >> functions.txt + done } -function llm_gpt_case_list_dygraph() { - # The test name must have "llm_gpt_" as a prefix, which will - # be used for tracking the execution status of the case. - llm_gpt_recompute_bs32_bf16_MP2-SD4-stage1 +function gpt_case_list_dygraph() { + fun_list=( + # The test name must have "gpt_" as a prefix, which will + # be used for tracking the execution status of the case. + gpt_preprocess_data + gpt_345M_single + gpt_1.3B_dp + gpt_6.7B_stage2_dp2_sharding4 + gpt_6.7B_stage3_dp2_sharding4 + gpt_6.7B_stage2_sharding8 + gpt_175B_DP1_MP4_PP2 + gpt_175B_DP1_MP4_PP2_sp + gpt_175B_DP1_MP8_PP1 + gpt_175B_DP1_MP8_PP1_sp + gpt_175B_DP1_MP1_PP8 + gpt_generation_345M_single + gpt_generation_345M_hybrid + gpt_345M_mp8_qat + # gpt_export_345M_mp1 + # gpt_export_345M_mp2 + # gpt_export_qat_345M + # gpt_inference_345M_single + # gpt_inference_345M_dp8 + gpt_345M_single_finetune + gpt_eval_WikiText + gpt_eval_LAMBADA + ) + if [ $1 = "prepare_case" ]; then + restore_func $fun_list + elif [ $1 = "exec_case" ]; then + for fun in "${fun_list[@]}"; do + eval "$fun" + done + track_case_status $FUNCNAME "gpt_" + else + echo -e "\033[31m ---- Invalid status $1 \033[0m" + return 1 + fi +} - track_case_status $FUNCNAME "llm_gpt_" +function llm_gpt_case_list_dygraph() { + fun_list=( + # The test name must have "llm_gpt_" as a prefix, which will + # be used for tracking the execution status of the case. + llm_gpt_recompute_bs32_bf16_MP2-SD4-stage1 + ) + if [ $1 = "prepare_case" ]; then + restore_func $fun_list + elif [ $1 = "exec_case" ]; then + for fun in "${fun_list[@]}"; do + eval "$fun" + done + track_case_status $FUNCNAME "llm_gpt_" + else + echo -e "\033[31m ---- Invalid status $1 \033[0m" + return 1 + fi } ############ case start ############ function gpt_preprocess_data() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python ppfleetx/data/data_tools/gpt/raw_trans_to_json.py \ --input_path ./dataset/wikitext_103_en \ @@ -119,7 +152,6 @@ function gpt_preprocess_data() { function gpt_345M_single() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python tools/train.py \ -c ppfleetx/configs/nlp/gpt/pretrain_gpt_345M_single_card.yaml \ @@ -133,7 +165,6 @@ function gpt_345M_single() { function gpt_1.3B_dp() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" tools/train.py\ -c ppfleetx/configs/nlp/gpt/pretrain_gpt_1.3B_dp8.yaml \ @@ -147,7 +178,6 @@ function gpt_1.3B_dp() { function gpt_6.7B_stage2_dp2_sharding4() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" \ tools/train.py -c ppfleetx/configs/nlp/gpt/pretrain_gpt_6.7B_sharding16.yaml \ @@ -164,7 +194,6 @@ function gpt_6.7B_stage2_dp2_sharding4() { function gpt_6.7B_stage3_dp2_sharding4() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" \ tools/train.py -c ppfleetx/configs/nlp/gpt/pretrain_gpt_6.7B_sharding16.yaml \ @@ -181,7 +210,6 @@ function gpt_6.7B_stage3_dp2_sharding4() { function gpt_6.7B_stage2_sharding8() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" \ tools/train.py -c ppfleetx/configs/nlp/gpt/pretrain_gpt_6.7B_sharding16.yaml \ @@ -198,7 +226,6 @@ function gpt_6.7B_stage2_sharding8() { function gpt_175B_DP1_MP4_PP2() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" tools/train.py\ -c ppfleetx/configs/nlp/gpt/pretrain_gpt_175B_mp8_pp16.yaml \ @@ -215,7 +242,6 @@ function gpt_175B_DP1_MP4_PP2() { function gpt_175B_DP1_MP4_PP2_sp() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" tools/train.py\ -c ppfleetx/configs/nlp/gpt/pretrain_gpt_175B_mp8_pp16.yaml \ @@ -231,7 +257,6 @@ function gpt_175B_DP1_MP4_PP2_sp() { function gpt_175B_DP1_MP8_PP1() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" tools/train.py\ -c ppfleetx/configs/nlp/gpt/pretrain_gpt_175B_mp8_pp16.yaml \ @@ -248,7 +273,6 @@ function gpt_175B_DP1_MP8_PP1() { function gpt_175B_DP1_MP8_PP1_sp() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" tools/train.py\ -c ppfleetx/configs/nlp/gpt/pretrain_gpt_175B_mp8_pp16.yaml \ @@ -263,8 +287,6 @@ function gpt_175B_DP1_MP8_PP1_sp() { } function gpt_175B_DP1_MP1_PP8() { - echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" tools/train.py\ -c ppfleetx/configs/nlp/gpt/pretrain_gpt_175B_mp8_pp16.yaml \ @@ -283,7 +305,6 @@ function gpt_175B_DP1_MP1_PP8() { function gpt_345M_mp8_qat() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" tools/train.py\ -c ppfleetx/configs/nlp/gpt/qat_gpt_345M_mp8.yaml \ @@ -297,7 +318,6 @@ function gpt_345M_mp8_qat() { function gpt_generation_345M_single() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python tasks/gpt/generation.py \ -c ppfleetx/configs/nlp/gpt/generation_gpt_345M_single_card.yaml \ @@ -309,7 +329,6 @@ function gpt_generation_345M_single() { function gpt_generation_345M_hybrid() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python -m paddle.distributed.launch --devices "0" tasks/gpt/generation.py \ -c ppfleetx/configs/nlp/gpt/generation_gpt_345M_dp8.yaml \ @@ -321,7 +340,6 @@ function gpt_generation_345M_hybrid() { function gpt_export_345M_mp1() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} log_dir=log_export rm -rf $log_dir rm -rf output @@ -343,7 +361,6 @@ function gpt_export_345M_mp1() { function gpt_export_345M_mp2() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} log_dir=log_export rm -rf $log_dir rm -rf output @@ -366,7 +383,6 @@ function gpt_export_345M_mp2() { function gpt_export_qat_345M() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} log_dir=log_export rm -rf $log_dir rm -rf output @@ -386,7 +402,6 @@ function gpt_export_qat_345M() { function gpt_inference_345M_single() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log rm -rf output python tools/export.py \ @@ -402,7 +417,6 @@ function gpt_inference_345M_single() { function gpt_inference_345M_dp8() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log rm -rf output python -m paddle.distributed.launch --devices "0" tools/export.py \ @@ -419,7 +433,6 @@ function gpt_inference_345M_dp8() { function gpt_345M_single_finetune() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python ./tools/train.py \ -c ./ppfleetx/configs/nlp/gpt/finetune_gpt_345M_single_card_glue.yaml \ @@ -437,7 +450,6 @@ function gpt_345M_single_finetune() { function gpt_eval_WikiText() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python ./tools/eval.py \ -c ./ppfleetx/configs/nlp/gpt/eval_gpt_345M_single_card.yaml \ @@ -453,7 +465,6 @@ function gpt_eval_WikiText() { function gpt_eval_LAMBADA() { echo "=========== $FUNCNAME run begin ===========" - cd ${gpt_case_path} rm -rf log python ./tools/eval.py \ -c ./ppfleetx/configs/nlp/gpt/eval_gpt_345M_single_card.yaml \ @@ -469,7 +480,6 @@ function gpt_eval_LAMBADA() { function llm_gpt_recompute_bs32_bf16_MP2-SD4-stage1() { echo "=========== $FUNCNAME run begin ===========" - cd ${llm_gpt_case_path} export FLAGS_cudnn_deterministic=1 export FLAGS_embedding_deterministic=1 export PYTHONPATH=$root_path/:$PYTHONPATH @@ -705,53 +715,6 @@ function before_hook_for_llm_gpt() { fi } -function restore_func() { - fun_list=$1 - cd ${log_path} || { echo "Failed to enter log_path: $log_path"; return 1; } - if [ -e "functions.txt" ]; then - rm "functions.txt" - echo "Deleted existing functions.txt" - fi - for function in ${fun_list[@]};do - echo "$function" >> functions.txt - done -} - -function restore_gpt_case_list_dygraph_func() { - fun_list=( - gpt_preprocess_data - gpt_345M_single - gpt_1.3B_dp - gpt_6.7B_stage2_dp2_sharding4 - gpt_6.7B_stage3_dp2_sharding4 - gpt_6.7B_stage2_sharding8 - gpt_175B_DP1_MP4_PP2 - gpt_175B_DP1_MP4_PP2_sp - gpt_175B_DP1_MP8_PP1 - gpt_175B_DP1_MP8_PP1_sp - gpt_175B_DP1_MP1_PP8 - gpt_generation_345M_single - gpt_generation_345M_hybrid - gpt_345M_mp8_qat - # gpt_export_345M_mp1 - # gpt_export_345M_mp2 - # gpt_export_qat_345M - # gpt_inference_345M_single - # gpt_inference_345M_dp8 - gpt_345M_single_finetune - gpt_eval_WikiText - gpt_eval_LAMBADA - ) - restore_func $fun_list -} - -function restore_llm_gpt_case_list_dygraph_func() { - fun_list=( - llm_gpt_recompute_bs32_bf16_MP2-SD4-stage1 - ) - restore_func $fun_list -} - export status=$1 if [[ $status = "prepare_case" ]];then @@ -759,20 +722,24 @@ if [[ $status = "prepare_case" ]];then export FLAGS_download_data=$4 if [[ $2 = "gpt_case_list_dygraph" ]];then before_hook_for_gpt - restore_gpt_case_list_dygraph_func + gpt_case_list_dygraph prepare_case elif [[ $2 = "llm_gpt_case_list_dygraph" ]];then before_hook_for_llm_gpt - restore_llm_gpt_case_list_dygraph_func + llm_gpt_case_list_dygraph prepare_case else echo -e "\033[31m ---- Invalid exec_case $2 \033[0m" fi elif [[ $status = "exec_case" ]];then export FLAGS_install_deps=$3 export FLAGS_download_data=$4 + if [[ $2 =~ "llm_gpt" ]];then + cd ${llm_gpt_case_path} + elif [[ $2 =~ "gpt" ]];then + cd ${gpt_case_path} + fi $2 else echo -e "\033[31m ---- Start executing $1 \033[0m" - export exec_case=$1 export FLAGS_install_deps=$2 export FLAGS_download_data=$3 @@ -787,6 +754,6 @@ else echo -e "\033[31m ---- Invalid exec_case $exec_case \033[0m" fi - $1 + $1 exec_case fi diff --git a/scripts/distribute/run_ci.sh b/scripts/distribute/run_ci.sh index d0b9884f32cc..722173caba56 100644 --- a/scripts/distribute/run_ci.sh +++ b/scripts/distribute/run_ci.sh @@ -20,6 +20,12 @@ mkdir -p /workspace/case_logs export log_path=/workspace/case_logs export case_list=() +galobal_total_count=0 +galobal_success_count=0 +galobal_exit_250_arr=() +galobal_runtime_fail_arr=() +galobal_verification_fail_arr=() + target_lists_for_gpt=( "slm/model_zoo/gpt-3" "llm/auto_parallel/gpt-3" @@ -84,58 +90,58 @@ IS_A100=$(is_a100) #################################### get_diff_TO_case(){ -cd ${nlp_dir} -if [ $IS_A100 -ne 0 ];then - for file_name in `git diff --numstat upstream/${AGILE_COMPILE_BRANCH} |awk '{print $NF}'`;do - arr_file_name=(${file_name//// }) - dir1=${arr_file_name[0]} - dir2=${arr_file_name[1]} - dir3=${arr_file_name[2]} - dir4=${arr_file_name[3]} - file_item=$dir1/$dir2/$dir3/$dir4 - echo "file_name:"${file_name}, "path:"${file_item} - if [ ! -f ${file_name} ];then # 针对pr删掉文件 - continue - elif [[ ${file_name##*.} == "md" ]] || [[ ${file_name##*.} == "rst" ]] || [[ ${dir1} == "docs" ]];then - continue - else - for ((i=0; i<${#target_lists_for_gpt[@]}; i++)); do - if [[ ! ${dir3} =~ "benchmarks" ]] && [[ ${file_item} == *${target_lists_for_gpt[i]}* ]];then - case_list[${#case_list[*]}]=gpt-3_auto - case_list[${#case_list[*]}]=gpt-3_dygraph - fi - done - for ((i=0; i<${#target_lists_for_llama[@]}; i++)); do - if [[ ${file_item} == *${target_lists_for_llama[i]}* ]];then - case_list[${#case_list[*]}]=llama_auto - fi - done - fi - done -else - case_list[${#case_list[*]}]=gpt-3_auto - case_list[${#case_list[*]}]=llama_auto - for file_name in `git diff --numstat upstream/${AGILE_COMPILE_BRANCH} |awk '{print $NF}'`;do - arr_file_name=(${file_name//// }) - dir1=${arr_file_name[0]} - dir2=${arr_file_name[1]} - dir3=${arr_file_name[2]} - dir4=${arr_file_name[3]} - file_item=$dir1/$dir2/$dir3/$dir4 - echo "file_name:"${file_name}, "path:"${file_item} - if [ ! -f ${file_name} ];then # 针对pr删掉文件 - continue - elif [[ ${file_name##*.} == "md" ]] || [[ ${file_name##*.} == "rst" ]] || [[ ${dir1} == "docs" ]];then - continue - else - for ((i=0; i<${#target_lists_for_gpt[@]}; i++)); do - if [[ ! ${dir3} =~ "benchmarks" ]] && [[ ${file_item} == *${target_lists_for_gpt[i]}* ]];then - case_list[${#case_list[*]}]=gpt-3_dygraph - fi - done - fi - done -fi + cd ${nlp_dir} + if [ $IS_A100 -ne 0 ];then + for file_name in `git diff --numstat upstream/${AGILE_COMPILE_BRANCH} |awk '{print $NF}'`;do + arr_file_name=(${file_name//// }) + dir1=${arr_file_name[0]} + dir2=${arr_file_name[1]} + dir3=${arr_file_name[2]} + dir4=${arr_file_name[3]} + file_item=$dir1/$dir2/$dir3/$dir4 + echo "file_name:"${file_name}, "path:"${file_item} + if [ ! -f ${file_name} ];then # 针对pr删掉文件 + continue + elif [[ ${file_name##*.} == "md" ]] || [[ ${file_name##*.} == "rst" ]] || [[ ${dir1} == "docs" ]];then + continue + else + for ((i=0; i<${#target_lists_for_gpt[@]}; i++)); do + if [[ ! ${dir3} =~ "benchmarks" ]] && [[ ${file_item} == *${target_lists_for_gpt[i]}* ]];then + case_list[${#case_list[*]}]=gpt-3_auto + case_list[${#case_list[*]}]=gpt-3_dygraph + fi + done + for ((i=0; i<${#target_lists_for_llama[@]}; i++)); do + if [[ ${file_item} == *${target_lists_for_llama[i]}* ]];then + case_list[${#case_list[*]}]=llama_auto + fi + done + fi + done + else + case_list[${#case_list[*]}]=gpt-3_auto + case_list[${#case_list[*]}]=llama_auto + for file_name in `git diff --numstat upstream/${AGILE_COMPILE_BRANCH} |awk '{print $NF}'`;do + arr_file_name=(${file_name//// }) + dir1=${arr_file_name[0]} + dir2=${arr_file_name[1]} + dir3=${arr_file_name[2]} + dir4=${arr_file_name[3]} + file_item=$dir1/$dir2/$dir3/$dir4 + echo "file_name:"${file_name}, "path:"${file_item} + if [ ! -f ${file_name} ];then # 针对pr删掉文件 + continue + elif [[ ${file_name##*.} == "md" ]] || [[ ${file_name##*.} == "rst" ]] || [[ ${dir1} == "docs" ]];then + continue + else + for ((i=0; i<${#target_lists_for_gpt[@]}; i++)); do + if [[ ! ${dir3} =~ "benchmarks" ]] && [[ ${file_item} == *${target_lists_for_gpt[i]}* ]];then + case_list[${#case_list[*]}]=gpt-3_dygraph + fi + done + fi + done + fi } #################################### function contain_case(){ @@ -157,31 +163,36 @@ function execute_func_list(){ exit_250_count=0 while IFS= read -r func_name; do let total_count++ - excute_num=1 + let galobal_total_count++ + execute_num=1 while true; do bash $1 exec_case $func_name $FLAGS_install_deps $FLAGS_download_data result=$? if [ $result -eq 0 ]; then echo -e "\033[32m test success!" let success_count++ + let galobal_success_count++ elif [ $result -eq 2 ]; then echo -e "\033[31m verification failed!" let verification_fail_count++ + galobal_verification_fail_arr+=("$func_name") elif [ $result -eq 250 ]; then - if [ $excute_num -eq 1 ]; then + if [ $execute_num -eq 1 ]; then echo -e "\033[31m fist time execute failed, try again!" - let excute_num++ + let execute_num++ continue else echo -e "\033[31m second time execute failed, exit!" let exit_250_count++ + galobal_exit_250_arr+=("$func_name") fi else echo "test failed!" - mv ${log_path}/$func_name ${log_path}/$func_name_FAIL.log + mv ${log_path}/$func_name ${log_path}/${func_name}_FAIL.log echo -e "\033[31m ${log_path}/$func_name_FAIL \033" - tail -15 ${log_path}/$func_name_FAIL.log - let runtime_fail_count++ + tail -15 ${log_path}/${func_name}_FAIL.log + let runtime_fail_count++ + galobal_runtime_fail_arr+=("$func_name") fi break done @@ -193,36 +204,7 @@ function execute_func_list(){ echo -e "\033[31m $(printf '\t') verification fail tests : $verification_fail_count \033" echo -e "\033[31m $(printf '\t') exit 250 tests(intermittent issue) : $exit_250_count \033" } -#################################### -function track_case_status() { - local case_name="$1" - local prefix="$2" - local original_path - - original_path=$(pwd) - cd ${log_path} || { echo "Failed to enter log_path: $log_path"; return 1; } - - total_count=$(ls -1 "$prefix"* 2>/dev/null | grep -Ev 'result\.log|functions\.txt' | wc -l) - run_fail_count=$(ls -1 "$prefix"*_FAIL* 2>/dev/null | wc -l) - loss_fail_count=$(grep 'check failed! ' result.log | awk -v prefix="$prefix" '{if ($2 ~ "^" prefix) print $2}'| wc -l) - - echo -e "\033[31m ---- $case_name total tests : $total_count \033" - if [ $run_fail_count -eq 0 ] && [ $loss_fail_count -eq 0 ]; then - echo -e "\033[32m ---- all cases Success \033" - else - if [[ $run_fail_count -ne 0 ]] ; then - echo -e "\033[31m ---- $case_name runtime failed test : $run_fail_count \033" - ls -1 "$prefix"*_FAIL* 2>/dev/null | awk -v OFS="\t" '{print "\t" $0 "(failed)"}' - fi - if [[ $loss_fail_count -ne 0 ]] ; then - echo -e "\033[31m ---- $case_name verification failed test : $loss_fail_count \033" - grep 'check failed! ' result.log | awk -v prefix="$prefix" 'BEGIN {OFS="\t"} {if ($2 ~ "^" prefix) print "\t" $2 "(failed)"}' - fi - return 2 - fi - cd "$original_path" || { echo "Failed to return to original path: $original_path"; return 1; } - return 0 -} + #################################### get_diff_TO_case # 获取待执行case列表 case_list=($(awk -v RS=' ' '!a[$1]++' <<< ${case_list[*]})) # 去重并将结果存储回原列表 @@ -270,8 +252,28 @@ if [[ ${#case_list[*]} -ne 0 ]];then fi echo -e "\033[31m ---- end run case \033" - track_case_status $FUNCNAME "" - EXCODE=$? + echo -e "\033[31m ---- total tests : $galobal_total_count \033" + if [ ${#galobal_exit_250_arr[@]} -ne 0 ]; then + echo -e "\033[32m ---- exit 250 test : ${#galobal_exit_250_arr[@]} \033" + for case in "${galobal_exit_250_arr[@]}"; do + echo -e "\t$case(exit 250)" + done + fi + + if [ ${#galobal_runtime_fail_arr[@]} -eq 0 ] && [ ${#galobal_verification_fail_arr[@]} -eq 0 ]; then + echo -e "\033[32m ---- all cases Success \033" + EXCODE=0 + else + echo -e "\033[32m ---- runtime failed test : ${#galobal_runtime_fail_arr[@]} \033" + for case in "${galobal_runtime_fail_arr[@]}"; do + echo -e "\t$case(failed)" + done + echo -e "\033[32m ---- verification failed test : ${#galobal_verification_fail_arr[@]} \033" + for case in "${galobal_verification_fail_arr[@]}"; do + echo -e "\t$case(failed)" + done + EXCODE=1 + fi else echo -e "\033[32m Changed Not CI case, Skips \033" EXCODE=0 diff --git a/tests/fixtures/llm/finetune.yaml b/tests/fixtures/llm/finetune.yaml index 7e79f9b441a8..abe9aad5d39e 100644 --- a/tests/fixtures/llm/finetune.yaml +++ b/tests/fixtures/llm/finetune.yaml @@ -63,4 +63,53 @@ inference-infer: dtype: float16 batch_size: 2 decode_strategy: greedy_search - max_length: 20 \ No newline at end of file + max_length: 20 + +ckpt_quant: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 4 + gradient_accumulation_steps: 4 + per_device_eval_batch_size: 8 + eval_accumulation_steps: 16 + num_train_epochs: 3 + learning_rate: 3e-05 + warmup_steps: 30 + logging_steps: 1 + evaluation_strategy: "steps" + save_strategy: "steps" + save_steps: 1 + max_steps: 1 + src_length: 1024 + max_length: 2048 + fp16: true + fp16_opt_level: "O2" + ckpt_quant_stage: "O1" + do_train: true + do_eval: true + use_flash_attention: true + unified_checkpoint: true + unified_checkpoint_config: "async_save remove_master_weight" + disable_tqdm: true + load_best_model_at_end: true + eval_with_do_generation: false + metric_for_best_model: "accuracy" + recompute: true + tensor_parallel_degree: 2 + pipeline_parallel_degree: 1 + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama + chatglm: + model_name_or_path: __internal_testing__/tiny-fused-chatglm + chatglm2: + model_name_or_path: __internal_testing__/tiny-fused-chatglm2 + bloom: + model_name_or_path: __internal_testing__/tiny-fused-bloom + qwen: + model_name_or_path: __internal_testing__/tiny-fused-qwen + baichuan: + model_name_or_path: __internal_testing__/tiny-fused-baichuan + qwen2: + model_name_or_path: __internal_testing__/tiny-random-qwen2 diff --git a/tests/fixtures/llm/finetune_pose.yaml b/tests/fixtures/llm/finetune_pose.yaml new file mode 100644 index 000000000000..24af2a366bda --- /dev/null +++ b/tests/fixtures/llm/finetune_pose.yaml @@ -0,0 +1,71 @@ +finetune: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 4 + gradient_accumulation_steps: 4 + per_device_eval_batch_size: 8 + eval_accumulation_steps: 16 + num_train_epochs: 3 + learning_rate: 3e-05 + warmup_steps: 30 + logging_steps: 1 + evaluation_strategy: "epoch" + save_strategy: "epoch" + src_length: 1024 + max_length: 2048 + use_long_sequence_strategies: true + rope_scaling_factor: 8 + strategy_type: "embedding_strategies" + strategy_name: "YaRNScalingRotaryEmbedding" + autoregressive: true + use_pose_convert: true + fp16: true + fp16_opt_level: "O2" + do_train: true + do_eval: true + use_flash_attention: false + disable_tqdm: true + load_best_model_at_end: true + eval_with_do_generation: false + metric_for_best_model: "accuracy" + recompute: true + save_total_limit: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + ignore_save_lr_and_optim: 1 + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama + # chatglm: + # model_name_or_path: __internal_testing__/tiny-fused-chatglm + # chatglm2: + # model_name_or_path: __internal_testing__/tiny-fused-chatglm2 + # bloom: + # model_name_or_path: __internal_testing__/tiny-fused-bloom + # qwen: + # model_name_or_path: __internal_testing__/tiny-fused-qwen + # baichuan: + # model_name_or_path: __internal_testing__/tiny-fused-baichuan + # qwen2: + # model_name_or_path: __internal_testing__/tiny-random-qwen2 + +inference-predict: + default: + mode: dynamic + max_length: 20 + batch_size: 2 + decode_strategy: greedy_search + dtype: float16 + +inference-to-static: + default: + dtype: float16 + +inference-infer: + default: + mode: static + dtype: float16 + batch_size: 2 + decode_strategy: greedy_search + max_length: 20 \ No newline at end of file diff --git a/tests/fixtures/llm/mos_lora.yaml b/tests/fixtures/llm/mos_lora.yaml new file mode 100644 index 000000000000..908a7bae5508 --- /dev/null +++ b/tests/fixtures/llm/mos_lora.yaml @@ -0,0 +1,113 @@ +lora: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 4 + gradient_accumulation_steps: 4 + per_device_eval_batch_size: 8 + eval_accumulation_steps: 16 + num_train_epochs: 3 + learning_rate: 3e-04 + warmup_steps: 30 + logging_steps: 1 + evaluation_strategy: "epoch" + save_strategy: "epoch" + src_length: 1024 + max_length: 2048 + fp16: true + fp16_opt_level: "O2" + do_train: true + do_eval: true + disable_tqdm: true + load_best_model_at_end: true + eval_with_do_generation: false + metric_for_best_model: "accuracy" + recompute: true + save_total_limit: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + lora: true + lora_use_mixer: true + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama + chatglm: + model_name_or_path: __internal_testing__/tiny-fused-chatglm + chatglm2: + model_name_or_path: __internal_testing__/tiny-fused-chatglm2 + bloom: + model_name_or_path: __internal_testing__/tiny-fused-bloom + qwen: + model_name_or_path: __internal_testing__/tiny-fused-qwen + qwen2: + model_name_or_path: __internal_testing__/tiny-random-qwen2 + qwen2moe: + model_name_or_path: __internal_testing__/tiny-random-qwen2moe + baichuan: + model_name_or_path: __internal_testing__/tiny-fused-baichuan + +rslora_plus: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 4 + gradient_accumulation_steps: 4 + per_device_eval_batch_size: 8 + eval_accumulation_steps: 16 + num_train_epochs: 3 + learning_rate: 3e-04 + warmup_steps: 30 + logging_steps: 1 + evaluation_strategy: "epoch" + save_strategy: "epoch" + src_length: 1024 + max_length: 2048 + fp16: true + fp16_opt_level: "O2" + do_train: true + do_eval: true + disable_tqdm: true + load_best_model_at_end: true + eval_with_do_generation: false + metric_for_best_model: "accuracy" + recompute: true + save_total_limit: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + lora: true + lora_plus_scale: 4 + rslora: true + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama + chatglm: + model_name_or_path: __internal_testing__/tiny-fused-chatglm + chatglm2: + model_name_or_path: __internal_testing__/tiny-fused-chatglm2 + bloom: + model_name_or_path: __internal_testing__/tiny-fused-bloom + qwen: + model_name_or_path: __internal_testing__/tiny-fused-qwen + baichuan: + model_name_or_path: __internal_testing__/tiny-fused-baichuan + +inference-predict: + default: + mode: dynamic + max_length: 20 + batch_size: 2 + decode_strategy: greedy_search + dtype: float16 + +inference-to-static: + default: + dtype: float16 + max_length: 20 + +inference-infer: + default: + mode: static + dtype: float16 + batch_size: 2 + decode_strategy: greedy_search + max_length: 20 \ No newline at end of file diff --git a/tests/fixtures/llm/reft.yaml b/tests/fixtures/llm/reft.yaml new file mode 100644 index 000000000000..4b276fc4271d --- /dev/null +++ b/tests/fixtures/llm/reft.yaml @@ -0,0 +1,68 @@ +reft: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 2 + gradient_accumulation_steps: 2 + per_device_eval_batch_size: 4 + num_train_epochs: 2 + learning_rate: 3e-04 + warmup_ratio: 0.01 + logging_steps: 1 + remove_unused_columns: false + evaluation_strategy: "no" + metric_for_best_model: "no" + save_strategy: "epoch" + src_length: 1024 + max_length: 512 + autoregressive: false + bf16: true + fp16_opt_level: "O2" + do_train: true + do_eval: false + disable_tqdm: true + load_best_model_at_end: false + eval_with_do_generation: true + save_total_limit: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + zero_padding: false + reft: true + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama + chatglm: + model_name_or_path: __internal_testing__/tiny-fused-chatglm + chatglm2: + model_name_or_path: __internal_testing__/tiny-fused-chatglm2 + bloom: + model_name_or_path: __internal_testing__/tiny-fused-bloom + qwen: + model_name_or_path: __internal_testing__/tiny-fused-qwen + qwen2: + model_name_or_path: __internal_testing__/tiny-random-qwen2 + qwen2moe: + model_name_or_path: __internal_testing__/tiny-random-qwen2moe + baichuan: + model_name_or_path: __internal_testing__/tiny-fused-baichuan + + +inference-predict: + default: + mode: dynamic + max_length: 20 + batch_size: 2 + decode_strategy: greedy_search + dtype: float16 + +inference-to-static: + default: + dtype: float16 + +inference-infer: + default: + mode: static + dtype: float16 + batch_size: 2 + decode_strategy: greedy_search + max_length: 20 \ No newline at end of file diff --git a/tests/llm/test_finetune.py b/tests/llm/test_finetune.py index d1fda6e67b5f..76c88a3c6a94 100644 --- a/tests/llm/test_finetune.py +++ b/tests/llm/test_finetune.py @@ -18,6 +18,7 @@ from parameterized import parameterized_class +from tests.parallel_launch import TestMultipleGpus from tests.testing_utils import argv_context_guard, load_test_config from .testing_utils import LLMTest @@ -63,3 +64,38 @@ def test_finetune(self): self.run_predictor({"inference_model": True}) self.run_predictor({"inference_model": False}) + + +@parameterized_class( + ["model_dir"], + [ + ["llama"], + ], +) +class CkptQuantTest(LLMTest, TestMultipleGpus): + config_path: str = "./tests/fixtures/llm/finetune.yaml" + model_dir: str = None + + def setUp(self) -> None: + LLMTest.setUp(self) + + sys.path.insert(0, self.model_dir) + self.run_sft = "llm/run_finetune.py" + + def tearDown(self) -> None: + LLMTest.tearDown(self) + + def test_ckpt_quant(self): + finetune_config = load_test_config(self.config_path, "ckpt_quant", self.model_dir) + + finetune_config["dataset_name_or_path"] = self.data_dir + finetune_config["output_dir"] = self.output_dir + + self.runfirst(finetune_config) + self.rerun(finetune_config) + + def runfirst(self, train_args): + self.run_n1c2(self.run_sft, **train_args) + + def rerun(self, train_args): + self.run_n1c2(self.run_sft, **train_args) diff --git a/tests/llm/test_mos_lora.py b/tests/llm/test_mos_lora.py new file mode 100644 index 000000000000..f4847f3ab929 --- /dev/null +++ b/tests/llm/test_mos_lora.py @@ -0,0 +1,162 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import sys +import unittest + +import paddle +from parameterized import parameterized_class + +from tests.testing_utils import argv_context_guard, load_test_config + +from .testing_utils import LLMTest + + +@parameterized_class( + ["model_dir"], + [ + ["llama"], + # ["chatglm"], @skip("Skip and wait to fix.") + # ["chatglm2"], @skip("Skip and wait to fix.") + # ["bloom"], @skip("Skip and wait to fix.") + ["qwen"], + ["baichuan"], + ], +) +class MosLoraTest(LLMTest, unittest.TestCase): + config_path: str = "./tests/fixtures/llm/mos_lora.yaml" + model_dir: str = None + + def setUp(self) -> None: + LLMTest.setUp(self) + + self.model_codes_dir = os.path.join(self.root_path, self.model_dir) + sys.path.insert(0, self.model_codes_dir) + + def tearDown(self) -> None: + LLMTest.tearDown(self) + sys.path.remove(self.model_codes_dir) + + def test_mos_lora(self): + self.disable_static() + paddle.set_default_dtype("float32") + + lora_config = load_test_config(self.config_path, "lora", self.model_dir) + lora_config["output_dir"] = self.output_dir + lora_config["dataset_name_or_path"] = self.data_dir + # use_quick_lora + lora_config["use_quick_lora"] = True + + with argv_context_guard(lora_config): + from run_finetune import main + + main() + + # merge weights + merge_lora_weights_config = { + "lora_path": lora_config["output_dir"], + "model_name_or_path": lora_config["model_name_or_path"], + "output_path": lora_config["output_dir"], + } + with argv_context_guard(merge_lora_weights_config): + from tools.merge_lora_params import merge + + merge() + + # TODO(wj-Mcat): disable chatglm2 test temporarily + if self.model_dir not in ["qwen", "baichuan", "chatglm2", "llama"]: + self.run_predictor({"inference_model": True}) + + self.run_predictor({"inference_model": False}) + + +# @parameterized_class( +# ["model_dir"], +# [ +# ["llama"], +# ["qwen"], +# ], +# ) +# class LoraChatTemplateTest(LLMTest, unittest.TestCase): +# config_path: str = "./tests/fixtures/llm/lora.yaml" +# model_dir: str = None + +# def setUp(self) -> None: +# LLMTest.setUp(self) + +# self.model_codes_dir = os.path.join(self.root_path, self.model_dir) +# sys.path.insert(0, self.model_codes_dir) + +# self.rounds_data_dir = tempfile.mkdtemp() +# shutil.copyfile( +# os.path.join(self.data_dir, "train.json"), +# os.path.join(self.rounds_data_dir, "train.json"), +# ) +# shutil.copyfile( +# os.path.join(self.data_dir, "dev.json"), +# os.path.join(self.rounds_data_dir, "dev.json"), +# ) +# self.create_multi_turns_data(os.path.join(self.rounds_data_dir, "train.json")) +# self.create_multi_turns_data(os.path.join(self.rounds_data_dir, "dev.json")) + +# def create_multi_turns_data(self, file: str): +# result = [] +# with open(file, "r", encoding="utf-8") as f: +# for line in f: +# data = json.loads(line) +# data["src"] = [data["src"]] * 3 +# data["tgt"] = [data["tgt"]] * 3 +# result.append(data) + +# with open(file, "w", encoding="utf-8") as f: +# for data in result: +# line = json.dumps(line) +# f.write(line + "\n") + +# def tearDown(self) -> None: +# LLMTest.tearDown(self) +# sys.path.remove(self.model_codes_dir) + +# def test_lora(self): +# self.disable_static() +# paddle.set_default_dtype("float32") + +# lora_config = load_test_config(self.config_path, "lora", self.model_dir) + +# lora_config["dataset_name_or_path"] = self.rounds_data_dir +# lora_config["chat_template"] = "./tests/fixtures/chat_template.json" +# lora_config["output_dir"] = self.output_dir + +# with argv_context_guard(lora_config): +# from run_finetune import main + +# main() + +# # merge weights +# merge_lora_weights_config = { +# "model_name_or_path": lora_config["model_name_or_path"], +# "lora_path": lora_config["output_dir"], +# "merge_model_path": lora_config["output_dir"], +# } +# with argv_context_guard(merge_lora_weights_config): +# from tools.merge_lora_params import merge + +# merge() + +# if self.model_dir not in ["chatglm2", "qwen", "baichuan"]: +# self.run_predictor({"inference_model": True}) + +# self.run_predictor({"inference_model": False}) diff --git a/tests/llm/test_ptq.py b/tests/llm/test_ptq.py index dfbe2417a500..3d7a727bea89 100644 --- a/tests/llm/test_ptq.py +++ b/tests/llm/test_ptq.py @@ -47,7 +47,7 @@ def test_ptq(self): finetune_config["output_dir"] = self.output_dir with argv_context_guard(finetune_config): - from run_finetune import main + from run_quantization import main main() @@ -60,7 +60,7 @@ def test_blha(self): finetune_config["output_dir"] = self.output_dir with argv_context_guard(finetune_config): - from run_finetune import main + from run_quantization import main main() @@ -75,7 +75,7 @@ def test_ptq_smooth(self): finetune_config["smooth"] = True with argv_context_guard(finetune_config): - from run_finetune import main + from run_quantization import main main() @@ -91,7 +91,7 @@ def test_ptq_shift(self): finetune_config["shift"] = True with argv_context_guard(finetune_config): - from run_finetune import main + from run_quantization import main main() diff --git a/tests/llm/test_reft.py b/tests/llm/test_reft.py new file mode 100644 index 000000000000..34b572027741 --- /dev/null +++ b/tests/llm/test_reft.py @@ -0,0 +1,72 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import sys +import unittest + +import paddle +from parameterized import parameterized_class + +from tests.testing_utils import argv_context_guard, load_test_config + +from .testing_utils import LLMTest + + +@parameterized_class( + ["model_dir"], + [ + ["llama"], + ["baichuan"], + ], +) +class ReftTest(LLMTest, unittest.TestCase): + config_path: str = "./tests/fixtures/llm/reft.yaml" + model_dir: str = None + + def setUp(self) -> None: + LLMTest.setUp(self) + + self.model_codes_dir = os.path.join(self.root_path, self.model_dir) + sys.path.insert(0, self.model_codes_dir) + + def tearDown(self) -> None: + LLMTest.tearDown(self) + sys.path.remove(self.model_codes_dir) + + def test_reft(self): + self.disable_static() + paddle.set_default_dtype("float32") + + reft_config = load_test_config(self.config_path, "reft", self.model_dir) + reft_config["output_dir"] = self.output_dir + reft_config["dataset_name_or_path"] = self.data_dir + + with argv_context_guard(reft_config): + from run_finetune import main + + main() + + perdict_params = { + "model_name_or_path": reft_config["model_name_or_path"], + "reft_path": self.output_dir, + "dataset_name_or_path": self.data_dir, + "batch_size": 8, + } + self.run_reft_predictor(perdict_params) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/llm/test_speculate_decoding.py b/tests/llm/test_speculate_decoding.py new file mode 100644 index 000000000000..05d4652caf88 --- /dev/null +++ b/tests/llm/test_speculate_decoding.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import unittest + +import paddle + +from .testing_utils import LLMTest, argv_context_guard + + +class SpeculatePredictorTest(LLMTest, unittest.TestCase): + model_name_or_path: str = "__internal_testing__/tiny-random-llama-hd128" + + def setUp(self) -> None: + super().setUp() + paddle.set_default_dtype("bfloat16") + self.config_params = { + "model_name_or_path": self.model_name_or_path, + "mode": "dynamic", + "dtype": "bfloat16", + "max_length": 48, + "inference_model": 1, + "speculate_method": None, + } + + def run_speculate_predictor(self, speculate_params): + """ + base speculative decoding forward test. + """ + predict_config = self.config_params + predict_config.update(speculate_params) + + # dynamic forward + self.disable_static() + with argv_context_guard(predict_config): + from predict.predictor import predict + + predict() + + # to static + self.disable_static() + predict_config["output_path"] = self.output_dir + with argv_context_guard(predict_config): + from predict.export_model import main + + main() + + # static forward + self.disable_static() + + predict_config["mode"] = "static" + predict_config["model_name_or_path"] = self.output_dir + + predict_config.pop("output_path") + with argv_context_guard(predict_config): + from predict.predictor import predict + + predict() + + def test_inference_with_reference(self): + """ + test inference with reference method. + """ + speculate_params = { + "speculate_method": "inference_with_reference", + "speculate_max_draft_token_num": 5, + "speculate_max_ngram_size": 2, + } + self.run_speculate_predictor(speculate_params) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/llm/testing_utils.py b/tests/llm/testing_utils.py index 3684ec243576..1b3117b7971c 100644 --- a/tests/llm/testing_utils.py +++ b/tests/llm/testing_utils.py @@ -108,3 +108,10 @@ def run_predictor(self, config_params=None): for predict_item, infer_item in zip(predict_result, infer_result): self.assertEqual(predict_item, infer_item) + + def run_reft_predictor(self, predict_config=None): + predict_config["output_file"] = os.path.join(self.output_dir, "predict.json") + with argv_context_guard(predict_config): + from predict.reft_predictor import main + + main() diff --git a/tests/peft/test_mos_lora.py b/tests/peft/test_mos_lora.py new file mode 100644 index 000000000000..d2030469dd2b --- /dev/null +++ b/tests/peft/test_mos_lora.py @@ -0,0 +1,215 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import re +import unittest +from tempfile import TemporaryDirectory + +import numpy as np +import paddle +from parameterized import parameterized + +from paddlenlp.peft.lora import LoRAConfig, LoRALinear, LoRAModel +from paddlenlp.transformers import AutoModel, BertModel + + +class TestMosLoraLayer(unittest.TestCase): + def test_r_raise_exception(self): + with self.assertRaises(ValueError): + LoRALinear(in_features=16, out_features=8, r=0, lora_dropout=0.1, lora_alpha=8, lora_use_mixer=True) + + def test_forward(self): + lora_layer = LoRALinear( + in_features=16, out_features=8, r=4, lora_dropout=0.1, lora_alpha=8, lora_use_mixer=True + ) + x = paddle.randn([2, 4, 16], "float32") + output = lora_layer(x) + self.assertFalse(lora_layer.lora_A.stop_gradient) + self.assertFalse(lora_layer.lora_B.stop_gradient) + self.assertTrue(lora_layer.weight.stop_gradient) + self.assertFalse(lora_layer.bias.stop_gradient) + self.assertEqual(output.shape, [2, 4, 8]) + + def test_train_eval(self): + x = paddle.randn([2, 4, 16], "float32") + lora_layer = LoRALinear(in_features=16, out_features=8, r=4, lora_use_mixer=True) + lora_layer.train() + train_result = lora_layer(x) + train_weight = copy.deepcopy(lora_layer.weight) # deep copy since this is a pointer + lora_layer.eval() + eval_result = lora_layer(x) + eval_weight = lora_layer.weight + self.assertTrue(paddle.allclose(train_result, eval_result)) + self.assertTrue(paddle.allclose(train_weight, eval_weight)) + + def test_save_load(self): + with TemporaryDirectory() as tempdir: + lora_layer = LoRALinear(in_features=16, out_features=8, r=4, lora_use_mixer=True) + weights_path = os.path.join(tempdir, "model.pdparams") + paddle.save(lora_layer.state_dict(), weights_path) + new_lora_layer = LoRALinear(in_features=16, out_features=8, r=4, lora_use_mixer=True) + state_dict = paddle.load(weights_path) + new_lora_layer.set_dict(state_dict) + x = paddle.randn([2, 4, 16], "float32") + self.assertTrue(paddle.allclose(new_lora_layer(x), lora_layer(x))) + + def test_load_regular_linear(self): + with TemporaryDirectory() as tempdir: + regular_linear = paddle.nn.Linear(in_features=16, out_features=8) + weights_path = os.path.join(tempdir, "model.pdparams") + paddle.save(regular_linear.state_dict(), weights_path) + state_dict = paddle.load(weights_path) + # should be identical to regular linear + lora_layer_r8 = LoRALinear(in_features=16, out_features=8, r=8, lora_use_mixer=True) + lora_layer_r4 = LoRALinear(in_features=16, out_features=8, r=4, lora_use_mixer=True) + lora_layer_r8.set_dict(state_dict) + lora_layer_r4.set_dict(state_dict) + x = paddle.randn([2, 4, 16], "float32") + self.assertTrue(paddle.allclose(lora_layer_r8(x), regular_linear(x))) + self.assertTrue(paddle.allclose(lora_layer_r4(x), regular_linear(x))) + + def test_merge(self): + lora_layer_r8 = LoRALinear(in_features=16, out_features=8, r=8, lora_use_mixer=True) + lora_layer_r8.merge() + + def test_unmerge(self): + lora_layer_r8 = LoRALinear(in_features=16, out_features=8, r=8, lora_use_mixer=True) + lora_layer_r8.merged = True + lora_layer_r8.unmerge() + lora_layer_r8 = LoRALinear(in_features=16, out_features=8, r=8) + lora_layer_r8.merged = True + lora_layer_r8.unmerge() + + +class TestMosLoraModel(unittest.TestCase): + def test_lora_model_restore(self): + lora_config = LoRAConfig( + target_modules=[".*q_proj.*", ".*v_proj.*"], + r=4, + lora_alpha=8, + enable_lora_list=[None, [True, False]], + head_dim=2, + lora_use_mixer=True, + ) + model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") + input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 20])) + model.eval() + original_results_1 = model(input_ids) + lora_model = LoRAModel(model, lora_config) + restored_model = lora_model.restore_original_model() + restored_model.eval() + original_results_2 = restored_model(input_ids) + self.assertIsNotNone(original_results_1) + self.assertIsNotNone(original_results_2) + self.assertIsInstance(restored_model, BertModel) + self.assertTrue(paddle.allclose(original_results_1[0], original_results_2[0])) + + def test_parallel_support(self): + lora_config = LoRAConfig( + target_modules=[".*q_proj.*", ".*v_proj.*"], + r=4, + lora_alpha=8, + enable_lora_list=[None, [True, False]], + head_dim=2, + lora_use_mixer=True, + tensor_parallel_degree=2, + ) + model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") + model.eval() + with self.assertRaises(NotImplementedError): + LoRAModel(model, lora_config) + + @parameterized.expand([(None,), ("all",), ("lora",)]) + def test_lora_model_constructor(self, bias): + lora_config = LoRAConfig( + target_modules=[".*q_proj.*", ".*v_proj.*"], + r=4, + lora_alpha=8, + enable_lora_list=[None, [True, False]], + trainable_bias=bias, + head_dim=2, + lora_use_mixer=True, + ) + # turn off plm dropout for to test train vs test + model = AutoModel.from_pretrained( + "__internal_testing__/tiny-random-bert", hidden_dropout_prob=0, attention_probs_dropout_prob=0 + ) + lora_model = LoRAModel(model, lora_config) + lora_model.mark_only_lora_as_trainable() + for name, weight in lora_model.state_dict().items(): + if any([re.fullmatch(target_module, name) for target_module in lora_config.target_modules]): + if "lora" in name: + self.assertFalse(weight.stop_gradient) + elif "bias" in name and bias in ["lora", "all"]: + self.assertFalse(weight.stop_gradient) + else: + self.assertTrue(weight.stop_gradient) + else: + if "bias" in name and bias == "all": + self.assertFalse(weight.stop_gradient) + else: + self.assertTrue(weight.stop_gradient) + input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 20])) + lora_model.train() + train_forward_results = lora_model(input_ids) + self.assertIsNotNone(train_forward_results) + lora_model.eval() + eval_forward_results = lora_model(input_ids) + self.assertIsNotNone(eval_forward_results) + self.assertTrue(paddle.allclose(train_forward_results[0], eval_forward_results[0])) + + def test_lora_model_save_load(self): + with TemporaryDirectory() as tempdir: + input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 20])) + lora_config = LoRAConfig( + target_modules=[".*q_proj.*", ".*v_proj.*"], r=4, lora_alpha=8, lora_use_mixer=True + ) + model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") + lora_model = LoRAModel(model, lora_config) + lora_model.eval() + original_results = lora_model(input_ids) + lora_model.save_pretrained(tempdir) + + loaded_lora_model = LoRAModel.from_pretrained(model, tempdir) + loaded_lora_model.eval() + loaded_results = loaded_lora_model(input_ids) + self.assertTrue(paddle.allclose(original_results[0], loaded_results[0])) + + config_loaded_lora_model = LoRAModel.from_pretrained(model, tempdir, lora_config=lora_config) + config_loaded_lora_model.eval() + config_loaded_results = config_loaded_lora_model(input_ids) + self.assertTrue(paddle.allclose(original_results[0], config_loaded_results[0])) + + def test_lora_module_raise_exception(self): + lora_config = LoRAConfig( + target_modules=[".*norm1.*"], r=4, lora_alpha=8, enable_lora_list=None, lora_use_mixer=True + ) + model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") + with self.assertRaises(ValueError): + LoRAModel(model, lora_config) + + +class TestMosLoRAConfig(unittest.TestCase): + def test_save_load(self): + with TemporaryDirectory() as tempdir: + lora_config = LoRAConfig() + lora_config.save_pretrained(tempdir) + loaded_lora_config = LoRAConfig.from_pretrained(tempdir) + self.assertEqual(lora_config, loaded_lora_config) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/peft/test_reft.py b/tests/peft/test_reft.py new file mode 100644 index 000000000000..8d405370947d --- /dev/null +++ b/tests/peft/test_reft.py @@ -0,0 +1,347 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import unittest +from functools import partial +from tempfile import TemporaryDirectory +from types import SimpleNamespace + +import paddle + +from llm.utils.data import convert_example_for_reft +from paddlenlp.data import DataCollatorForSeq2Seq +from paddlenlp.datasets import load_dataset +from paddlenlp.peft.reft import ( + LoreftIntervention, + LowRankRotateLayer, + ReFTConfig, + ReftDataCollator, + ReFTModel, + TinyIntervention, + do_predict, +) +from paddlenlp.peft.reft.modeling_utils import ( + count_parameters, + get_type_from_string, + set_seed, +) +from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer +from paddlenlp.trl import SFTTrainer + + +class TestReftDataCollator(unittest.TestCase): + def test_call(self): + model_name = "__internal_testing__/tiny-random-llama" + tokenizer = AutoTokenizer.from_pretrained( + model_name, + model_max_length=512, + padding_side="right", + ) + tokenizer.pad_token_id = tokenizer.eos_token_id + model = AutoModelForCausalLM.from_pretrained(model_name) + data_collator = DataCollatorForSeq2Seq( + tokenizer=tokenizer, model=model, label_pad_token_id=-100, padding="longest" + ) + reft_data_collator = ReftDataCollator(data_collator) + instances = [ + { + "input_ids": paddle.to_tensor([[1, 2, 3], [4, 5, 6]]), + "intervention_locations": paddle.to_tensor([[0, 1, 0], [1, 0, 1]]), + }, + { + "input_ids": paddle.to_tensor([[7, 8, 9], [10, 11, 12]]), + "intervention_locations": paddle.to_tensor([[1, 0, 1], [0, 1, 0]]), + }, + ] + + batch_inputs = reft_data_collator(instances) + + self.assertIn("input_ids", batch_inputs) + self.assertIn("intervention_locations", batch_inputs) + self.assertIsInstance(batch_inputs["input_ids"], paddle.Tensor) + self.assertIsInstance(batch_inputs["intervention_locations"], paddle.Tensor) + + +class TestBasicUtils(unittest.TestCase): + def test_get_type_from_string(self): + class_str = "paddlenlp.peft.reft.LoreftIntervention" + cls = get_type_from_string(class_str) + self.assertIsInstance(cls, type(LoreftIntervention)) + + def test_set_seed(self): + set_seed(42) + set_seed(66) + + def test_count_param(self): + model = AutoModelForCausalLM.from_pretrained("__internal_testing__/tiny-random-llama") + count_parameters(model) + + +class TestReftConfig(unittest.TestCase): + def test_reft_config(self): + layers = [0, 1, 2] + representations = [ + { + "layer": l, + "component": "block_output", + "low_rank_dimension": 4, + "intervention": LoreftIntervention( + embed_dim=768, + low_rank_dimension=4, + dropout=0.00, + dtype="float32", + act_fn="linear", + device="gpu", + add_bias=False, + ), + } + for l in layers + ] + reft_config = ReFTConfig(representations=representations) + reft_config.__str__() + + +class TestLoReftIntervention(unittest.TestCase): + def setUp(self): + self.kwargs = { + "embed_dim": 64, + "low_rank_dimension": 4, + "dtype": paddle.float32, + "dropout": 0.1, + "act_fn": "linear", + } + + def test_initialization(self): + intervention = LoreftIntervention(**self.kwargs) + self.assertIsInstance(intervention.rotate_layer, LowRankRotateLayer) + self.assertIsInstance(intervention.learned_source, paddle.nn.Linear) + self.assertEqual(intervention.dropout.p, self.kwargs["dropout"]) + + def test_forward(self): + base = paddle.randn([10, self.kwargs["embed_dim"]]) + intervention = LoreftIntervention(**self.kwargs) + output = intervention.forward(base) + self.assertEqual(output.shape, base.shape) + self.assertEqual(output.dtype, self.kwargs["dtype"]) + + def test_load_state_dict(self): + model = LoreftIntervention(**self.kwargs) + state_dict = { + "learned_source.weight": paddle.randn([64, 4]), + "learned_source.bias": paddle.zeros([4]), + "rotate_layer.weight": paddle.randn([64, 4]), + } + model.load_state_dict(state_dict) + self.assertTrue(paddle.allclose(model.learned_source.weight.data, state_dict["learned_source.weight"])) + self.assertTrue(paddle.allclose(model.learned_source.bias.data, state_dict["learned_source.bias"])) + self.assertTrue( + paddle.allclose( + model.rotate_layer.weight[:, : state_dict["rotate_layer.weight"].shape[-1]], + state_dict["rotate_layer.weight"], + ) + ) + + +class TestTinyIntervention(unittest.TestCase): + def setUp(self): + self.kwargs = { + "embed_dim": 768, + "low_rank_dimension": 4, + "dtype": paddle.float32, + "dropout": 0.1, + "act_fn": "relu", + } + + def test_initialization(self): + intervention = TinyIntervention(**self.kwargs) + self.assertEqual(intervention.rank, self.kwargs["low_rank_dimension"]) + self.assertEqual(intervention.hidden_size, self.kwargs["embed_dim"]) + self.assertEqual(intervention.param_A.shape, [self.kwargs["embed_dim"], self.kwargs["low_rank_dimension"]]) + self.assertEqual(intervention.param_B.shape, [self.kwargs["low_rank_dimension"], self.kwargs["embed_dim"]]) + self.assertEqual(intervention.param_a.shape, [self.kwargs["low_rank_dimension"]]) + self.assertEqual(intervention.param_b.shape, [self.kwargs["embed_dim"]]) + + def test_forward(self): + base = paddle.randn([10, self.kwargs["embed_dim"]]) + intervention = TinyIntervention(**self.kwargs) + output = intervention.forward(base) + self.assertEqual(output.shape, base.shape) + self.assertEqual(output.dtype, self.kwargs["dtype"]) + + def test_load_state_dict(self): + model = TinyIntervention(**self.kwargs) + state_dict = { + "param_A": paddle.randn([768, 4]), + "param_B": paddle.randn([4, 768]), + "param_a": paddle.randn([4]), + "param_b": paddle.randn([768]), + } + model.load_state_dict(state_dict) + self.assertTrue(paddle.allclose(model.param_A, state_dict["param_A"])) + self.assertTrue(paddle.allclose(model.param_B, state_dict["param_B"])) + self.assertTrue(paddle.allclose(model.param_a, state_dict["param_a"])) + self.assertTrue(paddle.allclose(model.param_b, state_dict["param_b"])) + + +class TestReftModel(unittest.TestCase): + def test_get_reft_model(self): + model = AutoModelForCausalLM.from_pretrained("__internal_testing__/tiny-random-llama") + layers = [0] + representations = [ + { + "layer": l, + "component": "block_output", + "low_rank_dimension": 4, + "intervention": LoreftIntervention( + embed_dim=768, + low_rank_dimension=4, + dropout=0.00, + dtype="float32", + act_fn="linear", + device="gpu", + add_bias=False, + ), + } + for l in layers + ] + reft_config = ReFTConfig(representations=representations) + reft_model = ReFTModel(reft_config, model) + reft_model.print_trainable_parameters() + self.assertTrue(type(reft_model), ReFTModel) + + def test_reft_model_forward(self): + model = AutoModelForCausalLM.from_pretrained("__internal_testing__/tiny-random-llama") + + layers = [0] + representations = [ + { + "layer": l, + "component": "block_output", + "low_rank_dimension": 4, + "intervention": LoreftIntervention( + embed_dim=768, + low_rank_dimension=4, + dropout=0.00, + dtype="float32", + act_fn="linear", + device="gpu", + add_bias=False, + ), + } + for l in layers + ] + reft_config = ReFTConfig(representations=representations) + reft_model = ReFTModel(reft_config, model) + reft_model.print_trainable_parameters() + outputs = reft_model.model(**{"input_ids": paddle.randint(low=1, high=100, shape=(5, 10))}) + self.assertTrue(outputs[0].shape, [5, 10, 32000]) + + +class TestReFTModelPredict(unittest.TestCase): + def test_reft_model_predict(self): + tokenizer = AutoTokenizer.from_pretrained("__internal_testing__/tiny-random-llama") + tokenizer.pad_token_id = tokenizer.eos_token_id + train_ds = load_dataset( + "json", + data_files=os.path.join("./tests/fixtures/llm/data", "train.json"), + lazy=False, + )[0] + dev_ds = load_dataset( + "json", + data_files=os.path.join("./tests/fixtures/llm/data", "dev.json"), + lazy=False, + )[0] + + trans_func = partial( + convert_example_for_reft, + tokenizer=tokenizer, + data_args=SimpleNamespace(**{"max_length": 64, "src_length": 32, "autoregressive": False}), + positions="f7", + num_interventions=1, + ) + + train_ds = train_ds.map( + partial( + trans_func, + is_test=False, + zero_padding=False, + flash_mask=False, + ) + ) + dev_ds = dev_ds.map( + partial( + trans_func, + is_test=False, + zero_padding=False, + flash_mask=False, + ) + ) + + model = AutoModelForCausalLM.from_pretrained("__internal_testing__/tiny-random-llama") + layers = [0] + representations = [ + { + "layer": l, + "component": "block_output", + "low_rank_dimension": 4, + "intervention": LoreftIntervention( + embed_dim=768, + low_rank_dimension=4, + dropout=0.00, + dtype="float32", + act_fn="linear", + device="gpu", + add_bias=False, + ), + } + for l in layers + ] + reft_config = ReFTConfig(representations=representations) + reft_model = ReFTModel(reft_config, model) + reft_model.disable_model_gradients() + reft_model.model.train() + reft_model.print_trainable_parameters() + data_collator_fn = DataCollatorForSeq2Seq( + tokenizer=tokenizer, model=model, label_pad_token_id=-100, padding="longest" + ) + data_collator = ReftDataCollator(data_collator=data_collator_fn) + trainer = SFTTrainer( + model=reft_model, + tokenizer=tokenizer, + train_dataset=train_ds, + data_collator=data_collator, + eval_dataset=None, + compute_metrics=None, + gen_args=None, + data_args=None, + do_generation=False, + ) + trainer.train() + + with TemporaryDirectory() as tempdir: + reft_model.save_pretrained(tempdir) + # 预测 + do_predict( + intervenable=reft_model, + tokenizer=tokenizer, + eval_dataset=dev_ds, + batch_size=1, + predict_path=f"{tempdir}/pred_result.json", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pose/__init__.py b/tests/pose/__init__.py new file mode 100644 index 000000000000..595add0aed9e --- /dev/null +++ b/tests/pose/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/pose/test_long_sequence_strategies_yarn.py b/tests/pose/test_long_sequence_strategies_yarn.py new file mode 100644 index 000000000000..d2d066932704 --- /dev/null +++ b/tests/pose/test_long_sequence_strategies_yarn.py @@ -0,0 +1,4918 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +# import os +import sys +import unittest + +import numpy as np +import paddle +from parameterized import parameterized_class + +from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM + +from .testing_utils import LLMTest + +all_inputs = [ + # llama-7b + [ + [ + 1, + 910, + 3461, + 8128, + 3239, + 2472, + 322, + 5626, + 363, + 11559, + 373, + 2473, + 6360, + 9580, + 545, + 358, + 313, + 17870, + 29925, + 29897, + 14974, + 6360, + 9580, + 545, + 358, + 313, + 17870, + 29925, + 29897, + 322, + 2908, + 15649, + 8078, + 292, + 313, + 29933, + 5371, + 29897, + 526, + 4266, + 8078, + 292, + 7208, + 12903, + 393, + 11559, + 3635, + 1169, + 278, + 10317, + 310, + 5282, + 1947, + 313, + 3970, + 29928, + 29897, + 304, + ] + ], + # qwen-7b + [ + [ + 1986, + 1895, + 5707, + 4004, + 1995, + 323, + 4714, + 369, + 7992, + 389, + 7299, + 3157, + 52578, + 320, + 44, + 9954, + 8, + 323, + 2504, + 3695, + 20358, + 3157, + 52578, + 320, + 44, + 9954, + 8, + 323, + 2504, + 3695, + 59406, + 320, + 66755, + 8, + 525, + 3281, + 59406, + 23783, + 429, + 7992, + 28690, + 279, + 5887, + 315, + 16373, + 320, + 35, + 2069, + 8, + 311, + 990, + 369, + 264, + 7199, + 1372, + 315, + 9055, + 23390, + ] + ], + # chatglm3-6b + [ + [ + 64790, + 64792, + 666, + 1284, + 2736, + 4467, + 1097, + 293, + 2326, + 332, + 4168, + 331, + 5332, + 2475, + 23355, + 359, + 26594, + 30947, + 30945, + 293, + 15903, + 2475, + 23355, + 359, + 26594, + 30947, + 30945, + 293, + 3579, + 2505, + 26317, + 359, + 54223, + 30945, + 383, + 1720, + 26317, + 11972, + 343, + 4168, + 15125, + 267, + 2902, + 290, + 10196, + 359, + 30952, + 3809, + 30945, + 289, + 792, + 332, + 260, + 3666, + 1276, + 290, + 5735, + 10625, + ] + ], + # chatglm-6b + [ + [ + 200, + 647, + 986, + 1186, + 320, + 102, + 953, + 108, + 2355, + 111, + 1297, + 626, + 26020, + 19, + 10806, + 266, + 14, + 102, + 130001, + 130004, + 6723, + 626, + 26020, + 19, + 10806, + 266, + 14, + 102, + 1204, + 1784, + 27817, + 19, + 27798, + 14, + 118, + 972, + 27817, + 2055, + 109, + 2355, + 9187, + 100, + 1334, + 101, + 7319, + 19, + 9220, + 234, + 14, + 103, + 179, + 108, + 104, + 1132, + 277, + 101, + 2576, + 6225, + ] + ], + # bloom + [ + [ + 55, + 75, + 76, + 86, + 210, + 85, + 72, + 83, + 82, + 85, + 87, + 210, + 83, + 85, + 82, + 89, + 76, + 71, + 72, + 86, + 48, + 88, + 79, + 87, + 76, + 92, + 72, + 68, + 85, + 210, + 83, + 85, + 82, + 70, + 88, + 85, + 72, + 80, + 72, + 81, + 87, + 210, + 11, + 48, + 60, + 51, + 12, + 210, + 68, + 81, + 71, + 210, + 69, + 79, + 82, + 70, + 78, + 210, + ] + ], +] +all_position_ids = [ + # llama-7b + [ + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + ] + ], + # qwen07b + [ + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + ] + ], + # chatglm3-6b + [ + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + ] + ], + # chatglm-6b + [ + [ + [ + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + ], + ] + ], + # bloom + [ + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + ] + ], +] +all_attention_mask = [ + # llama + [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ] + ], + # qwen + [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ] + ], + # chatglm3-6b + [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ] + ], + # chatglm-6b + [ + [ + [ + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + ], + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + ], + ] + ] + ], + # bloom + [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ] + ], +] +all_labels = [ + # llama + [ + [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 14974, + 6360, + 9580, + 545, + 358, + 313, + 17870, + 29925, + 29897, + 322, + 2908, + 15649, + 8078, + 292, + 313, + 29933, + 5371, + 29897, + 526, + 4266, + 8078, + 292, + 7208, + 12903, + 393, + 11559, + 3635, + 1169, + 278, + 10317, + 310, + 5282, + 1947, + 313, + 3970, + 29928, + 29897, + 304, + 671, + ] + ], + # qwen + [ + [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 20358, + 3157, + 52578, + 320, + 44, + 9954, + 8, + 323, + 2504, + 3695, + 59406, + 320, + 66755, + 8, + 525, + 3281, + 59406, + 23783, + 429, + 7992, + 28690, + 279, + 5887, + 315, + 16373, + 320, + 35, + 2069, + 8, + 311, + 990, + 369, + 264, + 7199, + 1372, + 315, + 9055, + 23390, + 7468, + ] + ], + # chatglm3-6b + [ + [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 15903, + 2475, + 23355, + 359, + 26594, + 30947, + 30945, + 293, + 3579, + 2505, + 26317, + 359, + 54223, + 30945, + 383, + 1720, + 26317, + 11972, + 343, + 4168, + 15125, + 267, + 2902, + 290, + 10196, + 359, + 30952, + 3809, + 30945, + 289, + 792, + 332, + 260, + 3666, + 1276, + 290, + 5735, + 10625, + 3181, + ] + ], + # chatglm-6b + [ + [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 130004, + 6723, + 626, + 26020, + 19, + 10806, + 266, + 14, + 102, + 1204, + 1784, + 27817, + 19, + 27798, + 14, + 118, + 972, + 27817, + 2055, + 109, + 2355, + 9187, + 100, + 1334, + 101, + 7319, + 19, + 9220, + 234, + 14, + 103, + 179, + 108, + 104, + 1132, + 277, + 101, + 2576, + 6225, + 1785, + ] + ], + # bloom + [ + [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 48, + 88, + 79, + 87, + 76, + 92, + 72, + 68, + 85, + 210, + 83, + 85, + 82, + 70, + 88, + 85, + 72, + 80, + 72, + 81, + 87, + 210, + 11, + 48, + 60, + 51, + 12, + 210, + 68, + 81, + 71, + 210, + 69, + 79, + 82, + 70, + 78, + 210, + 69, + ] + ], +] + +all_ppl = [ + # llama + 31361.590644223128, + 31361.590644223128, + 31362.757106912533, + 31361.62055298091, + # qwen + 155909.83795939674, + 155939.57823718787, + 155917.27249705535, + 155909.83795939674, + # chatglm3-6b + 64415.31959719674, + 64454.8934643284, + 64416.60966606845, + 64420.172847651804, + # chatglm-6b + 130540.64669131214, + 130573.01895270264, + 130539.15278071642, + 130538.4058318297, + # llama-alibi + 31369.517462860927, + # bloom-alibi + 251106.84487228873, +] + + +@parameterized_class( + [ + "model_name_or_path", + "strategy_type", + "strategy_name", + "inputs", + "positin_ids", + "labels", + "attention_mask", + "ppl", + ], + [ + [ + "__internal_testing__/micro-random-llama", + "embedding_strategies", + "YaRNScalingRotaryEmbedding", + all_inputs[0], + all_position_ids[0], + all_labels[0], + all_attention_mask[0], + all_ppl[3], + ], + # [ + # "__internal_testing__/tiny-new-random-qwen-7b", + # "embedding_strategies", + # "YaRNScalingRotaryEmbedding", + # all_inputs[1], + # all_position_ids[1], + # all_labels[1], + # all_attention_mask[1], + # all_ppl[7], + # ], + # [ + # "__internal_testing__/tiny-new-random-chatglm3-6b", + # "embedding_strategies", + # "YaRNScalingRotaryEmbedding", + # all_inputs[2], + # all_position_ids[2], + # all_labels[2], + # all_attention_mask[2], + # all_ppl[11], + # ], + ], +) +class TestLongSequenceStrategiesTest(LLMTest, unittest.TestCase): + config_path: str = "./tests/fixtures/llm/predictor.yaml" + root_path = "" + + def setUp(self) -> None: + super().setUp() + sys.path.insert(0, "./llm") + + def disable_static(self): + paddle.utils.unique_name.switch() + paddle.disable_static() + + def get_model(self, model_name_or_path): + model_config = AutoConfig.from_pretrained(model_name_or_path) + if self.strategy_type == "embedding_strategies": + model_config.alibi = False + else: + model_config.alibi = True + model_config.use_long_sequence_strategies = True + model_config.long_sequence_strategy_type = self.strategy_type + model_config.long_sequence_strategy_name = self.strategy_name + max_position_embeddings = 10 if self.strategy_name == "DynamicNTKScalingRotaryEmbedding" else 2048 + scaling_factor = 4 + model_config.long_sequence_init_args = { + "dim": int(model_config.hidden_size / model_config.num_attention_heads), + "max_position_embeddings": int(max_position_embeddings * scaling_factor), + "base": 10000, + "scaling_factor": scaling_factor, + "original_max_position_embeddings": max_position_embeddings, + } + # model_config.long_sequence_init_args = { + # "dim": int(model_config.hidden_size / model_config.num_attention_heads), + # "max_position_embeddings": max_position_embeddings , + # "base": 10000, + # "scaling_factor": scaling_factor, + # "original_max_position_embeddings": 10, + # } + if "chatglm" in model_name_or_path: + model_config.long_sequence_init_args["position_encoding_2d"] = True + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=model_config, dtype="float32") + return model + + def test_long_sequence_strategies(self): + input_ids = paddle.to_tensor(self.inputs, dtype=paddle.int64) + position_ids = paddle.to_tensor(self.positin_ids, dtype=paddle.int64) + attention_mask = paddle.to_tensor(self.attention_mask, dtype=paddle.int64) + labels = paddle.to_tensor(self.labels, dtype=paddle.int64) + ppl = self.ppl + inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "labels": labels, + "attention_mask": attention_mask, + } + model = self.get_model(self.model_name_or_path) + + output = model(**inputs) + self.assertTrue( + np.allclose( + np.exp(output[0].item()), + ppl, + rtol=1e-2, + ) + ) diff --git a/tests/pose/testing_utils.py b/tests/pose/testing_utils.py new file mode 100644 index 000000000000..3684ec243576 --- /dev/null +++ b/tests/pose/testing_utils.py @@ -0,0 +1,110 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +import os +import shutil +import sys +import tempfile + +import paddle + +from tests.testing_utils import argv_context_guard, load_test_config + + +class LLMTest: + config_path: str = None + data_dir = "./tests/fixtures/llm/data/" + + def setUp(self) -> None: + self.root_path = "./llm" + self.output_dir = tempfile.mkdtemp() + self.inference_output_dir = tempfile.mkdtemp() + sys.path.insert(0, self.root_path) + self.disable_static() + paddle.set_default_dtype("float32") + + def tearDown(self) -> None: + sys.path.remove(self.root_path) + shutil.rmtree(self.output_dir) + shutil.rmtree(self.inference_output_dir) + self.disable_static() + paddle.device.cuda.empty_cache() + + def disable_static(self): + paddle.utils.unique_name.switch() + paddle.disable_static() + + def _read_result(self, file): + result = [] + # read output field from json file + with open(file, "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + result.append(data["output"]) + return result + + def run_predictor(self, config_params=None): + if config_params is None: + config_params = {} + + # to avoid the same parameter + self.disable_static() + predict_config = load_test_config(self.config_path, "inference-predict") + predict_config["output_file"] = os.path.join(self.output_dir, "predict.json") + predict_config["model_name_or_path"] = self.output_dir + predict_config.update(config_params) + + with argv_context_guard(predict_config): + from predict.predictor import predict + + predict() + + # prefix_tuning dynamic graph do not support to_static + if not predict_config["inference_model"]: + return + + # to static + self.disable_static() + config = load_test_config(self.config_path, "inference-to-static") + config["output_path"] = self.inference_output_dir + config["model_name_or_path"] = self.output_dir + config.update(config_params) + with argv_context_guard(config): + from predict.export_model import main + + main() + + # inference + self.disable_static() + config = load_test_config(self.config_path, "inference-infer") + config["model_name_or_path"] = self.inference_output_dir + config["output_file"] = os.path.join(self.inference_output_dir, "infer.json") + + config_params.pop("model_name_or_path", None) + config.update(config_params) + with argv_context_guard(config): + from predict.predictor import predict + + predict() + + self.disable_static() + + predict_result = self._read_result(predict_config["output_file"]) + infer_result = self._read_result(config["output_file"]) + assert len(predict_result) == len(infer_result) + + for predict_item, infer_item in zip(predict_result, infer_result): + self.assertEqual(predict_item, infer_item) diff --git a/tests/test_tipc/static/auto_parallel/llama2/benchmark_common/run_benchmark.sh b/tests/test_tipc/static/auto_parallel/llama2/benchmark_common/run_benchmark.sh index 03b2175d9465..b74b3a9df2e2 100644 --- a/tests/test_tipc/static/auto_parallel/llama2/benchmark_common/run_benchmark.sh +++ b/tests/test_tipc/static/auto_parallel/llama2/benchmark_common/run_benchmark.sh @@ -57,6 +57,57 @@ function _set_params(){ OUTPUT_PATH=${run_log_path}/output } +# 循环监控文件写入状态和进程状态 +monitor_log_file() { + local log_file="$1" # 获取日志文件路径 + local training_pid="$2" # 获取训练进程的 PID + local no_update_duration=0 # 初始化无更新时长计数 + local last_size=0 + + echo "开始监控进程 $training_pid 和日志文件 $log_file..." + + while true; do + sleep 5 # 每隔 5 秒检查一次日志文件 + + # 判断日志文件是否存在 + if [ ! -f "$log_file" ]; then + echo "日志文件 $log_file 不存在,检查进程状态..." + # 如果日志文件不存在,直接判断进程是否结束 + if ! ps -p $training_pid > /dev/null; then + echo "进程 $training_pid 已经结束。" + break + fi + continue # 如果文件不存在,跳过后续逻辑,继续循环 + fi + + # 获取当前日志文件的大小 + new_size=$(stat -c %s "$log_file") + + if [ "$last_size" -eq "$new_size" ]; then + # 文件大小未变化,增加无更新时长计数 + no_update_duration=$((no_update_duration + 5)) + + if [ "$no_update_duration" -ge 180 ]; then + echo "文件在过去的 3 分钟内没有继续写入,准备杀掉进程 $training_pid." + kill -9 $training_pid # 杀掉进程 + echo "进程 $training_pid 已经被杀掉。" + break + fi + else + # 文件大小有变化,重置无更新时长计数 + echo "文件仍在写入..." + no_update_duration=0 + last_size=$new_size + fi + + # 如果训练进程已经结束,退出监控 + if ! ps -p $training_pid > /dev/null; then + echo "进程 $training_pid 已经结束。" + break + fi + done +} + function _train(){ batch_size=${per_device_train_batch_size} # 如果模型跑多卡单进程时,请在_train函数中计算出多卡需要的bs @@ -140,9 +191,25 @@ function _train(){ rm -rf mylog && rm -rf checkpoints echo "train_cmd: ${train_cmd} log_file: ${log_file}" - timeout 40m ${train_cmd} > ${log_file} 2>&1 + timeout 40m ${train_cmd} > ${log_file} 2>&1 & + training_pid=$! # 获取后台进程的 PID + + # 监控进程和日志的更新状态 + monitor_log_file "$log_file" "$training_pid" & + monitor_log_file_pid=$! # 获取日志监控进程的 PID + + # 等待训练进程完成 + wait $training_pid + exit_code=$? + + # 获取训练进程的退出码 + echo "训练进程 $training_pid 的退出码是 $exit_code" + + # 清理后台日志监控进程 + kill $monitor_log_file_pid + - if [ $? -ne 0 ];then + if [ ${exit_code} -ne 0 ];then echo -e "${model_name}, FAIL" else echo -e "${model_name}, SUCCESS" @@ -158,6 +225,9 @@ function _train(){ export FLAGS_selected_gpus="0,1,2,3,4,5,6,7" export NCCL_IB_DISABLE=0 export PYTHONPATH=$(dirname "$PWD"):$PYTHONPATH +# https://github.com/PaddlePaddle/Paddle/pull/69410 合入影响 +# 如不设置参数为1,则默认选择不带tensor fusion的sharding stage1版本 +export FLAGS_enable_sharding_stage1_tensor_fusion=1 export CUDA_DEVICE_MAX_CONNECTIONS=1 export PARALLEL_CROSS_ENTROPY=true diff --git a/tests/transformers/auto/test_tokenizer.py b/tests/transformers/auto/test_tokenizer.py index 54c568113023..1e47267f91a3 100644 --- a/tests/transformers/auto/test_tokenizer.py +++ b/tests/transformers/auto/test_tokenizer.py @@ -48,6 +48,10 @@ def test_from_pretrained_cache_dir(self): # check against double appending model_name in cache_dir self.assertFalse(os.path.exists(os.path.join(tempdir, model_name, model_name))) + def test_from_pretrained_tokenizer_fast(self): + tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-base-v2", use_fast=True) + self.assertIsInstance(tokenizer, BertTokenizerFast) + def test_new_tokenizer_registration(self): try: AutoConfig.register("custom", CustomConfig) diff --git a/tests/transformers/gemma/test_tokenizer.py b/tests/transformers/gemma/test_tokenizer.py index e8527c40ee4b..17c792f7a447 100644 --- a/tests/transformers/gemma/test_tokenizer.py +++ b/tests/transformers/gemma/test_tokenizer.py @@ -223,3 +223,29 @@ def test_add_special_tokens(self): self.assertEqual(encoded, input_encoded + special_token_id) decoded = tokenizer.decode(encoded, skip_special_tokens=True) self.assertTrue(special_token not in decoded) + + def test_extract_non_learnable_parts(self): + models_with_templates = ["google/gemma-2b-it", "google/gemma-7b-it"] + dummy_conversastions = [ + ["Q.", "A."], + ["Q.A.", "A."], + ["Q?", "A!"], + ] + decode_outputs = [ + ["user\nQ.\nmodel\n", "A.\n"], + ["user\nQ.A.\nmodel\n", "A.\n"], + ["user\nQ?\nmodel\n", "A!\n"], + ] + context_data = {} + context_data["is_training"] = True + for model_id in models_with_templates: + tokenizer = GemmaTokenizer.from_pretrained(model_id) + if tokenizer.chat_template is None: + continue + conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs( + dummy_conversastions, + context_data=context_data, + ) + for idx, round in enumerate(conversation_result["conversations"]): + self.assertEquals(tokenizer.decode(round[0]), decode_outputs[idx][0]) + self.assertEquals(tokenizer.decode(round[1]), decode_outputs[idx][1]) diff --git a/tests/transformers/yuan/__init__.py b/tests/transformers/yuan/__init__.py new file mode 100644 index 000000000000..fd05a9208165 --- /dev/null +++ b/tests/transformers/yuan/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/transformers/yuan/test_tokenizer.py b/tests/transformers/yuan/test_tokenizer.py new file mode 100644 index 000000000000..4a479fe3336a --- /dev/null +++ b/tests/transformers/yuan/test_tokenizer.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddlenlp.transformers import YuanTokenizer + + +class YuanTokenizationTest(unittest.TestCase): + def test_extract_non_learnable_parts(self): + models_with_templates = [ + "IEITYuan/Yuan2-2B", + "IEITYuan/Yuan2-51B", + "IEITYuan/Yuan2-102B", + ] + dummy_conversastions = [ + ["Q.", "A."], + ["Q.A.", "A."], + ["Q?", "A!"], + ] + decode_outputs = [ + ["Q.", "A."], + ["Q.A.", "A."], + ["Q?", " A!"], # notify there is an extra space + ] + context_data = {} + context_data["is_training"] = True + for model_id in models_with_templates: + tokenizer = YuanTokenizer.from_pretrained(model_id) + if tokenizer.chat_template is None: + continue + conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs( + dummy_conversastions, + context_data=context_data, + ) + for idx, round in enumerate(conversation_result["conversations"]): + self.assertEquals(tokenizer.decode(round[0]), decode_outputs[idx][0]) + self.assertEquals(tokenizer.decode(round[1]), decode_outputs[idx][1])