From 4aeade2db8db012dfc7d60d64d5ac7f697591088 Mon Sep 17 00:00:00 2001 From: ckl117 Date: Fri, 8 Nov 2024 09:01:30 +0000 Subject: [PATCH] cutlass 3.x gemm on sm90 --- .../fuse_gemm_noact_template_3x.h | 144 ++++++++++ .../fp8_fp8_half_gemm_sm90.cu | 255 ++++++++++++++++++ .../generic_gemm_kernel_noact_3x.cu | 66 +++++ csrc/setup_cuda.py | 17 +- 4 files changed, 478 insertions(+), 4 deletions(-) create mode 100644 csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template_3x.h create mode 100644 csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm_sm90.cu create mode 100644 csrc/gpu/fp8_gemm_with_cutlass/generic_gemm_kernel_noact_3x.cu diff --git a/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template_3x.h b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template_3x.h new file mode 100644 index 000000000000..984a6696534e --- /dev/null +++ b/csrc/gpu/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template_3x.h @@ -0,0 +1,144 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "fp8_common.h" + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +template < + typename InputType = phi::dtype::float8_e4m3fn, + typename OutType = phi::dtype::float16, + typename TileShape = cute::Shape, + typename ClusterShape = cute::Shape, + typename KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + typename EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized +> +bool dispatch_fuse_gemm_noact_sm90(GemmEpilogueAllParams params) { + using ElementA = typename std::conditional_t, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementB = ElementA; + using ElementD = typename std::conditional_t, + cutlass::bfloat16_t, + cutlass::half_t>; + using ElementC = ElementD; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementScalar = float; + + // 16B alignment lets us use TMA + static constexpr int AlignmentA = 16 / sizeof(ElementA); + static constexpr int AlignmentB = 16 / sizeof(ElementB); + static constexpr int AlignmentC = 16 / sizeof(ElementC); + static constexpr int AlignmentD = 16 / sizeof(ElementD); + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule, + DefaultOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + // + // Data members + // + + /// Initialization + StrideA stride_A{params.lda, cute::Int<1>{}, 0}; + StrideB stride_B{params.ldb, cute::Int<1>{}, 0}; + StrideC stride_C{0, cute::Int<1>{}, 0}; + StrideD stride_D{params.ldd, cute::Int<1>{}, 0}; + + auto a_ptr = reinterpret_cast(const_cast(params.A)); + auto b_ptr = reinterpret_cast(const_cast(params.B)); + auto c_ptr = reinterpret_cast(const_cast(params.bias)); + auto d_ptr = reinterpret_cast(params.D); + + ProblemShapeType problem_size = ProblemShapeType{params.M, params.N, params.K, 1}; + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {a_ptr, stride_A, b_ptr, stride_B}, + {{}, // epilogue.thread + c_ptr, stride_C, d_ptr, stride_D} + }; + + arguments.epilogue.thread.alpha = params.scale; + arguments.epilogue.thread.beta = 1.0f; + + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + phi::Allocator* allocator = paddle::GetAllocator(params.place); + auto workspace = allocator->Allocate(workspace_size); + cutlass::Status status = gemm_op.can_implement(arguments); + status = gemm_op(arguments, workspace->ptr(), params.stream); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::run() failed" << std::endl; + return false; + } + +} diff --git a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm_sm90.cu b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm_sm90.cu new file mode 100644 index 000000000000..09f5702dd0c2 --- /dev/null +++ b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm_sm90.cu @@ -0,0 +1,255 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "fp8_gemm_fused/fuse_gemm_noact_template_3x.h" +#include "fp8_common.h" + + +std::vector cutlass_fp8_fp8_half_gemm_sm90( + const paddle::Tensor& x, + const paddle::Tensor& y, + const paddle::optional& bias, + bool trans_x, + bool trans_y, + float scale, // only support per-tensor quantization + std::string output_dtype, + std::string activation_type) { + paddle::Tensor out; + void* out_ptr = nullptr; + const void* x_ptr = nullptr; + const void* y_ptr = nullptr; + + auto place = x.place(); + cudaStream_t stream = x.stream(); + int64_t device_id = place.GetDeviceId(); + int sm_version = GetGPUComputeCapability(device_id); + + int rank = x.dims().size(); + int M = 0; + int K = 0; + int N = 0; + int ldd = 0; + + int lda = x.dims()[rank - 1]; + int ldb = y.dims()[rank - 1]; + + if (!trans_x) { + M = x.dims()[rank - 2]; + K = x.dims()[rank - 1]; + + } else { + M = x.dims()[rank - 1]; + K = x.dims()[rank - 2]; + } + if (!trans_y) { + N = y.dims()[rank - 1]; + ldd = y.dims()[rank - 1]; + } else { + N = y.dims()[rank - 2]; + ldd = y.dims()[rank - 2]; + } + + int batch_count = 1; + for (size_t i = 0; i < rank - 2; ++i) { + batch_count *= x.dims()[i]; + } + + std::string input_dtype = ""; + if (x.dtype() == phi::DataType::FLOAT8_E4M3FN) { + input_dtype = "float8_e4m3fn"; + x_ptr = reinterpret_cast(x.data()); + y_ptr = reinterpret_cast(y.data()); + } else if (x.dtype() == phi::DataType::FLOAT8_E5M2) { + input_dtype = "float8_e5m2"; + x_ptr = reinterpret_cast(x.data()); + y_ptr = reinterpret_cast(y.data()); + } else { + PADDLE_THROW(phi::errors::Fatal( + "fp8_fp8_half_gemm_fused_sm90 only support e4m3 and e5m2 input")); + } + + std::vector out_shape = x.shape(); + out_shape[rank - 1] = N; + out_shape[rank - 2] = M; + + if (output_dtype == "bfloat16") { + out = paddle::empty(out_shape, paddle::DataType::BFLOAT16, x.place()); + out_ptr = reinterpret_cast(out.data()); + } else if (output_dtype == "float16") { + out = paddle::empty(out_shape, paddle::DataType::FLOAT16, x.place()); + out_ptr = reinterpret_cast(out.data()); + } else { + PADDLE_THROW(phi::errors::Fatal( + "fp8_fp8_half_gemm_fused only support bfloat16 and float16 output")); + } + + std::string isbias = bias ? "true" : "false"; + std::string act = (activation_type == "" || activation_type == "identity") + ? "noact" + : activation_type; + + std::string fuse_gemm_config = + input_dtype + "_" + output_dtype + "_" + isbias + "_" + act; + + void* bias_data = nullptr; + std::vector bias_dims{}; + if (bias) { + bias_dims = common::vectorize(bias.get().dims()); + if (output_dtype == "bfloat16") { + bias_data = reinterpret_cast(const_cast( + bias.get().data())); + } else { + bias_data = reinterpret_cast(const_cast( + bias.get().data())); + } + } + + GemmEpilogueAllParams params = { + x_ptr, + y_ptr, + out_ptr, + scale, + M, + N, + K, + lda, + ldb, + ldd, + batch_count, + place, + stream, + sm_version, + 0.01, // for leaky_relu + bias_data, + bias_dims, + fuse_gemm_config}; + + if (x.dtype() == phi::DataType::FLOAT8_E4M3FN){ + if (output_dtype == "float16") { + dispatch_fuse_gemm_noact_sm90, + cute::Shape, + cutlass::gemm::collective::KernelScheduleAuto, + cutlass::epilogue::collective::EpilogueScheduleAuto + >(params); + dispatch_fuse_gemm_noact_sm90, + cute::Shape, + cutlass::gemm::collective::KernelScheduleAuto, + cutlass::epilogue::collective::EpilogueScheduleAuto + >(params); + dispatch_fuse_gemm_noact_sm90, + cute::Shape, + cutlass::gemm::collective::KernelScheduleAuto, + cutlass::epilogue::collective::EpilogueScheduleAuto + >(params); + + dispatch_fuse_gemm_noact_sm90, + cute::Shape, + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + cutlass::epilogue::TmaWarpSpecialized + >(params); + dispatch_fuse_gemm_noact_sm90, + cute::Shape, + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + cutlass::epilogue::TmaWarpSpecialized + >(params); + dispatch_fuse_gemm_noact_sm90, + cute::Shape, + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + cutlass::epilogue::TmaWarpSpecialized + >(params); + } + } + // else{ + // if (output_dtype == "bfloat16") { + // dispatch_fuse_gemm_noact_3x(params); + // }else{ + // dispatch_fuse_gemm_noact_3x(params); + // } + // } + + return {out}; +} + +std::vector> CutlassFp8Fp8HalfGemmSm90FusedInferShape( + const std::vector& x_shape, + const std::vector& y_shape, + const paddle::optional>& bias_shape, + bool trans_x, + bool trans_y){ + PADDLE_ENFORCE_EQ(x_shape.size(), + y_shape.size(), + phi::errors::InvalidArgument( + "The rank of input X and Y should be equal, but received X's rank is %d, Y's rank is %d.", + x_shape.size(), + y_shape.size())); + + int rank = x_shape.size(); + int M = 0; + int N = 0; + + if (!trans_x) { + M = x_shape[rank - 2]; + + } else { + M = x_shape[rank - 1]; + } + if (!trans_y) { + N = y_shape[rank - 1]; + } else { + N = y_shape[rank - 2]; + } + std::vector out_shape = x_shape; + out_shape[rank - 1] = N; + out_shape[rank - 2] = M; + return {out_shape}; +} + +std::vector CutlassFp8Fp8HalfGemmSm90FusedInferDtype( + const paddle::DataType& x_type, + const paddle::DataType& y_type, + const paddle::optional& bias_type, + bool trans_x, + bool trans_y, + float scale, // only support per-tensor quantization + std::string output_dtype) { + paddle::DataType data_type; + if (output_dtype == "bfloat16") + data_type = paddle::DataType::BFLOAT16; + else if (output_dtype == "float16") + data_type = paddle::DataType::FLOAT16; + else + PD_THROW( + "fp8_fp8_half_gemm_fused_sm90 only support bfloat16 and float16 output"); + return {data_type}; +} + +PD_BUILD_OP(cutlass_fp8_fp8_half_gemm_sm90_fused) + .Inputs({"x", "y", paddle::Optional("bias")}) + .Attrs({"transpose_x: bool", + "transpose_y: bool", + "scale: float", + "output_dtype: std::string", + "act: std::string"}) + .Outputs({"out"}) + .SetKernelFn(PD_KERNEL(cutlass_fp8_fp8_half_gemm_sm90)) + .SetInferShapeFn(PD_INFER_SHAPE(CutlassFp8Fp8HalfGemmSm90FusedInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(CutlassFp8Fp8HalfGemmSm90FusedInferDtype)); diff --git a/csrc/gpu/fp8_gemm_with_cutlass/generic_gemm_kernel_noact_3x.cu b/csrc/gpu/fp8_gemm_with_cutlass/generic_gemm_kernel_noact_3x.cu new file mode 100644 index 000000000000..04f38e177fad --- /dev/null +++ b/csrc/gpu/fp8_gemm_with_cutlass/generic_gemm_kernel_noact_3x.cu @@ -0,0 +1,66 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fp8_gemm_fused/fuse_gemm_noact_template_3x.h" + + +template<> +bool dispatch_fuse_gemm_noact_sm90, + cute::Shape, + cutlass::gemm::collective::KernelScheduleAuto, + cutlass::epilogue::collective::EpilogueScheduleAuto + >(GemmEpilogueAllParams); + +template<> +bool dispatch_fuse_gemm_noact_sm90, + cute::Shape, + cutlass::gemm::collective::KernelScheduleAuto, + cutlass::epilogue::collective::EpilogueScheduleAuto + >(GemmEpilogueAllParams); + +template<> +bool dispatch_fuse_gemm_noact_sm90, + cute::Shape, + cutlass::gemm::collective::KernelScheduleAuto, + cutlass::epilogue::collective::EpilogueScheduleAuto + >(GemmEpilogueAllParams); + +template<> +bool dispatch_fuse_gemm_noact_sm90, + cute::Shape, + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + cutlass::epilogue::TmaWarpSpecialized + >(GemmEpilogueAllParams); + +template<> +bool dispatch_fuse_gemm_noact_sm90, + cute::Shape, + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + cutlass::epilogue::TmaWarpSpecialized + >(GemmEpilogueAllParams); + +template<> +bool dispatch_fuse_gemm_noact_sm90, + cute::Shape, + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + cutlass::epilogue::TmaWarpSpecialized + >(GemmEpilogueAllParams); \ No newline at end of file diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index 3978f8c17de8..7272c5de13a7 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -66,6 +66,8 @@ def get_gencode_flags(): if not strtobool(os.getenv("FLAG_LLM_PDC", "False")): prop = paddle.device.cuda.get_device_properties() cc = prop.major * 10 + prop.minor + if cc == 90: + cc = f"{cc}a" return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)] else: # support more cuda archs @@ -123,7 +125,7 @@ def get_gencode_flags(): if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir): if not os.path.exists(cutlass_dir): os.makedirs(cutlass_dir) - clone_git_repo("v3.5.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir) + clone_git_repo("v3.5.1", "https://github.com/NVIDIA/cutlass.git", cutlass_dir) json_dir = "third_party/nlohmann_json" if not os.path.exists(json_dir) or not os.listdir(json_dir): @@ -153,15 +155,22 @@ def get_gencode_flags(): if cc >= 80: sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"] -if cc >= 89 and cuda_version >= 12.4: - os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") - os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") +if cc == 89 and cuda_version == 12.4: + os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py --cuda_arch 89") + os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py --cuda_arch 89") sources += find_end_files("gpu/cutlass_kernels/fp8_gemm_fused/autogen", ".cu") sources += [ "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu", "gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu", ] +sources = [] +if cc >= 90 and cuda_version >= 12.0: + sources += [ + "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm_sm90.cu", + "gpu/fp8_gemm_with_cutlass/generic_gemm_kernel_noact_3x.cu", + ] + setup( name="paddlenlp_ops", ext_modules=CUDAExtension(