Skip to content

Commit 31bfa59

Browse files
shengfukevinpytorchmergebot
authored andcommitted
Capture primitive data type arguments for profiling python_function (pytorch#120949)
RECORD_FUNCTION in python_function only captures argument that is a Tensor. However, it is very common for user to use non tensor arguments in custom ops, for example, sequence length in GPT attention custom op. My previous PR tries to capture all non-tensor arguments, it turned out in some cases, it is very expensive. This PR is to support primitive (or its container) arguments in RECORD_FUNCTION. Pull Request resolved: pytorch#120949 Approved by: https://github.com/soulitzer
1 parent 5680f56 commit 31bfa59

File tree

2 files changed

+78
-19
lines changed

2 files changed

+78
-19
lines changed

torch/csrc/autograd/python_function.cpp

+18-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <torch/csrc/jit/ir/ir.h>
3030
#include <torch/csrc/jit/python/pybind_utils.h>
3131
#include <torch/csrc/jit/python/python_tracer.h>
32+
#include <torch/csrc/profiler/api.h>
3233
#include <torch/csrc/utils/python_strings.h>
3334
#include <torch/csrc/utils/tensor_dtypes.h>
3435

@@ -857,6 +858,8 @@ static std::unordered_set<at::TensorImpl*> _parse_non_differentiable(
857858
struct UnpackedInput {
858859
THPObjectPtr input_tuple;
859860
variable_list input_vars;
861+
// record_function_inputs is for RECORD_FUNCTION only
862+
std::vector<c10::IValue> record_function_inputs;
860863
};
861864

862865
struct InputFlags {
@@ -874,6 +877,9 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject* args) {
874877
auto num_args = PyTuple_GET_SIZE(args);
875878
unpacked.input_tuple = PyTuple_New(num_args);
876879
flags.needs_input_grad = PyTuple_New(num_args);
880+
bool profiler_need_input = torch::autograd::profiler::profilerEnabled() &&
881+
torch::autograd::profiler::getProfilerConfig().report_input_shapes;
882+
877883
for (const auto i : c10::irange(num_args)) {
878884
PyObject* arg = PyTuple_GET_ITEM(args, i);
879885

@@ -889,12 +895,23 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject* args) {
889895
}
890896
Py_INCREF(Py_False);
891897
PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False);
898+
899+
if (profiler_need_input) {
900+
// The following conversion from PyObject to IValue is expensive
901+
// Only do it if profiler is enabled and needs input shapes
902+
auto match = torch::jit::tryToInferPrimitiveType(arg);
903+
if (match.success()) {
904+
unpacked.record_function_inputs.push_back(
905+
torch::jit::toIValue(arg, match.type()));
906+
}
907+
}
892908
} else {
893909
const auto& tensor = THPVariable_Unpack(arg);
894910
unpacked.input_vars.push_back(tensor);
895911
PyObject* needs_grad = tensor.requires_grad() ? Py_True : Py_False;
896912
Py_INCREF(needs_grad);
897913
PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad);
914+
unpacked.record_function_inputs.emplace_back(tensor);
898915
}
899916
Py_INCREF(arg);
900917
PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg);
@@ -1253,8 +1270,7 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
12531270
// before context has been allocated.
12541271
RECORD_FUNCTION(
12551272
((PyTypeObject*)cls)->tp_name,
1256-
std::vector<c10::IValue>(
1257-
unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()),
1273+
unpacked_input.record_function_inputs,
12581274
seq_id);
12591275

12601276
const auto& functorch_tls = at::functorch::functorchTLSAccessor();

torch/csrc/jit/python/pybind_utils.h

+60-17
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ inline c10::optional<TypePtr> unifyOrInitializeType(
359359

360360
using InferredType = c10::InferredType;
361361

362-
InferredType tryToInferContainerType(py::handle input);
362+
InferredType tryToInferContainerType(py::handle input, bool primitiveTypeOnly);
363363

364364
// Try to infer the type of a Python object
365365
// The type cannot be inferred if:
@@ -496,17 +496,44 @@ inline InferredType tryToInferType(py::handle input) {
496496
}
497497

498498
// Try container types
499-
return tryToInferContainerType(input);
499+
return tryToInferContainerType(input, false);
500500
}
501501

502-
inline InferredType tryToInferContainerType(py::handle input) {
502+
// This function is similar to tryToInferType, but it only tries to infer
503+
// primitive types (int, float, bool, complex) or nested container of primitive
504+
// types.
505+
inline InferredType tryToInferPrimitiveType(py::handle input) {
506+
if (input.is_none()) {
507+
return InferredType(NoneType::get());
508+
}
509+
510+
// Only primitive data type
511+
if (py::isinstance<py::bool_>(input)) {
512+
return InferredType(BoolType::get());
513+
// NOLINTNEXTLINE(bugprone-branch-clone)
514+
} else if (py::isinstance<py::int_>(input)) {
515+
return InferredType(IntType::get());
516+
} else if (py::isinstance<py::float_>(input)) {
517+
return InferredType(FloatType::get());
518+
} else if (PyComplex_CheckExact(input.ptr())) {
519+
return InferredType(ComplexType::get());
520+
}
521+
522+
// Try container types
523+
return tryToInferContainerType(input, true);
524+
}
525+
526+
inline InferredType tryToInferContainerType(
527+
py::handle input,
528+
bool primitiveTypeOnly = false) {
503529
if (six::isTuple(input)) {
504530
py::tuple tuple = py::cast<py::tuple>(input);
505531
std::vector<TypePtr> element_types;
506532
element_types.reserve(tuple.size());
507533

508534
for (py::handle elem : tuple) {
509-
auto type_match = tryToInferType(elem);
535+
auto type_match = primitiveTypeOnly ? tryToInferPrimitiveType(elem)
536+
: tryToInferType(elem);
510537
if (type_match.success()) {
511538
element_types.push_back(type_match.type());
512539
} else {
@@ -528,7 +555,9 @@ inline InferredType tryToInferContainerType(py::handle input) {
528555

529556
for (auto entry : dict) {
530557
// Try to infer the key type and unify it with the existing one
531-
auto entry_key_type_match = tryToInferType(entry.first);
558+
auto entry_key_type_match = primitiveTypeOnly
559+
? tryToInferPrimitiveType(entry.first)
560+
: tryToInferType(entry.first);
532561
if (!entry_key_type_match.success()) {
533562
return entry_key_type_match.reason();
534563
}
@@ -543,7 +572,9 @@ inline InferredType tryToInferContainerType(py::handle input) {
543572
}
544573

545574
// Try to infer the value type and unify it with the existing one
546-
auto entry_value_type_match = tryToInferType(entry.second);
575+
auto entry_value_type_match = primitiveTypeOnly
576+
? tryToInferPrimitiveType(entry.second)
577+
: tryToInferType(entry.second);
547578
if (!entry_value_type_match.success()) {
548579
return entry_value_type_match.reason();
549580
}
@@ -571,7 +602,9 @@ inline InferredType tryToInferContainerType(py::handle input) {
571602

572603
TypePtr element_type = nullptr;
573604
for (auto elem : list) {
574-
auto element_type_match = tryToInferType(elem);
605+
auto element_type_match = primitiveTypeOnly
606+
? tryToInferPrimitiveType(elem)
607+
: tryToInferType(elem);
575608
if (!element_type_match.success()) {
576609
return InferredType(c10::str(
577610
"Could not infer type of list element: ",
@@ -590,16 +623,26 @@ inline InferredType tryToInferContainerType(py::handle input) {
590623
}
591624
return InferredType(ListType::create(element_type));
592625
} else {
593-
// TODO: this message is not correct anymore, since this InferredType is
594-
// used from a bunch of circumstances unrelated to tracing. We can re-use
595-
// this instead of the attribute_failure stuff in concreteType
596-
return InferredType(c10::str(
597-
"Only tensors and (possibly nested) tuples of tensors, lists, or dicts",
598-
"are supported ",
599-
"as inputs or outputs of traced functions",
600-
", but instead got value of type ",
601-
py::str(input.get_type().attr("__name__")),
602-
"."));
626+
if (primitiveTypeOnly) {
627+
return InferredType(c10::str(
628+
"Only tuple, list, or dict (possibly nested) of primitive types (bool, float, int, complex)",
629+
"are supported ",
630+
"as inputs or outputs of traced functions",
631+
", but instead got value of type ",
632+
py::str(input.get_type().attr("__name__")),
633+
"."));
634+
} else {
635+
// TODO: this message is not correct anymore, since this InferredType is
636+
// used from a bunch of circumstances unrelated to tracing. We can re-use
637+
// this instead of the attribute_failure stuff in concreteType
638+
return InferredType(c10::str(
639+
"Only tensors and (possibly nested) tuples of tensors, lists, or dicts",
640+
"are supported ",
641+
"as inputs or outputs of traced functions",
642+
", but instead got value of type ",
643+
py::str(input.get_type().attr("__name__")),
644+
"."));
645+
}
603646
}
604647
}
605648

0 commit comments

Comments
 (0)