@@ -60,6 +60,7 @@ def LinearAllReduce(name, x, weight, bias, group):
6060 param .hasResidual = False
6161 param .parallelType = infer_param .ParallelType .LINEAR_ALL_REDUCE
6262 param .commMode = infer_param .CommMode .COMM_MULTI_PROCESS
63+ param .commDomain = group if group is not None else ""
6364
6465 # LCCL has issues with multiple communication domains. By default,
6566 # in single-machine multi-card scenarios, LCCL is enabled.
@@ -75,7 +76,6 @@ def LinearAllReduce(name, x, weight, bias, group):
7576 param .backend = "lccl"
7677 else :
7778 param .backend = "hccl"
78- param .commDomain = group if group is not None else ""
7979 if rank_table_file is not None :
8080 param .rankTableFile = rank_table_file
8181
@@ -95,6 +95,7 @@ def AllReduce(name, x, reduce_type, group):
9595 param .rankRoot = 0
9696 param .allReduceType = reduce_type
9797 param .commMode = infer_param .CommMode .COMM_MULTI_PROCESS
98+ param .commDomain = group if group is not None else ""
9899
99100 # LCCL has issues with multiple communication domains. By default,
100101 # in single-machine multi-card scenarios, LCCL is enabled.
@@ -110,7 +111,6 @@ def AllReduce(name, x, reduce_type, group):
110111 param .backend = "lccl"
111112 else :
112113 param .backend = "hccl"
113- param .commDomain = group if group is not None else ""
114114 if rank_table_file is not None :
115115 param .rankTableFile = rank_table_file
116116
0 commit comments