|
121 | 121 | layer_tp_plan = { |
122 | 122 | "attention_norm": SequenceParallel(), |
123 | 123 | "attention": PrepareModuleInput( |
124 | | - input_layouts=(Shard(1), None), |
125 | | - desired_input_layouts=(Replicate(), None), |
| 124 | + input_layouts=(Shard(1), Replicate()), |
| 125 | + desired_input_layouts=(Replicate(), Replicate()), |
126 | 126 | ), |
127 | | - "attention.wq": ColwiseParallel(), |
128 | | - "attention.wk": ColwiseParallel(), |
129 | | - "attention.wv": ColwiseParallel(), |
| 127 | + "attention.wq": ColwiseParallel(use_local_output=False), |
| 128 | + "attention.wk": ColwiseParallel(use_local_output=False), |
| 129 | + "attention.wv": ColwiseParallel(use_local_output=False), |
130 | 130 | "attention.wo": RowwiseParallel(output_layouts=Shard(1)), |
131 | 131 | "ffn_norm": SequenceParallel(), |
132 | 132 | "feed_forward": PrepareModuleInput( |
|
138 | 138 | "feed_forward.w3": ColwiseParallel(), |
139 | 139 | } |
140 | 140 |
|
141 | | - # Adjust attention module to use the local number of heads |
142 | | - attn_layer = transformer_block.attention |
143 | | - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() |
144 | | - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() |
145 | | - |
146 | 141 | # Custom parallelization plan for the model |
147 | 142 | parallelize_module( |
148 | 143 | module=transformer_block, |
|
0 commit comments