Skip to content

Commit

Permalink
Adding functionality to override the vanilla KCL tasks (#1440)
Browse files Browse the repository at this point in the history
  • Loading branch information
gguptp authored Feb 21, 2025
1 parent 68a7a9b commit 8deebe4
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@
import software.amazon.kinesis.leases.exceptions.DependencyException;
import software.amazon.kinesis.leases.exceptions.InvalidStateException;
import software.amazon.kinesis.leases.exceptions.ProvisionedThroughputException;
import software.amazon.kinesis.lifecycle.ConsumerTaskFactory;
import software.amazon.kinesis.lifecycle.KinesisConsumerTaskFactory;
import software.amazon.kinesis.lifecycle.LifecycleConfig;
import software.amazon.kinesis.lifecycle.ShardConsumer;
import software.amazon.kinesis.lifecycle.ShardConsumerArgument;
Expand Down Expand Up @@ -188,6 +190,7 @@ public class Scheduler implements Runnable {
private final SchemaRegistryDecoder schemaRegistryDecoder;

private final DeletedStreamListProvider deletedStreamListProvider;
private final ConsumerTaskFactory taskFactory;

@Getter(AccessLevel.NONE)
private final MigrationStateMachine migrationStateMachine;
Expand Down Expand Up @@ -264,6 +267,33 @@ protected Scheduler(
@NonNull final ProcessorConfig processorConfig,
@NonNull final RetrievalConfig retrievalConfig,
@NonNull final DiagnosticEventFactory diagnosticEventFactory) {
this(
checkpointConfig,
coordinatorConfig,
leaseManagementConfig,
lifecycleConfig,
metricsConfig,
processorConfig,
retrievalConfig,
diagnosticEventFactory,
new KinesisConsumerTaskFactory());
}

/**
* Customers do not currently have the ability to customize the DiagnosticEventFactory, but this visibility
* is desired for testing. This constructor is only used for testing to provide a mock DiagnosticEventFactory.
*/
@VisibleForTesting
protected Scheduler(
@NonNull final CheckpointConfig checkpointConfig,
@NonNull final CoordinatorConfig coordinatorConfig,
@NonNull final LeaseManagementConfig leaseManagementConfig,
@NonNull final LifecycleConfig lifecycleConfig,
@NonNull final MetricsConfig metricsConfig,
@NonNull final ProcessorConfig processorConfig,
@NonNull final RetrievalConfig retrievalConfig,
@NonNull final DiagnosticEventFactory diagnosticEventFactory,
@NonNull final ConsumerTaskFactory taskFactory) {
this.checkpointConfig = checkpointConfig;
this.coordinatorConfig = coordinatorConfig;
this.leaseManagementConfig = leaseManagementConfig;
Expand Down Expand Up @@ -371,6 +401,7 @@ protected Scheduler(
this.schemaRegistryDecoder = this.retrievalConfig.glueSchemaRegistryDeserializer() == null
? null
: new SchemaRegistryDecoder(this.retrievalConfig.glueSchemaRegistryDeserializer());
this.taskFactory = taskFactory;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,19 @@ interface ConsumerState {
* the consumer to use build the task, or execute state.
* @param input
* the process input received, this may be null if it's a control message
* @param taskFactory
* a factory for creating tasks
* @return a valid task for this state or null if there is no task required.
*/
ConsumerTask createTask(ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input);
ConsumerTask createTask(
ShardConsumerArgument consumerArgument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory);

/**
* Provides the next state of the consumer upon success of the task return by
* {@link ConsumerState#createTask(ShardConsumerArgument, ShardConsumer, ProcessRecordsInput)}.
* {@link ConsumerState#createTask(ShardConsumerArgument, ShardConsumer, ProcessRecordsInput, ConsumerTaskFactory)}.
*
* @return the next state that the consumer should transition to, this may be the same object as the current
* state.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import lombok.Getter;
import lombok.experimental.Accessors;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.ThrottlingReporter;

/**
* Top level container for all the possible states a {@link ShardConsumer} can be in. The logic for creation of tasks,
Expand Down Expand Up @@ -121,11 +120,11 @@ static class BlockedOnParentState implements ConsumerState {

@Override
public ConsumerTask createTask(
ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input) {
return new BlockOnParentShardTask(
consumerArgument.shardInfo(),
consumerArgument.leaseCoordinator().leaseRefresher(),
consumerArgument.parentShardPollIntervalMillis());
ShardConsumerArgument consumerArgument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return taskFactory.createBlockOnParentTask(consumerArgument);
}

@Override
Expand Down Expand Up @@ -187,16 +186,11 @@ static class InitializingState implements ConsumerState {

@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
return new InitializeTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.checkpoint(),
argument.recordProcessorCheckpointer(),
argument.initialPositionInStream(),
argument.recordsPublisher(),
argument.taskBackoffTimeMillis(),
argument.metricsFactory());
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return taskFactory.createInitializeTask(argument);
}

@Override
Expand Down Expand Up @@ -250,24 +244,11 @@ static class ProcessingState implements ConsumerState {

@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ThrottlingReporter throttlingReporter =
new ThrottlingReporter(5, argument.shardInfo().shardId());
return new ProcessTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
argument.taskBackoffTimeMillis(),
argument.skipShardSyncAtWorkerInitializationIfLeasesExist(),
argument.shardDetector(),
throttlingReporter,
input,
argument.shouldCallProcessRecordsEvenForEmptyRecordList(),
argument.idleTimeInMilliseconds(),
argument.aggregatorUtil(),
argument.metricsFactory(),
argument.schemaRegistryDecoder(),
argument.leaseCoordinator().leaseStatsRecorder());
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return taskFactory.createProcessTask(argument, input);
}

@Override
Expand Down Expand Up @@ -331,14 +312,12 @@ static class ShutdownNotificationState implements ConsumerState {

@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
// TODO: notify shutdownrequested
return new ShutdownNotificationTask(
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownNotification(),
argument.shardInfo(),
consumer.shardConsumerArgument().leaseCoordinator());
return taskFactory.createShutdownNotificationTask(argument, consumer);
}

@Override
Expand Down Expand Up @@ -405,7 +384,10 @@ static class ShutdownNotificationCompletionState implements ConsumerState {

@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return null;
}

Expand Down Expand Up @@ -483,25 +465,12 @@ static class ShuttingDownState implements ConsumerState {

@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
// TODO: set shutdown reason
return new ShutdownTask(
argument.shardInfo(),
argument.shardDetector(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownReason(),
argument.initialPositionInStream(),
argument.cleanupLeasesOfCompletedShards(),
argument.ignoreUnexpectedChildShards(),
argument.leaseCoordinator(),
argument.taskBackoffTimeMillis(),
argument.recordsPublisher(),
argument.hierarchicalShardSyncer(),
argument.metricsFactory(),
input == null ? null : input.childShards(),
argument.streamIdentifier(),
argument.leaseCleanupManager());
return taskFactory.createShutdownTask(argument, consumer, input);
}

@Override
Expand Down Expand Up @@ -569,7 +538,10 @@ static class ShutdownCompleteState implements ConsumerState {

@Override
public ConsumerTask createTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
ShardConsumerArgument argument,
ShardConsumer consumer,
ProcessRecordsInput input,
ConsumerTaskFactory taskFactory) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates.
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package software.amazon.kinesis.lifecycle;

import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;

@KinesisClientInternalApi
public interface ConsumerTaskFactory {
/**
* Creates a shutdown task.
*/
ConsumerTask createShutdownTask(ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input);

/**
* Creates a process task.
*/
ConsumerTask createProcessTask(ShardConsumerArgument argument, ProcessRecordsInput processRecordsInput);

/**
* Creates an initialize task.
*/
ConsumerTask createInitializeTask(ShardConsumerArgument argument);

/**
* Creates a block on parent task.
*/
ConsumerTask createBlockOnParentTask(ShardConsumerArgument argument);

/**
* Creates a shutdown notification task.
*/
ConsumerTask createShutdownNotificationTask(ShardConsumerArgument argument, ShardConsumer consumer);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates.
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package software.amazon.kinesis.lifecycle;

import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.retrieval.ThrottlingReporter;

@KinesisClientInternalApi
public class KinesisConsumerTaskFactory implements ConsumerTaskFactory {

@Override
public ConsumerTask createShutdownTask(
ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
return new ShutdownTask(
argument.shardInfo(),
argument.shardDetector(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownReason(),
argument.initialPositionInStream(),
argument.cleanupLeasesOfCompletedShards(),
argument.ignoreUnexpectedChildShards(),
argument.leaseCoordinator(),
argument.taskBackoffTimeMillis(),
argument.recordsPublisher(),
argument.hierarchicalShardSyncer(),
argument.metricsFactory(),
input == null ? null : input.childShards(),
argument.streamIdentifier(),
argument.leaseCleanupManager());
}

@Override
public ConsumerTask createProcessTask(ShardConsumerArgument argument, ProcessRecordsInput processRecordsInput) {
ThrottlingReporter throttlingReporter =
new ThrottlingReporter(5, argument.shardInfo().shardId());
return new ProcessTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
argument.taskBackoffTimeMillis(),
argument.skipShardSyncAtWorkerInitializationIfLeasesExist(),
argument.shardDetector(),
throttlingReporter,
processRecordsInput,
argument.shouldCallProcessRecordsEvenForEmptyRecordList(),
argument.idleTimeInMilliseconds(),
argument.aggregatorUtil(),
argument.metricsFactory(),
argument.schemaRegistryDecoder(),
argument.leaseCoordinator().leaseStatsRecorder());
}

@Override
public ConsumerTask createInitializeTask(ShardConsumerArgument argument) {
return new InitializeTask(
argument.shardInfo(),
argument.shardRecordProcessor(),
argument.checkpoint(),
argument.recordProcessorCheckpointer(),
argument.initialPositionInStream(),
argument.recordsPublisher(),
argument.taskBackoffTimeMillis(),
argument.metricsFactory());
}

@Override
public ConsumerTask createBlockOnParentTask(ShardConsumerArgument argument) {
return new BlockOnParentShardTask(
argument.shardInfo(),
argument.leaseCoordinator().leaseRefresher(),
argument.parentShardPollIntervalMillis());
}

@Override
public ConsumerTask createShutdownNotificationTask(ShardConsumerArgument argument, ShardConsumer consumer) {
return new ShutdownNotificationTask(
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownNotification(),
argument.shardInfo(),
consumer.shardConsumerArgument().leaseCoordinator());
}
}
Loading

0 comments on commit 8deebe4

Please sign in to comment.