Skip to content

Commit 0d77476

Browse files
authored
[ascend] refactor parse atb graph (#172)
1 parent f091f0c commit 0d77476

File tree

7 files changed

+549
-356
lines changed

7 files changed

+549
-356
lines changed

dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def gen_call_func(self):
154154
# TODO check scalar input
155155
call_body = IndentedBuffer()
156156
self.args = [self.args_dict[x.name] for x in self.input_args]
157+
call_body.writeline("""symInputs = []""")
157158
if len(self.args) == 1:
158159
call_body.writeline(f"{self.args[0]} = args[0]")
159160
else:
@@ -169,7 +170,11 @@ def gen_call_func(self):
169170
if len(self.sym_to_inputs) > 0:
170171
for key in self.sym_to_inputs.keys():
171172
if not key.isdigit() and not self.operator_in_str(key):
172-
call_body.writeline(f"{key} = {self.sym_to_inputs[key]}")
173+
value = self.sym_to_inputs[key]
174+
call_body.writeline(f"{key} = {value}")
175+
call_body.writeline(
176+
f"""symInputs.append('{{ "name": "{key}", "value": ' + str({key}) + ' }}')"""
177+
)
173178

174179
# gen fixed output shape
175180
graph_input_names = self.atb_graph.inputs
@@ -192,7 +197,7 @@ def gen_call_func(self):
192197
input = create_info["input"]
193198
call_body.writeline(f"""{output} = {input}""")
194199

195-
call_body.writeline("""param_dict = {"hostTensors": []}""")
200+
call_body.writeline("""hostTensors = []""")
196201
call_body.writeline(f"""host_tensor_dict = {{}}""")
197202
host_tensors = []
198203
for tensor in self.atb_graph.hosts:
@@ -205,11 +210,16 @@ def gen_call_func(self):
205210
f"""host_tensor_dict["{tensor_name}"] = {tensor_name}.cpu().tolist()"""
206211
)
207212
host_tensors.append(tensor_name)
213+
call_body.writeline(
214+
f"""host_tensor_str_{tensor_name} = str(host_tensor_dict["{tensor_name}"])"""
215+
)
208216
call_body.writeline(
209-
f"""param_dict["hostTensors"].append({{"nodeId": {node_id}, "tensorId": {tensor_id}, "value": host_tensor_dict["{tensor_name}"] }})"""
217+
f"""hostTensors.append('{{"nodeId": {node_id}, "tensorId": {tensor_id}, "value": ' + str(host_tensor_str_{tensor_name}) + ' }}')"""
210218
)
211219

212-
call_body.writeline("""param = json.dumps(param_dict)""")
220+
call_body.writeline(
221+
"""param = f'{{ \"symInputs\": [{",".join(symInputs)}], \"hostTensors\": [{",".join(hostTensors)}] }}'"""
222+
)
213223
call_body.writeline(f"""inputs = [{','.join(graph_input_names)}]""")
214224

215225
call_body.writeline(f"""outputs = [{','.join(graph_output_names)}]""")

0 commit comments

Comments
 (0)