Skip to content

Commit

Permalink
Merge branch 'PaddlePaddle:develop' into open_vpp_test
Browse files Browse the repository at this point in the history
  • Loading branch information
AndSonder authored Nov 25, 2024
2 parents 1f2af2d + 8fd33a9 commit 3fd9551
Show file tree
Hide file tree
Showing 49 changed files with 3,754 additions and 191 deletions.
13 changes: 1 addition & 12 deletions csrc/cpu/src/stop_generation_multi_ends.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,9 @@
#include <stdlib.h>
#include <string.h>

#include "paddle/extension.h"
#include "helper.h"
#include <stdio.h>


bool is_in_end(const int64_t id, const int64_t* end_ids, int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
}
return flag;
}

void set_value_by_flags(const bool* stop_flags,
const int64_t* end_ids,
int64_t* topk_ids,
Expand Down
71 changes: 45 additions & 26 deletions csrc/gpu/get_output.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -17,53 +17,72 @@
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>

#include "paddle/extension.h"

#define MAX_BSZ 512
#define SPECULATE_MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6

struct msgdata {
long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
template <int SIZE>
struct MsgData {
long mtype;
std::array<int, SIZE> mtext;
};

void GetOutput(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
template <int SIZE>
void GetOutputFunc(MsgData<SIZE>& msg_rcv, // NOLINT
const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
if (rank_id > 0) return;

static struct msgdata msg_rcv;

static key_t key = ftok("./", 1);

static int msgid = msgget(key, IPC_CREAT | 0666);

int64_t *out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
if(ret == -1)
{
ret = msgrcv(
msgid, &msg_rcv, SIZE * sizeof(int), 0, wait_flag ? 0 : IPC_NOWAIT);

int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());

if (ret == -1) {
// read none
out_data[0] = -2;
out_data[1] = 0;
return;
}

int bsz = msg_rcv.mtext[1];
return;
}

for (int64_t i = 0; i < bsz + 2; i++) {
for (int64_t i = 0; i < SIZE; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}

return;
}

void GetOutput(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
bool speculative_decoding) {
if (!speculative_decoding) {
constexpr int SIZE = MAX_BSZ + 2; // stop_flag, bsz, tokens...
static struct MsgData<SIZE> msg_rcv;
GetOutputFunc<SIZE>(msg_rcv, x, rank_id, wait_flag);
} else {
constexpr int SIZE = SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS +
SPECULATE_MAX_BSZ +
2; // stop_flag, bsz, accept_num*bsz, tokens...
static struct MsgData<SIZE> specu_msg_rcv;
GetOutputFunc<SIZE>(specu_msg_rcv, x, rank_id, wait_flag);
}
}

PD_BUILD_OP(get_output)
.Inputs({"x"})
.Attrs({"rank_id: int64_t",
"wait_flag: bool"})
"wait_flag: bool",
"speculative_decoding: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(GetOutput));
.SetKernelFn(PD_KERNEL(GetOutput));
72 changes: 56 additions & 16 deletions csrc/gpu/get_padding_offset_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ __global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
const int64_t *input_data,
const int *cum_offsets,
const int *seq_lens,
const int64_t *draft_tokens,
const int *seq_lens_encoder,
const int max_draft_tokens,
const int max_seq_len) {
// get padding offset of each batch
const int bi = blockIdx.x;
Expand All @@ -31,8 +34,18 @@ __global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
const int tgt_seq_id = bi * max_seq_len - cum_offset + i;
const int src_seq_id = bi * max_seq_len + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
if (draft_tokens == nullptr) {
const int src_seq_id = bi * max_seq_len + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
} else { // speculative decoding
if (seq_lens_encoder[bi] > 0) {
const int src_seq_id = bi * max_seq_len + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
} else {
const int src_seq_id = bi * max_draft_tokens + i;
output_data[tgt_seq_id] = draft_tokens[src_seq_id];
}
}
}
if (ti == 0) {
if (bi == 0) {
Expand All @@ -50,7 +63,9 @@ __global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len) {
const paddle::Tensor& seq_len,
const paddle::optional<paddle::Tensor>& draft_tokens,
const paddle::optional<paddle::Tensor>& seq_lens_encoder) {
auto cu_stream = input_ids.stream();
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
Expand All @@ -65,23 +80,46 @@ std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
auto cu_seqlens_q = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place());

GetPaddingOffsetV2Kernel<<<bsz, 128, 0, cu_stream>>>(
padding_offset.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length);
int max_draft_tokens = 0;
if (draft_tokens) { // speculative decoding
max_draft_tokens = draft_tokens.get().shape()[1];
GetPaddingOffsetV2Kernel<<<bsz, 128, 0, cu_stream>>>(
padding_offset.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
draft_tokens.get_ptr()->data<int64_t>(),
seq_lens_encoder.get_ptr()->data<int>(),
max_draft_tokens,
seq_length);
} else {
GetPaddingOffsetV2Kernel<<<bsz, 128, 0, cu_stream>>>(
padding_offset.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
nullptr,
nullptr,
max_draft_tokens,
seq_length);
}
return {x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num};
}

std::vector<std::vector<int64_t>> GetPaddingOffsetV2InferShape(const std::vector<int64_t>& input_ids_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& token_num_shape,
const std::vector<int64_t>& seq_len_shape) {
const std::vector<int64_t>& seq_len_shape,
const std::vector<int64_t>& draft_tokens_shape,
const std::vector<int64_t>& seq_lens_encoder_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
Expand All @@ -90,12 +128,14 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetV2InferShape(const std::vector
std::vector<paddle::DataType> GetPaddingOffsetV2InferDtype(const paddle::DataType& input_ids_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& token_num_dtype,
const paddle::DataType& seq_len_dtype) {
const paddle::DataType& seq_len_dtype,
const paddle::DataType& draft_tokens_dtype,
const paddle::DataType& seq_lens_encoder_dtype) {
return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
}

PD_BUILD_OP(get_padding_offset_v2)
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len", paddle::Optional("draft_tokens"), paddle::Optional("seq_lens_encoder"),})
.Outputs({"x_remove_padding", "cum_offsets_out", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"})
.SetKernelFn(PD_KERNEL(GetPaddingOffsetV2))
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape))
Expand Down
10 changes: 10 additions & 0 deletions csrc/gpu/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,13 @@ inline paddle::Tensor GetEmptyTensor(const common::DDim& dims, const paddle::Dat
dense_tensor.AllocateFrom(allocator, dtype, dense_tensor.numel() * phi::SizeOf(dtype));
return paddle::Tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
}

__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids, int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
}
return flag;
}
Loading

0 comments on commit 3fd9551

Please sign in to comment.