Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ async def main(
engine: DynamicInferenceEngine,
requests: List[Request],
port: int,
mp_port: int,
sampling_params: SamplingParams | None = None,
):
if sampling_params is not None:
Expand All @@ -58,7 +57,6 @@ async def main(

await engine.start_listening_to_data_parallel_coordinator(
inference_coordinator_port=port,
inference_mp_coordinator_port=mp_port,
launch_inference_coordinator=True,
verbose=True,
)
Expand Down Expand Up @@ -258,6 +256,5 @@ async def main(
engine,
requests,
args.inference_coordinator_port,
args.inference_mp_coordinator_port
)
)
120 changes: 59 additions & 61 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)
from megatron.core.inference.utils import Counter, await_process_event
from megatron.core.transformer.cuda_graphs import delete_cuda_graphs
from megatron.core.utils import get_asyncio_loop, trace_async_exceptions
from megatron.core.utils import get_asyncio_loop, internal_api, trace_async_exceptions

try:
from tqdm import tqdm
Expand Down Expand Up @@ -237,10 +237,6 @@ def reset(self) -> None:

# Coordinator state.
self.use_coordinator = False
self.is_tp0_and_pp0 = (
parallel_state.get_tensor_model_parallel_rank() == 0
and parallel_state.get_pipeline_model_parallel_rank() == 0
)

def create_cuda_graphs(self, reset_context: bool = True):
"""Create cuda graphs.
Expand All @@ -263,7 +259,7 @@ def create_cuda_graphs(self, reset_context: bool = True):

if moe_pad_experts and context.non_decode_cuda_graphs:
context.non_decode_cuda_graphs = False
if torch.distributed.get_rank() == 0:
if self.rank == 0:
warnings.warn(
"MoE models do not support non-decode cuda graphs. "
"Forcing non_decode_cuda_graphs to False."
Expand Down Expand Up @@ -348,10 +344,10 @@ def create_cuda_graphs(self, reset_context: bool = True):

self.capture_stats = capture_stats

@internal_api
async def start_listening_to_data_parallel_coordinator(
self,
inference_coordinator_port: int,
inference_mp_coordinator_port: int = 20000,
launch_inference_coordinator: bool = True,
verbose: bool = False,
*,
Expand All @@ -364,6 +360,8 @@ async def start_listening_to_data_parallel_coordinator(
`InferenceCoordinator`. It configures different ZMQ socket patterns
based on the rank's role within the distributed topology.

Note that this method must be called on all ranks, as it uses blocking torch broadcasts.

The setup involves two primary roles within each data-parallel group:
1. **MP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
to the central coordinator via a ZMQ `DEALER` socket. It receives
Expand All @@ -382,9 +380,6 @@ async def start_listening_to_data_parallel_coordinator(
Args:
inference_coordinator_port (int): The network port where the central
`InferenceCoordinator` is or will be listening.
inference_mp_coordinator_port (int): The base network port where each model parallel
coordinator will broadcast messages from. Each MP group will compute an independent
port offset from this base port.
launch_inference_coordinator (bool, optional): If True, the global rank 0
process will spawn and manage the `InferenceCoordinator`
process. Defaults to True.
Expand All @@ -399,7 +394,25 @@ async def start_listening_to_data_parallel_coordinator(
"pip install msgpack"
)

if launch_inference_coordinator and torch.distributed.get_rank() == 0:
self.zmq_context = zmq.Context().instance()
self.zmq_sockets = [] # keep track of all sockets created by this engine

# Get world info.
dp_group = parallel_state.get_data_parallel_group()
dp_src = parallel_state.get_data_parallel_src_rank()
dp_size = parallel_state.get_data_parallel_world_size()
dp_rank = parallel_state.get_data_parallel_rank()

mp_group = parallel_state.get_model_parallel_group()
mp_src = parallel_state.get_model_parallel_src_rank()
tp_rank = parallel_state.get_tensor_model_parallel_rank()
pp_rank = parallel_state.get_pipeline_model_parallel_rank()

self.is_mp_coordinator = tp_rank == 0 and pp_rank == 0
self.is_dp_coordinator = (dp_rank == 0) and self.is_mp_coordinator

# Spawn a DP coordinator process and get the connection info.
if launch_inference_coordinator and self.is_dp_coordinator:
spawn_context = multiprocessing.get_context('spawn')
coordinator_ready_event = spawn_context.Event()
self.inference_coordinator_process = spawn_context.Process(
Expand All @@ -412,80 +425,67 @@ async def start_listening_to_data_parallel_coordinator(
)
self.inference_coordinator_process.start()

# Todo [Siddharth]: can we move this code to another file?
self.zmq_context = zmq.Context()
self.zmq_sockets = [] # keep track of all sockets created by this engine

# We need to broadcast the hostname of the (TP=0, PP=0) rank
# to all other ranks in the same model parallel group.
tp_rank = parallel_state.get_tensor_model_parallel_rank()
pp_rank = parallel_state.get_pipeline_model_parallel_rank()

hostname_list = [None]
if tp_rank == 0 and pp_rank == 0:
hostname_list[0] = socket.gethostname()
# Find available ports for MP and bind to them.
if self.is_mp_coordinator:
local_ip = socket.gethostname()
mp_req_sock = self.zmq_context.socket(zmq.PUB)
mp_req_sock.bind_to_random_port(f"tcp://{local_ip}")
mp_req_addr = mp_req_sock.getsockopt_string(zmq.LAST_ENDPOINT)

# Find the global rank of the (TP=0, PP=0) rank in our MP group
src_global_rank = parallel_state.get_model_parallel_src_rank()

torch.distributed.broadcast_object_list(
hostname_list, src=src_global_rank, group=parallel_state.get_model_parallel_group()
)
bcast_hostname = hostname_list[0]
mp_len_sock = self.zmq_context.socket(zmq.PUB)
mp_len_sock.bind_to_random_port(f"tcp://{local_ip}")
mp_len_addr = mp_len_sock.getsockopt_string(zmq.LAST_ENDPOINT)
else:
mp_req_addr = None
mp_len_addr = None

# We need unique ports for each MP group, so we compute an offset using the DP rank.
dp_rank = parallel_state.get_data_parallel_rank()
req_port = inference_mp_coordinator_port + (dp_rank * 2)
len_port = inference_mp_coordinator_port + (dp_rank * 2) + 1
# Broadcast addresses to respective ranks.
bcast = [mp_req_addr, mp_len_addr]
torch.distributed.broadcast_object_list(bcast, src=mp_src, group=mp_group)
[mp_req_addr, mp_len_addr] = bcast

ip_address_of_dp_coordinator = os.getenv('MASTER_ADDR', '127.0.0.1')
identity = f'mp-coord-{parallel_state.get_data_parallel_rank()}'
if (
parallel_state.get_tensor_model_parallel_rank() == 0
and parallel_state.get_pipeline_model_parallel_rank() == 0
):
dp_addr = f"tcp://{ip_address_of_dp_coordinator}:{inference_coordinator_port}"
identity = f'mp-coord-{dp_rank}'
if self.is_mp_coordinator:
# 1. Create dealer sockets where tp_rank = 0 and pp_rank = 0
# These will receive requests from an InferenceCoordinator.
self.socket_for_receiving_requests = self.zmq_context.socket(zmq.DEALER)

self.socket_for_receiving_requests.setsockopt(zmq.IDENTITY, identity.encode('utf-8'))
self.socket_for_receiving_requests.connect(
f"tcp://{ip_address_of_dp_coordinator}:{inference_coordinator_port}"
)
self.socket_for_receiving_requests.connect(dp_addr)

# send empty string. this is used to register with the coordinator.
self.socket_for_receiving_requests.send(b"")

# 2. Create a publisher socket. This is used to publish or broadcast
# requests within the model parallel group
self.model_parallel_publisher_socket = self.zmq_context.socket(zmq.PUB)
self.model_parallel_publisher_socket.bind(f"tcp://*:{req_port}")
self.model_parallel_publisher_socket = mp_req_sock

# 3. Create another publisher socket to broadcast the number of messages to receive.
self.model_parallel_num_msgs_publisher_socket = self.zmq_context.socket(zmq.PUB)
self.model_parallel_num_msgs_publisher_socket.bind(f"tcp://*:{len_port}")
self.model_parallel_num_msgs_publisher_socket = mp_len_sock
self.zmq_sockets += [
self.socket_for_receiving_requests,
self.model_parallel_num_msgs_publisher_socket,
self.model_parallel_publisher_socket,
]
# All MP ranks subscribe to the two publisher sockets
self.model_parallel_subscriber_socket = self.zmq_context.socket(zmq.SUB)
self.model_parallel_subscriber_socket.connect(f"tcp://{bcast_hostname}:{req_port}")
self.model_parallel_subscriber_socket.connect(mp_req_addr)
self.model_parallel_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")

self.model_parallel_num_msgs_subscriber_socket = self.zmq_context.socket(zmq.SUB)
self.model_parallel_num_msgs_subscriber_socket.connect(f"tcp://{bcast_hostname}:{len_port}")
self.model_parallel_num_msgs_subscriber_socket.connect(mp_len_addr)
self.model_parallel_num_msgs_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")

self.zmq_sockets += [
self.model_parallel_subscriber_socket,
self.model_parallel_num_msgs_subscriber_socket,
]

torch.distributed.barrier(parallel_state.get_model_parallel_group())
torch.distributed.barrier(mp_group)

if launch_inference_coordinator and torch.distributed.get_rank() == 0:
if launch_inference_coordinator and self.is_dp_coordinator:
await await_process_event(coordinator_ready_event, self.inference_coordinator_process)
logging.info("Inference co-ordinator is ready to receive requests!")

Expand Down Expand Up @@ -697,7 +697,7 @@ def _add_request(
try:
eod = self.controller.tokenizer.eod
except AttributeError:
if torch.distributed.get_rank() == 0:
if self.rank == 0:
warnings.warn(
"Termination ID not specified, and tokenizer does not define eod."
"Defaulting to not using termination id."
Expand Down Expand Up @@ -1093,7 +1093,7 @@ async def async_bookkeep(
self.failed_request_ids.clear()

# Handle necessary ZMQ DP coordinator communication.
if self.use_coordinator and self.is_tp0_and_pp0 and finished_request_records:
if self.use_coordinator and self.is_mp_coordinator and finished_request_records:
payload = msgpack.packb(
[Headers.ENGINE_REPLY.value, [r.serialize() for r in finished_request_records]],
use_bin_type=True,
Expand Down Expand Up @@ -1277,11 +1277,9 @@ def schedule_requests(self) -> int:
int: The number of messages that were received and processed in this batch.
"""

tp_rank = parallel_state.get_tensor_model_parallel_rank()
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
torch.cuda.nvtx.range_push("drain_zmq_socket")
all_messages = []
if tp_rank == 0 and pp_rank == 0:
if self.is_mp_coordinator:
while True:
try:
# Receive messages in a non-blocking way.
Expand All @@ -1297,8 +1295,8 @@ def schedule_requests(self) -> int:
struct.pack('!i', messages_to_dequeue)
)
# Now publish the actual messages to all model parallel ranks
for message in all_messages:
self.model_parallel_publisher_socket.send(message)
if messages_to_dequeue > 0:
self.model_parallel_publisher_socket.send_multipart(all_messages)
else:
# First, receive the number of messages to dequeue from mp-rank 0
messages_to_dequeue = struct.unpack(
Expand All @@ -1307,8 +1305,10 @@ def schedule_requests(self) -> int:
# Now, dequeue the same number of messages from the subscriber socket.
# Note that these receives are blocking, because the messages
# are guaranteed to be available after the tp-rank 0 has sent them.
for _ in range(messages_to_dequeue):
all_messages.append(self.model_parallel_subscriber_socket.recv())
if messages_to_dequeue > 0:
all_messages = self.model_parallel_subscriber_socket.recv_multipart()
else:
all_messages = []

torch.cuda.nvtx.range_pop()
for message in all_messages:
Expand Down Expand Up @@ -1347,7 +1347,6 @@ def stop(self):
for socket in self.zmq_sockets:
socket.close()
self.zmq_context.term()
parallel_state.destroy_model_parallel()

@trace_async_exceptions
async def run_engine(
Expand All @@ -1369,7 +1368,6 @@ async def run_engine(
)
)
)

await self.async_step(verbose=verbose)
except asyncio.CancelledError:
pass
Expand Down
2 changes: 0 additions & 2 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,8 +1510,6 @@ def _add_inference_args(parser):
'Default to 0 to disable inference wandb logging.')
group.add_argument("--inference-coordinator-port", type=int, default=12346,
help="This port will be used to setup the inference coordinator on node-0")
group.add_argument("--inference-mp-coordinator-port", type=int, default=20000,
help="This port will be used to setup the inference model parallel coordinators")
return parser


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ async def _run_test(cls, **test_config_kwargs):
env.timing_data["start_time"] = time.time()
await env.engine.start_listening_to_data_parallel_coordinator(
inference_coordinator_port=test_config.port,
inference_mp_coordinator_port=test_config.mp_port,
launch_inference_coordinator=test_config.launch_inference_coordinator,
)

Expand Down Expand Up @@ -232,7 +231,8 @@ async def _run_test(cls, **test_config_kwargs):
env.responses = all_results
if test_config.verify_results:
for batch in all_results:
for request in batch:
for record in batch:
request = record[-1]
assert request.status == Status.COMPLETED

return env
Expand Down
Loading