Skip to content

Commit 0e8211a

Browse files
committed
[CPU]Support RoPE for GLM4
Signed-off-by: Zhang Yi <[email protected]>
1 parent cb1ec75 commit 0e8211a

File tree

4 files changed

+63
-24
lines changed

4 files changed

+63
-24
lines changed

src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,11 @@ ov::pass::RoPEFusionChatGLMHF::RoPEFusionChatGLMHF() {
783783
auto reshape = pattern::wrap_type<v1::Reshape>({qk_linear, pattern::any_input()},
784784
pattern::shape_matches("[?, head_cnt, 1, head_size]"),
785785
{{"special_zero", false}});
786-
auto slice_1 = NewGenSlice(reshape, 0, "ndims", 1, 3);
786+
787+
auto qkv_proj =
788+
pattern::wrap_type<v1::VariadicSplit>({reshape, 3, {"ndims", "ndims"}});
789+
qkv_proj->set_output_size(2);
790+
auto slice_1 = NewGenSlice(reshape, 0, "ndims", 1, 3) | qkv_proj->output(0);
787791

788792
auto const_idx =
789793
pattern::wrap_type<ov::opset1::Constant>(pattern::type_matches(ov::element::i32) && const_idx_predicate);
@@ -807,7 +811,7 @@ ov::pass::RoPEFusionChatGLMHF::RoPEFusionChatGLMHF() {
807811
auto multiply_1 = pattern::wrap_type<v1::Multiply>({flatten, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}});
808812
auto add = pattern::wrap_type<v1::Add>({multiply, multiply_1}, {{"auto_broadcast", "numpy"}});
809813

810-
auto slice_5 = NewGenSlice(reshape, "ndims", INT_MAX, 1, 3);
814+
auto slice_5 = NewGenSlice(reshape, "ndims", INT_MAX, 1, 3) | qkv_proj->output(1);
811815
auto result = pattern::wrap_type<v0::Concat>({add, slice_5}, {{"axis", -1}});
812816

813817
matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {

src/plugins/intel_cpu/src/nodes/rope.cpp

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,26 @@ struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor {
252252
jcp.dst_prc = precision_of<T>::value;
253253
jcp.rotary_ndims = config.rotary_ndims;
254254
jcp.interleave = true;
255-
jcp.mix_cos_sin = true;
255+
// if use precomputed rope cache then it's mixed
256+
// otherwise rope will have separate cos/sin inputs
257+
jcp.mix_cos_sin = config.use_rope_cache;
256258
m_rotaryKernel = createJitKernel(jcp, true);
257259
}
258260

259261
void execute([[maybe_unused]] const dnnl::stream& strm,
260262
const std::vector<MemoryPtr>& inputs,
261263
const std::vector<MemoryPtr>& outputs) override {
262264
ov::intel_cpu::PlainTensor t_src(inputs[0]);
263-
ov::intel_cpu::PlainTensor t_cos_sin(inputs[1]);
265+
ov::intel_cpu::PlainTensor t_cos;
266+
ov::intel_cpu::PlainTensor t_sin;
267+
ov::intel_cpu::PlainTensor t_cos_sin;
268+
if (!m_config.use_rope_cache) {
269+
t_cos.reset(inputs[1]);
270+
t_sin.reset(inputs[2]);
271+
} else {
272+
t_cos_sin.reset(inputs[1]);
273+
}
274+
264275
ov::intel_cpu::PlainTensor t_dst(outputs[0]);
265276

266277
// [seq_len, batch_size, (hidden_states_q + hidden_states_k + hidden_states_v)]
@@ -277,27 +288,45 @@ struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor {
277288

278289
auto rotary_dims = m_config.rotary_ndims;
279290

280-
parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
281-
// src [batch, length, H x S]
282-
auto* src = t_src.ptr<T>(b, p, h * head_size);
283-
// [batch_size, length, ndims//2, 2]
284-
auto* cos_sin = &t_cos_sin.at<float>({b, p, 0, 0}, true);
285-
auto* dst = t_dst.ptr<T>(b, h, p, 0);
286-
287-
if (m_rotaryKernel) {
288-
execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr);
289-
} else {
290-
size_t i = 0;
291-
for (; i < rotary_dims; i += 2) {
292-
auto cosv = cos_sin[i];
293-
auto sinv = cos_sin[i + 1];
294-
dst[i] = cosv * src[i] - sinv * src[i + 1];
295-
dst[i + 1] = sinv * src[i] + cosv * src[i + 1];
291+
if (m_config.use_rope_cache) {
292+
parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
293+
// src [batch, length, H x S]
294+
auto* src = t_src.ptr<T>(b, p, h * head_size);
295+
// [batch_size, length, ndims//2, 2]
296+
auto* cos_sin = &t_cos_sin.at<float>({b, p, 0, 0}, true);
297+
auto* dst = t_dst.ptr<T>(b, h, p, 0);
298+
299+
if (m_rotaryKernel) {
300+
execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr);
301+
} else {
302+
size_t i = 0;
303+
for (; i < rotary_dims; i += 2) {
304+
auto cosv = cos_sin[i];
305+
auto sinv = cos_sin[i + 1];
306+
dst[i] = cosv * src[i] - sinv * src[i + 1];
307+
dst[i + 1] = sinv * src[i] + cosv * src[i + 1];
308+
}
296309
}
297-
}
298310

299-
memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T));
300-
});
311+
memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T));
312+
});
313+
} else {
314+
parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
315+
auto* src = t_src.ptr<T>(b, p, h * head_size);
316+
auto* dst = t_dst.ptr<T>(b, h, p);
317+
const auto* cos = t_cos.ptr<float>(b, 0, 0);
318+
const auto* sin = t_sin.ptr<float>(b, 0, 0);
319+
if (m_rotaryKernel) {
320+
execJitKernel(m_rotaryKernel, src, dst, cos, sin);
321+
} else {
322+
for (size_t i = 0; i < rotary_dims; i += 2) {
323+
dst[i] = cos[i / 2] * src[i] - sin[i / 2] * src[i + 1];
324+
dst[i + 1] = sin[i / 2] * src[i] + cos[i / 2] * src[i + 1];
325+
}
326+
}
327+
memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T));
328+
});
329+
}
301330
} else {
302331
auto seq_len = t_src.size(0);
303332
auto batch_size = t_src.size(1);

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,6 @@ void Transformations::PostLpt() {
10941094
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RoPEFusion, true);
10951095
CPU_REGISTER_PASS_ARM64(postLPTPassManager, ov::pass::RoPEFusion, true);
10961096
CPU_DISABLE_PASS_COMMON(postLPTPassManager, ov::pass::RoPEFusionFlux);
1097-
CPU_DISABLE_PASS_COMMON(postLPTPassManager, ov::pass::RoPEFusionChatGLMHF);
10981097
CPU_REGISTER_PASS_X64(postLPTPassManager, CausalMaskPreprocessFusion);
10991098

11001099
#if defined(OPENVINO_ARCH_X86_64)

src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,12 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwenVL,
7979
::testing::ValuesIn(vit_param)),
8080
RoPETestQwenVL::getTestCaseName);
8181

82+
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM,
83+
RoPETestChatGLMHF,
84+
::testing::Combine(
85+
::testing::Values(ov::element::f32),
86+
::testing::Values(ov::test::utils::DEVICE_CPU)),
87+
RoPETestChatGLMHF::getTestCaseName);
88+
8289
} // namespace test
8390
} // namespace ov

0 commit comments

Comments
 (0)