40
40
import org .apache .lucene .util .automaton .RegExp ;
41
41
import org .opensearch .common .lucene .BytesRefs ;
42
42
import org .opensearch .common .lucene .Lucene ;
43
- import org .opensearch .common .regex .Regex ;
44
43
import org .opensearch .common .unit .Fuzziness ;
45
44
import org .opensearch .core .xcontent .XContentParser ;
46
45
import org .opensearch .index .analysis .IndexAnalyzers ;
@@ -430,22 +429,27 @@ public Query wildcardQuery(String value, MultiTermQuery.RewriteMethod method, bo
430
429
finalValue = value ;
431
430
}
432
431
Predicate <String > matchPredicate ;
433
- if (value .contains ("?" )) {
434
- Automaton automaton = WildcardQuery .toAutomaton (new Term (name (), finalValue ));
435
- CompiledAutomaton compiledAutomaton = new CompiledAutomaton (automaton );
432
+ Automaton automaton = WildcardQuery .toAutomaton (new Term (name (), finalValue ));
433
+ CompiledAutomaton compiledAutomaton = new CompiledAutomaton (automaton );
434
+ if (compiledAutomaton .type == CompiledAutomaton .AUTOMATON_TYPE .SINGLE ) {
435
+ // when type equals SINGLE, #compiledAutomaton.runAutomaton is null
436
436
matchPredicate = s -> {
437
437
if (caseInsensitive ) {
438
438
s = s .toLowerCase (Locale .ROOT );
439
439
}
440
- BytesRef valueBytes = BytesRefs .toBytesRef (s );
441
- return compiledAutomaton .runAutomaton .run (valueBytes .bytes , valueBytes .offset , valueBytes .length );
440
+ return s .equals (finalValue );
442
441
};
442
+ } else if (compiledAutomaton .type == CompiledAutomaton .AUTOMATON_TYPE .ALL ) {
443
+ return existsQuery (context );
444
+ } else if (compiledAutomaton .type == CompiledAutomaton .AUTOMATON_TYPE .NONE ) {
445
+ return new MatchNoDocsQuery ("Wildcard expression matches nothing" );
443
446
} else {
444
447
matchPredicate = s -> {
445
448
if (caseInsensitive ) {
446
449
s = s .toLowerCase (Locale .ROOT );
447
450
}
448
- return Regex .simpleMatch (finalValue , s );
451
+ BytesRef valueBytes = BytesRefs .toBytesRef (s );
452
+ return compiledAutomaton .runAutomaton .run (valueBytes .bytes , valueBytes .offset , valueBytes .length );
449
453
};
450
454
}
451
455
@@ -468,22 +472,30 @@ public Query wildcardQuery(String value, MultiTermQuery.RewriteMethod method, bo
468
472
// Package-private for testing
469
473
static Set <String > getRequiredNGrams (String value ) {
470
474
Set <String > terms = new HashSet <>();
475
+
476
+ if (value .isEmpty ()) {
477
+ return terms ;
478
+ }
479
+
471
480
int pos = 0 ;
481
+ String rawSequence = null ;
472
482
String currentSequence = null ;
473
483
if (!value .startsWith ("?" ) && !value .startsWith ("*" )) {
474
484
// Can add prefix term
475
- currentSequence = getNonWildcardSequence (value , 0 );
485
+ rawSequence = getNonWildcardSequence (value , 0 );
486
+ currentSequence = performEscape (rawSequence );
476
487
if (currentSequence .length () == 1 ) {
477
488
terms .add (new String (new char [] { 0 , currentSequence .charAt (0 ) }));
478
489
} else {
479
490
terms .add (new String (new char [] { 0 , currentSequence .charAt (0 ), currentSequence .charAt (1 ) }));
480
491
}
481
492
} else {
482
493
pos = findNonWildcardSequence (value , pos );
483
- currentSequence = getNonWildcardSequence (value , pos );
494
+ rawSequence = getNonWildcardSequence (value , pos );
484
495
}
485
496
while (pos < value .length ()) {
486
- boolean isEndOfValue = pos + currentSequence .length () == value .length ();
497
+ boolean isEndOfValue = pos + rawSequence .length () == value .length ();
498
+ currentSequence = performEscape (rawSequence );
487
499
if (!currentSequence .isEmpty () && currentSequence .length () < 3 && !isEndOfValue && pos > 0 ) {
488
500
// If this is a prefix or suffix of length < 3, then we already have a longer token including the anchor.
489
501
terms .add (currentSequence );
@@ -502,16 +514,16 @@ static Set<String> getRequiredNGrams(String value) {
502
514
terms .add (new String (new char [] { a , b , 0 }));
503
515
}
504
516
}
505
- pos = findNonWildcardSequence (value , pos + currentSequence .length ());
506
- currentSequence = getNonWildcardSequence (value , pos );
517
+ pos = findNonWildcardSequence (value , pos + rawSequence .length ());
518
+ rawSequence = getNonWildcardSequence (value , pos );
507
519
}
508
520
return terms ;
509
521
}
510
522
511
523
private static String getNonWildcardSequence (String value , int startFrom ) {
512
524
for (int i = startFrom ; i < value .length (); i ++) {
513
525
char c = value .charAt (i );
514
- if (c == '?' || c == '*' ) {
526
+ if (( c == '?' || c == '*' ) && ( i == 0 || value . charAt ( i - 1 ) != '\\' ) ) {
515
527
return value .substring (startFrom , i );
516
528
}
517
529
}
@@ -529,6 +541,22 @@ private static int findNonWildcardSequence(String value, int startFrom) {
529
541
return value .length ();
530
542
}
531
543
544
+ private static String performEscape (String str ) {
545
+ StringBuilder sb = new StringBuilder ();
546
+ for (int i = 0 ; i < str .length (); i ++) {
547
+ if (str .charAt (i ) == '\\' && (i + 1 ) < str .length ()) {
548
+ char c = str .charAt (i + 1 );
549
+ if (c == '*' || c == '?' ) {
550
+ i ++;
551
+ }
552
+ }
553
+ sb .append (str .charAt (i ));
554
+ }
555
+ assert !sb .toString ().contains ("\\ *" );
556
+ assert !sb .toString ().contains ("\\ ?" );
557
+ return sb .toString ();
558
+ }
559
+
532
560
@ Override
533
561
public Query regexpQuery (
534
562
String value ,
0 commit comments