Skip to content

Commit 396c0df

Browse files
author
Prabhakar Sithanandam
committed
Cache the shard routings with no weight for faster access
The list of shards to run a query is determined for every request and the weight of the nodes guides the shard selection. Currently, IndexRoutingTable caches the shard routings with weight for faster access. But, during cases where the fail open option is enabled, shards with no weight is also returned lower in the order along with shards with weights. They will be used as fall back if the shards with weights can't be used due to some error. The shard routing with no weight is not cached, hence it does a full loop for every request, this impacts the search latency when the number of shards to query or the number of nodes in the cluster is high. The latency impact is very high when both the number of shards and the number of nodes are high. This change introduces a caching mechanism for shard routing with no weights similar to the existing cache for shard routing with weights. Signed-off-by: Prabhakar Sithanandam <prabhakar.s87@gmail.com>
1 parent 8426e14 commit 396c0df

File tree

2 files changed

+101
-41
lines changed

2 files changed

+101
-41
lines changed

server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java

+85-36
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535
import org.apache.logging.log4j.LogManager;
3636
import org.apache.logging.log4j.Logger;
37-
import org.opensearch.cluster.metadata.WeightedRoutingMetadata;
3837
import org.opensearch.cluster.node.DiscoveryNode;
3938
import org.opensearch.cluster.node.DiscoveryNodes;
4039
import org.opensearch.common.Nullable;
@@ -63,7 +62,6 @@
6362
import java.util.Set;
6463
import java.util.function.Predicate;
6564
import java.util.stream.Collectors;
66-
import java.util.stream.Stream;
6765

6866
import static java.util.Collections.emptyMap;
6967

@@ -96,8 +94,8 @@ public class IndexShardRoutingTable implements Iterable<ShardRouting> {
9694
private volatile Map<AttributesKey, AttributesRoutings> initializingShardsByAttributes = emptyMap();
9795
private final Object shardsByAttributeMutex = new Object();
9896
private final Object shardsByWeightMutex = new Object();
99-
private volatile Map<WeightedRoutingKey, List<ShardRouting>> activeShardsByWeight = emptyMap();
100-
private volatile Map<WeightedRoutingKey, List<ShardRouting>> initializingShardsByWeight = emptyMap();
97+
private volatile Map<WeightedRoutingKey, WeightedShardRoutings> activeShardsByWeight = emptyMap();
98+
private volatile Map<WeightedRoutingKey, WeightedShardRoutings> initializingShardsByWeight = emptyMap();
10199

102100
private static final Logger logger = LogManager.getLogger(IndexShardRoutingTable.class);
103101

@@ -249,7 +247,7 @@ public List<ShardRouting> assignedShards() {
249247
return this.assignedShards;
250248
}
251249

252-
public Map<WeightedRoutingKey, List<ShardRouting>> getActiveShardsByWeight() {
250+
public Map<WeightedRoutingKey, WeightedShardRoutings> getActiveShardsByWeight() {
253251
return activeShardsByWeight;
254252
}
255253

@@ -338,23 +336,7 @@ public ShardIterator activeInitializingShardsWeightedIt(
338336
// append shards for attribute value with weight zero, so that shard search requests can be tried on
339337
// shard copies in case of request failure from other attribute values.
340338
if (isFailOpenEnabled) {
341-
try {
342-
Stream<String> keys = weightedRouting.weights()
343-
.entrySet()
344-
.stream()
345-
.filter(entry -> entry.getValue().intValue() == WeightedRoutingMetadata.WEIGHED_AWAY_WEIGHT)
346-
.map(Map.Entry::getKey);
347-
keys.forEach(key -> {
348-
ShardIterator iterator = onlyNodeSelectorActiveInitializingShardsIt(weightedRouting.attributeName() + ":" + key, nodes);
349-
while (iterator.remaining() > 0) {
350-
ordered.add(iterator.nextOrNull());
351-
}
352-
});
353-
} catch (IllegalArgumentException e) {
354-
// this exception is thrown by {@link onlyNodeSelectorActiveInitializingShardsIt} in case count of shard
355-
// copies found is zero
356-
logger.debug("no shard copies found for shard id [{}] for node attribute with weight zero", shardId);
357-
}
339+
ordered.addAll(activeInitializingShardsWithoutWeights(weightedRouting, nodes, defaultWeight));
358340
}
359341

360342
return new PlainShardIterator(shardId, ordered);
@@ -378,6 +360,18 @@ private List<ShardRouting> activeInitializingShardsWithWeights(
378360
return orderedListWithDistinctShards;
379361
}
380362

363+
private List<ShardRouting> activeInitializingShardsWithoutWeights(
364+
WeightedRouting weightedRouting,
365+
DiscoveryNodes nodes,
366+
double defaultWeight
367+
) {
368+
List<ShardRouting> ordered = new ArrayList<>(getActiveShardsWithoutWeight(weightedRouting, nodes, defaultWeight));
369+
if (!allInitializingShards.isEmpty()) {
370+
ordered.addAll(getInitializingShardsWithoutWeight(weightedRouting, nodes, defaultWeight));
371+
}
372+
return ordered.stream().distinct().collect(Collectors.toList());
373+
}
374+
381375
/**
382376
* Returns a list containing shard routings ordered using weighted round-robin scheduling.
383377
*/
@@ -949,20 +943,55 @@ public int hashCode() {
949943
}
950944
}
951945

946+
@PublicApi(since = "2.14.0")
947+
public static class WeightedShardRoutings {
948+
private final List<ShardRouting> shardRoutingsWithWeight;
949+
private final List<ShardRouting> shardRoutingWithoutWeight;
950+
951+
public WeightedShardRoutings(List<ShardRouting> shardRoutingsWithWeight, List<ShardRouting> shardRoutingWithoutWeight) {
952+
this.shardRoutingsWithWeight = Collections.unmodifiableList(shardRoutingsWithWeight);
953+
this.shardRoutingWithoutWeight = Collections.unmodifiableList(shardRoutingWithoutWeight);
954+
}
955+
956+
public List<ShardRouting> getShardRoutingsWithWeight() {
957+
return shardRoutingsWithWeight;
958+
}
959+
960+
public List<ShardRouting> getShardRoutingWithoutWeight() {
961+
return shardRoutingWithoutWeight;
962+
}
963+
}
964+
952965
/**
953966
* *
954967
* Gets active shard routing from memory if available, else calculates and put it in memory.
955968
*/
956969
private List<ShardRouting> getActiveShardsByWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
957970
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
958-
List<ShardRouting> shardRoutings = activeShardsByWeight.get(key);
959-
if (shardRoutings == null) {
960-
synchronized (shardsByWeightMutex) {
961-
shardRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight);
962-
activeShardsByWeight = new MapBuilder().put(key, shardRoutings).immutableMap();
963-
}
971+
if (activeShardsByWeight.get(key) == null) {
972+
populateActiveShardWeightsMap(weightedRouting, nodes, defaultWeight);
973+
}
974+
return activeShardsByWeight.get(key).getShardRoutingsWithWeight();
975+
}
976+
977+
private List<ShardRouting> getActiveShardsWithoutWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
978+
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
979+
if (activeShardsByWeight.get(key) == null) {
980+
populateActiveShardWeightsMap(weightedRouting, nodes, defaultWeight);
981+
}
982+
return activeShardsByWeight.get(key).getShardRoutingWithoutWeight();
983+
}
984+
985+
private void populateActiveShardWeightsMap(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
986+
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
987+
List<ShardRouting> weightedRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight);
988+
List<ShardRouting> nonWeightedRoutings = activeShards.stream()
989+
.filter(shard -> !weightedRoutings.contains(shard))
990+
.collect(Collectors.toUnmodifiableList());
991+
synchronized (shardsByWeightMutex) {
992+
activeShardsByWeight = new MapBuilder().put(key, new WeightedShardRoutings(weightedRoutings, nonWeightedRoutings))
993+
.immutableMap();
964994
}
965-
return shardRoutings;
966995
}
967996

968997
/**
@@ -971,14 +1000,34 @@ private List<ShardRouting> getActiveShardsByWeight(WeightedRouting weightedRouti
9711000
*/
9721001
private List<ShardRouting> getInitializingShardsByWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
9731002
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
974-
List<ShardRouting> shardRoutings = initializingShardsByWeight.get(key);
975-
if (shardRoutings == null) {
976-
synchronized (shardsByWeightMutex) {
977-
shardRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight);
978-
initializingShardsByWeight = new MapBuilder().put(key, shardRoutings).immutableMap();
979-
}
1003+
if (initializingShardsByWeight.get(key) == null) {
1004+
populateInitializingShardWeightsMap(weightedRouting, nodes, defaultWeight);
1005+
}
1006+
return initializingShardsByWeight.get(key).getShardRoutingsWithWeight();
1007+
}
1008+
1009+
private List<ShardRouting> getInitializingShardsWithoutWeight(
1010+
WeightedRouting weightedRouting,
1011+
DiscoveryNodes nodes,
1012+
double defaultWeight
1013+
) {
1014+
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
1015+
if (initializingShardsByWeight.get(key) == null) {
1016+
populateInitializingShardWeightsMap(weightedRouting, nodes, defaultWeight);
1017+
}
1018+
return initializingShardsByWeight.get(key).getShardRoutingWithoutWeight();
1019+
}
1020+
1021+
private void populateInitializingShardWeightsMap(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
1022+
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
1023+
List<ShardRouting> weightedRoutings = shardsOrderedByWeight(allInitializingShards, weightedRouting, nodes, defaultWeight);
1024+
List<ShardRouting> nonWeightedRoutings = allInitializingShards.stream()
1025+
.filter(shard -> !weightedRoutings.contains(shard))
1026+
.collect(Collectors.toUnmodifiableList());
1027+
synchronized (shardsByWeightMutex) {
1028+
initializingShardsByWeight = new MapBuilder().put(key, new WeightedShardRoutings(weightedRoutings, nonWeightedRoutings))
1029+
.immutableMap();
9801030
}
981-
return shardRoutings;
9821031
}
9831032

9841033
/**

server/src/test/java/org/opensearch/cluster/structure/RoutingIteratorTests.java

+16-5
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,18 @@ public void testWeightedRoutingWithDifferentWeights() {
700700
.shard(0)
701701
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null);
702702
assertEquals(1, shardIterator.size());
703-
shardRouting = shardIterator.nextOrNull();
704-
assertNotNull(shardRouting);
705-
assertFalse(Arrays.asList("node2", "node1").contains(shardRouting.currentNodeId()));
703+
assertEquals("node3", shardIterator.nextOrNull().currentNodeId());
704+
705+
weights = Map.of("zone1", -1.0, "zone2", 0.0, "zone3", 1.0);
706+
weightedRouting = new WeightedRouting("zone", weights);
707+
shardIterator = clusterState.routingTable()
708+
.index("test")
709+
.shard(0)
710+
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null);
711+
assertEquals(3, shardIterator.size());
712+
assertEquals("node3", shardIterator.nextOrNull().currentNodeId());
713+
assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId());
714+
assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId());
706715

707716
weights = Map.of("zone1", 3.0, "zone2", 2.0, "zone3", 0.0);
708717
weightedRouting = new WeightedRouting("zone", weights);
@@ -711,8 +720,9 @@ public void testWeightedRoutingWithDifferentWeights() {
711720
.shard(0)
712721
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null);
713722
assertEquals(3, shardIterator.size());
714-
shardRouting = shardIterator.nextOrNull();
715-
assertNotNull(shardRouting);
723+
assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId());
724+
assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId());
725+
assertEquals("node3", shardIterator.nextOrNull().currentNodeId());
716726
} finally {
717727
terminate(threadPool);
718728
}
@@ -887,6 +897,7 @@ public void testWeightedRoutingShardState() {
887897
shardRouting = shardIterator.nextOrNull();
888898
assertNotNull(shardRouting);
889899
requestCount.put(shardRouting.currentNodeId(), requestCount.getOrDefault(shardRouting.currentNodeId(), 0) + 1);
900+
890901
}
891902
assertEquals(3, requestCount.get("node1").intValue());
892903
assertEquals(2, requestCount.get("node2").intValue());

0 commit comments

Comments
 (0)