From 6c75c6da6231ae7a95b221fa29dbcad06ac42a2c Mon Sep 17 00:00:00 2001 From: tianhaodongbd Date: Tue, 26 Nov 2024 17:01:49 +0800 Subject: [PATCH] Adapting npu for FusedHeadAndCrossEntropy --- .../transformers/tensor_parallel_utils.py | 291 ++++++++++++------ 1 file changed, 200 insertions(+), 91 deletions(-) diff --git a/paddlenlp/transformers/tensor_parallel_utils.py b/paddlenlp/transformers/tensor_parallel_utils.py index 679345ff78c1..a391fa9ea9c0 100644 --- a/paddlenlp/transformers/tensor_parallel_utils.py +++ b/paddlenlp/transformers/tensor_parallel_utils.py @@ -243,107 +243,216 @@ def forward( else: grad_hidden_states = None - # initialize outputs - token_loss = paddle.empty((n_tokens,), dtype=hidden_states.dtype) - - # blockwise calculations - for i in range(0, n_tokens, loop_chunk_size): - token_start_idx = i - token_end_idx = min(i + loop_chunk_size, n_tokens) - hidden_states_chunk = hidden_states[token_start_idx:token_end_idx] - labels_chunk = labels[token_start_idx:token_end_idx] - - # logits calculations - logits_chunk_cast = paddle.matmul( - hidden_states_chunk, - lm_head_weight_cast, - transpose_y=transpose_y, - ) - if lm_head_bias is not None: - logits_chunk_cast += lm_head_bias_cast - if tensor_parallel_degree > 1 and not tensor_parallel_output: - logits_chunk_cast_lst = [] - dist.all_gather( - logits_chunk_cast_lst, - logits_chunk_cast, - group=model_parallel_group, + if get_env_device() == "npu": + token_loss_list = [] + grad_hidden_states_list = [] + + token_idx_section = [loop_chunk_size for _ in range(0, n_tokens, loop_chunk_size)] + token_idx_section[-1] = -1 + hidden_states_chunk = hidden_states.split(token_idx_section, axis=0) + labels_chunk = labels.split(token_idx_section, axis=0) + loss_mask_chunk = loss_mask.split(token_idx_section, axis=0) + + for i in range(len(token_idx_section)): + # logits calculations + logits_chunk_cast = paddle.matmul( + hidden_states_chunk[i], + lm_head_weight_cast, + transpose_y=transpose_y, ) - logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1) - logits_chunk = logits_chunk_cast.astype("float32") - - # log softmax - max_logits = paddle.max(logits_chunk, axis=-1, keepdim=True) - if tensor_parallel_degree > 1 and tensor_parallel_output: - dist.all_reduce(max_logits, op=dist.ReduceOp.MAX, group=model_parallel_group) - normalized_logits = logits_chunk - max_logits - exp_logits = paddle.exp(normalized_logits) - sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True) - if tensor_parallel_degree > 1 and tensor_parallel_output: - dist.all_reduce( - sum_exp_logits, - op=dist.ReduceOp.SUM, - group=model_parallel_group, + if lm_head_bias is not None: + logits_chunk_cast += lm_head_bias_cast + if tensor_parallel_degree > 1 and not tensor_parallel_output: + logits_chunk_cast_lst = [] + dist.all_gather( + logits_chunk_cast_lst, + logits_chunk_cast, + group=model_parallel_group, + ) + logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1) + logits_chunk = logits_chunk_cast.astype("float32") + + # log softmax + max_logits = paddle.max(logits_chunk, axis=-1, keepdim=True) + if tensor_parallel_degree > 1 and tensor_parallel_output: + dist.all_reduce(max_logits, op=dist.ReduceOp.MAX, group=model_parallel_group) + normalized_logits = logits_chunk - max_logits + exp_logits = paddle.exp(normalized_logits) + sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True) + if tensor_parallel_degree > 1 and tensor_parallel_output: + dist.all_reduce( + sum_exp_logits, + op=dist.ReduceOp.SUM, + group=model_parallel_group, + ) + log_sum_exp_logits = paddle.log(sum_exp_logits) + + # cross entropy + labels_one_hot = labels_chunk[i].unsqueeze(1) == indices + label_logits = paddle.sum( + paddle.where( + labels_one_hot, + normalized_logits, + paddle.zeros_like(normalized_logits), + ), + axis=-1, + keepdim=True, ) - log_sum_exp_logits = paddle.log(sum_exp_logits) + if tensor_parallel_degree > 1 and tensor_parallel_output: + dist.all_reduce( + label_logits, + op=dist.ReduceOp.SUM, + group=model_parallel_group, + ) + token_loss_chunk = (log_sum_exp_logits - label_logits).squeeze(1) / divisor + cond = loss_mask_chunk[i].astype("bool") + token_loss_chunk = paddle.where(cond, token_loss_chunk, paddle.zeros_like(token_loss_chunk)) + token_loss_list.append((token_loss_chunk * loss_mask_chunk[i])) + + # gradients calculations + if not return_token_loss: + if tensor_parallel_degree > 1 and not tensor_parallel_output: + exp_logits = exp_logits.split(model_parallel_group.nranks, axis=-1)[model_parallel_group.rank] + labels_one_hot = labels_one_hot.split(model_parallel_group.nranks, axis=-1)[ + model_parallel_group.rank + ] + grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot.astype("float32")) / divisor + grad_logits_chunk = grad_logits_chunk.astype(dtype) + grad_logits_chunk = paddle.where( + cond.unsqueeze(1), + grad_logits_chunk, + paddle.zeros_like(grad_logits_chunk), + ) - # cross entropy - labels_one_hot = labels_chunk.unsqueeze(1) == indices - label_logits = paddle.sum( - paddle.where( - labels_one_hot, - normalized_logits, - paddle.zeros_like(normalized_logits), - ), - axis=-1, - keepdim=True, - ) - if tensor_parallel_degree > 1 and tensor_parallel_output: - dist.all_reduce( - label_logits, - op=dist.ReduceOp.SUM, - group=model_parallel_group, + if grad_hidden_states is not None: + grad_hidden_states_list.append( + paddle.matmul( + grad_logits_chunk, + lm_head_weight_cast, + transpose_y=not transpose_y, + ) + ) + if grad_lm_head_weight is not None: + if transpose_y: + grad_lm_head_weight += paddle.matmul( + grad_logits_chunk, + hidden_states_chunk[i], + transpose_x=True, + ) + else: + grad_lm_head_weight += paddle.matmul( + hidden_states_chunk[i], + grad_logits_chunk, + transpose_x=True, + ) + if grad_lm_head_bias is not None: + grad_lm_head_bias += grad_logits_chunk.astype("float32").sum(axis=0).astype(dtype) + + token_loss = paddle.concat(token_loss_list, axis=0) + if grad_hidden_states is not None: + grad_hidden_states = paddle.concat(grad_hidden_states_list, axis=0) + else: + # initialize outputs + token_loss = paddle.empty((n_tokens,), dtype=hidden_states.dtype) + + # blockwise calculations + for i in range(0, n_tokens, loop_chunk_size): + token_start_idx = i + token_end_idx = min(i + loop_chunk_size, n_tokens) + hidden_states_chunk = hidden_states[token_start_idx:token_end_idx] + labels_chunk = labels[token_start_idx:token_end_idx] + + # logits calculations + logits_chunk_cast = paddle.matmul( + hidden_states_chunk, + lm_head_weight_cast, + transpose_y=transpose_y, ) - token_loss_chunk = (log_sum_exp_logits - label_logits).squeeze(1) / divisor - cond = loss_mask[token_start_idx:token_end_idx].astype("bool") - token_loss_chunk = paddle.where(cond, token_loss_chunk, paddle.zeros_like(token_loss_chunk)) - token_loss[token_start_idx:token_end_idx] = token_loss_chunk * loss_mask[token_start_idx:token_end_idx] - - # gradients calculations - if not return_token_loss: + if lm_head_bias is not None: + logits_chunk_cast += lm_head_bias_cast if tensor_parallel_degree > 1 and not tensor_parallel_output: - exp_logits = exp_logits.split(model_parallel_group.nranks, axis=-1)[model_parallel_group.rank] - labels_one_hot = labels_one_hot.split(model_parallel_group.nranks, axis=-1)[ - model_parallel_group.rank - ] - grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot.astype("float32")) / divisor - grad_logits_chunk = grad_logits_chunk.astype(dtype) - grad_logits_chunk = paddle.where( - cond.unsqueeze(1), - grad_logits_chunk, - paddle.zeros_like(grad_logits_chunk), + logits_chunk_cast_lst = [] + dist.all_gather( + logits_chunk_cast_lst, + logits_chunk_cast, + group=model_parallel_group, + ) + logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1) + logits_chunk = logits_chunk_cast.astype("float32") + + # log softmax + max_logits = paddle.max(logits_chunk, axis=-1, keepdim=True) + if tensor_parallel_degree > 1 and tensor_parallel_output: + dist.all_reduce(max_logits, op=dist.ReduceOp.MAX, group=model_parallel_group) + normalized_logits = logits_chunk - max_logits + exp_logits = paddle.exp(normalized_logits) + sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True) + if tensor_parallel_degree > 1 and tensor_parallel_output: + dist.all_reduce( + sum_exp_logits, + op=dist.ReduceOp.SUM, + group=model_parallel_group, + ) + log_sum_exp_logits = paddle.log(sum_exp_logits) + + # cross entropy + labels_one_hot = labels_chunk.unsqueeze(1) == indices + label_logits = paddle.sum( + paddle.where( + labels_one_hot, + normalized_logits, + paddle.zeros_like(normalized_logits), + ), + axis=-1, + keepdim=True, ) - - if grad_hidden_states is not None: - grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul( + if tensor_parallel_degree > 1 and tensor_parallel_output: + dist.all_reduce( + label_logits, + op=dist.ReduceOp.SUM, + group=model_parallel_group, + ) + token_loss_chunk = (log_sum_exp_logits - label_logits).squeeze(1) / divisor + cond = loss_mask[token_start_idx:token_end_idx].astype("bool") + token_loss_chunk = paddle.where(cond, token_loss_chunk, paddle.zeros_like(token_loss_chunk)) + token_loss[token_start_idx:token_end_idx] = token_loss_chunk * loss_mask[token_start_idx:token_end_idx] + + # gradients calculations + if not return_token_loss: + if tensor_parallel_degree > 1 and not tensor_parallel_output: + exp_logits = exp_logits.split(model_parallel_group.nranks, axis=-1)[model_parallel_group.rank] + labels_one_hot = labels_one_hot.split(model_parallel_group.nranks, axis=-1)[ + model_parallel_group.rank + ] + grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot.astype("float32")) / divisor + grad_logits_chunk = grad_logits_chunk.astype(dtype) + grad_logits_chunk = paddle.where( + cond.unsqueeze(1), grad_logits_chunk, - lm_head_weight_cast, - transpose_y=not transpose_y, + paddle.zeros_like(grad_logits_chunk), ) - if grad_lm_head_weight is not None: - if transpose_y: - grad_lm_head_weight += paddle.matmul( - grad_logits_chunk, - hidden_states_chunk, - transpose_x=True, - ) - else: - grad_lm_head_weight += paddle.matmul( - hidden_states_chunk, + + if grad_hidden_states is not None: + grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul( grad_logits_chunk, - transpose_x=True, + lm_head_weight_cast, + transpose_y=not transpose_y, ) - if grad_lm_head_bias is not None: - grad_lm_head_bias += grad_logits_chunk.astype("float32").sum(axis=0).astype(dtype) + if grad_lm_head_weight is not None: + if transpose_y: + grad_lm_head_weight += paddle.matmul( + grad_logits_chunk, + hidden_states_chunk, + transpose_x=True, + ) + else: + grad_lm_head_weight += paddle.matmul( + hidden_states_chunk, + grad_logits_chunk, + transpose_x=True, + ) + if grad_lm_head_bias is not None: + grad_lm_head_bias += grad_logits_chunk.astype("float32").sum(axis=0).astype(dtype) if return_token_loss: loss = token_loss.reshape(original_shape[:-1])