-
Notifications
You must be signed in to change notification settings - Fork 533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[onnx_importer.py] Fix dim_value None not correctly processed and missing Float8E4M3FNUZType. #4037
base: main
Are you sure you want to change the base?
Conversation
Fix Float8E4M3FNUZType.
@@ -707,13 +704,14 @@ def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: | |||
|
|||
tt = tp.tensor_type | |||
if tt.elem_type: | |||
if not tt.shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is never None (protobuf default initializes), but rather an empty TensorShapeProto, which corresponds to a valid shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also going to take a look at the CI failure and let you know what I figure out.
temp_inferred_file = temp_dir / "inferred.onnx" | ||
onnx.save(raw_model, temp_raw_file, save_as_external_data=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section precisely occurs when the provided model is large, so saving a temp file would be expensive. I'd prefer to only do this if it had actually been modified, so perhaps add a bool to track if the model got modified by previous arg specifications, and only do this if so.
I'm also concerned about not saving external data in this case, since this is exactly when we would be exceeding the 2gb protobuf limit.
Running
With this patch:Crashes with:
IR dump: module {
func.func @main_graph(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.7.0"} {
%none = torch.constant.none
%0 = torch.operator "onnx.Shape"(%arg1) : (!torch.vtensor<[],f32>) -> !torch.vtensor<[?],si64>
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%4 = torch.operator "onnx.Slice"(%0, %2, %3, %1) : (!torch.vtensor<[?],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64>
%5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%6 = torch.operator "onnx.Concat"(%4, %5) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[?],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64>
%7 = torch.operator "onnx.Reshape"(%arg1, %6) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],f32>
%8 = torch.operator "onnx.Mul"(%arg0, %7) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32>
return %8 : !torch.vtensor<[?],f32>
}
}
{-#
dialect_resources: {
builtin: {
_: "0x080000000000000000000000",
__1: "0x080000000000000000000000",
__2: "0x08000000FFFFFFFFFFFFFFFF",
__3: "0x08000000FFFFFFFFFFFFFFFF"
}
}
#-} On main:module {
func.func @main_graph(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.7.0"} {
%none = torch.constant.none
%0 = torch.operator "onnx.Shape"(%arg1) : (!torch.vtensor<[],f32>) -> !torch.vtensor<[0],si64>
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%4 = torch.operator "onnx.Slice"(%0, %2, %3, %1) : (!torch.vtensor<[0],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[0],si64>
%5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%6 = torch.operator "onnx.Concat"(%4, %5) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[0],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64>
%7 = torch.operator "onnx.Reshape"(%arg1, %6) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],f32>
%8 = torch.operator "onnx.Mul"(%arg0, %7) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32>
return %8 : !torch.vtensor<[?],f32>
}
}
{-#
dialect_resources: {
builtin: {
_: "0x080000000000000000000000",
__1: "0x080000000000000000000000",
__2: "0x08000000FFFFFFFFFFFFFFFF",
__3: "0x08000000FFFFFFFFFFFFFFFF"
}
}
#-} I'm going to print out the onnx model and take a look at the shape information to see what is going on. |
Looking at a printout of the value info for the first shape node: value_info {
name: "/Shape_output_0"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 0
}
}
}
}
} So it looks like since the I'm not sure if I have a good idea of how to fix this. Do you know if shape inference will fill the value info fields with Or is there python protobuf api for checking if a field is not explicitly set? ` |
As per title. Changes tested on SHARK-TestSuite's
alt_e2eshark
.