@@ -275,22 +275,13 @@ def register_op_input(self, arg: Any, node_id: int, port_id: int, op_meta: OpMet
275
275
:param node_id: Id if operation node.
276
276
:param port_id: Port id of input argument.
277
277
:param op_meta: Metadata about the operation.
278
- :return: Descriptor of the input. For a Tensor, this is a `TensorMeta` object.
279
- For a collection of Tensors, a collection of `TensorMeta` objects is returned .
280
- For other types, the original input `arg` is returned as-is.
278
+ :return: Descriptor of the input.
279
+ For a Tensor, this is a `TensorMeta` object .
280
+ For other types, the original input `arg` is returned as-is.
281
281
"""
282
282
if isinstance (arg , torch .Tensor ):
283
283
self .register_op_input_tensor (arg , node_id , port_id , op_meta )
284
284
return TensorMeta .from_tensor (arg )
285
- elif isinstance (arg , (list , tuple , set )):
286
- op_attr = []
287
- for x in arg :
288
- if isinstance (x , torch .Tensor ):
289
- self .register_op_input_tensor (x , node_id , port_id , op_meta )
290
- op_attr .append (TensorMeta .from_tensor (x ))
291
- else :
292
- op_attr .append (x )
293
- return op_attr
294
285
return arg
295
286
296
287
def register_op_node (self , args : Tuple [Any ], kwargs : Dict [str , Any ], op_meta : OpMeta ) -> None :
@@ -312,13 +303,30 @@ def register_op_node(self, args: Tuple[Any], kwargs: Dict[str, Any], op_meta: Op
312
303
313
304
op_attrs = []
314
305
op_kwargs = {}
315
- for port_id , arg in enumerate (args ):
316
- op_attr = self .register_op_input (arg , node_id , port_id , op_meta )
317
- op_attrs .append (op_attr )
318
-
319
- for port_id , (name , arg ) in enumerate (kwargs .items (), start = len (args )):
320
- op_attr = self .register_op_input (arg , node_id , port_id , op_meta )
321
- op_kwargs [name ] = op_attr
306
+ port_id = 0
307
+
308
+ for value in args :
309
+ if isinstance (value , (list , tuple )) and all (isinstance (v , torch .Tensor ) for v in value ):
310
+ list_attr = [None ] * len (value )
311
+ for idx , tensor in enumerate (value ):
312
+ op_attr = self .register_op_input (tensor , node_id , port_id , op_meta )
313
+ list_attr [idx ] = op_attr
314
+ port_id += 1
315
+ op_attrs .append (tuple (list_attr ) if isinstance (value , tuple ) else list_attr )
316
+ else :
317
+ op_attr = self .register_op_input (value , node_id , port_id , op_meta )
318
+ op_attrs .append (op_attr )
319
+ port_id += 1
320
+
321
+ for kw_name , value in kwargs .items ():
322
+ if isinstance (value , (list , tuple )) and all (isinstance (v , torch .Tensor ) for v in value ):
323
+ op_kwargs [kw_name ] = [None ] * len (value )
324
+ for tensor_idx , tensor in enumerate (value ):
325
+ op_kwargs [kw_name ][tensor_idx ] = self .register_op_input (value , node_id , port_id , op_meta )
326
+ port_id += 1
327
+ else :
328
+ op_kwargs [kw_name ] = self .register_op_input (value , node_id , port_id , op_meta )
329
+ port_id += 1
322
330
323
331
self .graph .add_node (
324
332
node_id ,
0 commit comments