Skip to content

Commit 954d310

Browse files
authored
feature(xjx): cli in new pipeline (#160)
* Cli ditask * Import ditask in init * Add current path as default package path * Fix style * Add topology on ditask
1 parent 92d973c commit 954d310

File tree

5 files changed

+135
-11
lines changed

5 files changed

+135
-11
lines changed

ding/entry/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .cli import cli
2+
from .cli_ditask import cli_ditask
23
from .serial_entry import serial_pipeline
34
from .serial_entry_onpolicy import serial_pipeline_onpolicy
45
from .serial_entry_offline import serial_pipeline_offline

ding/entry/cli_ditask.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import click
2+
import os
3+
import sys
4+
import importlib
5+
import importlib.util
6+
from click.core import Context, Option
7+
8+
from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__
9+
from ding.framework import Parallel
10+
11+
12+
def print_version(ctx: Context, param: Option, value: bool) -> None:
13+
if not value or ctx.resilient_parsing:
14+
return
15+
click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__))
16+
click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__))
17+
ctx.exit()
18+
19+
20+
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
21+
22+
23+
@click.command(context_settings=CONTEXT_SETTINGS)
24+
@click.option(
25+
'-v',
26+
'--version',
27+
is_flag=True,
28+
callback=print_version,
29+
expose_value=False,
30+
is_eager=True,
31+
help="Show package's version information."
32+
)
33+
@click.option('-p', '--package', type=str, help="Your code package path, could be a directory or a zip file.")
34+
@click.option('--parallel-workers', type=int, default=1, help="Parallel worker number, default: 1")
35+
@click.option(
36+
'--protocol',
37+
type=click.Choice(["tcp", "ipc"]),
38+
default="tcp",
39+
help="Network protocol in parallel mode, default: tcp"
40+
)
41+
@click.option(
42+
"--ports",
43+
type=str,
44+
default="50515",
45+
help="The port addresses that the tasks listen to, e.g. 50515,50516, default: 50515"
46+
)
47+
@click.option("--attach-to", type=str, help="The addresses to connect to.")
48+
@click.option("--address", type=str, help="The address to listen to (without port).")
49+
@click.option("--labels", type=str, help="Labels.")
50+
@click.option("--node-ids", type=str, help="Candidate node ids.")
51+
@click.option(
52+
"--topology",
53+
type=click.Choice(["alone", "mesh", "star"]),
54+
default="alone",
55+
help="Network topology, default: alone."
56+
)
57+
@click.option("-m", "--main", type=str, help="Main function of entry module.")
58+
def cli_ditask(
59+
package: str, main: str, parallel_workers: int, protocol: str, ports: str, attach_to: str, address: str,
60+
labels: str, node_ids: str, topology: str
61+
):
62+
# Parse entry point
63+
if not package:
64+
package = os.getcwd()
65+
sys.path.append(package)
66+
if main is None:
67+
mod_name = os.path.basename(package)
68+
mod_name, _ = os.path.splitext(mod_name)
69+
func_name = "main"
70+
else:
71+
mod_name, func_name = main.rsplit(".", 1)
72+
root_mod_name = mod_name.split(".", 1)[0]
73+
sys.path.append(os.path.join(package, root_mod_name))
74+
mod = importlib.import_module(mod_name)
75+
main_func = getattr(mod, func_name)
76+
# Parse arguments
77+
ports = ports.split(",")
78+
ports = list(map(lambda i: int(i), ports))
79+
ports = ports[0] if len(ports) == 1 else ports
80+
if attach_to:
81+
attach_to = attach_to.split(",")
82+
attach_to = list(map(lambda s: s.strip(), attach_to))
83+
if labels:
84+
labels = labels.split(",")
85+
labels = set(map(lambda s: s.strip(), labels))
86+
if node_ids:
87+
node_ids = node_ids.split(",")
88+
node_ids = list(map(lambda i: int(i), node_ids))
89+
Parallel.runner(
90+
n_parallel_workers=parallel_workers,
91+
ports=ports,
92+
protocol=protocol,
93+
topology=topology,
94+
attach_to=attach_to,
95+
address=address,
96+
labels=labels,
97+
node_ids=node_ids
98+
)(main_func)

ding/framework/parallel.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import tempfile
1010
import socket
1111
from os import path
12-
from typing import Callable, Dict, List, Optional, Tuple, Union
12+
from typing import Callable, Dict, List, Optional, Tuple, Union, Set
1313
from threading import Thread
1414
from pynng.nng import Bus0, Socket
1515
from ding.utils.design_helper import SingletonMetaclass
@@ -30,10 +30,18 @@ def __init__(self) -> None:
3030
self.attach_to = None
3131
self.finished = False
3232
self.node_id = None
33+
self.labels = set()
3334

34-
def run(self, node_id: int, listen_to: str, attach_to: List[str] = None) -> None:
35+
def run(
36+
self,
37+
node_id: int,
38+
listen_to: str,
39+
attach_to: Optional[List[str]] = None,
40+
labels: Optional[Set[str]] = None
41+
) -> None:
3542
self.node_id = node_id
3643
self.attach_to = attach_to = attach_to or []
44+
self.labels = labels or set()
3745
self._listener = Thread(
3846
target=self.listen,
3947
kwargs={
@@ -52,7 +60,9 @@ def runner(
5260
protocol: str = "ipc",
5361
address: Optional[str] = None,
5462
ports: Optional[List[int]] = None,
55-
topology: str = "mesh"
63+
topology: str = "mesh",
64+
labels: Optional[Set[str]] = None,
65+
node_ids: Optional[List[int]] = None
5666
) -> Callable:
5767
"""
5868
Overview:
@@ -66,6 +76,9 @@ def runner(
6676
- topology (:obj:`str`): Network topology, includes:
6777
`mesh` (default): fully connected between each other;
6878
`star`: only connect to the first node;
79+
`alone`: do not connect to any node, except the node attached to;
80+
- labels (:obj:`Optional[Set[str]]`): Labels.
81+
- node_ids (:obj:`Optional[List[int]]`): Candidate node ids.
6982
Returns:
7083
- _runner (:obj:`Callable`): The wrapper function for main.
7184
"""
@@ -91,21 +104,29 @@ def cleanup_nodes():
91104

92105
atexit.register(cleanup_nodes)
93106

94-
def topology_network(node_id: int) -> List[str]:
107+
def topology_network(i: int) -> List[str]:
95108
if topology == "mesh":
96-
return nodes[:node_id] + attach_to
109+
return nodes[:i] + attach_to
97110
elif topology == "star":
98-
return nodes[:min(1, node_id)]
111+
return nodes[:min(1, i)] + attach_to
112+
elif topology == "alone":
113+
return attach_to
99114
else:
100115
raise ValueError("Unknown topology: {}".format(topology))
101116

102117
params_group = []
103-
for node_id in range(n_parallel_workers):
118+
candidate_node_ids = node_ids or range(n_parallel_workers)
119+
assert len(candidate_node_ids) == n_parallel_workers, \
120+
"The number of workers must be the same as the number of node_ids, \
121+
now there are {} workers and {} nodes"\
122+
.format(n_parallel_workers, len(candidate_node_ids))
123+
for i in range(n_parallel_workers):
104124
runner_args = []
105125
runner_kwargs = {
106-
"node_id": node_id,
107-
"listen_to": nodes[node_id],
108-
"attach_to": topology_network(node_id) + attach_to
126+
"node_id": candidate_node_ids[i],
127+
"listen_to": nodes[i],
128+
"attach_to": topology_network(i) + attach_to,
129+
"labels": labels
109130
}
110131
params = [(runner_args, runner_kwargs), (main_process, args, kwargs)]
111132
params_group.append(params)
@@ -151,6 +172,8 @@ def get_node_addrs(
151172
elif protocol == "tcp":
152173
address = address or Parallel.get_ip()
153174
ports = ports or range(50515, 50515 + n_workers)
175+
if isinstance(ports, int):
176+
ports = range(ports, ports + n_workers)
154177
assert len(ports) == n_workers, "The number of ports must be the same as the number of workers, \
155178
now there are {} ports and {} workers".format(len(ports), n_workers)
156179
nodes = ["tcp://{}:{}".format(address, port) for port in ports]

ding/framework/task.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def init_labels(self):
9898
if self.router.is_active:
9999
self.labels.add("distributed")
100100
self.labels.add("node.{}".format(self.router.node_id))
101+
for label in self.router.labels:
102+
self.labels.add(label)
101103
else:
102104
self.labels.add("standalone")
103105

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@
158158
'kubernetes',
159159
]
160160
},
161-
entry_points={'console_scripts': ['ding=ding.entry.cli:cli']},
161+
entry_points={'console_scripts': ['ding=ding.entry.cli:cli', 'ditask=ding.entry.cli_ditask:cli_ditask']},
162162
classifiers=[
163163
'Development Status :: 5 - Production/Stable',
164164
"Intended Audience :: Science/Research",

0 commit comments

Comments
 (0)