2
2
// SPDX-License-Identifier: Apache-2.0
3
3
//
4
4
5
-
6
-
7
5
#include " regex_normalization.hpp"
8
6
#include " utils.hpp"
9
7
8
+
10
9
using namespace ov ;
11
10
12
11
@@ -19,32 +18,61 @@ m_global_replace(global_replace) {
19
18
auto replace_pattern_const = as_type_ptr<Constant>(arguments[4 ].get_node_shared_ptr ());
20
19
auto search_pattern_buf = static_cast <const char *>(search_pattern_const->get_data_ptr ());
21
20
auto replace_pattern_buf = static_cast <const char *>(replace_pattern_const->get_data_ptr ());
22
- auto search_pattern = absl::string_view ((const char *)search_pattern_buf, search_pattern_const->get_byte_size ());
23
- m_replace_pattern = absl::string_view ((const char *)replace_pattern_buf, replace_pattern_const->get_byte_size ());
24
- m_search_pattern_re = std::make_shared<re2::RE2>(search_pattern);
21
+ auto search_pattern = absl::string_view (search_pattern_buf, search_pattern_const->get_byte_size ());
22
+ m_replace_pattern = absl::string_view (replace_pattern_buf, replace_pattern_const->get_byte_size ());
23
+
24
+ auto options = re2::RE2::Options ();
25
+ options.set_log_errors (false );
26
+ m_search_pattern_re = std::make_shared<re2::RE2>(search_pattern, options);
27
+
28
+ if (m_search_pattern_re->NumberOfCapturingGroups () == -1 ) {
29
+ // If RE2 was unable to process pattern.
30
+ m_search_pattern_pcre2 = std::make_shared<PCRE2Wrapper>(search_pattern);
31
+ m_search_pattern_re = nullptr ;
32
+ }
33
+
25
34
constructor_validate_and_infer_types ();
26
35
}
27
36
28
37
29
38
RegexNormalization::RegexNormalization (
30
39
const ov::OutputVector& arguments,
31
40
const std::shared_ptr<re2::RE2>& search_pattern_re,
41
+ const std::shared_ptr<PCRE2Wrapper>& search_pattern_pcre2,
32
42
const absl::string_view replace_pattern,
33
43
bool global_replace
34
44
) : ov::op::Op(arguments),
35
45
m_search_pattern_re(search_pattern_re),
46
+ m_search_pattern_pcre2(search_pattern_pcre2),
36
47
m_replace_pattern(replace_pattern),
37
48
m_global_replace(global_replace) {
38
49
39
- if (m_search_pattern_re == nullptr ) {
40
- auto search_pattern_const = as_type_ptr<Constant>(arguments[3 ].get_node_shared_ptr ());
41
- auto replace_pattern_const = as_type_ptr<Constant>(arguments[4 ].get_node_shared_ptr ());
42
- auto search_pattern_buf = static_cast <const char *>(search_pattern_const->get_data_ptr ());
43
- auto replace_pattern_buf = static_cast <const char *>(replace_pattern_const->get_data_ptr ());
44
- auto search_pattern = absl::string_view ((const char *)search_pattern_buf, search_pattern_const->get_byte_size ());
45
- m_replace_pattern = absl::string_view ((const char *)replace_pattern_buf, replace_pattern_const->get_byte_size ());
46
- m_search_pattern_re = std::make_shared<re2::RE2>(search_pattern);
50
+ auto search_pattern_const = as_type_ptr<Constant>(arguments[3 ].get_node_shared_ptr ());
51
+ auto replace_pattern_const = as_type_ptr<Constant>(arguments[4 ].get_node_shared_ptr ());
52
+ const char * search_pattern_buf;
53
+ const char * replace_pattern_buf;
54
+ absl::string_view search_pattern;
55
+
56
+ if (m_search_pattern_re == nullptr || m_search_pattern_pcre2 == nullptr ) {
57
+ search_pattern_buf = static_cast <const char *>(search_pattern_const->get_data_ptr ());
58
+ replace_pattern_buf = static_cast <const char *>(replace_pattern_const->get_data_ptr ());
59
+ search_pattern = absl::string_view (search_pattern_buf, search_pattern_const->get_byte_size ());
60
+ m_replace_pattern = absl::string_view (replace_pattern_buf, replace_pattern_const->get_byte_size ());
47
61
};
62
+
63
+ auto options = re2::RE2::Options ();
64
+ options.set_log_errors (false );
65
+ if (m_search_pattern_re == nullptr ) {
66
+ auto options = re2::RE2::Options ();
67
+ options.set_log_errors (false );
68
+ m_search_pattern_re = std::make_shared<re2::RE2>(search_pattern, options);
69
+ }
70
+
71
+ if (m_search_pattern_re->NumberOfCapturingGroups () == -1 && m_search_pattern_pcre2 == nullptr ) {
72
+ m_search_pattern_pcre2 = std::make_shared<PCRE2Wrapper>(search_pattern);
73
+ m_search_pattern_re = nullptr ;
74
+ }
75
+
48
76
constructor_validate_and_infer_types ();
49
77
}
50
78
@@ -58,24 +86,40 @@ void RegexNormalization::validate_and_infer_types() {
58
86
59
87
60
88
bool RegexNormalization::evaluate (ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
61
- if (m_search_pattern_re == nullptr ) {
62
- auto search_pattern = absl::string_view (inputs[3 ].data <const char >(), inputs[3 ].get_size ());
89
+ absl::string_view search_pattern;
90
+ if (m_search_pattern_re == nullptr || m_search_pattern_pcre2 == nullptr ) {
91
+ search_pattern = absl::string_view (inputs[3 ].data <const char >(), inputs[3 ].get_size ());
63
92
m_replace_pattern = absl::string_view (inputs[4 ].data <const char >(), inputs[4 ].get_size ());
64
- m_search_pattern_re = std::make_shared<re2::RE2>(search_pattern);
65
- };
93
+ }
94
+
95
+ if (m_search_pattern_re == nullptr && m_search_pattern_pcre2 == nullptr ) {
96
+ auto options = re2::RE2::Options ();
97
+ options.set_log_errors (false );
98
+ m_search_pattern_re = std::make_shared<re2::RE2>(search_pattern, options);
99
+ }
100
+
101
+ if ((m_search_pattern_re == nullptr ) || (m_search_pattern_re->NumberOfCapturingGroups () == -1 && m_search_pattern_pcre2 == nullptr )) {
102
+ m_search_pattern_pcre2 = std::make_shared<PCRE2Wrapper>(search_pattern);
103
+ m_search_pattern_re = nullptr ;
104
+ }
105
+
66
106
return evaluate_normalization_helper (
67
107
outputs, inputs,
68
- [this ](const std::string& str) {
69
- // FIXME: if regex is not valid re2, return string without changing (use another regex engine)
70
- if (m_search_pattern_re->NumberOfCapturingGroups () == -1 )
71
- return str;
72
-
108
+ [this ](const std::string& str) -> std::string {
73
109
std::string result = str;
74
- if (m_global_replace) {
75
- re2::RE2::GlobalReplace (&result, *m_search_pattern_re, m_replace_pattern);
110
+
111
+ // Use RE2 where possible, and fallback to PCRE2 if RE2 was not able to process.
112
+ if (m_search_pattern_re) {
113
+ if (m_global_replace) {
114
+ re2::RE2::GlobalReplace (&result, *m_search_pattern_re, m_replace_pattern);
115
+ } else {
116
+ re2::RE2::Replace (&result, *m_search_pattern_re, m_replace_pattern);
117
+ };
118
+ return result;
119
+ } else if (m_search_pattern_pcre2) {
120
+ return m_search_pattern_pcre2->substitute (result, m_replace_pattern, m_global_replace);
76
121
} else {
77
- re2::RE2::Replace (&result, *m_search_pattern_re, m_replace_pattern);
78
- };
79
- return result;
122
+ return result;
123
+ }
80
124
});
81
125
}
0 commit comments