-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[js/webgpu] support FlashAttention-2 for attention operator #22915
base: main
Are you sure you want to change the base?
Conversation
1eb8d05
to
1aaa47e
Compare
The FlashAttention-2 algorithm is based on the paper https://tridao.me/publications/flash2/flash2.pdf. |
/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline |
/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models |
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline |
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
Azure Pipelines successfully started running 1 pipeline(s). |
1aaa47e
to
c936365
Compare
Sorry for the error, I wanted to keep the names as the paper, but the variables‘ names mismatched rules. I had already modified the names. Thanks. |
Q_i[local_id.y][u32(${workgroupSize[0]} * tile) + local_id.x] = Q[offset + local_id.y * uniforms.d + u32(${workgroupSize[0]} * tile) + local_id.x]; | ||
} | ||
|
||
for (var j = 0; j < ${tC}; j++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turn tC to uniform or add it to hint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I also noticed this issue. I added it to uniform.
context.inputs[4] === undefined && | ||
context.inputs[5] === undefined | ||
) { | ||
return applyFlashAttentionV2(context, q, k, v, params, attributes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any existing case cover this branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, I also tested it on sd2.1. From the conditions, we know that the input data size is large, so I am not sure it is reasonable to add an unit test here.
@xhcao my github name is jchen10. |
Description
Motivation and Context