Skip to content

Commit d55085f

Browse files
bythew3ijax authors
authored andcommitted
[Mosaic] Support tpu.concatenate along the tiling dims as long as the shapes are aligned to native tiling.
PiperOrigin-RevId: 574523694
1 parent cf65480 commit d55085f

File tree

3 files changed

+59
-10
lines changed

3 files changed

+59
-10
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,9 +1236,31 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
12361236
const VectorType res_ty = concatenate_op.getResult().getType();
12371237
const uint32_t dimension = concatenate_op.getDimension();
12381238
if (dimension - res_ty.getRank() >= -2) {
1239-
return op.emitOpError(
1240-
"Not implemented: Concatenation along the last two dimensions");
1239+
if (!layout.hasNaturalTopology(ctx.target_shape) ||
1240+
layout.offsets() != LayoutOffsets{0, 0}) {
1241+
return op.emitOpError(
1242+
"Only native tiling with offset (0, 0) is supported when "
1243+
"concatenation along tiling dims.");
1244+
}
1245+
// Check if shapes of src and res are aligned to native tiling.
1246+
auto check_aligned = [&](const VectorType &vty) {
1247+
return vty.getRank() >= 2 &&
1248+
*(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) == 0 &&
1249+
*(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) == 0;
1250+
};
1251+
bool is_aligned = check_aligned(res_ty);
1252+
int op_idx = 0;
1253+
while (is_aligned && op_idx < op.getNumOperands()) {
1254+
auto vty = dyn_cast<VectorType>(op.getOperand(op_idx++).getType());
1255+
is_aligned = check_aligned(vty);
1256+
}
1257+
if (!is_aligned) {
1258+
return op.emitOpError(
1259+
"Only aligned shapes are supported when concatenation along tiling "
1260+
"dims");
1261+
}
12411262
}
1263+
12421264
SmallVector<xla::Array<Value>> tiles;
12431265
tiles.reserve(concatenate_op->getNumOperands());
12441266
for (Value operand : concatenate_op.getOperands()) {

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -615,14 +615,25 @@ class VectorLayoutInferer {
615615
}
616616

617617
LogicalResult infer(tpu::ConcatenateOp op) {
618-
TPU_CHECK_OP(op.getDimension() - op.getType().getRank() < -2,
619-
"Concatenation is not supported along the last two axes");
620618
TPU_CHECK_OP(!op.getSources().empty(),
621619
"Need at least one vector to concatenate");
622-
// Fix all the layouts to the layout of the first operand.
623-
// This might not be the best strategy, but it works.
624-
SmallVector<Layout> in_layouts(op.getNumOperands(),
625-
getLayout(op.getSources().front()));
620+
auto res_rank = op.getType().getRank();
621+
auto dimension = op.getDimension();
622+
TPU_CHECK_OP(0 <= dimension && dimension < res_rank,
623+
"Expect a valid concatenate dimension");
624+
if (res_rank == 1) {
625+
NYI("Support concatenation with 1D vectors");
626+
}
627+
auto res_ty = op.getResult().getType();
628+
int8_t bitwidth = res_ty.getElementTypeBitWidth();
629+
if (bitwidth != 32) {
630+
NYI("Support concatenation with non 32-bit data");
631+
}
632+
auto layout = (dimension >= res_rank - 2)
633+
? VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
634+
ImplicitDim::kNone)
635+
: getLayout(op.getSources().front());
636+
SmallVector<Layout> in_layouts(op->getNumOperands(), layout);
626637
setLayout(op, in_layouts, in_layouts.back());
627638
return success();
628639
}

jaxlib/mosaic/python/apply_vector_layout.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,8 +2093,24 @@ def _tpu_concatenate_rule(
20932093
raise NotImplementedError
20942094
res_ty = ir.VectorType(op.result.type)
20952095
dimension = ir.IntegerAttr(op.dimension).value
2096-
if dimension - res_ty.rank >= -2:
2097-
raise NotImplementedError("Concatenation along the last two dimensions")
2096+
2097+
if dimension >= res_ty.rank - 2:
2098+
if (not layout.has_natural_topology) or layout.offsets != (0, 0):
2099+
raise NotImplementedError(
2100+
"Only native tiling with offset (0, 0) is supported when"
2101+
" concatenation along tiling dims."
2102+
)
2103+
# Check if shapes of src and res are aligned to native tiling.
2104+
for vty in [res_ty] + [ir.VectorType(src.type) for src in op.operands]:
2105+
if (
2106+
vty.rank < 2
2107+
or vty.shape[-2] % layout.tiling[-2] != 0
2108+
or vty.shape[-1] % layout.tiling[-1] != 0
2109+
):
2110+
raise NotImplementedError(
2111+
"Only aligned shapes are supported when concatenation along tiling"
2112+
" dims."
2113+
)
20982114
tiles = [disassemble(layout, x) for x in op.operands]
20992115
res_tiles = np.concatenate(tiles, axis=dimension)
21002116
ctx.replace(op, assemble(res_ty, layout, res_tiles))

0 commit comments

Comments
 (0)