99import tempfile
1010import socket
1111from os import path
12- from typing import Callable , Dict , List , Optional , Tuple , Union
12+ from typing import Callable , Dict , List , Optional , Tuple , Union , Set
1313from threading import Thread
1414from pynng .nng import Bus0 , Socket
1515from ding .utils .design_helper import SingletonMetaclass
@@ -30,10 +30,18 @@ def __init__(self) -> None:
3030 self .attach_to = None
3131 self .finished = False
3232 self .node_id = None
33+ self .labels = set ()
3334
34- def run (self , node_id : int , listen_to : str , attach_to : List [str ] = None ) -> None :
35+ def run (
36+ self ,
37+ node_id : int ,
38+ listen_to : str ,
39+ attach_to : Optional [List [str ]] = None ,
40+ labels : Optional [Set [str ]] = None
41+ ) -> None :
3542 self .node_id = node_id
3643 self .attach_to = attach_to = attach_to or []
44+ self .labels = labels or set ()
3745 self ._listener = Thread (
3846 target = self .listen ,
3947 kwargs = {
@@ -52,7 +60,9 @@ def runner(
5260 protocol : str = "ipc" ,
5361 address : Optional [str ] = None ,
5462 ports : Optional [List [int ]] = None ,
55- topology : str = "mesh"
63+ topology : str = "mesh" ,
64+ labels : Optional [Set [str ]] = None ,
65+ node_ids : Optional [List [int ]] = None
5666 ) -> Callable :
5767 """
5868 Overview:
@@ -66,6 +76,9 @@ def runner(
6676 - topology (:obj:`str`): Network topology, includes:
6777 `mesh` (default): fully connected between each other;
6878 `star`: only connect to the first node;
79+ `alone`: do not connect to any node, except the node attached to;
80+ - labels (:obj:`Optional[Set[str]]`): Labels.
81+ - node_ids (:obj:`Optional[List[int]]`): Candidate node ids.
6982 Returns:
7083 - _runner (:obj:`Callable`): The wrapper function for main.
7184 """
@@ -91,21 +104,29 @@ def cleanup_nodes():
91104
92105 atexit .register (cleanup_nodes )
93106
94- def topology_network (node_id : int ) -> List [str ]:
107+ def topology_network (i : int ) -> List [str ]:
95108 if topology == "mesh" :
96- return nodes [:node_id ] + attach_to
109+ return nodes [:i ] + attach_to
97110 elif topology == "star" :
98- return nodes [:min (1 , node_id )]
111+ return nodes [:min (1 , i )] + attach_to
112+ elif topology == "alone" :
113+ return attach_to
99114 else :
100115 raise ValueError ("Unknown topology: {}" .format (topology ))
101116
102117 params_group = []
103- for node_id in range (n_parallel_workers ):
118+ candidate_node_ids = node_ids or range (n_parallel_workers )
119+ assert len (candidate_node_ids ) == n_parallel_workers , \
120+ "The number of workers must be the same as the number of node_ids, \
121+ now there are {} workers and {} nodes" \
122+ .format (n_parallel_workers , len (candidate_node_ids ))
123+ for i in range (n_parallel_workers ):
104124 runner_args = []
105125 runner_kwargs = {
106- "node_id" : node_id ,
107- "listen_to" : nodes [node_id ],
108- "attach_to" : topology_network (node_id ) + attach_to
126+ "node_id" : candidate_node_ids [i ],
127+ "listen_to" : nodes [i ],
128+ "attach_to" : topology_network (i ) + attach_to ,
129+ "labels" : labels
109130 }
110131 params = [(runner_args , runner_kwargs ), (main_process , args , kwargs )]
111132 params_group .append (params )
@@ -151,6 +172,8 @@ def get_node_addrs(
151172 elif protocol == "tcp" :
152173 address = address or Parallel .get_ip ()
153174 ports = ports or range (50515 , 50515 + n_workers )
175+ if isinstance (ports , int ):
176+ ports = range (ports , ports + n_workers )
154177 assert len (ports ) == n_workers , "The number of ports must be the same as the number of workers, \
155178 now there are {} ports and {} workers" .format (len (ports ), n_workers )
156179 nodes = ["tcp://{}:{}" .format (address , port ) for port in ports ]
0 commit comments