Skip to content

Commit 7103e56

Browse files
authored
Add ShardBatchCache to support caching for TransportNodesListGatewayStartedShardsBatch (opensearch-project#12504)
Signed-off-by: Aman Khare <amkhar@amazon.com>
1 parent 645b1f1 commit 7103e56

15 files changed

+621
-73
lines changed

server/src/internalClusterTest/java/org/opensearch/gateway/GatewayRecoveryTestUtils.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public static Map<ShardId, ShardAttributes> prepareRequestMap(String[] indices,
5454
);
5555
for (int shardIdNum = 0; shardIdNum < primaryShardCount; shardIdNum++) {
5656
final ShardId shardId = new ShardId(index, shardIdNum);
57-
shardIdShardAttributesMap.put(shardId, new ShardAttributes(shardId, customDataPath));
57+
shardIdShardAttributesMap.put(shardId, new ShardAttributes(customDataPath));
5858
}
5959
}
6060
return shardIdShardAttributesMap;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.gateway;
10+
11+
import org.apache.logging.log4j.Logger;
12+
import org.opensearch.action.support.nodes.BaseNodeResponse;
13+
import org.opensearch.action.support.nodes.BaseNodesResponse;
14+
import org.opensearch.cluster.node.DiscoveryNode;
15+
import org.opensearch.common.logging.Loggers;
16+
import org.opensearch.core.index.shard.ShardId;
17+
import org.opensearch.indices.store.ShardAttributes;
18+
19+
import java.lang.reflect.Array;
20+
import java.util.HashMap;
21+
import java.util.Map;
22+
import java.util.Set;
23+
import java.util.function.Predicate;
24+
25+
import reactor.util.annotation.NonNull;
26+
27+
/**
28+
* Implementation of AsyncShardFetch with batching support. This class is responsible for executing the fetch
29+
* part using the base class {@link AsyncShardFetch}. Other functionalities needed for a batch are only written here.
30+
* This separation also takes care of the extra generic type V which is only needed for batch
31+
* transport actions like {@link TransportNodesListGatewayStartedShardsBatch} and
32+
* {@link org.opensearch.indices.store.TransportNodesListShardStoreMetadataBatch}.
33+
*
34+
* @param <T> Response type of the transport action.
35+
* @param <V> Data type of shard level response.
36+
*
37+
* @opensearch.internal
38+
*/
39+
public abstract class AsyncShardBatchFetch<T extends BaseNodeResponse, V> extends AsyncShardFetch<T> {
40+
41+
@SuppressWarnings("unchecked")
42+
AsyncShardBatchFetch(
43+
Logger logger,
44+
String type,
45+
Map<ShardId, ShardAttributes> shardAttributesMap,
46+
AsyncShardFetch.Lister<? extends BaseNodesResponse<T>, T> action,
47+
String batchId,
48+
Class<V> clazz,
49+
V emptyShardResponse,
50+
Predicate<V> emptyShardResponsePredicate,
51+
ShardBatchResponseFactory<T, V> responseFactory
52+
) {
53+
super(
54+
logger,
55+
type,
56+
shardAttributesMap,
57+
action,
58+
batchId,
59+
new ShardBatchCache<>(
60+
logger,
61+
type,
62+
shardAttributesMap,
63+
"BatchID=[" + batchId + "]",
64+
clazz,
65+
emptyShardResponse,
66+
emptyShardResponsePredicate,
67+
responseFactory
68+
)
69+
);
70+
}
71+
72+
/**
73+
* Remove a shard from the cache maintaining a full batch of shards. This is needed to clear the shard once it's
74+
* assigned or failed.
75+
*
76+
* @param shardId shardId to be removed from the batch.
77+
*/
78+
public synchronized void clearShard(ShardId shardId) {
79+
this.shardAttributesMap.remove(shardId);
80+
this.cache.deleteShard(shardId);
81+
}
82+
83+
/**
84+
* Cache implementation of transport actions returning batch of shards related data in the response.
85+
* Store node level responses of transport actions like {@link TransportNodesListGatewayStartedShardsBatch} or
86+
* {@link org.opensearch.indices.store.TransportNodesListShardStoreMetadataBatch} with memory efficient caching
87+
* approach. This cache class is not thread safe, all of its methods are being called from
88+
* {@link AsyncShardFetch} class which has synchronized blocks present to handle multiple threads.
89+
*
90+
* @param <T> Response type of transport action.
91+
* @param <V> Data type of shard level response.
92+
*/
93+
static class ShardBatchCache<T extends BaseNodeResponse, V> extends AsyncShardFetchCache<T> {
94+
private final Map<String, NodeEntry<V>> cache;
95+
private final Map<ShardId, Integer> shardIdToArray;
96+
private final int batchSize;
97+
private final Class<V> shardResponseClass;
98+
private final ShardBatchResponseFactory<T, V> responseFactory;
99+
private final V emptyResponse;
100+
private final Predicate<V> emptyShardResponsePredicate;
101+
private final Logger logger;
102+
103+
public ShardBatchCache(
104+
Logger logger,
105+
String type,
106+
Map<ShardId, ShardAttributes> shardAttributesMap,
107+
String logKey,
108+
Class<V> clazz,
109+
V emptyResponse,
110+
Predicate<V> emptyShardResponsePredicate,
111+
ShardBatchResponseFactory<T, V> responseFactory
112+
) {
113+
super(Loggers.getLogger(logger, "_" + logKey), type);
114+
this.batchSize = shardAttributesMap.size();
115+
this.emptyShardResponsePredicate = emptyShardResponsePredicate;
116+
cache = new HashMap<>();
117+
shardIdToArray = new HashMap<>();
118+
fillShardIdKeys(shardAttributesMap.keySet());
119+
this.shardResponseClass = clazz;
120+
this.emptyResponse = emptyResponse;
121+
this.logger = logger;
122+
this.responseFactory = responseFactory;
123+
}
124+
125+
@Override
126+
@NonNull
127+
public Map<String, ? extends BaseNodeEntry> getCache() {
128+
return cache;
129+
}
130+
131+
@Override
132+
public void deleteShard(ShardId shardId) {
133+
if (shardIdToArray.containsKey(shardId)) {
134+
Integer shardIdIndex = shardIdToArray.remove(shardId);
135+
for (String nodeId : cache.keySet()) {
136+
cache.get(nodeId).clearShard(shardIdIndex);
137+
}
138+
}
139+
}
140+
141+
@Override
142+
public void initData(DiscoveryNode node) {
143+
cache.put(node.getId(), new NodeEntry<>(node.getId(), shardResponseClass, batchSize, emptyShardResponsePredicate));
144+
}
145+
146+
/**
147+
* Put the response received from data nodes into the cache.
148+
* Get shard level data from batch, then filter out if any shards received failures.
149+
* After that complete storing the data at node level and mark fetching as done.
150+
*
151+
* @param node node from which we got the response.
152+
* @param response shard metadata coming from node.
153+
*/
154+
@Override
155+
public void putData(DiscoveryNode node, T response) {
156+
NodeEntry<V> nodeEntry = cache.get(node.getId());
157+
Map<ShardId, V> batchResponse = responseFactory.getShardBatchData(response);
158+
nodeEntry.doneFetching(batchResponse, shardIdToArray);
159+
}
160+
161+
@Override
162+
public T getData(DiscoveryNode node) {
163+
return this.responseFactory.getNewResponse(node, getBatchData(cache.get(node.getId())));
164+
}
165+
166+
private HashMap<ShardId, V> getBatchData(NodeEntry<V> nodeEntry) {
167+
V[] nodeShardEntries = nodeEntry.getData();
168+
boolean[] emptyResponses = nodeEntry.getEmptyShardResponse();
169+
HashMap<ShardId, V> shardData = new HashMap<>();
170+
for (Map.Entry<ShardId, Integer> shardIdEntry : shardIdToArray.entrySet()) {
171+
ShardId shardId = shardIdEntry.getKey();
172+
Integer arrIndex = shardIdEntry.getValue();
173+
if (emptyResponses[arrIndex]) {
174+
shardData.put(shardId, emptyResponse);
175+
} else if (nodeShardEntries[arrIndex] != null) {
176+
// ignore null responses here
177+
shardData.put(shardId, nodeShardEntries[arrIndex]);
178+
}
179+
}
180+
return shardData;
181+
}
182+
183+
private void fillShardIdKeys(Set<ShardId> shardIds) {
184+
int shardIdIndex = 0;
185+
for (ShardId shardId : shardIds) {
186+
this.shardIdToArray.putIfAbsent(shardId, shardIdIndex++);
187+
}
188+
}
189+
190+
/**
191+
* A node entry, holding the state of the fetched data for a specific shard
192+
* for a giving node.
193+
*/
194+
static class NodeEntry<V> extends BaseNodeEntry {
195+
private final V[] shardData;
196+
private final boolean[] emptyShardResponse; // we can not rely on null entries of the shardData array,
197+
// those null entries means that we need to ignore those entries. Empty responses on the other hand are
198+
// actually needed in allocation/explain API response. So instead of storing full empty response object
199+
// in cache, it's better to just store a boolean and create that object on the fly just before
200+
// decision-making.
201+
private final Predicate<V> emptyShardResponsePredicate;
202+
203+
NodeEntry(String nodeId, Class<V> clazz, int batchSize, Predicate<V> emptyShardResponsePredicate) {
204+
super(nodeId);
205+
this.shardData = (V[]) Array.newInstance(clazz, batchSize);
206+
this.emptyShardResponse = new boolean[batchSize];
207+
this.emptyShardResponsePredicate = emptyShardResponsePredicate;
208+
}
209+
210+
void doneFetching(Map<ShardId, V> shardDataFromNode, Map<ShardId, Integer> shardIdKey) {
211+
fillShardData(shardDataFromNode, shardIdKey);
212+
super.doneFetching();
213+
}
214+
215+
void clearShard(Integer shardIdIndex) {
216+
this.shardData[shardIdIndex] = null;
217+
emptyShardResponse[shardIdIndex] = false;
218+
}
219+
220+
V[] getData() {
221+
return this.shardData;
222+
}
223+
224+
boolean[] getEmptyShardResponse() {
225+
return emptyShardResponse;
226+
}
227+
228+
private void fillShardData(Map<ShardId, V> shardDataFromNode, Map<ShardId, Integer> shardIdKey) {
229+
for (Map.Entry<ShardId, V> shardData : shardDataFromNode.entrySet()) {
230+
if (shardData.getValue() != null) {
231+
ShardId shardId = shardData.getKey();
232+
if (emptyShardResponsePredicate.test(shardData.getValue())) {
233+
this.emptyShardResponse[shardIdKey.get(shardId)] = true;
234+
this.shardData[shardIdKey.get(shardId)] = null;
235+
} else {
236+
this.shardData[shardIdKey.get(shardId)] = shardData.getValue();
237+
}
238+
}
239+
}
240+
}
241+
}
242+
}
243+
}

server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ public interface Lister<NodesResponse extends BaseNodesResponse<NodeResponse>, N
8282
protected final String type;
8383
protected final Map<ShardId, ShardAttributes> shardAttributesMap;
8484
private final Lister<BaseNodesResponse<T>, T> action;
85-
private final AsyncShardFetchCache<T> cache;
85+
protected final AsyncShardFetchCache<T> cache;
8686
private final AtomicLong round = new AtomicLong();
8787
private boolean closed;
88-
private final String reroutingKey;
88+
final String reroutingKey;
8989
private final Map<ShardId, Set<String>> shardToIgnoreNodes = new HashMap<>();
9090

9191
@SuppressWarnings("unchecked")
@@ -99,7 +99,7 @@ protected AsyncShardFetch(
9999
this.logger = logger;
100100
this.type = type;
101101
shardAttributesMap = new HashMap<>();
102-
shardAttributesMap.put(shardId, new ShardAttributes(shardId, customDataPath));
102+
shardAttributesMap.put(shardId, new ShardAttributes(customDataPath));
103103
this.action = (Lister<BaseNodesResponse<T>, T>) action;
104104
this.reroutingKey = "ShardId=[" + shardId.toString() + "]";
105105
cache = new ShardCache<>(logger, reroutingKey, type);
@@ -120,14 +120,15 @@ protected AsyncShardFetch(
120120
String type,
121121
Map<ShardId, ShardAttributes> shardAttributesMap,
122122
Lister<? extends BaseNodesResponse<T>, T> action,
123-
String batchId
123+
String batchId,
124+
AsyncShardFetchCache<T> cache
124125
) {
125126
this.logger = logger;
126127
this.type = type;
127128
this.shardAttributesMap = shardAttributesMap;
128129
this.action = (Lister<BaseNodesResponse<T>, T>) action;
129130
this.reroutingKey = "BatchID=[" + batchId + "]";
130-
cache = new ShardCache<>(logger, reroutingKey, type);
131+
this.cache = cache;
131132
}
132133

133134
@Override

server/src/main/java/org/opensearch/gateway/AsyncShardFetchCache.java

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
* @opensearch.internal
4949
*/
5050
public abstract class AsyncShardFetchCache<K extends BaseNodeResponse> {
51+
5152
private final Logger logger;
5253
private final String type;
5354

server/src/main/java/org/opensearch/gateway/PrimaryShardBatchAllocator.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.opensearch.cluster.routing.allocation.AllocateUnassignedDecision;
1616
import org.opensearch.cluster.routing.allocation.RoutingAllocation;
1717
import org.opensearch.gateway.AsyncShardFetch.FetchResult;
18+
import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.GatewayStartedShard;
1819
import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.NodeGatewayStartedShard;
1920
import org.opensearch.gateway.TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch;
2021

@@ -132,9 +133,7 @@ private static List<NodeGatewayStartedShard> adaptToNodeShardStates(
132133

133134
// build data for a shard from all the nodes
134135
nodeResponses.forEach((node, nodeGatewayStartedShardsBatch) -> {
135-
TransportNodesGatewayStartedShardHelper.GatewayStartedShard shardData = nodeGatewayStartedShardsBatch
136-
.getNodeGatewayStartedShardsBatch()
137-
.get(unassignedShard.shardId());
136+
GatewayStartedShard shardData = nodeGatewayStartedShardsBatch.getNodeGatewayStartedShardsBatch().get(unassignedShard.shardId());
138137
nodeShardStates.add(
139138
new NodeGatewayStartedShard(
140139
shardData.allocationId(),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.gateway;
10+
11+
import org.opensearch.action.support.nodes.BaseNodeResponse;
12+
import org.opensearch.cluster.node.DiscoveryNode;
13+
import org.opensearch.core.index.shard.ShardId;
14+
import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.GatewayStartedShard;
15+
import org.opensearch.gateway.TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch;
16+
import org.opensearch.indices.store.TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadata;
17+
import org.opensearch.indices.store.TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadataBatch;
18+
19+
import java.util.Map;
20+
21+
/**
22+
* A factory class to create new responses of batch transport actions like
23+
* {@link TransportNodesListGatewayStartedShardsBatch} or {@link org.opensearch.indices.store.TransportNodesListShardStoreMetadataBatch}
24+
*
25+
* @param <T> Node level response returned by batch transport actions.
26+
* @param <V> Shard level metadata returned by batch transport actions.
27+
*/
28+
public class ShardBatchResponseFactory<T extends BaseNodeResponse, V> {
29+
private final boolean primary;
30+
31+
public ShardBatchResponseFactory(boolean primary) {
32+
this.primary = primary;
33+
}
34+
35+
public T getNewResponse(DiscoveryNode node, Map<ShardId, V> shardData) {
36+
if (primary) {
37+
return (T) new NodeGatewayStartedShardsBatch(node, (Map<ShardId, GatewayStartedShard>) shardData);
38+
} else {
39+
return (T) new NodeStoreFilesMetadataBatch(node, (Map<ShardId, NodeStoreFilesMetadata>) shardData);
40+
}
41+
}
42+
43+
public Map<ShardId, V> getShardBatchData(T response) {
44+
if (primary) {
45+
return (Map<ShardId, V>) ((NodeGatewayStartedShardsBatch) response).getNodeGatewayStartedShardsBatch();
46+
} else {
47+
return (Map<ShardId, V>) ((NodeStoreFilesMetadataBatch) response).getNodeStoreFilesMetadataBatch();
48+
}
49+
}
50+
51+
}

0 commit comments

Comments
 (0)