Skip to content

Commit

Permalink
Adapting npu for FusedHeadAndCrossEntropy
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd committed Nov 26, 2024
1 parent d68a385 commit 6c75c6d
Showing 1 changed file with 200 additions and 91 deletions.
291 changes: 200 additions & 91 deletions paddlenlp/transformers/tensor_parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,107 +243,216 @@ def forward(
else:
grad_hidden_states = None

# initialize outputs
token_loss = paddle.empty((n_tokens,), dtype=hidden_states.dtype)

# blockwise calculations
for i in range(0, n_tokens, loop_chunk_size):
token_start_idx = i
token_end_idx = min(i + loop_chunk_size, n_tokens)
hidden_states_chunk = hidden_states[token_start_idx:token_end_idx]
labels_chunk = labels[token_start_idx:token_end_idx]

# logits calculations
logits_chunk_cast = paddle.matmul(
hidden_states_chunk,
lm_head_weight_cast,
transpose_y=transpose_y,
)
if lm_head_bias is not None:
logits_chunk_cast += lm_head_bias_cast
if tensor_parallel_degree > 1 and not tensor_parallel_output:
logits_chunk_cast_lst = []
dist.all_gather(
logits_chunk_cast_lst,
logits_chunk_cast,
group=model_parallel_group,
if get_env_device() == "npu":
token_loss_list = []
grad_hidden_states_list = []

token_idx_section = [loop_chunk_size for _ in range(0, n_tokens, loop_chunk_size)]
token_idx_section[-1] = -1
hidden_states_chunk = hidden_states.split(token_idx_section, axis=0)
labels_chunk = labels.split(token_idx_section, axis=0)
loss_mask_chunk = loss_mask.split(token_idx_section, axis=0)

for i in range(len(token_idx_section)):
# logits calculations
logits_chunk_cast = paddle.matmul(
hidden_states_chunk[i],
lm_head_weight_cast,
transpose_y=transpose_y,
)
logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1)
logits_chunk = logits_chunk_cast.astype("float32")

# log softmax
max_logits = paddle.max(logits_chunk, axis=-1, keepdim=True)
if tensor_parallel_degree > 1 and tensor_parallel_output:
dist.all_reduce(max_logits, op=dist.ReduceOp.MAX, group=model_parallel_group)
normalized_logits = logits_chunk - max_logits
exp_logits = paddle.exp(normalized_logits)
sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True)
if tensor_parallel_degree > 1 and tensor_parallel_output:
dist.all_reduce(
sum_exp_logits,
op=dist.ReduceOp.SUM,
group=model_parallel_group,
if lm_head_bias is not None:
logits_chunk_cast += lm_head_bias_cast
if tensor_parallel_degree > 1 and not tensor_parallel_output:
logits_chunk_cast_lst = []
dist.all_gather(
logits_chunk_cast_lst,
logits_chunk_cast,
group=model_parallel_group,
)
logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1)
logits_chunk = logits_chunk_cast.astype("float32")

# log softmax
max_logits = paddle.max(logits_chunk, axis=-1, keepdim=True)
if tensor_parallel_degree > 1 and tensor_parallel_output:
dist.all_reduce(max_logits, op=dist.ReduceOp.MAX, group=model_parallel_group)
normalized_logits = logits_chunk - max_logits
exp_logits = paddle.exp(normalized_logits)
sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True)
if tensor_parallel_degree > 1 and tensor_parallel_output:
dist.all_reduce(
sum_exp_logits,
op=dist.ReduceOp.SUM,
group=model_parallel_group,
)
log_sum_exp_logits = paddle.log(sum_exp_logits)

# cross entropy
labels_one_hot = labels_chunk[i].unsqueeze(1) == indices
label_logits = paddle.sum(
paddle.where(
labels_one_hot,
normalized_logits,
paddle.zeros_like(normalized_logits),
),
axis=-1,
keepdim=True,
)
log_sum_exp_logits = paddle.log(sum_exp_logits)
if tensor_parallel_degree > 1 and tensor_parallel_output:
dist.all_reduce(
label_logits,
op=dist.ReduceOp.SUM,
group=model_parallel_group,
)
token_loss_chunk = (log_sum_exp_logits - label_logits).squeeze(1) / divisor
cond = loss_mask_chunk[i].astype("bool")
token_loss_chunk = paddle.where(cond, token_loss_chunk, paddle.zeros_like(token_loss_chunk))
token_loss_list.append((token_loss_chunk * loss_mask_chunk[i]))

# gradients calculations
if not return_token_loss:
if tensor_parallel_degree > 1 and not tensor_parallel_output:
exp_logits = exp_logits.split(model_parallel_group.nranks, axis=-1)[model_parallel_group.rank]
labels_one_hot = labels_one_hot.split(model_parallel_group.nranks, axis=-1)[
model_parallel_group.rank
]
grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot.astype("float32")) / divisor
grad_logits_chunk = grad_logits_chunk.astype(dtype)
grad_logits_chunk = paddle.where(
cond.unsqueeze(1),
grad_logits_chunk,
paddle.zeros_like(grad_logits_chunk),
)

# cross entropy
labels_one_hot = labels_chunk.unsqueeze(1) == indices
label_logits = paddle.sum(
paddle.where(
labels_one_hot,
normalized_logits,
paddle.zeros_like(normalized_logits),
),
axis=-1,
keepdim=True,
)
if tensor_parallel_degree > 1 and tensor_parallel_output:
dist.all_reduce(
label_logits,
op=dist.ReduceOp.SUM,
group=model_parallel_group,
if grad_hidden_states is not None:
grad_hidden_states_list.append(
paddle.matmul(
grad_logits_chunk,
lm_head_weight_cast,
transpose_y=not transpose_y,
)
)
if grad_lm_head_weight is not None:
if transpose_y:
grad_lm_head_weight += paddle.matmul(
grad_logits_chunk,
hidden_states_chunk[i],
transpose_x=True,
)
else:
grad_lm_head_weight += paddle.matmul(
hidden_states_chunk[i],
grad_logits_chunk,
transpose_x=True,
)
if grad_lm_head_bias is not None:
grad_lm_head_bias += grad_logits_chunk.astype("float32").sum(axis=0).astype(dtype)

token_loss = paddle.concat(token_loss_list, axis=0)
if grad_hidden_states is not None:
grad_hidden_states = paddle.concat(grad_hidden_states_list, axis=0)
else:
# initialize outputs
token_loss = paddle.empty((n_tokens,), dtype=hidden_states.dtype)

# blockwise calculations
for i in range(0, n_tokens, loop_chunk_size):
token_start_idx = i
token_end_idx = min(i + loop_chunk_size, n_tokens)
hidden_states_chunk = hidden_states[token_start_idx:token_end_idx]
labels_chunk = labels[token_start_idx:token_end_idx]

# logits calculations
logits_chunk_cast = paddle.matmul(
hidden_states_chunk,
lm_head_weight_cast,
transpose_y=transpose_y,
)
token_loss_chunk = (log_sum_exp_logits - label_logits).squeeze(1) / divisor
cond = loss_mask[token_start_idx:token_end_idx].astype("bool")
token_loss_chunk = paddle.where(cond, token_loss_chunk, paddle.zeros_like(token_loss_chunk))
token_loss[token_start_idx:token_end_idx] = token_loss_chunk * loss_mask[token_start_idx:token_end_idx]

# gradients calculations
if not return_token_loss:
if lm_head_bias is not None:
logits_chunk_cast += lm_head_bias_cast
if tensor_parallel_degree > 1 and not tensor_parallel_output:
exp_logits = exp_logits.split(model_parallel_group.nranks, axis=-1)[model_parallel_group.rank]
labels_one_hot = labels_one_hot.split(model_parallel_group.nranks, axis=-1)[
model_parallel_group.rank
]
grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot.astype("float32")) / divisor
grad_logits_chunk = grad_logits_chunk.astype(dtype)
grad_logits_chunk = paddle.where(
cond.unsqueeze(1),
grad_logits_chunk,
paddle.zeros_like(grad_logits_chunk),
logits_chunk_cast_lst = []
dist.all_gather(
logits_chunk_cast_lst,
logits_chunk_cast,
group=model_parallel_group,
)
logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1)
logits_chunk = logits_chunk_cast.astype("float32")

# log softmax
max_logits = paddle.max(logits_chunk, axis=-1, keepdim=True)
if tensor_parallel_degree > 1 and tensor_parallel_output:
dist.all_reduce(max_logits, op=dist.ReduceOp.MAX, group=model_parallel_group)
normalized_logits = logits_chunk - max_logits
exp_logits = paddle.exp(normalized_logits)
sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True)
if tensor_parallel_degree > 1 and tensor_parallel_output:
dist.all_reduce(
sum_exp_logits,
op=dist.ReduceOp.SUM,
group=model_parallel_group,
)
log_sum_exp_logits = paddle.log(sum_exp_logits)

# cross entropy
labels_one_hot = labels_chunk.unsqueeze(1) == indices
label_logits = paddle.sum(
paddle.where(
labels_one_hot,
normalized_logits,
paddle.zeros_like(normalized_logits),
),
axis=-1,
keepdim=True,
)

if grad_hidden_states is not None:
grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul(
if tensor_parallel_degree > 1 and tensor_parallel_output:
dist.all_reduce(
label_logits,
op=dist.ReduceOp.SUM,
group=model_parallel_group,
)
token_loss_chunk = (log_sum_exp_logits - label_logits).squeeze(1) / divisor
cond = loss_mask[token_start_idx:token_end_idx].astype("bool")
token_loss_chunk = paddle.where(cond, token_loss_chunk, paddle.zeros_like(token_loss_chunk))
token_loss[token_start_idx:token_end_idx] = token_loss_chunk * loss_mask[token_start_idx:token_end_idx]

# gradients calculations
if not return_token_loss:
if tensor_parallel_degree > 1 and not tensor_parallel_output:
exp_logits = exp_logits.split(model_parallel_group.nranks, axis=-1)[model_parallel_group.rank]
labels_one_hot = labels_one_hot.split(model_parallel_group.nranks, axis=-1)[
model_parallel_group.rank
]
grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot.astype("float32")) / divisor
grad_logits_chunk = grad_logits_chunk.astype(dtype)
grad_logits_chunk = paddle.where(
cond.unsqueeze(1),
grad_logits_chunk,
lm_head_weight_cast,
transpose_y=not transpose_y,
paddle.zeros_like(grad_logits_chunk),
)
if grad_lm_head_weight is not None:
if transpose_y:
grad_lm_head_weight += paddle.matmul(
grad_logits_chunk,
hidden_states_chunk,
transpose_x=True,
)
else:
grad_lm_head_weight += paddle.matmul(
hidden_states_chunk,

if grad_hidden_states is not None:
grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul(
grad_logits_chunk,
transpose_x=True,
lm_head_weight_cast,
transpose_y=not transpose_y,
)
if grad_lm_head_bias is not None:
grad_lm_head_bias += grad_logits_chunk.astype("float32").sum(axis=0).astype(dtype)
if grad_lm_head_weight is not None:
if transpose_y:
grad_lm_head_weight += paddle.matmul(
grad_logits_chunk,
hidden_states_chunk,
transpose_x=True,
)
else:
grad_lm_head_weight += paddle.matmul(
hidden_states_chunk,
grad_logits_chunk,
transpose_x=True,
)
if grad_lm_head_bias is not None:
grad_lm_head_bias += grad_logits_chunk.astype("float32").sum(axis=0).astype(dtype)

if return_token_loss:
loss = token_loss.reshape(original_shape[:-1])
Expand Down

0 comments on commit 6c75c6d

Please sign in to comment.