-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Optional Activation node to NodeUnit #22888
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
const Node& target_node_; | ||
const Node* p_activation_node_; // Optional activation node for the QDQ group, nullptr if not present. | ||
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const Node& target_node_; | |
const Node* p_activation_node_; // Optional activation node for the QDQ group, nullptr if not present. | |
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs | |
const Node& target_node_; | |
const Node* p_activation_node_; // Optional activation node for the QDQ group, nullptr if not present. | |
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs |
@@ -87,6 +89,7 @@ class NodeUnit { | |||
ProviderType GetExecutionProviderType() const noexcept; | |||
|
|||
const Node& GetNode() const noexcept { return target_node_; } | |||
const Node* GetActivationNode() const noexcept { return p_activation_node_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have a comment explaining what the 'activation' node is and how it is expected to be used.
IIUC
- if you're using the QDQ node unit for the quantized version of the target node the activation node can be ignored as it's made redundant by the values of the Q node
- so it's not really a 'fusion' per se as we're not combining the values of the Clip/Relu with the Q, we're ignoring it
- if you are falling back to higher precision and dropping the DQ/Q nodes, you need to keep both the target node and activation node if present
If that's correct I'd almost be inclined to call it something like redundant_clip_node (given Relu is a form of Clip).
Also as the OpenVINO EP (IIRC) is doing the fallback to higher precision does it need an update to be aware of the activation node in the NodeUnit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to combine values of Clip/Relu with Q? Is there a formular I can follow to adjust the scale/zp value in the Q?
} | ||
return true; | ||
} | ||
bool GetQSalarScaleZp(const GraphViewer& graph_viewer, const Node& q_node, float& scale, int32_t& zp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool GetQSalarScaleZp(const GraphViewer& graph_viewer, const Node& q_node, float& scale, int32_t& zp, | |
bool GetQScalarScaleZp(const GraphViewer& graph_viewer, const Node& q_node, float& scale, int32_t& zp, |
Would be good to refactor some of the utils here as there seems to be a fair bit of duplication.
e.g. maybe a general purpose helper that reads the scale and zp values (scalar or otherwise), and has a bool to indicate if they're scalar. that helper could be used by many of the utils here.
int32_t& data_type) { | ||
assert(q_node.OpType() == QOpName); | ||
const auto& q_input_defs = q_node.InputDefs(); | ||
if (q_input_defs.size() != 3 || !q_input_defs[2]->Exists()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zp is optional and defaults to zero, so do we need to require 3 inputs here?
@@ -49,7 +49,7 @@ std::vector<const Node*> FindQDQNodes(const GraphViewer& graph_viewer, const Nod | |||
} | |||
} // namespace | |||
|
|||
bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node, | |||
bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we don't use hungarian notation anywhere else, so could we use activation_node
instead of p_activation_node
?
if (p_activation_node) { | ||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the activation node is made redundant by the Q, what's the reason we can't create a QDQ node unit for this sort of operator?
No description provided.