Skip to content

Commit 3c505fb

Browse files
qihqipytorchmergebot
authored andcommittedApr 13, 2022
Expose some functions out of ENABLE_FLATBUFFER.
Summary: Because they don't depend on flatbuffer Test Plan: existing unittests Differential Revision: D35291748 Pull Request resolved: pytorch#75700 Approved by: https://github.com/iseeyuan
1 parent ac8d220 commit 3c505fb

File tree

5 files changed

+22
-16
lines changed

5 files changed

+22
-16
lines changed
 

‎torch/csrc/jit/mobile/import_data.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <caffe2/serialize/inline_container.h>
88
#include <torch/csrc/jit/api/compilation_unit.h>
99
#include <torch/csrc/jit/mobile/file_format.h>
10+
#include <torch/csrc/jit/mobile/import_export_common.h>
1011
#include <torch/csrc/jit/mobile/module.h>
1112
#include <torch/csrc/jit/mobile/observer.h>
1213
#include <torch/csrc/jit/mobile/type_parser.h>
@@ -16,7 +17,6 @@
1617

1718
#if defined(ENABLE_FLATBUFFER)
1819
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
19-
#include <torch/csrc/jit/mobile/import_export_common.h>
2020
#endif // defined(ENABLE_FLATBUFFER)
2121

2222
#include <exception>
@@ -182,7 +182,7 @@ std::map<std::string, at::Tensor> load_parameters_from_zip(
182182
return map;
183183
}
184184

185-
#if defined(ENABLE_FLATBUFFER)
185+
} // namespace
186186

187187
/**
188188
* Extracts the parameter map stored in @p module. Expects a layout
@@ -238,10 +238,6 @@ std::map<std::string, at::Tensor> mobile_module_to_parameter_map(
238238
"' in deserialized mobile::Module");
239239
}
240240

241-
#endif // defined(ENABLE_FLATBUFFER)
242-
243-
} // namespace
244-
245241
std::map<std::string, at::Tensor> _load_parameters(
246242
std::istream& in,
247243
c10::optional<at::Device> device) {

‎torch/csrc/jit/mobile/import_data.h

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ATen/core/TensorBase.h>
44
#include <c10/core/Device.h>
55
#include <c10/util/Optional.h>
6+
#include <torch/csrc/jit/mobile/module.h>
67

78
#include <istream>
89
#include <map>
@@ -29,5 +30,9 @@ TORCH_API std::map<std::string, at::Tensor> _load_parameters(
2930
const std::string& filename,
3031
c10::optional<at::Device> device = c10::nullopt);
3132

33+
// NOTE: Please prefer using _load_parameters over using the function below.
34+
TORCH_API std::map<std::string, at::Tensor> mobile_module_to_parameter_map(
35+
const mobile::Module& module);
36+
3237
} // namespace jit
3338
} // namespace torch

‎torch/csrc/jit/mobile/import_export_common.h

-4
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ namespace torch {
99
namespace jit {
1010
namespace mobile {
1111

12-
#if defined(ENABLE_FLATBUFFER)
13-
1412
namespace internal {
1513
/**
1614
* The name of the mobile::Module attribute which contains saved parameters, as
@@ -20,8 +18,6 @@ namespace internal {
2018
constexpr char kSavedParametersAttributeName[] = "data";
2119
} // namespace internal
2220

23-
#endif // defined(ENABLE_FLATBUFFER)
24-
2521
} // namespace mobile
2622
} // namespace jit
2723
} // namespace torch

‎torch/csrc/jit/mobile/train/export_data.cpp

+3-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/csrc/jit/mobile/train/export_data.h>
22

3+
#include <torch/csrc/jit/mobile/import_export_common.h>
34
#include <torch/csrc/jit/mobile/module.h>
45
#include <torch/csrc/jit/runtime/instruction.h>
56
#include <torch/csrc/jit/serialization/pickler.h>
@@ -12,7 +13,6 @@
1213

1314
#if defined(ENABLE_FLATBUFFER)
1415
#include <flatbuffers/flatbuffers.h>
15-
#include <torch/csrc/jit/mobile/import_export_common.h>
1616
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
1717
#endif // defined(ENABLE_FLATBUFFER)
1818

@@ -76,6 +76,8 @@ class IValuePickler final {
7676
TypeNameUniquer type_name_uniquer_;
7777
};
7878

79+
} // namespace
80+
7981
/**
8082
* Converts a map of named tensors to a c10::Dict.
8183
*/
@@ -88,8 +90,6 @@ c10::Dict<std::string, at::Tensor> tensor_map_to_dict(
8890
return dict;
8991
}
9092

91-
#if defined(ENABLE_FLATBUFFER)
92-
9393
/**
9494
* Returns a Module with a single attribute, with the attribute name specified
9595
* by #internal::kSavedParametersAttributeName, whose value is the provided
@@ -117,9 +117,6 @@ mobile::Module tensor_dict_to_mobile(
117117
return mobile::Module(object, mcu);
118118
}
119119

120-
#endif // defined(ENABLE_FLATBUFFER)
121-
122-
} // namespace
123120
} // namespace mobile
124121

125122
void _save_parameters(

‎torch/csrc/jit/mobile/train/export_data.h

+12
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,17 @@ TORCH_API void _save_parameters(
3333
const std::string& filename,
3434
bool use_flatbuffer = false);
3535

36+
namespace mobile {
37+
38+
// NOTE: Please prefer using _save_parameters directly over using the 2
39+
// functions below.
40+
TORCH_API mobile::Module tensor_dict_to_mobile(
41+
const c10::Dict<std::string, at::Tensor>& dict);
42+
43+
c10::Dict<std::string, at::Tensor> tensor_map_to_dict(
44+
const std::map<std::string, at::Tensor>& map);
45+
46+
} // namespace mobile
47+
3648
} // namespace jit
3749
} // namespace torch

0 commit comments

Comments
 (0)