Skip to content

Commit 4f5bc7a

Browse files
authored
Align memory_format for conv2d/3d in Float8Tensor with hp Tensor (#3352)
Align memory_format for conv2d and conv3d in Float8Tensor with high precision Tensors Summary: att, we want to make sure the output of `F.conv3d(input, weight, ...)` and `F.conv3d(input, fp8_weight, ...)` have the same memory_format Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants Reviewers: Subscribers: Tasks: Tags:
1 parent 7d5e2f6 commit 4f5bc7a

File tree

2 files changed

+96
-60
lines changed

2 files changed

+96
-60
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,6 @@ def __init__(
8585
dtype=dtype,
8686
device=device,
8787
)
88-
if dim == 3:
89-
self.conv = self.conv.to(memory_format=torch.channels_last_3d)
90-
elif dim == 2:
91-
self.conv = self.conv.to(memory_format=torch.channels_last)
9288

9389
def forward(self, x):
9490
return self.conv(x)
@@ -340,41 +336,47 @@ def _test_fp8_matmul_model(
340336
@common_utils.parametrize("compile", [True, False])
341337
@common_utils.parametrize("inference_mode", [True, False])
342338
# test for 2D/3D conv
343-
# Inputs are (N, C_in, C_out, (D, H, W) or
344-
# (N, C_in, C_out, (H, W)
339+
# Inputs are (N, C_in, C_out, (D, H, W), kernel_size or
340+
# (N, C_in, C_out, (H, W), kernel_size
345341
@common_utils.parametrize(
346342
"sizes",
347343
[
348-
(4, 16, 64, (32, 32, 32)),
349-
(4, 16, 64, (32, 32)),
344+
(1, 160, 320, (3, 194, 130), 3),
345+
# Note: kernel_size can't be 1, otherwise
346+
# the weight will be channels_last even though
347+
# it's contiguous because of the value of
348+
# stride
349+
(1, 320, 640, (96, 64), 3),
350350
],
351351
)
352+
@common_utils.parametrize(
353+
"is_input_channels_last",
354+
[True, False],
355+
)
356+
@common_utils.parametrize(
357+
"is_weight_channels_last",
358+
[True, False],
359+
)
352360
def test_fp8_conv_variants(
353361
self,
354362
dtype: torch.dtype,
355363
compile: bool,
356364
inference_mode: bool,
357365
sizes: Tuple,
366+
is_input_channels_last: bool,
367+
is_weight_channels_last: bool,
358368
):
359369
torch.compiler.reset()
360370
granularity = PerTensor()
361371
kernel_preference = KernelPreference.AUTO
362372

363-
N, C_in, C_out, spatial_dims = sizes
373+
N, C_in, C_out, spatial_dims, kernel_size = sizes
364374
dim = len(spatial_dims)
365375
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
366376
assert dim in convs, f"Unsupported dim: {dim}"
367377
conv_class = convs[dim]
368378

369-
kernel_size = 3
370-
371-
# Note: this is channel last memory format
372379
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda")
373-
if dim == 3:
374-
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
375-
else:
376-
assert dim == 2
377-
input_tensor = input_tensor.to(memory_format=torch.channels_last)
378380

379381
model = ToyConvModel(
380382
dim,
@@ -387,6 +389,14 @@ def test_fp8_conv_variants(
387389
device="cuda",
388390
).eval()
389391

392+
channels_last_memory_format = (
393+
torch.channels_last_3d if dim == 3 else torch.channels_last
394+
)
395+
if is_input_channels_last:
396+
input_tensor = input_tensor.to(memory_format=channels_last_memory_format)
397+
if is_weight_channels_last:
398+
model = model.to(memory_format=channels_last_memory_format)
399+
390400
quantized_model = copy.deepcopy(model)
391401

392402
config = Float8DynamicActivationFloat8WeightConfig(
@@ -406,6 +416,20 @@ def test_fp8_conv_variants(
406416
output_original = model(input_tensor)
407417
output_quantized = quantized_model(input_tensor)
408418

419+
# making sure quantized kernel produces tensor with memory_format
420+
# that's aligned with bf16 kernel
421+
is_bf16_output_channels_last = output_original.is_contiguous(
422+
memory_format=channels_last_memory_format
423+
)
424+
is_quantized_output_channels_last = output_quantized.is_contiguous(
425+
memory_format=channels_last_memory_format
426+
)
427+
428+
assert is_bf16_output_channels_last == is_quantized_output_channels_last, (
429+
"unexpected output strides for quantized model: "
430+
f"{output_original.stride()} {output_quantized.stride()}"
431+
)
432+
409433
error = compute_error(output_original, output_quantized)
410434
assert compute_error(output_original, output_quantized) > 20, (
411435
f"Quantization error is too high got a SQNR of {error}"
@@ -452,13 +476,7 @@ def test_fp8_conv_skip_quant(
452476

453477
kernel_size = 3
454478

455-
# Note: this is channel last memory format
456479
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda")
457-
if dim == 3:
458-
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
459-
else:
460-
input_tensor = input_tensor.to(memory_format=torch.channels_last)
461-
462480
model = ToyConvModel(
463481
dim,
464482
C_in,
@@ -470,6 +488,13 @@ def test_fp8_conv_skip_quant(
470488
device="cuda",
471489
).eval()
472490

491+
if dim == 3:
492+
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
493+
model = model.to(memory_format=torch.channels_last_3d)
494+
else:
495+
input_tensor = input_tensor.to(memory_format=torch.channels_last)
496+
model = model.to(memory_format=torch.channels_last)
497+
473498
quantized_model = copy.deepcopy(model)
474499

475500
config = Float8DynamicActivationFloat8WeightConfig(
@@ -932,6 +957,8 @@ def test_unsqueeze_conv2d_weight(self):
932957
device=device,
933958
).eval()
934959

960+
model = model.to(memory_format=torch.channels_last)
961+
935962
quantized_model = copy.deepcopy(model)
936963

937964
config = Float8DynamicActivationFloat8WeightConfig(

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -533,36 +533,60 @@ def _quantize_and_scaled_conv3d(
533533
)
534534

535535
assert kernel_choice == "fbgemm", "Only fbgemm kernel choice is supported currently"
536+
input_qdata = input_tensor.qdata
537+
weight_qdata = weight_tensor.qdata
538+
539+
is_input_channels_last = input_qdata.is_contiguous(
540+
memory_format=torch.channels_last_3d
541+
)
542+
is_weight_channels_last = weight_qdata.is_contiguous(
543+
memory_format=torch.channels_last_3d
544+
)
545+
546+
# convert the input/weight to channels_last_3d memory_format here
547+
# to make sure we can call the fbgemm conv
548+
# kernel, it should be a no-op if both activation and weight are in
549+
# channels_last_3d memory_format
550+
input_qdata = input_qdata.contiguous(memory_format=torch.channels_last_3d)
551+
weight_qdata = weight_qdata.contiguous(memory_format=torch.channels_last_3d)
552+
536553
# move C_in to last dim
537554
# after permute: (N, D, H, W, C_in)
538-
act_qdata = input_tensor.qdata.permute([0, 2, 3, 4, 1])
555+
input_qdata = input_qdata.permute([0, 2, 3, 4, 1])
539556

540557
# move C_in to last dim
541558
# after permute: (C_out, K1, K2, K3, C_in)
559+
weight_qdata = weight_qdata.permute([0, 2, 3, 4, 1])
542560

543-
weight_qdata = weight_tensor.qdata.permute([0, 2, 3, 4, 1])
544-
545-
assert act_qdata.is_contiguous() and weight_qdata.is_contiguous(), (
546-
"Please make sure both activation and weights are in the `channels_last_3d` memory_format"
547-
)
548-
549-
act_scale = input_tensor.scale
561+
input_scale = input_tensor.scale
550562
weight_scale = weight_tensor.scale
551563
output = torch.ops.fbgemm.f8f8bf16_conv(
552-
act_qdata,
564+
input_qdata,
553565
weight_qdata,
554-
act_scale * weight_scale,
566+
input_scale * weight_scale,
555567
padding,
556568
stride,
557569
dilation,
558570
)
559571
# output shape after permute: N, C_out, D_out, H_out, W_out
560572
output = output.permute([0, 4, 1, 2, 3])
573+
574+
# aligning the semantics with bfloat16 conv ops, the
575+
# output should use contiguous_format if none of the input/weight
576+
# are in channels_last format, otherwise, the output is already
577+
# in channels_last format (from fbgemm kernel)
578+
if not (is_input_channels_last or is_weight_channels_last):
579+
output = output.contiguous()
561580
return output
562581

563582

564583
@implements(aten.convolution.default)
565584
def _(func, types, args, kwargs):
585+
"""The semantics of memory_format will match high precision counterparts
586+
i.e. if any of input or weight are in channels_last_3d format
587+
the output will be in channels_last_3d format, otherwise the output
588+
will be contiguous
589+
"""
566590
(
567591
input_tensor,
568592
weight_tensor,
@@ -580,11 +604,6 @@ def _(func, types, args, kwargs):
580604
assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}"
581605

582606
if dim == 2:
583-
assert input_tensor.is_contiguous(
584-
memory_format=torch.channels_last
585-
) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last), (
586-
"Please make sure both activation and weights are in the `channels_last` memory_format"
587-
)
588607
# (N, C, H, W) --> (N, C, 1, H, W)
589608
input_tensor = input_tensor.unsqueeze(2)
590609
weight_tensor = weight_tensor.unsqueeze(2)
@@ -606,11 +625,6 @@ def _(func, types, args, kwargs):
606625
res = res.squeeze(2)
607626
return res
608627
else:
609-
assert input_tensor.is_contiguous(
610-
memory_format=torch.channels_last_3d
611-
) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last_3d), (
612-
"Please make sure both activation and weights are in the `channels_last_3d` memory_format"
613-
)
614628
assert tuple(output_padding) == (0, 0, 0), (
615629
f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}"
616630
)
@@ -626,6 +640,11 @@ def _(func, types, args, kwargs):
626640

627641
@implements(aten.conv3d.default)
628642
def _(func, types, args, kwargs):
643+
"""The semantics of memory_format will match high precision counterparts
644+
i.e. if any of input or weight are in channels_last_3d format
645+
the output will be in channels_last_3d format, otherwise the output
646+
will be contiguous
647+
"""
629648
(
630649
input_tensor,
631650
weight_tensor,
@@ -635,23 +654,24 @@ def _(func, types, args, kwargs):
635654
dilation,
636655
groups,
637656
) = fill_defaults(args, 7, [None, [1, 1, 1], [0, 0, 0], [1, 1, 1], 1])
638-
assert input_tensor.is_contiguous(
639-
memory_format=torch.channels_last_3d
640-
) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last_3d), (
641-
"Please make sure both activation and weights are in the `channels_last_3d` memory_format"
642-
)
643-
return _quantize_and_scaled_conv3d(
657+
conv3d_output = _quantize_and_scaled_conv3d(
644658
input_tensor,
645659
weight_tensor,
646660
bias,
647661
stride,
648662
padding,
649663
dilation,
650664
)
665+
return conv3d_output
651666

652667

653668
@implements(aten.conv2d.default)
654669
def _(func, types, args, kwargs):
670+
"""The semantics of memory_format will match high precision counterparts
671+
i.e. if any of input or weight are in channels_last_3d format
672+
the output will be in channels_last_3d format, otherwise the output
673+
will be contiguous
674+
"""
655675
(
656676
input_tensor,
657677
weight_tensor,
@@ -662,20 +682,9 @@ def _(func, types, args, kwargs):
662682
groups,
663683
) = fill_defaults(args, 7, [None, [1, 1], [0, 0], [1, 1], 1])
664684
# (N, C, H, W) --> (N, C, 1, H, W)
665-
# memory_format of both tensors should be torch.channels_last
666-
# and it should be preserved with unsqueeze(2) (becoming torch.channels_last_3d)
667-
assert input_tensor.is_contiguous(
668-
memory_format=torch.channels_last
669-
) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last), (
670-
"Please make sure both activation and weights are in the `channels_last` memory_format"
671-
)
672685
input_tensor = input_tensor.unsqueeze(2)
673686
weight_tensor = weight_tensor.unsqueeze(2)
674687

675-
assert input_tensor.is_contiguous(
676-
memory_format=torch.channels_last_3d
677-
) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last_3d)
678-
679688
padding = [0, *padding]
680689
stride = [1, *stride]
681690
dilation = [1, *dilation]

0 commit comments

Comments
 (0)