From d8fd414d5bfe942b486412e3f70ff4f64461fd39 Mon Sep 17 00:00:00 2001 From: Olivia Kotsopoulos Date: Mon, 15 Jul 2024 11:11:47 -0400 Subject: [PATCH] PR feedback MdcUtils: - Added public utility method callWithContext for running and returning a callable with MDC's context map temporarily overwritten - Remove public modifier on overwriteContext in favor of the above QueueMessageReady: - Use MdcUtils.callWithContext to simplify process definition - Revert process return strategy to return booleans directly rather than setting a boolean to return later - Remove unnecessary callingThreadContext setter (left existing setters untouched even though they can likely be removed: OOS) - Left public callingThreadContext getter unchanged: it is required to be public for serde operations Added unit test coverage for all changes. --- .../bio/terra/stairway/impl/MdcUtils.java | 29 +++++++++- .../stairway/queue/QueueMessageReady.java | 45 +++++++-------- .../bio/terra/stairway/impl/MdcUtilsTest.java | 57 +++++++++++++++++++ .../stairway/queue/QueueMessageTest.java | 16 +++--- 4 files changed, 113 insertions(+), 34 deletions(-) diff --git a/stairway/src/main/java/bio/terra/stairway/impl/MdcUtils.java b/stairway/src/main/java/bio/terra/stairway/impl/MdcUtils.java index 19edbd57..9b7a49c5 100644 --- a/stairway/src/main/java/bio/terra/stairway/impl/MdcUtils.java +++ b/stairway/src/main/java/bio/terra/stairway/impl/MdcUtils.java @@ -1,7 +1,9 @@ package bio.terra.stairway.impl; import bio.terra.stairway.FlightContext; +import bio.terra.stairway.exception.StairwayExecutionException; import java.util.Map; +import java.util.concurrent.Callable; import org.slf4j.MDC; /** @@ -9,7 +11,6 @@ * (MDC). */ public class MdcUtils { - /** ID of the flight */ static final String FLIGHT_ID_KEY = "flightId"; @@ -25,12 +26,36 @@ public class MdcUtils { /** The step's execution order */ static final String FLIGHT_STEP_NUMBER_KEY = "flightStepNumber"; + /** + * Run and return the result of the callable with MDC's context map temporarily overwritten during + * computation. The initial context map is then restored after computation. + * + * @param context to override MDC's context map + * @param callable to call and return + */ + public static T callWithContext(Map context, Callable callable) + throws InterruptedException { + // Save the initial thread context so that it can be restored + Map initialContext = MDC.getCopyOfContextMap(); + try { + MdcUtils.overwriteContext(context); + System.out.println(MDC.getCopyOfContextMap()); + return callable.call(); + } catch (InterruptedException ex) { + throw ex; + } catch (Exception ex) { + throw new StairwayExecutionException("Unexpected exception " + ex.getMessage(), ex); + } finally { + MdcUtils.overwriteContext(initialContext); + } + } + /** * Null-safe utility method for overwriting the current thread's MDC. * * @param context to set as MDC, if null then MDC will be cleared. */ - public static void overwriteContext(Map context) { + static void overwriteContext(Map context) { MDC.clear(); if (context != null) { MDC.setContextMap(context); diff --git a/stairway/src/main/java/bio/terra/stairway/queue/QueueMessageReady.java b/stairway/src/main/java/bio/terra/stairway/queue/QueueMessageReady.java index 0708b599..59414185 100644 --- a/stairway/src/main/java/bio/terra/stairway/queue/QueueMessageReady.java +++ b/stairway/src/main/java/bio/terra/stairway/queue/QueueMessageReady.java @@ -32,27 +32,26 @@ public QueueMessageReady(String flightId) { @Override public boolean process(StairwayImpl stairwayImpl) throws InterruptedException { - boolean processed = false; - // Save the initial thread context so that it can be restored - Map initialContext = MDC.getCopyOfContextMap(); - try { - MdcUtils.overwriteContext(callingThreadContext); - // Resumed is false if the flight is not found in the Ready state. We still call that - // a complete processing of the message and return true. We assume that some this is a - // duplicate message or that some other Stairway found the ready flight on recovery. - boolean resumed = stairwayImpl.resume(flightId); - logger.info( - "Stairway " - + stairwayImpl.getStairwayName() - + (resumed ? " resumed flight: " : " did not find flight to resume: ") - + flightId); - processed = true; - } catch (DatabaseOperationException ex) { - logger.error("Unexpected stairway error, leaving %s on the queue".formatted(flightId), ex); - } finally { - MdcUtils.overwriteContext(initialContext); - } - return processed; + return MdcUtils.callWithContext( + callingThreadContext, + () -> { + try { + // Resumed is false if the flight is not found in the Ready state. We still call that + // a complete processing of the message and return true. We assume that some this is a + // duplicate message or that some other Stairway found the ready flight on recovery. + boolean resumed = stairwayImpl.resume(flightId); + logger.info( + "Stairway " + + stairwayImpl.getStairwayName() + + (resumed ? " resumed flight: " : " did not find flight to resume: ") + + flightId); + return true; + } catch (DatabaseOperationException ex) { + logger.error( + "Unexpected stairway error, leaving %s on the queue".formatted(flightId), ex); + return false; + } + }); } public QueueMessageType getType() { @@ -74,8 +73,4 @@ public void setFlightId(String flightId) { public Map getCallingThreadContext() { return callingThreadContext; } - - public void setCallingThreadContext(Map callingThreadContext) { - this.callingThreadContext = callingThreadContext; - } } diff --git a/stairway/src/test/java/bio/terra/stairway/impl/MdcUtilsTest.java b/stairway/src/test/java/bio/terra/stairway/impl/MdcUtilsTest.java index be36c155..1182c43e 100644 --- a/stairway/src/test/java/bio/terra/stairway/impl/MdcUtilsTest.java +++ b/stairway/src/test/java/bio/terra/stairway/impl/MdcUtilsTest.java @@ -2,21 +2,27 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertThrows; import bio.terra.stairway.Direction; +import bio.terra.stairway.exception.StairwayExecutionException; import bio.terra.stairway.fixtures.TestFlightContext; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.UUID; import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.MDC; @Tag("unit") class MdcUtilsTest { + private static final Map INITIAL_CONTEXT = Map.of("initial", "context"); private static final Map FOO_BAR = Map.of("foo", "bar"); private static final String FLIGHT_ID = "flightId" + UUID.randomUUID(); private static final String FLIGHT_CLASS = "flightClass" + UUID.randomUUID(); @@ -52,6 +58,57 @@ static Stream> contextMap() { return Stream.of(null, Map.of(), FOO_BAR); } + @ParameterizedTest + @MethodSource("contextMap") + void callWithContext(Map newContext) throws InterruptedException { + MDC.setContextMap(INITIAL_CONTEXT); + Boolean result = + MdcUtils.callWithContext( + newContext, + () -> { + assertThat( + "Context is overwritten during computation", + MDC.getCopyOfContextMap(), + equalTo(newContext)); + return true; + }); + assertThat("Result of computation is returned", result, equalTo(true)); + assertThat("Initial context is restored", MDC.getCopyOfContextMap(), equalTo(INITIAL_CONTEXT)); + } + + static Stream callWithContext_exception() { + List arguments = new ArrayList<>(); + for (var newContext : contextMap().toList()) { + arguments.add( + Arguments.of( + newContext, new InterruptedException("interrupted"), InterruptedException.class)); + arguments.add( + Arguments.of( + newContext, new RuntimeException("unexpected"), StairwayExecutionException.class)); + } + return arguments.stream(); + } + + @ParameterizedTest + @MethodSource + void callWithContext_exception( + Map newContext, Exception exception, Class expectedExceptionClass) { + MDC.setContextMap(INITIAL_CONTEXT); + assertThrows( + expectedExceptionClass, + () -> + MdcUtils.callWithContext( + newContext, + () -> { + assertThat( + "Context is overwritten during computation", + MDC.getCopyOfContextMap(), + equalTo(newContext)); + throw exception; + })); + assertThat("Initial context is restored", MDC.getCopyOfContextMap(), equalTo(INITIAL_CONTEXT)); + } + @ParameterizedTest @MethodSource("contextMap") void overwriteContext(Map newContext) { diff --git a/stairway/src/test/java/bio/terra/stairway/queue/QueueMessageTest.java b/stairway/src/test/java/bio/terra/stairway/queue/QueueMessageTest.java index 54e3d564..02d350ac 100644 --- a/stairway/src/test/java/bio/terra/stairway/queue/QueueMessageTest.java +++ b/stairway/src/test/java/bio/terra/stairway/queue/QueueMessageTest.java @@ -37,15 +37,19 @@ void beforeEach() { MDC.clear(); } + private QueueMessageReady createQueueMessageWithContext(Map expectedMdc) + throws InterruptedException { + return MdcUtils.callWithContext(expectedMdc, () -> new QueueMessageReady(FLIGHT_ID)); + } + private static Stream> message_serde() { return Stream.of(null, CALLING_THREAD_CONTEXT); } @ParameterizedTest @MethodSource - void message_serde(Map expectedMdc) { - MdcUtils.overwriteContext(expectedMdc); - QueueMessageReady messageReady = new QueueMessageReady(FLIGHT_ID); + void message_serde(Map expectedMdc) throws InterruptedException { + QueueMessageReady messageReady = createQueueMessageWithContext(expectedMdc); WorkQueueProcessor workQueueProcessor = new WorkQueueProcessor(stairway); // Now we add something else to the MDC, but it won't show up in our deserialized queue message. @@ -69,8 +73,7 @@ void message_serde(Map expectedMdc) { @ParameterizedTest @ValueSource(booleans = {true, false}) void process(boolean resumeAnswer) throws InterruptedException { - QueueMessageReady messageReady = new QueueMessageReady(FLIGHT_ID); - messageReady.setCallingThreadContext(CALLING_THREAD_CONTEXT); + QueueMessageReady messageReady = createQueueMessageWithContext(CALLING_THREAD_CONTEXT); when(stairway.resume(FLIGHT_ID)) .thenAnswer( @@ -92,8 +95,7 @@ void process(boolean resumeAnswer) throws InterruptedException { @Test void process_DatabaseOperationException() throws InterruptedException { - QueueMessageReady messageReady = new QueueMessageReady(FLIGHT_ID); - messageReady.setCallingThreadContext(CALLING_THREAD_CONTEXT); + QueueMessageReady messageReady = createQueueMessageWithContext(CALLING_THREAD_CONTEXT); doThrow(DatabaseOperationException.class).when(stairway).resume(FLIGHT_ID);