@@ -154,6 +154,7 @@ def gen_call_func(self):
154154 # TODO check scalar input
155155 call_body = IndentedBuffer ()
156156 self .args = [self .args_dict [x .name ] for x in self .input_args ]
157+ call_body .writeline ("""symInputs = []""" )
157158 if len (self .args ) == 1 :
158159 call_body .writeline (f"{ self .args [0 ]} = args[0]" )
159160 else :
@@ -169,7 +170,11 @@ def gen_call_func(self):
169170 if len (self .sym_to_inputs ) > 0 :
170171 for key in self .sym_to_inputs .keys ():
171172 if not key .isdigit () and not self .operator_in_str (key ):
172- call_body .writeline (f"{ key } = { self .sym_to_inputs [key ]} " )
173+ value = self .sym_to_inputs [key ]
174+ call_body .writeline (f"{ key } = { value } " )
175+ call_body .writeline (
176+ f"""symInputs.append('{{ "name": "{ key } ", "value": ' + str({ key } ) + ' }}')"""
177+ )
173178
174179 # gen fixed output shape
175180 graph_input_names = self .atb_graph .inputs
@@ -192,7 +197,7 @@ def gen_call_func(self):
192197 input = create_info ["input" ]
193198 call_body .writeline (f"""{ output } = { input } """ )
194199
195- call_body .writeline ("""param_dict = {"hostTensors": []} """ )
200+ call_body .writeline ("""hostTensors = [] """ )
196201 call_body .writeline (f"""host_tensor_dict = {{}}""" )
197202 host_tensors = []
198203 for tensor in self .atb_graph .hosts :
@@ -205,11 +210,16 @@ def gen_call_func(self):
205210 f"""host_tensor_dict["{ tensor_name } "] = { tensor_name } .cpu().tolist()"""
206211 )
207212 host_tensors .append (tensor_name )
213+ call_body .writeline (
214+ f"""host_tensor_str_{ tensor_name } = str(host_tensor_dict["{ tensor_name } "])"""
215+ )
208216 call_body .writeline (
209- f"""param_dict[" hostTensors"] .append({{"nodeId": { node_id } , "tensorId": { tensor_id } , "value": host_tensor_dict[" { tensor_name } "] }} )"""
217+ f"""hostTensors.append(' {{"nodeId": { node_id } , "tensorId": { tensor_id } , "value": ' + str(host_tensor_str_ { tensor_name } ) + ' }}' )"""
210218 )
211219
212- call_body .writeline ("""param = json.dumps(param_dict)""" )
220+ call_body .writeline (
221+ """param = f'{{ \" symInputs\" : [{",".join(symInputs)}], \" hostTensors\" : [{",".join(hostTensors)}] }}'"""
222+ )
213223 call_body .writeline (f"""inputs = [{ ',' .join (graph_input_names )} ]""" )
214224
215225 call_body .writeline (f"""outputs = [{ ',' .join (graph_output_names )} ]""" )
0 commit comments