@@ -533,36 +533,60 @@ def _quantize_and_scaled_conv3d(
533533 )
534534
535535 assert kernel_choice == "fbgemm" , "Only fbgemm kernel choice is supported currently"
536+ input_qdata = input_tensor .qdata
537+ weight_qdata = weight_tensor .qdata
538+
539+ is_input_channels_last = input_qdata .is_contiguous (
540+ memory_format = torch .channels_last_3d
541+ )
542+ is_weight_channels_last = weight_qdata .is_contiguous (
543+ memory_format = torch .channels_last_3d
544+ )
545+
546+ # convert the input/weight to channels_last_3d memory_format here
547+ # to make sure we can call the fbgemm conv
548+ # kernel, it should be a no-op if both activation and weight are in
549+ # channels_last_3d memory_format
550+ input_qdata = input_qdata .contiguous (memory_format = torch .channels_last_3d )
551+ weight_qdata = weight_qdata .contiguous (memory_format = torch .channels_last_3d )
552+
536553 # move C_in to last dim
537554 # after permute: (N, D, H, W, C_in)
538- act_qdata = input_tensor . qdata .permute ([0 , 2 , 3 , 4 , 1 ])
555+ input_qdata = input_qdata .permute ([0 , 2 , 3 , 4 , 1 ])
539556
540557 # move C_in to last dim
541558 # after permute: (C_out, K1, K2, K3, C_in)
559+ weight_qdata = weight_qdata .permute ([0 , 2 , 3 , 4 , 1 ])
542560
543- weight_qdata = weight_tensor .qdata .permute ([0 , 2 , 3 , 4 , 1 ])
544-
545- assert act_qdata .is_contiguous () and weight_qdata .is_contiguous (), (
546- "Please make sure both activation and weights are in the `channels_last_3d` memory_format"
547- )
548-
549- act_scale = input_tensor .scale
561+ input_scale = input_tensor .scale
550562 weight_scale = weight_tensor .scale
551563 output = torch .ops .fbgemm .f8f8bf16_conv (
552- act_qdata ,
564+ input_qdata ,
553565 weight_qdata ,
554- act_scale * weight_scale ,
566+ input_scale * weight_scale ,
555567 padding ,
556568 stride ,
557569 dilation ,
558570 )
559571 # output shape after permute: N, C_out, D_out, H_out, W_out
560572 output = output .permute ([0 , 4 , 1 , 2 , 3 ])
573+
574+ # aligning the semantics with bfloat16 conv ops, the
575+ # output should use contiguous_format if none of the input/weight
576+ # are in channels_last format, otherwise, the output is already
577+ # in channels_last format (from fbgemm kernel)
578+ if not (is_input_channels_last or is_weight_channels_last ):
579+ output = output .contiguous ()
561580 return output
562581
563582
564583@implements (aten .convolution .default )
565584def _ (func , types , args , kwargs ):
585+ """The semantics of memory_format will match high precision counterparts
586+ i.e. if any of input or weight are in channels_last_3d format
587+ the output will be in channels_last_3d format, otherwise the output
588+ will be contiguous
589+ """
566590 (
567591 input_tensor ,
568592 weight_tensor ,
@@ -580,11 +604,6 @@ def _(func, types, args, kwargs):
580604 assert groups == 1 , f"Only 1 is supported for `groups`, got: { groups } "
581605
582606 if dim == 2 :
583- assert input_tensor .is_contiguous (
584- memory_format = torch .channels_last
585- ) and weight_tensor .qdata .is_contiguous (memory_format = torch .channels_last ), (
586- "Please make sure both activation and weights are in the `channels_last` memory_format"
587- )
588607 # (N, C, H, W) --> (N, C, 1, H, W)
589608 input_tensor = input_tensor .unsqueeze (2 )
590609 weight_tensor = weight_tensor .unsqueeze (2 )
@@ -606,11 +625,6 @@ def _(func, types, args, kwargs):
606625 res = res .squeeze (2 )
607626 return res
608627 else :
609- assert input_tensor .is_contiguous (
610- memory_format = torch .channels_last_3d
611- ) and weight_tensor .qdata .is_contiguous (memory_format = torch .channels_last_3d ), (
612- "Please make sure both activation and weights are in the `channels_last_3d` memory_format"
613- )
614628 assert tuple (output_padding ) == (0 , 0 , 0 ), (
615629 f"Only (0, 0, 0) is supported for `output_padding`, got: f{ output_padding } "
616630 )
@@ -626,6 +640,11 @@ def _(func, types, args, kwargs):
626640
627641@implements (aten .conv3d .default )
628642def _ (func , types , args , kwargs ):
643+ """The semantics of memory_format will match high precision counterparts
644+ i.e. if any of input or weight are in channels_last_3d format
645+ the output will be in channels_last_3d format, otherwise the output
646+ will be contiguous
647+ """
629648 (
630649 input_tensor ,
631650 weight_tensor ,
@@ -635,23 +654,24 @@ def _(func, types, args, kwargs):
635654 dilation ,
636655 groups ,
637656 ) = fill_defaults (args , 7 , [None , [1 , 1 , 1 ], [0 , 0 , 0 ], [1 , 1 , 1 ], 1 ])
638- assert input_tensor .is_contiguous (
639- memory_format = torch .channels_last_3d
640- ) and weight_tensor .qdata .is_contiguous (memory_format = torch .channels_last_3d ), (
641- "Please make sure both activation and weights are in the `channels_last_3d` memory_format"
642- )
643- return _quantize_and_scaled_conv3d (
657+ conv3d_output = _quantize_and_scaled_conv3d (
644658 input_tensor ,
645659 weight_tensor ,
646660 bias ,
647661 stride ,
648662 padding ,
649663 dilation ,
650664 )
665+ return conv3d_output
651666
652667
653668@implements (aten .conv2d .default )
654669def _ (func , types , args , kwargs ):
670+ """The semantics of memory_format will match high precision counterparts
671+ i.e. if any of input or weight are in channels_last_3d format
672+ the output will be in channels_last_3d format, otherwise the output
673+ will be contiguous
674+ """
655675 (
656676 input_tensor ,
657677 weight_tensor ,
@@ -662,20 +682,9 @@ def _(func, types, args, kwargs):
662682 groups ,
663683 ) = fill_defaults (args , 7 , [None , [1 , 1 ], [0 , 0 ], [1 , 1 ], 1 ])
664684 # (N, C, H, W) --> (N, C, 1, H, W)
665- # memory_format of both tensors should be torch.channels_last
666- # and it should be preserved with unsqueeze(2) (becoming torch.channels_last_3d)
667- assert input_tensor .is_contiguous (
668- memory_format = torch .channels_last
669- ) and weight_tensor .qdata .is_contiguous (memory_format = torch .channels_last ), (
670- "Please make sure both activation and weights are in the `channels_last` memory_format"
671- )
672685 input_tensor = input_tensor .unsqueeze (2 )
673686 weight_tensor = weight_tensor .unsqueeze (2 )
674687
675- assert input_tensor .is_contiguous (
676- memory_format = torch .channels_last_3d
677- ) and weight_tensor .qdata .is_contiguous (memory_format = torch .channels_last_3d )
678-
679688 padding = [0 , * padding ]
680689 stride = [1 , * stride ]
681690 dilation = [1 , * dilation ]
0 commit comments