Skip to content

Commit 5f9b432

Browse files
cyyeverpytorchmergebot
authored andcommitted
[2/N] Replace std::tie with structural binding (pytorch#119879)
This PR follows pytorch#119774, Python generated code was changed to use structural binding. Pull Request resolved: pytorch#119879 Approved by: https://github.com/albanD
1 parent 9ff9798 commit 5f9b432

24 files changed

+60
-164
lines changed

tools/autograd/gen_trace_type.py

+3-15
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Dict, List, Sequence, Union
33

44
from torchgen.api import cpp
5-
65
from torchgen.api.types import DispatcherSignature
76
from torchgen.code_template import CodeTemplate
87
from torchgen.context import with_native_function
@@ -376,22 +375,11 @@ def format_postrecord_trace(f: NativeFunction) -> str:
376375
return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
377376

378377

379-
def declare_returned_variables(f: NativeFunction) -> str:
380-
modifies_arguments = f.func.kind() in (SchemaKind.inplace, SchemaKind.out)
381-
if modifies_arguments:
382-
return ""
383-
if len(f.func.returns) == 1:
384-
return ""
385-
types = [cpp.return_type(r, symint=True) for r in f.func.returns]
386-
names = cpp.return_names(f)
387-
return "\n".join(f"{type.cpp_type()} {name};" for type, name in zip(types, names))
388-
389-
390378
def tie_return_values(f: NativeFunction) -> str:
391379
if len(f.func.returns) == 1:
392380
return f'auto {f.func.returns[0].name or "result"}'
393381
names = cpp.return_names(f)
394-
return f'std::tie({", ".join(names)})'
382+
return f'auto [{", ".join(names)}]'
395383

396384

397385
def get_return_value(f: NativeFunction) -> str:
@@ -415,7 +403,6 @@ def emit_trace_body(f: NativeFunction) -> List[str]:
415403
trace_body: List[str] = []
416404

417405
trace_body.append(format_prerecord_trace(f))
418-
trace_body.append(declare_returned_variables(f))
419406

420407
dispatcher_sig = DispatcherSignature.from_schema(f.func)
421408
dispatcher_exprs = dispatcher_sig.exprs()
@@ -433,7 +420,8 @@ def emit_trace_body(f: NativeFunction) -> List[str]:
433420
)
434421

435422
# Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
436-
# We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
423+
# We could probably work harder to ensure that the fast variants are
424+
# called instead, but the perf benefit would be minimal.
437425
trace_body.append(
438426
TRACE_DISPATCH.substitute(
439427
assign_return_values=assign_return_values,

tools/autograd/gen_variable_type.py

-2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@
9696
WRAPPER_REGISTRATION,
9797
)
9898
from .gen_trace_type import (
99-
declare_returned_variables,
10099
get_return_value,
101100
MANUAL_AUTOGRAD_AND_TRACER,
102101
MANUAL_BACKEND,
@@ -2130,7 +2129,6 @@ def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str:
21302129
body.extend(emit_check_inplace())
21312130
body.extend(emit_original_self_definition())
21322131
body.extend(setup_derivative(differentiable_inputs))
2133-
body.append(declare_returned_variables(f))
21342132

21352133
body.append(emit_call(f, unpacked_bindings, try_jit_decomposition))
21362134
if requires_derivative:

torch/csrc/jit/codegen/onednn/kernel.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,7 @@ void LlgaKernel::run(Stack& stack) {
275275
GRAPH_DEBUG("Preparing runtime tensors");
276276
#endif
277277
TensorArgs outputs;
278-
RunArgs runInputs, runOutputs;
279-
std::tie(runInputs, runOutputs) = prepareRunArgs(inputs, outputs);
278+
auto [runInputs, runOutputs] = prepareRunArgs(inputs, outputs);
280279
#ifdef GRAPH_DEBUG_ENABLED
281280
GRAPH_DEBUG("Executing partition");
282281
#endif

torch/csrc/jit/frontend/ir_emitter.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -2892,15 +2892,12 @@ struct to_ir {
28922892

28932893
// If it's a tensor, copy the RHS data into it
28942894
if (sliceable->type()->isSubtypeOf(*TensorType::get())) {
2895-
std::vector<Value*> tensorIndices;
2896-
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
2897-
Value* sliced;
28982895
// Handle multi-dimensional slicing: first emit int/slice indexing
28992896
// TODO: the Python equivalent code has special-cased copy_to
29002897
// broadcasting to match NumPy semantics (see PR#4853). We can't
29012898
// replicate that without knowing the size of the Tensor; so really that
29022899
// code should be moved into the aten function
2903-
std::tie(sliced, tensorIndices) = emitIntAndSliceIndexing(
2900+
auto [sliced, tensorIndices] = emitIntAndSliceIndexing(
29042901
lhs.range(), sliceable, lhs.subscript_exprs());
29052902

29062903
const auto slicedArg = NamedValue(lhs.range(), sliced);

torch/csrc/jit/frontend/source_range.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,7 @@ void SourceRange::print_with_context(
267267

268268
// print out location information
269269
if (auto flc = file_line_col()) {
270-
std::string filename;
271-
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
272-
size_t line, col;
273-
std::tie(filename, line, col) = *flc;
270+
auto [filename, line, col] = *flc;
274271
out << " File \"" << filename << "\", line " << line;
275272
if (!funcname.empty()) {
276273
out << ", in " << funcname;

torch/csrc/jit/ir/ir.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -354,10 +354,7 @@ std::ostream& Node::print(
354354
}
355355
}
356356
if (auto file_line_col = r.file_line_col()) {
357-
std::string filename;
358-
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
359-
size_t line, col;
360-
std::tie(filename, line, col) = *file_line_col;
357+
auto [filename, line, col] = *file_line_col;
361358
out << " # " << filename << ":" << line << ":" << col;
362359
}
363360
}

torch/csrc/jit/mobile/compatibility/model_compatibility.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,7 @@ static uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size);
7575
uint64_t _get_model_bytecode_version(std::istream& in) {
7676
auto orig_pos = in.tellg();
7777
in.seekg(0, in.beg);
78-
std::shared_ptr<char> data;
79-
size_t size = 0;
80-
std::tie(data, size) = get_stream_content(in);
78+
auto [data, size] = get_stream_content(in);
8179
in.seekg(orig_pos, in.beg);
8280
return _get_model_bytecode_version_from_bytes(data.get(), size);
8381
}
@@ -89,9 +87,7 @@ uint64_t _get_model_bytecode_version(const std::string& filename) {
8987

9088
uint64_t _get_model_bytecode_version(
9189
std::shared_ptr<ReadAdapterInterface> rai) {
92-
std::shared_ptr<char> data;
93-
size_t size = 0;
94-
std::tie(data, size) = get_rai_content(rai.get());
90+
auto [data, size] = get_rai_content(rai.get());
9591
return _get_model_bytecode_version_from_bytes(data.get(), size);
9692
}
9793

torch/csrc/jit/mobile/debug_info.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ MobileDebugTable::MobileDebugTable(
119119
const c10::string_view suffix(".debug_pkl");
120120
for (const auto& record_name : record_names) {
121121
if (c10::string_view(record_name).ends_with(suffix)) {
122-
at::DataPtr debug_data;
123-
size_t debug_size{0};
124-
std::tie(debug_data, debug_size) = reader->getRecord(record_name);
122+
auto [debug_data, debug_size] = reader->getRecord(record_name);
125123
auto ivalueTuple = jit::unpickle(
126124
reinterpret_cast<const char*>(debug_data.get()),
127125
debug_size,
@@ -157,9 +155,7 @@ MobileDebugTable::MobileDebugTable(
157155
}
158156
const std::string callstack_debug_file("callstack_debug_map.pkl");
159157
if (reader->hasRecord("callstack_debug_map.pkl")) {
160-
at::DataPtr callstack_data;
161-
size_t callstack_data_size{0};
162-
std::tie(callstack_data, callstack_data_size) =
158+
auto [callstack_data, callstack_data_size] =
163159
reader->getRecord(callstack_debug_file);
164160
CallStackDebugInfoUnpickler unpickler;
165161
callstack_ptr_map_ = unpickler.unpickle(

torch/csrc/jit/mobile/flatbuffer_loader.cpp

+4-12
Original file line numberDiff line numberDiff line change
@@ -827,24 +827,18 @@ mobile::Module load_mobile_module_from_file(
827827
const std::string& filename,
828828
c10::optional<c10::Device> device,
829829
ExtraFilesMap* extra_files) {
830-
std::shared_ptr<char> data;
831-
size_t size = 0;
832-
std::tie(data, size) = get_file_content(filename.c_str());
830+
auto [data, size] = get_file_content(filename.c_str());
833831
return parse_and_initialize_mobile_module(
834832
std::move(data), size, device, extra_files);
835833
}
836834

837835
uint64_t get_bytecode_version(std::istream& in) {
838-
std::shared_ptr<char> data;
839-
size_t size = 0;
840-
std::tie(data, size) = get_stream_content(in);
836+
auto [data, size] = get_stream_content(in);
841837
return get_bytecode_version_from_bytes(data.get());
842838
}
843839

844840
uint64_t get_bytecode_version(const std::string& filename) {
845-
std::shared_ptr<char> data;
846-
size_t size = 0;
847-
std::tie(data, size) = get_file_content(filename.c_str());
841+
auto [data, size] = get_file_content(filename.c_str());
848842
return get_bytecode_version_from_bytes(data.get());
849843
}
850844

@@ -893,9 +887,7 @@ mobile::Module load_mobile_module_from_stream_with_copy(
893887
std::istream& in,
894888
c10::optional<at::Device> device,
895889
ExtraFilesMap* extra_files) {
896-
std::shared_ptr<char> data;
897-
size_t size = 0;
898-
std::tie(data, size) = get_stream_content(in);
890+
auto [data, size] = get_stream_content(in);
899891
return parse_and_initialize_mobile_module(
900892
std::move(data), size, device, extra_files);
901893
}

torch/csrc/jit/mobile/import_data.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -269,18 +269,14 @@ static std::map<std::string, at::Tensor> _load_parameters_bytes(
269269
std::map<std::string, at::Tensor> _load_parameters(
270270
std::istream& in,
271271
c10::optional<at::Device> device) {
272-
std::shared_ptr<char> data;
273-
size_t size = 0;
274-
std::tie(data, size) = get_stream_content(in);
272+
auto [data, size] = get_stream_content(in);
275273
return _load_parameters_bytes(std::move(data), size, device);
276274
}
277275

278276
std::map<std::string, at::Tensor> _load_parameters(
279277
const std::string& filename,
280278
c10::optional<at::Device> device) {
281-
std::shared_ptr<char> data;
282-
size_t size = 0;
283-
std::tie(data, size) = get_file_content(filename.c_str());
279+
auto [data, size] = get_file_content(filename.c_str());
284280
return _load_parameters_bytes(std::move(data), size, device);
285281
}
286282

torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,7 @@ void MKLDNNLayerNormOp(Stack& stack, bool inplace) {
300300
auto shape = pop(stack).toDimVector();
301301
auto input = pop(stack).toTensor();
302302

303-
at::Tensor dst, mean, rstd;
304-
std::tie(dst, mean, rstd) =
303+
auto [dst, mean, rstd] =
305304
at::native::mkldnn_layer_norm_last_index_weight_bias_f32(
306305
input, shape, weight, bias, eps, inplace);
307306
push(stack, dst);

torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,7 @@ void FixupONNXLoopBlockInputs(Node* n) {
257257
Value* input_i = block->inputs().at(i);
258258
if (input_i->type()->cast<OptionalType>() &&
259259
!block->outputs().at(i)->type()->cast<OptionalType>()) {
260-
TypePtr merged_type;
261-
bool inferred = false;
262-
std::tie(merged_type, inferred) = MergeInferredType(
260+
auto [merged_type, inferred] = MergeInferredType(
263261
input_i->type()->cast<OptionalType>()->getElementType(),
264262
block->outputs().at(i)->type());
265263
if (inferred) {
@@ -336,9 +334,7 @@ void FixupONNXLoopNodeInputs(Node* node, int opset_version) {
336334
// vice-versa.
337335
if (!input->type()->cast<OptionalType>() && sub_block_input_optional) {
338336
if (!input->type()->cast<NoneType>()) {
339-
TypePtr merged_type;
340-
bool inferred = false;
341-
std::tie(merged_type, inferred) = MergeInferredType(
337+
auto [merged_type, inferred] = MergeInferredType(
342338
sub_block_input_optional->getElementType(), input->type());
343339
if (inferred) {
344340
sub_block_input_optional = OptionalType::create(merged_type);

torch/csrc/jit/passes/onnx/function_extraction.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -877,9 +877,7 @@ std::tuple<FunctionExtractor::scope_ctx_map, node_list> FunctionExtractor::
877877
}
878878

879879
for (auto* sub_b : n->blocks()) {
880-
scope_ctx_map subblock_scope_ctxs;
881-
node_list subblock_no_scope_nlist;
882-
std::tie(subblock_scope_ctxs, subblock_no_scope_nlist) =
880+
auto [subblock_scope_ctxs, subblock_no_scope_nlist] =
883881
PartitionNodesByScope(sub_b);
884882

885883
for (auto& it : subblock_scope_ctxs) {

torch/csrc/jit/passes/onnx/shape_type_inference.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ void MergeInferredTypeAndSetMap(
7777
Value* dest_v,
7878
TypePtr existing_type,
7979
TypePtr inferred_type) {
80-
TypePtr mergedType;
81-
bool inferred;
82-
std::tie(mergedType, inferred) =
83-
MergeInferredType(existing_type, inferred_type);
80+
auto [mergedType, inferred] = MergeInferredType(existing_type, inferred_type);
8481
dest_v->setType(mergedType);
8582
ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred);
8683
}

torch/csrc/jit/python/init.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -1497,9 +1497,7 @@ void initJITBindings(PyObject* module) {
14971497
.def(
14981498
"get_record",
14991499
[](PyTorchStreamReader& self, const std::string& key) {
1500-
at::DataPtr data;
1501-
size_t size = 0;
1502-
std::tie(data, size) = self.getRecord(key);
1500+
auto [data, size] = self.getRecord(key);
15031501
return py::bytes(reinterpret_cast<const char*>(data.get()), size);
15041502
})
15051503
.def(

torch/csrc/jit/python/python_ir.cpp

+20-25
Original file line numberDiff line numberDiff line change
@@ -259,31 +259,26 @@ void initPythonIRBindings(PyObject* module_) {
259259
const std::string& onnx_file_path,
260260
const NodeAttrNameMap& node_attr_to_name) {
261261
std::string graph;
262-
std::shared_ptr<::ONNX_NAMESPACE::ModelProto> model_proto;
263-
RawDataExportMap export_map;
264-
SymbolDimMap symbol_map;
265-
bool val_use_external_data_format = false;
266-
NodeNameMap onnx_node_names;
267-
std::tie(
268-
model_proto,
269-
export_map,
270-
symbol_map,
271-
val_use_external_data_format,
272-
onnx_node_names) =
273-
export_onnx(
274-
g,
275-
initializers,
276-
onnx_opset_version,
277-
dynamic_axes,
278-
defer_weight_export,
279-
operator_export_type,
280-
strip_doc_string,
281-
keep_initializers_as_inputs,
282-
custom_opsets,
283-
add_node_names,
284-
val_use_external_data_format,
285-
onnx_file_path,
286-
node_attr_to_name);
262+
auto
263+
[model_proto,
264+
export_map,
265+
symbol_map,
266+
val_use_external_data_format,
267+
onnx_node_names] =
268+
export_onnx(
269+
g,
270+
initializers,
271+
onnx_opset_version,
272+
dynamic_axes,
273+
defer_weight_export,
274+
operator_export_type,
275+
strip_doc_string,
276+
keep_initializers_as_inputs,
277+
custom_opsets,
278+
add_node_names,
279+
false,
280+
onnx_file_path,
281+
node_attr_to_name);
287282
std::unordered_map<std::string, py::bytes>
288283
python_serialized_export_map;
289284
for (auto& kv : export_map) {

torch/csrc/jit/python/python_sugared_value.cpp

+2-9
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,7 @@ FunctionSchema PythonValue::getSchema(
8282
rets.emplace_back(Argument("0", ret_type, {}, {}, false));
8383
} else {
8484
// Use the provided type signature
85-
std::vector<TypePtr> arg_types;
86-
TypePtr ret_type;
87-
std::tie(arg_types, ret_type) =
85+
auto [arg_types, ret_type] =
8886
py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
8987

9088
// arg_types does not include self but param_names does, so adjust for that
@@ -1022,12 +1020,7 @@ TypePtr registerNamedTuple(
10221020
py::module::import("torch._jit_internal")
10231021
.attr("_get_named_tuple_properties")(obj, loc, py::cpp_function(rcb));
10241022

1025-
std::string unqualName;
1026-
std::vector<std::string> field_names;
1027-
std::vector<TypePtr> field_types;
1028-
std::vector<py::object> objects;
1029-
1030-
std::tie(unqualName, field_names, field_types, objects) = py::cast<std::tuple<
1023+
auto [unqualName, field_names, field_types, objects] = py::cast<std::tuple<
10311024
std::string,
10321025
std::vector<std::string>,
10331026
std::vector<TypePtr>,

torch/csrc/jit/python/python_tree_views.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ struct SourceRangeFactory {
3434
leading_whitespace_chars_(leading_whitespace_chars) {}
3535

3636
SourceRange create(int line, int start_col, int end_col) {
37-
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
38-
size_t start_byte_offset, end_byte_offset;
39-
std::tie(start_byte_offset, end_byte_offset) = line_col_to_byte_offs(
37+
auto [start_byte_offset, end_byte_offset] = line_col_to_byte_offs(
4038
line,
4139
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4240
start_col + leading_whitespace_chars_,

torch/csrc/jit/python/script_init.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -879,9 +879,7 @@ void initJitScriptBindings(PyObject* module) {
879879
},
880880
[](const std::tuple<py::object, std::string>& state_tup)
881881
-> Object {
882-
py::object state;
883-
std::string qualname;
884-
std::tie(state, qualname) = state_tup;
882+
auto [state, qualname] = state_tup;
885883
auto class_type = getCustomClass(qualname);
886884
TORCH_CHECK(
887885
class_type,

torch/csrc/jit/runtime/autodiff.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ static ReverseDetails addReverseInline(Gradient& grad_desc) {
391391
auto it = grad_map.find(v);
392392
if (it == grad_map.end()) {
393393
auto autograd_zero = graph.insertNode(graph.createAutogradZero());
394-
std::tie(it, std::ignore) = grad_map.emplace(v, autograd_zero->output());
394+
it = grad_map.emplace(v, autograd_zero->output()).first;
395395
}
396396
return it->second;
397397
};

0 commit comments

Comments
 (0)