File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
onnxruntime/contrib_ops/cpu Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -70,7 +70,6 @@ class QuickGelu : public OpKernel {
7070 int64_t elem_count = input->Shape ().Size ();
7171 constexpr int64_t length_per_task = 4096 ; // this number comes from FastGelu.
7272 int64_t task_count = (elem_count + length_per_task - 1 ) / length_per_task;
73-
7473 concurrency::ThreadPool::TryBatchParallelFor (
7574 tp, static_cast <int32_t >(task_count),
7675 [&](ptrdiff_t task_idx) {
@@ -80,7 +79,11 @@ class QuickGelu : public OpKernel {
8079 int64_t count = std::min (length_per_task, elem_count - start);
8180
8281 if (alpha_ != 1 .0f ) {
83- // TODO: Vectorize this compute
82+ // TODO: Consider vectorizing this scalar multiplication.
83+ // It needs exposing a new API in MLAS to take in a scalar
84+ // that will be used in the elementwise multiplcation.
85+ // Estimate the cost-benefit tradeoff before proceeding
86+ // with that optimization.
8487 for (int64_t i = 0 ; i < count; i++) {
8588 p_output[i] = p_input[i] * alpha_;
8689 }
You can’t perform that action at this time.
0 commit comments