@@ -359,7 +359,7 @@ inline c10::optional<TypePtr> unifyOrInitializeType(
359
359
360
360
using InferredType = c10::InferredType;
361
361
362
- InferredType tryToInferContainerType (py::handle input);
362
+ InferredType tryToInferContainerType (py::handle input, bool primitiveTypeOnly );
363
363
364
364
// Try to infer the type of a Python object
365
365
// The type cannot be inferred if:
@@ -496,17 +496,44 @@ inline InferredType tryToInferType(py::handle input) {
496
496
}
497
497
498
498
// Try container types
499
- return tryToInferContainerType (input);
499
+ return tryToInferContainerType (input, false );
500
500
}
501
501
502
- inline InferredType tryToInferContainerType (py::handle input) {
502
+ // This function is similar to tryToInferType, but it only tries to infer
503
+ // primitive types (int, float, bool, complex) or nested container of primitive
504
+ // types.
505
+ inline InferredType tryToInferPrimitiveType (py::handle input) {
506
+ if (input.is_none ()) {
507
+ return InferredType (NoneType::get ());
508
+ }
509
+
510
+ // Only primitive data type
511
+ if (py::isinstance<py::bool_>(input)) {
512
+ return InferredType (BoolType::get ());
513
+ // NOLINTNEXTLINE(bugprone-branch-clone)
514
+ } else if (py::isinstance<py::int_>(input)) {
515
+ return InferredType (IntType::get ());
516
+ } else if (py::isinstance<py::float_>(input)) {
517
+ return InferredType (FloatType::get ());
518
+ } else if (PyComplex_CheckExact (input.ptr ())) {
519
+ return InferredType (ComplexType::get ());
520
+ }
521
+
522
+ // Try container types
523
+ return tryToInferContainerType (input, true );
524
+ }
525
+
526
+ inline InferredType tryToInferContainerType (
527
+ py::handle input,
528
+ bool primitiveTypeOnly = false ) {
503
529
if (six::isTuple (input)) {
504
530
py::tuple tuple = py::cast<py::tuple>(input);
505
531
std::vector<TypePtr> element_types;
506
532
element_types.reserve (tuple.size ());
507
533
508
534
for (py::handle elem : tuple) {
509
- auto type_match = tryToInferType (elem);
535
+ auto type_match = primitiveTypeOnly ? tryToInferPrimitiveType (elem)
536
+ : tryToInferType (elem);
510
537
if (type_match.success ()) {
511
538
element_types.push_back (type_match.type ());
512
539
} else {
@@ -528,7 +555,9 @@ inline InferredType tryToInferContainerType(py::handle input) {
528
555
529
556
for (auto entry : dict) {
530
557
// Try to infer the key type and unify it with the existing one
531
- auto entry_key_type_match = tryToInferType (entry.first );
558
+ auto entry_key_type_match = primitiveTypeOnly
559
+ ? tryToInferPrimitiveType (entry.first )
560
+ : tryToInferType (entry.first );
532
561
if (!entry_key_type_match.success ()) {
533
562
return entry_key_type_match.reason ();
534
563
}
@@ -543,7 +572,9 @@ inline InferredType tryToInferContainerType(py::handle input) {
543
572
}
544
573
545
574
// Try to infer the value type and unify it with the existing one
546
- auto entry_value_type_match = tryToInferType (entry.second );
575
+ auto entry_value_type_match = primitiveTypeOnly
576
+ ? tryToInferPrimitiveType (entry.second )
577
+ : tryToInferType (entry.second );
547
578
if (!entry_value_type_match.success ()) {
548
579
return entry_value_type_match.reason ();
549
580
}
@@ -571,7 +602,9 @@ inline InferredType tryToInferContainerType(py::handle input) {
571
602
572
603
TypePtr element_type = nullptr ;
573
604
for (auto elem : list) {
574
- auto element_type_match = tryToInferType (elem);
605
+ auto element_type_match = primitiveTypeOnly
606
+ ? tryToInferPrimitiveType (elem)
607
+ : tryToInferType (elem);
575
608
if (!element_type_match.success ()) {
576
609
return InferredType (c10::str (
577
610
" Could not infer type of list element: " ,
@@ -590,16 +623,26 @@ inline InferredType tryToInferContainerType(py::handle input) {
590
623
}
591
624
return InferredType (ListType::create (element_type));
592
625
} else {
593
- // TODO: this message is not correct anymore, since this InferredType is
594
- // used from a bunch of circumstances unrelated to tracing. We can re-use
595
- // this instead of the attribute_failure stuff in concreteType
596
- return InferredType (c10::str (
597
- " Only tensors and (possibly nested) tuples of tensors, lists, or dicts" ,
598
- " are supported " ,
599
- " as inputs or outputs of traced functions" ,
600
- " , but instead got value of type " ,
601
- py::str (input.get_type ().attr (" __name__" )),
602
- " ." ));
626
+ if (primitiveTypeOnly) {
627
+ return InferredType (c10::str (
628
+ " Only tuple, list, or dict (possibly nested) of primitive types (bool, float, int, complex)" ,
629
+ " are supported " ,
630
+ " as inputs or outputs of traced functions" ,
631
+ " , but instead got value of type " ,
632
+ py::str (input.get_type ().attr (" __name__" )),
633
+ " ." ));
634
+ } else {
635
+ // TODO: this message is not correct anymore, since this InferredType is
636
+ // used from a bunch of circumstances unrelated to tracing. We can re-use
637
+ // this instead of the attribute_failure stuff in concreteType
638
+ return InferredType (c10::str (
639
+ " Only tensors and (possibly nested) tuples of tensors, lists, or dicts" ,
640
+ " are supported " ,
641
+ " as inputs or outputs of traced functions" ,
642
+ " , but instead got value of type " ,
643
+ py::str (input.get_type ().attr (" __name__" )),
644
+ " ." ));
645
+ }
603
646
}
604
647
}
605
648
0 commit comments