Skip to content

Commit 49d7a42

Browse files
committed
sync to Xavier's cpu refactors
1 parent 1e7d5ae commit 49d7a42

File tree

5 files changed

+189
-210
lines changed

5 files changed

+189
-210
lines changed

onnxruntime/core/providers/cpu/llm/attention.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include "core/providers/cpu/llm/attention.h"
5+
#include "core/providers/cpu/llm/attention_helper.h"
56

67
#include "core/common/common.h"
78
#include "core/common/safeint.h"
@@ -140,10 +141,10 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
140141
const Tensor* past_value = context->Input<Tensor>(5);
141142

142143
AttentionParameters parameters;
143-
std::vector<int64_t> y_shape;
144-
std::vector<int64_t> present_key_shape;
145-
std::vector<int64_t> present_value_shape;
146-
std::vector<int64_t> output_qk_shape;
144+
TensorShape y_shape;
145+
TensorShape present_key_shape;
146+
TensorShape present_value_shape;
147+
TensorShape output_qk_shape;
147148

148149
ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention(
149150
Q,

onnxruntime/core/providers/cpu/llm/attention.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "core/common/common.h"
66
#include "core/framework/op_kernel.h"
77
#include "core/platform/threadpool.h"
8-
#include "core/providers/cpu/llm/attention_helper.h"
8+
#include "core/providers/cpu/llm/attention_parameters.h"
99

1010
namespace onnxruntime {
1111

@@ -95,4 +95,4 @@ class Attention final : public AttentionBase<T> {
9595
int softmax_precision_;
9696
};
9797

98-
} // namespace onnxruntime
98+
} // namespace onnxruntime

onnxruntime/core/providers/cpu/llm/attention_helper.cc

Lines changed: 0 additions & 156 deletions
This file was deleted.

0 commit comments

Comments
 (0)