@@ -147,6 +147,8 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible {
147147
148148private:
149149 const size_t xmm_len = 16 ;
150+ const size_t ymm_len = 32 ;
151+ const size_t zmm_len = 64 ;
150152#ifdef _WIN32
151153 const size_t xmm_to_preserve_start = 6 ;
152154 const size_t xmm_to_preserve = 10 ;
@@ -182,6 +184,91 @@ class jit_generator : public Xbyak::CodeGenerator, public c_compatible {
182184
183185 inline size_t get_size_of_abi_save_regs () { return size_of_abi_save_regs; }
184186
187+ using Xbyak::CodeGenerator::push;
188+ using Xbyak::CodeGenerator::pop;
189+
190+ inline void push (const Xbyak::Xmm &xmm) {
191+ if (xmm.isXMM ()) {
192+ sub (rsp, xmm_len);
193+ uni_vmovdqu (ptr[rsp], xmm);
194+ } else if (xmm.isYMM ()) {
195+ sub (rsp, ymm_len);
196+ uni_vmovdqu (ptr[rsp], Xbyak::Ymm{xmm.getIdx ()});
197+ } else if (xmm.isZMM ()) {
198+ sub (rsp, zmm_len);
199+ uni_vmovdqu (ptr[rsp], Xbyak::Zmm{xmm.getIdx ()});
200+ }
201+ }
202+
203+ inline void push (const std::vector<Xbyak::Xmm> &xmms) {
204+ std::vector<std::function<void ()>> deferred_movs{};
205+ size_t offset = 0 ;
206+ for (size_t i = 0 ; i < xmms.size (); ++i) {
207+ const auto & xmm = xmms[i];
208+ if (xmm.isXMM ()) {
209+ deferred_movs.emplace_back ([this , offset, &xmm]() {
210+ uni_vmovdqu (ptr[rsp + offset], xmm);
211+ });
212+ offset += xmm_len;
213+ } else if (xmm.isYMM ()) {
214+ deferred_movs.emplace_back ([this , offset, &xmm]() {
215+ uni_vmovdqu (ptr[rsp + offset], Xbyak::Ymm{xmm.getIdx ()});
216+ });
217+ offset += ymm_len;
218+ } else if (xmm.isZMM ()) {
219+ deferred_movs.emplace_back ([this , offset, &xmm]() {
220+ uni_vmovdqu (ptr[rsp + offset], Xbyak::Zmm{xmm.getIdx ()});
221+ });
222+ offset += zmm_len;
223+ }
224+ }
225+ sub (rsp, offset);
226+ for (const auto & def_mov : deferred_movs) {
227+ def_mov ();
228+ }
229+ }
230+
231+ inline void pop (const Xbyak::Xmm &xmm) {
232+ if (xmm.isXMM ()) {
233+ uni_vmovdqu (xmm, ptr[rsp]);
234+ add (rsp, xmm_len);
235+ } else if (xmm.isYMM ()) {
236+ uni_vmovdqu (Xbyak::Ymm{xmm.getIdx ()}, ptr[rsp]);
237+ add (rsp, ymm_len);
238+ } else if (xmm.isZMM ()) {
239+ uni_vmovdqu (Xbyak::Zmm{xmm.getIdx ()}, ptr[rsp]);
240+ add (rsp, zmm_len);
241+ }
242+ }
243+
244+ inline void pop (const std::vector<Xbyak::Xmm> &xmms) {
245+ std::vector<std::function<void ()>> deferred_movs{};
246+ size_t offset = 0 ;
247+ for (size_t i = 0 ; i < xmms.size (); ++i) {
248+ const auto & xmm = xmms[i];
249+ if (xmm.isXMM ()) {
250+ deferred_movs.emplace_back ([this , offset, &xmm]() {
251+ uni_vmovdqu (xmm, ptr[rsp + offset]);
252+ });
253+ offset += xmm_len;
254+ } else if (xmm.isYMM ()) {
255+ deferred_movs.emplace_back ([this , offset, &xmm]() {
256+ uni_vmovdqu (Xbyak::Ymm{xmm.getIdx ()}, ptr[rsp + offset]);
257+ });
258+ offset += ymm_len;
259+ } else if (xmm.isZMM ()) {
260+ deferred_movs.emplace_back ([this , offset, &xmm]() {
261+ uni_vmovdqu (Xbyak::Zmm{xmm.getIdx ()}, ptr[rsp + offset]);
262+ });
263+ offset += zmm_len;
264+ }
265+ }
266+ for (const auto & def_mov : deferred_movs) {
267+ def_mov ();
268+ }
269+ add (rsp, offset);
270+ }
271+
185272 void preamble () {
186273 if (xmm_to_preserve) {
187274 sub (rsp, xmm_to_preserve * xmm_len);
0 commit comments