Skip to content

Commit 9f801d3

Browse files
authored
Update activations.h
Add comment for potential future work
1 parent 667f22b commit 9f801d3

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

onnxruntime/contrib_ops/cpu/activations.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ class QuickGelu : public OpKernel {
7070
int64_t elem_count = input->Shape().Size();
7171
constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
7272
int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
73-
7473
concurrency::ThreadPool::TryBatchParallelFor(
7574
tp, static_cast<int32_t>(task_count),
7675
[&](ptrdiff_t task_idx) {
@@ -80,7 +79,11 @@ class QuickGelu : public OpKernel {
8079
int64_t count = std::min(length_per_task, elem_count - start);
8180

8281
if (alpha_ != 1.0f) {
83-
// TODO: Vectorize this compute
82+
// TODO: Consider vectorizing this scalar multiplication.
83+
// It needs exposing a new API in MLAS to take in a scalar
84+
// that will be used in the elementwise multiplcation.
85+
// Estimate the cost-benefit tradeoff before proceeding
86+
// with that optimization.
8487
for (int64_t i = 0; i < count; i++) {
8588
p_output[i] = p_input[i] * alpha_;
8689
}

0 commit comments

Comments
 (0)