Skip to content
Open
22 changes: 16 additions & 6 deletions onnxruntime/contrib_ops/cpu/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,27 @@
const T* p_input = input_data + start;
T* p_output = output_data + start;
int64_t count = std::min(length_per_task, elem_count - start);
for (int64_t i = 0; i < count; i++) {
p_output[i] = p_input[i] * alpha_;
}

MlasComputeLogistic(p_output, p_output, onnxruntime::narrow<size_t>(count));
if (alpha_ != 1.0f) {
// TODO: Consider vectorizing this scalar multiplication.

Check warning on line 82 in onnxruntime/contrib_ops/cpu/activations.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cpu/activations.h:82: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// It needs exposing a new API in MLAS to take in a scalar
// that will be used in the elementwise multiplcation.
// Estimate the cost-benefit tradeoff before proceeding
// with that optimization.
for (int64_t i = 0; i < count; i++) {
p_output[i] = p_input[i] * alpha_;
}

for (int64_t i = 0; i < count; i++) {
p_output[i] = p_input[i] * p_output[i];
MlasComputeLogistic(p_output, p_output, onnxruntime::narrow<size_t>(count));
} else {
// SILU activation - this needs no `alpha_` scaling as `alpha_` will be 1.0f
MlasComputeLogistic(p_input, p_output, onnxruntime::narrow<size_t>(count));
}

MlasEltwiseMul<float>(p_input, p_output, p_output, onnxruntime::narrow<size_t>(count));
},
0);

return Status::OK();
}

Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,16 @@ MlasEltwiseAdd(
size_t N
);

template <typename T>
void
MLASCALL
MlasEltwiseMul(
const T* left,
const T* right,
T* output,
size_t N
);

template<typename T>
void
MLASCALL
Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/core/mlas/lib/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,38 @@ MlasEltwiseAdd<float>(
}
}

template <>
void
MLASCALL
MlasEltwiseMul<float>(
const float* left,
const float* right,
float* output,
size_t N
) {
while (N > 0) {
if (N >= 4) {
MLAS_FLOAT32X4 LeftVec = MlasLoadFloat32x4(left);
MLAS_FLOAT32X4 RightVec = MlasLoadFloat32x4(right);

MLAS_FLOAT32X4 ResultVec = MlasMultiplyFloat32x4(LeftVec, RightVec);

MlasStoreFloat32x4(output, ResultVec);

left += 4;
right += 4;
output += 4;
N -= 4;
} else {
*output = *left * *right;

left += 1;
right += 1;
output += 1;
N -= 1;
}
}
}

template <>
void
Expand Down
Loading