@@ -488,7 +488,8 @@ TEST(NNOpsTest, FractionalPool_ShapeFn) {
488
488
.Finalize (&op.node_def ));
489
489
};
490
490
491
- set_op (std::vector<float >{2 .0f , 1 , 1 / 1 .5f , 1 / 2 .0f });
491
+ // pooling_ratio must >= 1.0
492
+ set_op (std::vector<float >{2 .0f , 1 , 1 .5f , 4 .0f });
492
493
493
494
// Rank check.
494
495
INFER_ERROR (" must be rank 4" , op, " [?,?,?]" );
@@ -497,11 +498,11 @@ TEST(NNOpsTest, FractionalPool_ShapeFn) {
497
498
INFER_OK (op, " ?" , " [?,?,?,?];[?];[?]" );
498
499
INFER_OK (op, " [?,?,?,?]" , " [?,?,?,?];[?];[?]" );
499
500
500
- INFER_OK (op, " [10,20,30,40]" , " [5,20,45,80 ];[20];[45 ]" );
501
- INFER_OK (op, " [?,20,30,40]" , " [?,20,45,80 ];[20];[45 ]" );
502
- INFER_OK (op, " [10,?,30,40]" , " [5,?,45,80 ];[?];[45 ]" );
503
- INFER_OK (op, " [10,20,?,40]" , " [5,20,?,80 ];[20];[?]" );
504
- INFER_OK (op, " [10,20,30,?]" , " [5,20,45 ,?];[20];[45 ]" );
501
+ INFER_OK (op, " [10,20,30,40]" , " [5,20,20,10 ];[20];[20 ]" );
502
+ INFER_OK (op, " [?,20,30,40]" , " [?,20,20,10 ];[20];[20 ]" );
503
+ INFER_OK (op, " [10,?,30,40]" , " [5,?,20,10 ];[?];[20 ]" );
504
+ INFER_OK (op, " [10,20,?,40]" , " [5,20,?,10 ];[20];[?]" );
505
+ INFER_OK (op, " [10,20,30,?]" , " [5,20,20 ,?];[20];[20 ]" );
505
506
506
507
// Wrong number of values for pooling_ratio.
507
508
set_op (std::vector<float >{.5 , 1.0 , 1.5 });
0 commit comments