@@ -252,15 +252,26 @@ struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor {
252252 jcp.dst_prc = precision_of<T>::value;
253253 jcp.rotary_ndims = config.rotary_ndims ;
254254 jcp.interleave = true ;
255- jcp.mix_cos_sin = true ;
255+ // if use precomputed rope cache then it's mixed
256+ // otherwise rope will have separate cos/sin inputs
257+ jcp.mix_cos_sin = config.use_rope_cache ;
256258 m_rotaryKernel = createJitKernel (jcp, true );
257259 }
258260
259261 void execute ([[maybe_unused]] const dnnl::stream& strm,
260262 const std::vector<MemoryPtr>& inputs,
261263 const std::vector<MemoryPtr>& outputs) override {
262264 ov::intel_cpu::PlainTensor t_src (inputs[0 ]);
263- ov::intel_cpu::PlainTensor t_cos_sin (inputs[1 ]);
265+ ov::intel_cpu::PlainTensor t_cos;
266+ ov::intel_cpu::PlainTensor t_sin;
267+ ov::intel_cpu::PlainTensor t_cos_sin;
268+ if (!m_config.use_rope_cache ) {
269+ t_cos.reset (inputs[1 ]);
270+ t_sin.reset (inputs[2 ]);
271+ } else {
272+ t_cos_sin.reset (inputs[1 ]);
273+ }
274+
264275 ov::intel_cpu::PlainTensor t_dst (outputs[0 ]);
265276
266277 // [seq_len, batch_size, (hidden_states_q + hidden_states_k + hidden_states_v)]
@@ -277,27 +288,45 @@ struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor {
277288
278289 auto rotary_dims = m_config.rotary_ndims ;
279290
280- parallel_for3d (batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
281- // src [batch, length, H x S]
282- auto * src = t_src.ptr <T>(b, p, h * head_size);
283- // [batch_size, length, ndims//2, 2]
284- auto * cos_sin = &t_cos_sin.at <float >({b, p, 0 , 0 }, true );
285- auto * dst = t_dst.ptr <T>(b, h, p, 0 );
286-
287- if (m_rotaryKernel) {
288- execJitKernel (m_rotaryKernel, src, dst, cos_sin, nullptr );
289- } else {
290- size_t i = 0 ;
291- for (; i < rotary_dims; i += 2 ) {
292- auto cosv = cos_sin[i];
293- auto sinv = cos_sin[i + 1 ];
294- dst[i] = cosv * src[i] - sinv * src[i + 1 ];
295- dst[i + 1 ] = sinv * src[i] + cosv * src[i + 1 ];
291+ if (m_config.use_rope_cache ) {
292+ parallel_for3d (batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
293+ // src [batch, length, H x S]
294+ auto * src = t_src.ptr <T>(b, p, h * head_size);
295+ // [batch_size, length, ndims//2, 2]
296+ auto * cos_sin = &t_cos_sin.at <float >({b, p, 0 , 0 }, true );
297+ auto * dst = t_dst.ptr <T>(b, h, p, 0 );
298+
299+ if (m_rotaryKernel) {
300+ execJitKernel (m_rotaryKernel, src, dst, cos_sin, nullptr );
301+ } else {
302+ size_t i = 0 ;
303+ for (; i < rotary_dims; i += 2 ) {
304+ auto cosv = cos_sin[i];
305+ auto sinv = cos_sin[i + 1 ];
306+ dst[i] = cosv * src[i] - sinv * src[i + 1 ];
307+ dst[i + 1 ] = sinv * src[i] + cosv * src[i + 1 ];
308+ }
296309 }
297- }
298310
299- memcpy (dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof (T));
300- });
311+ memcpy (dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof (T));
312+ });
313+ } else {
314+ parallel_for3d (batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
315+ auto * src = t_src.ptr <T>(b, p, h * head_size);
316+ auto * dst = t_dst.ptr <T>(b, h, p);
317+ const auto * cos = t_cos.ptr <float >(b, 0 , 0 );
318+ const auto * sin = t_sin.ptr <float >(b, 0 , 0 );
319+ if (m_rotaryKernel) {
320+ execJitKernel (m_rotaryKernel, src, dst, cos, sin);
321+ } else {
322+ for (size_t i = 0 ; i < rotary_dims; i += 2 ) {
323+ dst[i] = cos[i / 2 ] * src[i] - sin[i / 2 ] * src[i + 1 ];
324+ dst[i + 1 ] = sin[i / 2 ] * src[i] + cos[i / 2 ] * src[i + 1 ];
325+ }
326+ }
327+ memcpy (dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof (T));
328+ });
329+ }
301330 } else {
302331 auto seq_len = t_src.size (0 );
303332 auto batch_size = t_src.size (1 );
0 commit comments