Skip to content

Commit f40381e

Browse files
erifanJatin Bhateja
authored andcommitted
8356760: VectorAPI: Optimize VectorMask.fromLong for all-true/all-false cases
Reviewed-by: xgong, jbhateja
1 parent e801e51 commit f40381e

File tree

9 files changed

+1093
-13
lines changed

9 files changed

+1093
-13
lines changed

src/hotspot/share/opto/vectorIntrinsics.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -686,11 +686,20 @@ bool LibraryCallKit::inline_vector_frombits_coerced() {
686686
int opc = bcast_mode == VectorSupport::MODE_BITS_COERCED_LONG_TO_MASK ? Op_VectorLongToMask : Op_Replicate;
687687

688688
if (!arch_supports_vector(opc, num_elem, elem_bt, checkFlags, true /*has_scalar_args*/)) {
689-
log_if_needed(" ** not supported: arity=0 op=broadcast vlen=%d etype=%s ismask=%d bcast_mode=%d",
690-
num_elem, type2name(elem_bt),
691-
is_mask ? 1 : 0,
692-
bcast_mode);
693-
return false; // not supported
689+
// If the input long sets or unsets all lanes and Replicate is supported,
690+
// generate a MaskAll or Replicate instead.
691+
692+
// The "maskAll" API uses the corresponding integer types for floating-point data.
693+
BasicType maskall_bt = elem_bt == T_DOUBLE ? T_LONG : (elem_bt == T_FLOAT ? T_INT: elem_bt);
694+
if (!(opc == Op_VectorLongToMask &&
695+
VectorNode::is_maskall_type(bits_type, num_elem) &&
696+
arch_supports_vector(Op_Replicate, num_elem, maskall_bt, checkFlags, true /*has_scalar_args*/))) {
697+
log_if_needed(" ** not supported: arity=0 op=broadcast vlen=%d etype=%s ismask=%d bcast_mode=%d",
698+
num_elem, type2name(elem_bt),
699+
is_mask ? 1 : 0,
700+
bcast_mode);
701+
return false; // not supported
702+
}
694703
}
695704

696705
Node* broadcast = nullptr;

src/hotspot/share/opto/vectornode.cpp

Lines changed: 107 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,16 @@ bool VectorNode::implemented(int opc, uint vlen, BasicType bt) {
434434
return false;
435435
}
436436

437+
bool VectorNode::is_maskall_type(const TypeLong* type, int vlen) {
438+
assert(type != nullptr, "type must not be null");
439+
if (!type->is_con()) {
440+
return false;
441+
}
442+
long mask = (-1ULL >> (64 - vlen));
443+
long bit = type->get_con() & mask;
444+
return bit == 0 || bit == mask;
445+
}
446+
437447
bool VectorNode::is_muladds2i(const Node* n) {
438448
return n->Opcode() == Op_MulAddS2I;
439449
}
@@ -1503,6 +1513,45 @@ Node* ReductionNode::Ideal(PhaseGVN* phase, bool can_reshape) {
15031513
return nullptr;
15041514
}
15051515

1516+
// Convert fromLong to maskAll if the input sets or unsets all lanes.
1517+
Node* convertFromLongToMaskAll(PhaseGVN* phase, const TypeLong* bits_type, bool is_mask, const TypeVect* vt) {
1518+
uint vlen = vt->length();
1519+
BasicType bt = vt->element_basic_type();
1520+
// The "maskAll" API uses the corresponding integer types for floating-point data.
1521+
BasicType maskall_bt = (bt == T_FLOAT) ? T_INT : (bt == T_DOUBLE) ? T_LONG : bt;
1522+
1523+
if (VectorNode::is_maskall_type(bits_type, vlen) &&
1524+
Matcher::match_rule_supported_vector(Op_Replicate, vlen, maskall_bt)) {
1525+
Node* con = nullptr;
1526+
jlong con_value = bits_type->get_con() == 0L ? 0L : -1L;
1527+
if (maskall_bt == T_LONG) {
1528+
con = phase->longcon(con_value);
1529+
} else {
1530+
con = phase->intcon(con_value);
1531+
}
1532+
Node* res = VectorNode::scalar2vector(con, vlen, maskall_bt, is_mask);
1533+
// Convert back to the original floating-point data type.
1534+
if (is_floating_point_type(bt)) {
1535+
res = new VectorMaskCastNode(phase->transform(res), vt);
1536+
}
1537+
return res;
1538+
}
1539+
return nullptr;
1540+
}
1541+
1542+
Node* VectorLoadMaskNode::Ideal(PhaseGVN* phase, bool can_reshape) {
1543+
// VectorLoadMask(VectorLongToMask(-1/0)) => Replicate(-1/0)
1544+
if (in(1)->Opcode() == Op_VectorLongToMask) {
1545+
const TypeVect* vt = bottom_type()->is_vect();
1546+
Node* res = convertFromLongToMaskAll(phase, in(1)->in(1)->bottom_type()->isa_long(), false, vt);
1547+
if (res != nullptr) {
1548+
return res;
1549+
}
1550+
}
1551+
1552+
return VectorNode::Ideal(phase, can_reshape);
1553+
}
1554+
15061555
Node* VectorLoadMaskNode::Identity(PhaseGVN* phase) {
15071556
BasicType out_bt = type()->is_vect()->element_basic_type();
15081557
if (!Matcher::has_predicated_vectors() && out_bt == T_BOOLEAN) {
@@ -1918,6 +1967,45 @@ Node* VectorMaskOpNode::Ideal(PhaseGVN* phase, bool can_reshape) {
19181967
return nullptr;
19191968
}
19201969

1970+
Node* VectorMaskCastNode::Identity(PhaseGVN* phase) {
1971+
Node* in1 = in(1);
1972+
// VectorMaskCast (VectorMaskCast x) => x
1973+
if (in1->Opcode() == Op_VectorMaskCast &&
1974+
vect_type()->eq(in1->in(1)->bottom_type())) {
1975+
return in1->in(1);
1976+
}
1977+
return this;
1978+
}
1979+
1980+
// This function does the following optimization:
1981+
// VectorMaskToLong(MaskAll(l)) => (l & (-1ULL >> (64 - vlen)))
1982+
// VectorMaskToLong(VectorStoreMask(Replicate(l))) => (l & (-1ULL >> (64 - vlen)))
1983+
// l is -1 or 0.
1984+
Node* VectorMaskToLongNode::Ideal_MaskAll(PhaseGVN* phase) {
1985+
Node* in1 = in(1);
1986+
// VectorMaskToLong follows a VectorStoreMask if predicate is not supported.
1987+
if (in1->Opcode() == Op_VectorStoreMask) {
1988+
assert(!in1->in(1)->bottom_type()->isa_vectmask(), "sanity");
1989+
in1 = in1->in(1);
1990+
}
1991+
if (VectorNode::is_all_ones_vector(in1)) {
1992+
int vlen = in1->bottom_type()->is_vect()->length();
1993+
return new ConLNode(TypeLong::make(-1ULL >> (64 - vlen)));
1994+
}
1995+
if (VectorNode::is_all_zeros_vector(in1)) {
1996+
return new ConLNode(TypeLong::ZERO);
1997+
}
1998+
return nullptr;
1999+
}
2000+
2001+
Node* VectorMaskToLongNode::Ideal(PhaseGVN* phase, bool can_reshape) {
2002+
Node* res = Ideal_MaskAll(phase);
2003+
if (res != nullptr) {
2004+
return res;
2005+
}
2006+
return VectorMaskOpNode::Ideal(phase, can_reshape);
2007+
}
2008+
19212009
Node* VectorMaskToLongNode::Identity(PhaseGVN* phase) {
19222010
if (in(1)->Opcode() == Op_VectorLongToMask) {
19232011
return in(1)->in(1);
@@ -1927,28 +2015,41 @@ Node* VectorMaskToLongNode::Identity(PhaseGVN* phase) {
19272015

19282016
Node* VectorLongToMaskNode::Ideal(PhaseGVN* phase, bool can_reshape) {
19292017
const TypeVect* dst_type = bottom_type()->is_vect();
2018+
uint vlen = dst_type->length();
2019+
const TypeVectMask* is_mask = dst_type->isa_vectmask();
2020+
19302021
if (in(1)->Opcode() == Op_AndL &&
19312022
in(1)->in(1)->Opcode() == Op_VectorMaskToLong &&
19322023
in(1)->in(2)->bottom_type()->isa_long() &&
19332024
in(1)->in(2)->bottom_type()->is_long()->is_con() &&
1934-
in(1)->in(2)->bottom_type()->is_long()->get_con() == ((1L << dst_type->length()) - 1)) {
2025+
in(1)->in(2)->bottom_type()->is_long()->get_con() == ((1L << vlen) - 1)) {
19352026
// Different src/dst mask length represents a re-interpretation operation,
19362027
// we can however generate a mask casting operation if length matches.
19372028
Node* src = in(1)->in(1)->in(1);
1938-
if (dst_type->isa_vectmask() == nullptr) {
2029+
if (is_mask == nullptr) {
19392030
if (src->Opcode() != Op_VectorStoreMask) {
19402031
return nullptr;
19412032
}
19422033
src = src->in(1);
19432034
}
19442035
const TypeVect* src_type = src->bottom_type()->is_vect();
1945-
if (src_type->length() == dst_type->length() &&
1946-
((src_type->isa_vectmask() == nullptr && dst_type->isa_vectmask() == nullptr) ||
1947-
(src_type->isa_vectmask() && dst_type->isa_vectmask()))) {
2036+
if (src_type->length() == vlen &&
2037+
((src_type->isa_vectmask() == nullptr && is_mask == nullptr) ||
2038+
(src_type->isa_vectmask() && is_mask))) {
19482039
return new VectorMaskCastNode(src, dst_type);
19492040
}
19502041
}
1951-
return nullptr;
2042+
2043+
// VectorLongToMask(-1/0) => MaskAll(-1/0)
2044+
const TypeLong* bits_type = in(1)->bottom_type()->isa_long();
2045+
if (bits_type && is_mask) {
2046+
Node* res = convertFromLongToMaskAll(phase, bits_type, true, dst_type);
2047+
if (res != nullptr) {
2048+
return res;
2049+
}
2050+
}
2051+
2052+
return VectorNode::Ideal(phase, can_reshape);
19522053
}
19532054

19542055
Node* FmaVNode::Ideal(PhaseGVN* phase, bool can_reshape) {

src/hotspot/share/opto/vectornode.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class VectorNode : public TypeNode {
104104
static bool implemented(int opc, uint vlen, BasicType bt);
105105
static bool is_shift(Node* n);
106106
static bool is_vshift_cnt(Node* n);
107+
static bool is_maskall_type(const TypeLong* type, int vlen);
107108
static bool is_muladds2i(const Node* n);
108109
static bool is_roundopD(Node* n);
109110
static bool is_scalar_rotate(Node* n);
@@ -1383,6 +1384,8 @@ class VectorMaskToLongNode : public VectorMaskOpNode {
13831384
VectorMaskToLongNode(Node* mask, const Type* ty):
13841385
VectorMaskOpNode(mask, ty, Op_VectorMaskToLong) {}
13851386
virtual int Opcode() const;
1387+
Node* Ideal(PhaseGVN* phase, bool can_reshape);
1388+
Node* Ideal_MaskAll(PhaseGVN* phase);
13861389
virtual uint ideal_reg() const { return Op_RegL; }
13871390
virtual Node* Identity(PhaseGVN* phase);
13881391
};
@@ -1776,6 +1779,7 @@ class VectorLoadMaskNode : public VectorNode {
17761779

17771780
virtual int Opcode() const;
17781781
virtual Node* Identity(PhaseGVN* phase);
1782+
Node* Ideal(PhaseGVN* phase, bool can_reshape);
17791783
};
17801784

17811785
class VectorStoreMaskNode : public VectorNode {
@@ -1795,6 +1799,7 @@ class VectorMaskCastNode : public VectorNode {
17951799
const TypeVect* in_vt = in->bottom_type()->is_vect();
17961800
assert(in_vt->length() == vt->length(), "vector length must match");
17971801
}
1802+
Node* Identity(PhaseGVN* phase);
17981803
virtual int Opcode() const;
17991804
};
18001805

test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,21 @@ public class IRNode {
14021402
vectorNode(UMAX_VL, "UMaxV", TYPE_LONG);
14031403
}
14041404

1405+
public static final String MASK_ALL = PREFIX + "MASK_ALL" + POSTFIX;
1406+
static {
1407+
beforeMatchingNameRegex(MASK_ALL, "MaskAll");
1408+
}
1409+
1410+
public static final String VECTOR_LONG_TO_MASK = PREFIX + "VECTOR_LONG_TO_MASK" + POSTFIX;
1411+
static {
1412+
beforeMatchingNameRegex(VECTOR_LONG_TO_MASK, "VectorLongToMask");
1413+
}
1414+
1415+
public static final String VECTOR_MASK_TO_LONG = PREFIX + "VECTOR_MASK_TO_LONG" + POSTFIX;
1416+
static {
1417+
beforeMatchingNameRegex(VECTOR_MASK_TO_LONG, "VectorMaskToLong");
1418+
}
1419+
14051420
// Can only be used if avx512_vnni is available.
14061421
public static final String MUL_ADD_VS2VI_VNNI = PREFIX + "MUL_ADD_VS2VI_VNNI" + POSTFIX;
14071422
static {
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation.
8+
*
9+
* This code is distributed in the hope that it will be useful, but WITHOUT
10+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12+
* version 2 for more details (a copy is included in the LICENSE file that
13+
* accompanied this code).
14+
*
15+
* You should have received a copy of the GNU General Public License version
16+
* 2 along with this work; if not, write to the Free Software Foundation,
17+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18+
*
19+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20+
* or visit www.oracle.com if you need additional information or have any
21+
* questions.
22+
*/
23+
24+
/*
25+
* @test
26+
* @bug 8356760
27+
* @library /test/lib /
28+
* @summary Optimize VectorMask.fromLong for all-true/all-false cases
29+
* @modules jdk.incubator.vector
30+
*
31+
* @run driver compiler.vectorapi.VectorMaskCastIdentityTest
32+
*/
33+
34+
package compiler.vectorapi;
35+
36+
import compiler.lib.ir_framework.*;
37+
import java.util.Random;
38+
import jdk.incubator.vector.*;
39+
import jdk.test.lib.Asserts;
40+
import jdk.test.lib.Utils;
41+
42+
public class VectorMaskCastIdentityTest {
43+
private static final boolean[] mr = new boolean[128]; // 128 is large enough
44+
private static final Random rd = Utils.getRandomInstance();
45+
static {
46+
for (int i = 0; i < mr.length; i++) {
47+
mr[i] = rd.nextBoolean();
48+
}
49+
}
50+
51+
@Test
52+
@IR(counts = { IRNode.VECTOR_MASK_CAST, "= 2" }, applyIfCPUFeatureOr = {"asimd", "true"})
53+
public static int testTwoCastToDifferentType() {
54+
// The types before and after the two casts are not the same, so the cast cannot be eliminated.
55+
VectorMask<Float> mFloat64 = VectorMask.fromArray(FloatVector.SPECIES_64, mr, 0);
56+
VectorMask<Double> mDouble128 = mFloat64.cast(DoubleVector.SPECIES_128);
57+
VectorMask<Integer> mInt64 = mDouble128.cast(IntVector.SPECIES_64);
58+
return mInt64.trueCount();
59+
}
60+
61+
@Run(test = "testTwoCastToDifferentType")
62+
public static void testTwoCastToDifferentType_runner() {
63+
int count = testTwoCastToDifferentType();
64+
VectorMask<Float> mFloat64 = VectorMask.fromArray(FloatVector.SPECIES_64, mr, 0);
65+
Asserts.assertEquals(count, mFloat64.trueCount());
66+
}
67+
68+
@Test
69+
@IR(counts = { IRNode.VECTOR_MASK_CAST, "= 2" }, applyIfCPUFeatureOr = {"avx2", "true"})
70+
public static int testTwoCastToDifferentType2() {
71+
// The types before and after the two casts are not the same, so the cast cannot be eliminated.
72+
VectorMask<Integer> mInt128 = VectorMask.fromArray(IntVector.SPECIES_128, mr, 0);
73+
VectorMask<Double> mDouble256 = mInt128.cast(DoubleVector.SPECIES_256);
74+
VectorMask<Short> mShort64 = mDouble256.cast(ShortVector.SPECIES_64);
75+
return mShort64.trueCount();
76+
}
77+
78+
@Run(test = "testTwoCastToDifferentType2")
79+
public static void testTwoCastToDifferentType2_runner() {
80+
int count = testTwoCastToDifferentType2();
81+
VectorMask<Integer> mInt128 = VectorMask.fromArray(IntVector.SPECIES_128, mr, 0);
82+
Asserts.assertEquals(count, mInt128.trueCount());
83+
}
84+
85+
@Test
86+
@IR(counts = { IRNode.VECTOR_MASK_CAST, "= 0" }, applyIfCPUFeatureOr = {"avx2", "true", "asimd", "true"})
87+
public static int testTwoCastToSameType() {
88+
// The types before and after the two casts are the same, so the cast will be eliminated.
89+
VectorMask<Integer> mInt128 = VectorMask.fromArray(IntVector.SPECIES_128, mr, 0);
90+
VectorMask<Float> mFloat128 = mInt128.cast(FloatVector.SPECIES_128);
91+
VectorMask<Integer> mInt128_2 = mFloat128.cast(IntVector.SPECIES_128);
92+
return mInt128_2.trueCount();
93+
}
94+
95+
@Run(test = "testTwoCastToSameType")
96+
public static void testTwoCastToSameType_runner() {
97+
int count = testTwoCastToSameType();
98+
VectorMask<Integer> mInt128 = VectorMask.fromArray(IntVector.SPECIES_128, mr, 0);
99+
Asserts.assertEquals(count, mInt128.trueCount());
100+
}
101+
102+
@Test
103+
@IR(counts = { IRNode.VECTOR_MASK_CAST, "= 1" }, applyIfCPUFeatureOr = {"avx2", "true", "asimd", "true"})
104+
public static int testOneCastToDifferentType() {
105+
// The types before and after the only cast are different, the cast will not be eliminated.
106+
VectorMask<Float> mFloat128 = VectorMask.fromArray(FloatVector.SPECIES_128, mr, 0).not();
107+
VectorMask<Integer> mInt128 = mFloat128.cast(IntVector.SPECIES_128);
108+
return mInt128.trueCount();
109+
}
110+
111+
@Run(test = "testOneCastToDifferentType")
112+
public static void testOneCastToDifferentType_runner() {
113+
int count = testOneCastToDifferentType();
114+
VectorMask<Float> mInt128 = VectorMask.fromArray(FloatVector.SPECIES_128, mr, 0).not();
115+
Asserts.assertEquals(count, mInt128.trueCount());
116+
}
117+
118+
public static void main(String[] args) {
119+
TestFramework testFramework = new TestFramework();
120+
testFramework.setDefaultWarmup(10000)
121+
.addFlags("--add-modules=jdk.incubator.vector")
122+
.start();
123+
}
124+
}

0 commit comments

Comments
 (0)