diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index 8138eda82875..cf48cc770837 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -51,9 +51,6 @@ static constexpr int kWavefront = 64; // 256 thread, 4 wavefronts. static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1}; -static constexpr int kThreadsOneShot = 512; -static dim3 constexpr kBlockOneShot = {kThreadsOneShot, 1, 1}; - // Number of threads in a group for quantization // It corresponds to 32 F16 elements in quantization block static constexpr int kThreadGroupSize = 8; diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 0a345772bd3c..1e674f07f0e0 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -18,23 +18,6 @@ namespace quickreduce { using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); -static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize2 = 8192 * 12; -static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize4 = 8192 * 8; -static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize8 = 8192 * 4; -static constexpr unsigned int kOneShotAllreduceMaxSize = - std::max(kOneShotAllreduceMaxElemsWorldSize2 * 2, - std::max(kOneShotAllreduceMaxElemsWorldSize4 * 4, - kOneShotAllreduceMaxElemsWorldSize8 * 8)) * - sizeof(half); - -template -__global__ __quickreduce_launch_bounds_one_shot__ static void -allreduce_prototype_oneshot(T const* A, T* B, uint32_t N, int rank, - uint8_t** dbuffer_list, uint32_t data_offset, - uint32_t flag_color) { - AllReduceKernel::run(A, B, N, rank, dbuffer_list, data_offset, flag_color); -} - template __global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks, @@ -50,24 +33,6 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks, } } -#define ONESHOT_DISPATCH() \ - if (world_size == 2) { \ - using AllReduceKernel = AllReduceOneshot; \ - hipLaunchKernelGGL((allreduce_prototype_oneshot), \ - dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \ - rank, dbuffer_list, data_offset, flag_color); \ - } else if (world_size == 4) { \ - using AllReduceKernel = AllReduceOneshot; \ - hipLaunchKernelGGL((allreduce_prototype_oneshot), \ - dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \ - rank, dbuffer_list, data_offset, flag_color); \ - } else if (world_size == 8) { \ - using AllReduceKernel = AllReduceOneshot; \ - hipLaunchKernelGGL((allreduce_prototype_oneshot), \ - dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \ - rank, dbuffer_list, data_offset, flag_color); \ - } - #define TWOSHOT_DISPATCH(__codec) \ if (world_size == 2) { \ using LineCodec = __codec; \ @@ -132,8 +97,7 @@ struct DeviceComms { // Allocate buffer size for worst case: Twoshot FP16 2-stage buffer. uint32_t flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int); - static constexpr int64_t data_buffer_size = std::max( - 2 * kMaxProblemSize, static_cast(kOneShotAllreduceMaxSize)); + static constexpr int64_t data_buffer_size = 2 * kMaxProblemSize; int64_t total_buffer_size = flags_buffer_size + data_buffer_size; data_offset = flags_buffer_size; HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, @@ -204,33 +168,22 @@ struct DeviceComms { // Configuration. uint32_t msg_size = N * sizeof(T); - bool use_one_shot_allreduce = - (world_size == 2 and N <= kOneShotAllreduceMaxElemsWorldSize2) or - (world_size == 4 and N <= kOneShotAllreduceMaxElemsWorldSize4) or - (world_size == 8 and N <= kOneShotAllreduceMaxElemsWorldSize8); - if (use_one_shot_allreduce) { - // Each thread processes blocks out of 4 elements - uint64_t num_blocks = divceil(N, (4 * kThreadsOneShot)); - uint64_t grid = min(kMaxNumBlocks, num_blocks); - ONESHOT_DISPATCH() - } else { - uint64_t num_blocks = divceil(msg_size, kTileSize); - uint64_t grid = min(kMaxNumBlocks, num_blocks); - auto quant_level_ = static_cast(quant_level); - switch (quant_level_) { - case QuickReduceQuantLevel::INT8: - TWOSHOT_DISPATCH(CodecQ8) - break; - case QuickReduceQuantLevel::INT6: - TWOSHOT_DISPATCH(CodecQ6) - break; - case QuickReduceQuantLevel::INT4: - TWOSHOT_DISPATCH(CodecQ4) - break; - default: - TWOSHOT_DISPATCH(CodecFP) - break; - } + uint64_t num_blocks = divceil(msg_size, kTileSize); + uint64_t grid = min(kMaxNumBlocks, num_blocks); + auto quant_level_ = static_cast(quant_level); + switch (quant_level_) { + case QuickReduceQuantLevel::INT8: + TWOSHOT_DISPATCH(CodecQ8) + break; + case QuickReduceQuantLevel::INT6: + TWOSHOT_DISPATCH(CodecQ6) + break; + case QuickReduceQuantLevel::INT4: + TWOSHOT_DISPATCH(CodecQ4) + break; + default: + TWOSHOT_DISPATCH(CodecFP) + break; } HIP_CHECK(cudaGetLastError()); // Rotate the flag color. diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 89a07629d713..92be8ab8f127 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -677,108 +677,4 @@ struct AllReduceTwoshot { } }; -// Oneshot AllReduce -template -struct AllReduceOneshot { - static_assert(sizeof(T) == 2); - - __device__ static void run( - T const* __restrict__ A, // input - T* __restrict__ B, // output - uint32_t const N, // number of elements - uint32_t const rank, // rank index - uint8_t** __restrict__ buffer_list, // communication buffers - long const data_offset, // offset to start of the data buffer - uint32_t flag_color) { - BufferResource src_buffer(const_cast(A), N * sizeof(T)); - BufferResource dst_buffer(B, N * sizeof(T)); - - uint8_t* rank_buffer = buffer_list[rank]; - - const int block_size = blockDim.x; - const int thread = threadIdx.x; - const int block = blockIdx.x; - const uint32_t problem_size = (N + 3) / 4; - - int32x4_t tA, tB; - long grid = gridDim.x; - long data_stride = grid * block_size * sizeof(int32x4_t); - long comm_flags0_offset = block * (world_size * sizeof(int)); - long comm_flags1_offset = - comm_flags0_offset + grid * (world_size * sizeof(int)); - - for (int idx = block * block_size + thread; idx < problem_size; - idx += grid * block_size) { - // load values - tA = buffer_load_dwordx4(src_buffer.descriptor, idx * sizeof(int32x4_t), - 0, 0); - - // Write rank data into this rank segment of every rank's communication - // buffer. -#pragma unroll - for (int r = 0; r < world_size; r++) { - int32x4_t* send_buffer = reinterpret_cast( - buffer_list[r] + data_offset + rank * data_stride + - idx * sizeof(int32x4_t)); - __builtin_nontemporal_store(tA, send_buffer); - } - } - - __syncthreads(); - if (thread < world_size) { - int r = thread; - int* peer_flag_ptr = reinterpret_cast( - buffer_list[r] + comm_flags0_offset + rank * sizeof(int)); - __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELEASE); - int* self_flag_ptr = reinterpret_cast( - rank_buffer + comm_flags0_offset + r * sizeof(int)); - - // Wait for the flags to be set. - while (__atomic_load_n(self_flag_ptr, __ATOMIC_ACQUIRE) != flag_color) { - } - } - __syncthreads(); - - for (int idx = block * block_size + thread; idx < problem_size; - idx += grid * block_size) { - { - int r = 0; - // Read posted data from the rank's communication buffer. - int32x4_t* recv_buffer = reinterpret_cast( - rank_buffer + data_offset + r * data_stride + - idx * sizeof(int32x4_t)); - tA = __builtin_nontemporal_load(recv_buffer); - } -#pragma unroll - for (int r = 1; r < world_size; r++) { - // Read posted data from the rank's communication buffer. - int32x4_t* recv_buffer = reinterpret_cast( - rank_buffer + data_offset + r * data_stride + - idx * sizeof(int32x4_t)); - tB = __builtin_nontemporal_load(recv_buffer); - - // Reduce the local data with the read data - packed_assign_add(&tA, &tB); - } - - buffer_store_dwordx4(tA, dst_buffer.descriptor, idx * sizeof(int32x4_t), - 0, 0); - } - - __syncthreads(); - if (thread < world_size) { - int r = thread; - int* peer_flag_ptr = reinterpret_cast( - buffer_list[r] + comm_flags1_offset + rank * sizeof(int)); - __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELAXED); - int* self_flag_ptr = reinterpret_cast( - rank_buffer + comm_flags1_offset + r * sizeof(int)); - - // Wait for the flags to be set. - while (__atomic_load_n(self_flag_ptr, __ATOMIC_RELAXED) != flag_color) { - } - } - } -}; - } // namespace quickreduce \ No newline at end of file diff --git a/tests/distributed/test_quick_reduce.py b/tests/distributed/test_quick_reduce.py deleted file mode 100644 index 32731de9a648..000000000000 --- a/tests/distributed/test_quick_reduce.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import random - -import pytest -import ray -import torch -import torch.distributed as dist - -from vllm.distributed.communication_op import ( # noqa - tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, - get_tp_group, graph_capture) -from vllm.platforms import current_platform - -from ..utils import init_test_distributed_environment, multi_process_parallel - -random.seed(42) -test_sizes = [random.randint(256 * 8 * 4, 2048 * 1024) for _ in range(8)] -for i, v in enumerate(test_sizes): - test_sizes[i] -= v % 8 - - -# Same as in custom all-reduce -# Only enable QuickReduce -@ray.remote(num_gpus=1, max_calls=1) -def graph_allreduce( - monkeypatch: pytest.MonkeyPatch, - tp_size, - pp_size, - rank, - distributed_init_port, -): - with monkeypatch.context() as m: - m.delenv("CUDA_VISIBLE_DEVICES", raising=False) - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) - - group = get_tensor_model_parallel_group().device_group - - # A small all_reduce for warmup. - # this is needed because device communicators might be created lazily - # (e.g. NCCL). This will ensure that the communicator is initialized - # before any communication happens, so that this group can be used for - # graph capture immediately. - data = torch.zeros(1) - data = data.to(device=device) - torch.distributed.all_reduce(data, group=group) - torch.cuda.synchronize() - del data - - # we use the first group to communicate once - # and the second group to communicate twice - # and so on - # this is used to demonstrate that each group can - # communicate independently - num_communication = rank // tp_size + 1 - - for sz in test_sizes: - for dtype in [torch.float16, torch.bfloat16]: - with graph_capture(device=device) as graph_capture_context: - # use integers so result matches NCCL exactly - inp1 = torch.randint(1, - 16, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) - inp2 = torch.randint(1, - 16, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) - torch.cuda.synchronize() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, - stream=graph_capture_context.stream): - for i in range(num_communication): - out1 = tensor_model_parallel_all_reduce(inp1) - # the input buffer is immediately modified to test - # synchronization - dist.all_reduce(inp1, group=group) - out2 = tensor_model_parallel_all_reduce(inp2) - dist.all_reduce(inp2, group=group) - graph.replay() - torch.testing.assert_close(out1, inp1) - torch.testing.assert_close(out2, inp2) - - -@ray.remote(num_gpus=1, max_calls=1) -def eager_quick_allreduce( - monkeypatch: pytest.MonkeyPatch, - tp_size, - pp_size, - rank, - distributed_init_port, -): - with monkeypatch.context() as m: - m.delenv("CUDA_VISIBLE_DEVICES", raising=False) - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) - for dtype in [torch.float16, torch.bfloat16]: - - num_communication = rank // tp_size + 1 - sz = 256 * 8 * 8 - qr_comm = get_tp_group().device_communicator.qr_comm - inp = torch.ones(sz, dtype=dtype, device=device) - out = inp - for _ in range(num_communication): - out = qr_comm.all_reduce(out) - torch.testing.assert_close(out, inp * (tp_size**num_communication)) - - -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="Quick reduce is only supported on RocM.") -@pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) -@pytest.mark.parametrize("test_target", - [eager_quick_allreduce, graph_allreduce]) -def test_quick_reduce_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, - pipeline_parallel_size, test_target): - world_size = tp_size * pipeline_parallel_size - if world_size > torch.cuda.device_count(): - pytest.skip("Not enough GPUs to run the test.") - multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, - test_target) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 4d4ba4a8b4b6..5b724b33d2ec 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -8,7 +8,6 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.platforms import current_platform from .base_device_communicator import DeviceCommunicatorBase @@ -42,8 +41,6 @@ def __init__(self, CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) - from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickAllReduce) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -55,16 +52,10 @@ def __init__(self, self.ca_comm: Optional[CustomAllreduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. - self.ca_comm = CustomAllreduce(group=self.cpu_group, - device=self.device) - - self.qr_comm: Optional[QuickAllReduce] = None - if (use_custom_allreduce and current_platform.is_rocm() - and self.world_size > 1): - # Initialize a custom fast all-reduce implementation for AMD - # based on quick reduce (https://github.com/mk1-project/quickreduce). - self.qr_comm = QuickAllReduce(group=self.cpu_group, - device=self.device) + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND @@ -88,15 +79,8 @@ def __init__(self, raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): - # always try quick reduce first, then custom allreduce, - # and then pynccl. (quick reduce just for ROCM MI3*) - qr_comm = self.qr_comm - if qr_comm is not None and not qr_comm.disabled and \ - qr_comm.should_quick_allreduce(input_): - out = qr_comm.all_reduce(input_) - assert out is not None - return out - + # always try custom allreduce first, + # and then pynccl. ca_comm = self.ca_comm if ca_comm is not None and not ca_comm.disabled and \ ca_comm.should_custom_ar(input_): @@ -189,4 +173,4 @@ def dispatch( def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: assert self.all2all_manager is not None hidden_states = self.all2all_manager.combine(hidden_states) - return hidden_states + return hidden_states \ No newline at end of file diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 7dd104a4fcc4..40adaf891171 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager +from enum import Enum from typing import Optional, Union import torch @@ -10,6 +11,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.config import get_current_vllm_config from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as @@ -23,10 +25,24 @@ except Exception: # For CPUs custom_ar = False +try: + ops.qr_max_size() + quick_ar = True +except Exception: + # For CPUs + quick_ar = False logger = init_logger(__name__) +class QuickReduceRegime(Enum): + FP = 0 + INT8 = 1 + INT6 = 2 + INT4 = 3 + NONE = 4 + + def _can_p2p(rank: int, world_size: int) -> bool: for i in range(world_size): if i == rank: @@ -49,32 +65,58 @@ def is_weak_contiguous(inp: torch.Tensor): class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + _QR_SUPPORTED_WORLD_SIZES = [2, 4, 8] # max_size: max supported allreduce size def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], - max_size=8192 * 1024) -> None: + cr_max_size=8192 * 1024, + qr_max_size=512 * 1024 * 1024, + qr_min_size=2 * 1024 * 1024) -> None: """ + Custom allredcue (cr) is non-destructive acceleration, which is + available for cuda and rocm MI300 series. + Custom quick allreduce (qr) is accelerated by quantization, + currently supports fp16, Q8, Q6, Q4 quantization. + We view qr as complementary to cr, the condition for qr is + even more demanding; qr is initialized, then cr must also + be initialized. If the conditions of cr are not met, qr is + naturally not initialized. + Due to instruction set limitations, only rocm MI300 series + is supported for the time being. Args: group: the process group to work on. If None, it will use the default process group. device: the device to bind the CustomAllreduce to. If None, it will be bind to f"cuda:{local_rank}". + cr_max_size: max supported size of cr. + qr_max_size: max supported size of qr. + qr_min_size: min supported size of qr. Less than this size, + cr is better. It is the caller's responsibility to make sure each communicator is bind to a unique device, and all communicators in this group are in the same node. """ + self._QR_SHOULD_INIT = True self._IS_CAPTURING = False self.disabled = True + self.cr_max_size = cr_max_size + self.qr_max_size = qr_max_size + self.qr_min_size = qr_min_size if not custom_ar: # disable because of missing custom allreduce library # e.g. in a non-GPU environment logger.info("Custom allreduce is disabled because " "of missing custom allreduce library") - return + if not quick_ar: + logger.info("Custom quick allreduce is disabled because " + "of missing quick allreduce library") + self._QR_SHOULD_INIT = False + if not quick_ar and not custom_ar: + return self.group = group assert dist.get_backend(group) != dist.Backend.NCCL, ( @@ -88,10 +130,12 @@ def __init__(self, return rank = dist.get_rank(group=self.group) - self.rank = rank world_size = dist.get_world_size(group=self.group) + self.rank = rank + self.world_size = world_size if world_size == 1: - # No need to initialize custom allreduce for single GPU case. + # No need to initialize custom allreduce or custom quick + # allreduce for single GPU case. return if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: @@ -102,6 +146,13 @@ def __init__(self, world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) return + if self._QR_SHOULD_INIT and \ + world_size not in CustomAllreduce._QR_SUPPORTED_WORLD_SIZES: + self._QR_SHOULD_INIT = False + logger.warning( + "Custom quick allreduce is disabled due to an unsupported " + "world size: %d.", world_size) + if isinstance(device, int): device = torch.device(f"cuda:{device}") elif isinstance(device, str): @@ -131,9 +182,9 @@ def __init__(self, # where custom allreduce is not supported # this checks hardware and driver support for NVLink assert current_platform.is_cuda_alike() - fully_connected = current_platform.is_fully_connected( + self.fully_connected = current_platform.is_fully_connected( physical_device_ids) - if world_size > 2 and not fully_connected: + if world_size > 2 and not self.fully_connected: logger.warning( "Custom allreduce is disabled because it's not supported on" " more than two PCIe-only GPUs. To silence this warning, " @@ -143,23 +194,36 @@ def __init__(self, # this is expensive to compute at the first time # then we cache the result # On AMD GPU, p2p is always enabled between XGMI connected GPUs - if not current_platform.is_rocm() and not _can_p2p(rank, world_size): - logger.warning( - "Custom allreduce is disabled because your platform lacks " - "GPU P2P capability or P2P test failed. To silence this " - "warning, specify disable_custom_all_reduce=True explicitly.") - return - + if not current_platform.is_rocm(): + # First, we only enable custom allreduce for MI300 series, + # If it's rocm then it must be MI300 series, qr must be available. + self._QR_SHOULD_INIT = False + if not _can_p2p(rank, world_size): + logger.warning( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True " + "explicitly.") + return self.disabled = False + self.init_custom_allreduce() + self.init_custom_quick_allreduce() + + def init_custom_allreduce(self): + """ + Initialize custom allreduce + """ # Buffers memory are owned by this Python class and passed to C++. # Meta data composes of two parts: meta data for synchronization and a # temporary buffer for storing intermediate allreduce results. - self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, - group=group, + self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + + self.cr_max_size, + group=self.group, uncached=True) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + self.buffer_ptrs = self.create_shared_buffer(self.cr_max_size, + group=self.group) # This is a buffer for storing the tuples of pointers pointing to # IPC buffers from all ranks. Each registered tuple has size of # 8*world_size bytes where world_size is at most 8. Allocating 8MB @@ -168,13 +232,62 @@ def __init__(self, self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=self.device) - self.max_size = max_size - self.rank = rank - self.world_size = world_size - self.fully_connected = fully_connected - self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, - self.fully_connected) - ops.register_buffer(self._ptr, self.buffer_ptrs) + self.cr_max_size = self.cr_max_size + + self._cr_ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, + self.rank, self.fully_connected) + ops.register_buffer(self._cr_ptr, self.buffer_ptrs) + + def init_custom_quick_allreduce(self): + """ + Initialize a custom quick allreduce implementation for AMD + based on quick reduce (https://github.com/mk1-project/quickreduce). + """ + vllm_config = get_current_vllm_config() + dtype = vllm_config.model_config.dtype + if dtype not in [torch.float16, torch.bfloat16]: + self._QR_SHOULD_INIT = False + + # On RocM bfloat16 kernels are slower than fp16 + # due to slower match operations + # If environment is not set to 1 we convert input to fp16 + self.use_fp16_kernels: bool = envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16 + regime_str = envs.VLLM_ROCM_QR_QUANT_REGIME + + if self._QR_SHOULD_INIT: + if regime_str not in QuickReduceRegime.__members__: + logger.warning( + "Custom quick allreduce:", + f"Invalid quantization level: {regime_str}. " + "Supported levels: " + f"{list(QuickReduceRegime.__members__.keys())}") + return + + if regime_str == "NONE": + logger.debug("Custom quick allreduce is disabled based " + "on env variable VLLM_ROCM_QR_QUANT_REGIME") + return + + self.qr_quant_level = QuickReduceRegime[regime_str] + # These numbers are based on kernel tests. + # TODO: We need the full kernel test to guide the + # size adjustment here + if self.world_size == 2: + self.qr_min_size = 1 * 1024 * 1024 + else: + self.qr_min_size = 2 * 1024 * 1024 + self._qr_ptr = ops.init_custom_qr(self.rank, self.world_size) + self.create_qr_shared_buffer() + if dtype == torch.bfloat16 and self.use_fp16_kernels: + logger.info( + "Custom quick allreduce: due to the lack of bf16 assembly " + "instruction set, the performance gain of bf16 is " + "limited. We convert bfloat16 to float16 to speed " + "up quick allreduce. You can set " + "envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16=0 to turn " + "this conversion off.") + # There is no case where qr is initialized and + # cr is not initialized @contextmanager def capture(self): @@ -192,7 +305,7 @@ def capture(self): self.register_graph_buffers() def register_graph_buffers(self): - handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + handle, offset = ops.get_graph_buffer_ipc_meta(self._cr_ptr) logger.info("Registering %d cuda graph addresses", len(offset)) # We cannot directly use `dist.all_gather_object` here # because it is incompatible with `gloo` backend under inference mode. @@ -209,9 +322,37 @@ def register_graph_buffers(self): # Unpack list of tuples to tuple of lists. handles = [d[0] for d in all_data] # type: ignore offsets = [d[1] for d in all_data] # type: ignore - ops.register_graph_buffers(self._ptr, handles, offsets) + ops.register_graph_buffers(self._cr_ptr, handles, offsets) - def should_custom_ar(self, inp: torch.Tensor): + def should_quick_allreduce(self, inp: torch.Tensor): + """ + Check if quickreduce is available + """ + if self.disabled and not self._QR_SHOULD_INIT: + return False + inp_size = inp.numel() * inp.element_size() + # custom quick allreduce requires input byte size to be + # multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # custom quick allreduce requires input byte size to be multiples of 16 + if inp.dtype == torch.float16: + return inp_size <= self.qr_max_size and inp_size >= self.qr_min_size + elif inp.dtype == torch.bfloat16: + if self.use_fp16_kernels: + # cast2half, so the same condition + return inp_size <= self.qr_max_size and \ + inp_size >= self.qr_min_size + else: + # TODO: check bf16 condition for mi300 + return (inp_size <= self.qr_max_size + and inp_size > 1024 * 1024 * 16 + and self.world_size == 2) + return False + + def should_custom_allreduce(self, inp: torch.Tensor): if self.disabled: return False inp_size = inp.numel() * inp.element_size() @@ -223,15 +364,20 @@ def should_custom_ar(self, inp: torch.Tensor): # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. if self.world_size == 2 or self.fully_connected: - return inp_size < self.max_size + return inp_size < self.cr_max_size return False - def all_reduce(self, - inp: torch.Tensor, - *, - out: torch.Tensor = None, - registered: bool = False): - """Performs an out-of-place all reduce. + def should_custom_ar(self, inp: torch.Tensor): + # Determine whether to use qr, or cr or quit + return self.should_quick_allreduce( + inp) or self.should_custom_allreduce(inp) + + def cr_all_reduce(self, + inp: torch.Tensor, + *, + out: torch.Tensor = None, + registered: bool = False): + """Performs an out-of-place custom all reduce. If registered is True, this assumes inp's pointer is already IPC-registered. Otherwise, inp is first copied into a pre-registered @@ -240,37 +386,69 @@ def all_reduce(self, if out is None: out = torch.empty_like(inp) if registered: - ops.all_reduce(self._ptr, inp, out, 0, 0) + ops.all_reduce(self._cr_ptr, inp, out, 0, 0) else: - ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], - self.max_size) + ops.all_reduce(self._cr_ptr, inp, out, self.buffer_ptrs[self.rank], + self.cr_max_size) return out + def qr_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): + """Performs an out-of-place custom quick all reduce.""" + inp_dtype = inp.dtype + if inp_dtype == torch.bfloat16 and self.use_fp16_kernels: + inp = inp.to(torch.float16) + if out is None: + out = torch.empty_like(inp) + ops.qr_all_reduce(self._qr_ptr, inp, out, self.qr_quant_level.value) + return out.to(inp_dtype) + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: """The main allreduce API that provides support for cuda graph.""" # When custom allreduce is disabled, this will be None. - if self.disabled or not self.should_custom_ar(input): + if self.disabled: return None - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - return self.all_reduce(input, registered=True) - else: + # try custom quick allreduce first, then custom allreduce + if self.should_quick_allreduce(input): + # We don't need the context of quick allreduce to do graph capture + # because the ipc access is already collected in init() and + # we can capture the quick allreduce directly. + if self._IS_CAPTURING and \ + not torch.cuda.is_current_stream_capturing(): # If warm up, mimic the allocation pattern since custom # allreduce is out-of-place. return torch.empty_like(input) - else: - # Note: outside of cuda graph context, custom allreduce incurs a - # cost of cudaMemcpy, which should be small (<=1% of overall - # latency) compared to the performance gain of using custom kernels - return self.all_reduce(input, registered=False) + else: + return self.qr_all_reduce(input) + + if self.should_custom_allreduce(input): + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + return self.cr_all_reduce(input, registered=True) + else: + # If warm up, mimic the allocation pattern since custom + # allreduce is out-of-place. + return torch.empty_like(input) + else: + # Note: outside of cuda graph context, custom allreduce + # incurs a cost of cudaMemcpy, which should be small + # (<=1% of overall latency) compared to the performance + # gain of using custom kernels + return self.cr_all_reduce(input, registered=False) + + return None def close(self): - if not self.disabled and self._ptr: - if ops is not None: - ops.dispose(self._ptr) - self._ptr = 0 - self.free_shared_buffer(self.meta_ptrs, rank=self.rank) - self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) + if not self.disabled: + if self._cr_ptr: + if ops is not None: + ops.dispose(self._cr_ptr) + self._cr_ptr = 0 + self.free_shared_buffer(self.meta_ptrs, rank=self.rank) + self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) + if self._qr_ptr: + if ops is not None: + ops.qr_destroy(self._qr_ptr) + self._qr_ptr = 0 def __del__(self): self.close() @@ -294,6 +472,17 @@ def create_shared_buffer(size_in_bytes: int, pointers.append(ops.open_mem_handle(h)) return pointers + def create_qr_shared_buffer(self): + """ + Creates a shared buffer for quickreduce. + Has to be called after qr_init_device_collectives + """ + handle = ops.qr_get_handle(self._qr_ptr) + world_size = dist.get_world_size(group=self.group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=self.group) + ops.qr_open_handles(self._qr_ptr, handles) + @staticmethod def free_shared_buffer(pointers: list[int], group: Optional[ProcessGroup] = None, diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py deleted file mode 100644 index 322633c220a4..000000000000 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ /dev/null @@ -1,138 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import logging -from enum import Enum -from typing import Union - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup - -import vllm.envs as envs -from vllm import _custom_ops as ops -from vllm.platforms import current_platform - -logger = logging.getLogger(__name__) - -try: - ops.qr_max_size() - ops_available = True -except Exception: - # For CPUs - ops_available = False - - -class QuickReduceRegime(Enum): - FP = 0 - INT8 = 1 - INT6 = 2 - INT4 = 3 - NONE = 4 - - -class QuickAllReduce: - _SUPPORTED_WORLD_SIZES = [2, 4, 8] - _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] - - def __init__(self, group: ProcessGroup, - device: Union[int, str, torch.device]) -> None: - self.disabled = True - if not ops_available: - # disable because of missing quick reduce library - # e.g. in a non-cuda environment - logger.info("Custom quick allreduce is disabled because " - "of missing custom quick allreduce library") - return - - self.max_size = ops.qr_max_size() - self.group = group - regime_str = envs.VLLM_ROCM_CA_QUANT_REGIME - assert regime_str in QuickReduceRegime.__members__, ( - f"Invalid quantization level: {regime_str}. " - "Supported levels: " - f"{list(QuickReduceRegime.__members__.keys())}") - if regime_str == "NONE": - logger.debug("Custom quick allreduce is disabled based " - "on env variable VLLM_ROCM_CA_QUANT_REGIME") - return - self.quant_level = QuickReduceRegime[regime_str] - # On RocM bfloat16 kernels are slower than fp16 - # due to slower match operations - # If environment is not set to 1 we convert input to fp16 - self.use_fp16_kernels = envs.VLLM_ROCM_CA_CAST_BF16_TO_FP16 - assert dist.get_backend(group) != dist.Backend.NCCL, ( - "QuickReduce should be attached to a non-NCCL group.") - rank = dist.get_rank(group=self.group) - world_size = dist.get_world_size(group=self.group) - if world_size == 1: - # No need to initialize QuickReduce for single GPU case. - return - - if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES: - logger.warning( - "QuickReduce is disabled due to an unsupported world" - " size: %d. Supported world sizes: %s." - " To disable this warning set VLLM_ROCM_CA_BACKEND" - " to None", world_size, - str(QuickAllReduce._SUPPORTED_WORLD_SIZES)) - return - - assert current_platform.is_rocm(), ( - "QuickReduce is only supported on ROCm platform.") - if isinstance(device, int): - device = torch.device(f"cuda:{device}") - elif isinstance(device, str): - device = torch.device(device) - # now `device` is a `torch.device` object - assert isinstance(device, torch.device) - self.device = device - torch.cuda.set_device(self.device) - - self._ptr = ops.init_custom_qr(rank, world_size) - self.create_shared_buffer() - self.disabled = False - - def create_shared_buffer(self): - """ - Creates a shared buffer for quickreduce. - Has to be called after qr_init_device_collectives - """ - handle = ops.qr_get_handle(self._ptr) - world_size = dist.get_world_size(group=self.group) - handles = [None] * world_size - dist.all_gather_object(handles, handle, group=self.group) - ops.qr_open_handles(self._ptr, handles) - - def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): - """ - Performs an out-of-place all reduce. - """ - inp_size = inp.numel() * inp.element_size() - if inp_size > self.max_size: - return None - - inp_dtype = inp.dtype - if inp_dtype == torch.bfloat16 and self.use_fp16_kernels: - inp = inp.to(torch.float16) - if out is None: - out = torch.empty_like(inp) - - ops.qr_all_reduce(self._ptr, inp, out, self.quant_level.value) - return out.to(inp_dtype) - - def close(self): - if not self.disabled and getattr(self, "_ptr", None): - ops.qr_destroy(self._ptr) - self._ptr = 0 - - def __del__(self): - self.close() - - def should_quick_allreduce(self, inp: torch.Tensor): - if self.disabled: - return False - inp_size = inp.numel() * inp.element_size() - # QuickReduce requires input byte size to be multiples of 16 - if inp_size % 16 != 0: - return False - return inp.dtype in QuickAllReduce._SUPPORTED_DTYPES and \ - inp_size < self.max_size diff --git a/vllm/envs.py b/vllm/envs.py index 56443666e584..3d8d1e4d06ff 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -128,8 +128,8 @@ VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 - VLLM_ROCM_CA_QUANT_REGIME: str = "FP" - VLLM_ROCM_CA_CAST_BF16_TO_FP16: bool = True + VLLM_ROCM_QR_QUANT_REGIME: str = "FP" + VLLM_ROCM_QR_CAST_BF16_TO_FP16: bool = True def get_default_cache_root(): @@ -675,15 +675,15 @@ def get_vllm_port() -> Optional[int]: # Custom quick allreduce kernel for MI3* cards # Choice of quantization level: FP, INT8, INT6, INT4 or NONE # Recommended for large models to get allreduce - "VLLM_ROCM_CA_QUANT_REGIME": - lambda: os.getenv("VLLM_ROCM_CA_QUANT_REGIME", "FP").upper(), + "VLLM_ROCM_QR_QUANT_REGIME": + lambda: os.getenv("VLLM_ROCM_QR_QUANT_REGIME", "FP").upper(), # Custom quick allreduce kernel for MI3* cards # Due to the lack of the bfloat16 asm instruction, bfloat16 # kernels are slower than fp16, # If environment is not set to 1, we convert input to fp16 - "VLLM_ROCM_CA_CAST_BF16_TO_FP16": - lambda: (os.getenv("VLLM_ROCM_CA_CAST_BF16_TO_FP16", "True").lower() in + "VLLM_ROCM_QR_CAST_BF16_TO_FP16": + lambda: (os.getenv("VLLM_ROCM_QR_CAST_BF16_TO_FP16", "True").lower() in ("true", "1")), # If set, when running in Quark emulation mode, do not dequantize the