Skip to content

Commit 89fc267

Browse files
Fix equals and hashCode methods for KNNQuery and KNNQueryBuilder (#1397)
Signed-off-by: panguixin <panguixin@bytedance.com>
1 parent fcbfef1 commit 89fc267

File tree

4 files changed

+22
-10
lines changed

4 files changed

+22
-10
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2727
* Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317)
2828
* Fix script score queries not getting cached [#1367](https://github.com/opensearch-project/k-NN/pull/1367)
2929
* Fix KNNScorer to apply boost [#1403](https://github.com/opensearch-project/k-NN/pull/1403)
30+
* Fix equals and hashCode methods for KNNQuery and KNNQueryBuilder [#1397](https://github.com/opensearch-project/k-NN/pull/1397)
3031
### Infrastructure
3132
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
3233
* Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307)

src/main/java/org/opensearch/knn/index/query/KNNQuery.java

+9-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
package org.opensearch.knn.index.query;
77

8+
import java.util.Arrays;
9+
import java.util.Objects;
810
import lombok.Getter;
911
import lombok.Setter;
1012
import org.apache.lucene.search.BooleanClause;
@@ -127,7 +129,7 @@ public String toString(String field) {
127129

128130
@Override
129131
public int hashCode() {
130-
return field.hashCode() ^ queryVector.hashCode() ^ k;
132+
return Objects.hash(field, Arrays.hashCode(queryVector), k, indexName, filterQuery);
131133
}
132134

133135
@Override
@@ -136,6 +138,10 @@ public boolean equals(Object other) {
136138
}
137139

138140
private boolean equalsTo(KNNQuery other) {
139-
return this.field.equals(other.getField()) && this.queryVector.equals(other.getQueryVector()) && this.k == other.getK();
141+
return Objects.equals(field, other.field)
142+
&& Arrays.equals(queryVector, other.queryVector)
143+
&& Objects.equals(k, other.k)
144+
&& Objects.equals(indexName, other.indexName)
145+
&& Objects.equals(filterQuery, other.filterQuery);
140146
}
141-
};
147+
}

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

+8-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.knn.index.query;
77

8+
import java.util.Arrays;
89
import lombok.extern.log4j.Log4j2;
910
import org.apache.lucene.search.MatchNoDocsQuery;
1011
import org.opensearch.core.common.Strings;
@@ -46,7 +47,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
4647
public static final ParseField K_FIELD = new ParseField("k");
4748
public static final ParseField FILTER_FIELD = new ParseField("filter");
4849
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
49-
public static int K_MAX = 10000;
50+
public static final int K_MAX = 10000;
5051
/**
5152
* The name for the knn query
5253
*/
@@ -346,12 +347,16 @@ private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFie
346347

347348
@Override
348349
protected boolean doEquals(KNNQueryBuilder other) {
349-
return Objects.equals(fieldName, other.fieldName) && Objects.equals(vector, other.vector) && Objects.equals(k, other.k);
350+
return Objects.equals(fieldName, other.fieldName)
351+
&& Arrays.equals(vector, other.vector)
352+
&& Objects.equals(k, other.k)
353+
&& Objects.equals(filter, other.filter)
354+
&& Objects.equals(ignoreUnmapped, other.ignoreUnmapped);
350355
}
351356

352357
@Override
353358
protected int doHashCode() {
354-
return Objects.hash(fieldName, vector, k);
359+
return Objects.hash(fieldName, Arrays.hashCode(vector), k, filter, ignoreUnmapped);
355360
}
356361

357362
@Override

src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public void testEmptyVector() {
9090
expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector1, K));
9191
}
9292

93-
public void testFromXcontent() throws Exception {
93+
public void testFromXContent() throws Exception {
9494
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
9595
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
9696
XContentBuilder builder = XContentFactory.jsonBuilder();
@@ -103,10 +103,10 @@ public void testFromXcontent() throws Exception {
103103
XContentParser contentParser = createParser(builder);
104104
contentParser.nextToken();
105105
KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser);
106-
actualBuilder.equals(knnQueryBuilder);
106+
assertEquals(knnQueryBuilder, actualBuilder);
107107
}
108108

109-
public void testFromXcontent_WithFilter() throws Exception {
109+
public void testFromXContent_WithFilter() throws Exception {
110110
final ClusterService clusterService = mockClusterService(Version.CURRENT);
111111

112112
final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
@@ -125,7 +125,7 @@ public void testFromXcontent_WithFilter() throws Exception {
125125
XContentParser contentParser = createParser(builder);
126126
contentParser.nextToken();
127127
KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser);
128-
actualBuilder.equals(knnQueryBuilder);
128+
assertEquals(knnQueryBuilder, actualBuilder);
129129
}
130130

131131
public void testFromXContent_invalidQueryVectorType() throws Exception {

0 commit comments

Comments
 (0)