diff --git a/coordinator-common/src/main/java/org/apache/kafka/coordinator/common/runtime/CoordinatorExecutorImpl.java b/coordinator-common/src/main/java/org/apache/kafka/coordinator/common/runtime/CoordinatorExecutorImpl.java index bcd8fc795fb89..7971d6a22f81c 100644 --- a/coordinator-common/src/main/java/org/apache/kafka/coordinator/common/runtime/CoordinatorExecutorImpl.java +++ b/coordinator-common/src/main/java/org/apache/kafka/coordinator/common/runtime/CoordinatorExecutorImpl.java @@ -24,7 +24,6 @@ import org.slf4j.Logger; import java.time.Duration; -import java.util.Iterator; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; @@ -128,9 +127,6 @@ public void cancel(String key) { } public void cancelAll() { - Iterator iterator = tasks.keySet().iterator(); - while (iterator.hasNext()) { - iterator.remove(); - } + tasks.clear(); } } diff --git a/coordinator-common/src/test/java/org/apache/kafka/coordinator/common/runtime/CoordinatorExecutorImplTest.java b/coordinator-common/src/test/java/org/apache/kafka/coordinator/common/runtime/CoordinatorExecutorImplTest.java index 4f5e917f1795b..b2a82a6d70732 100644 --- a/coordinator-common/src/test/java/org/apache/kafka/coordinator/common/runtime/CoordinatorExecutorImplTest.java +++ b/coordinator-common/src/test/java/org/apache/kafka/coordinator/common/runtime/CoordinatorExecutorImplTest.java @@ -23,11 +23,13 @@ import org.junit.jupiter.api.Test; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -36,6 +38,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -314,4 +317,70 @@ public void testTaskSchedulingWriteOperationFailed() { assertFalse(operationCalled.get()); assertFalse(executor.isScheduled(TASK_KEY)); } + + @Test + public void testCancelAllTasks() { + CoordinatorShard coordinatorShard = mock(CoordinatorShard.class); + CoordinatorRuntime, String> runtime = mock(CoordinatorRuntime.class); + ExecutorService executorService = mock(ExecutorService.class); + CoordinatorExecutorImpl, String> executor = new CoordinatorExecutorImpl<>( + LOG_CONTEXT, + SHARD_PARTITION, + runtime, + executorService, + WRITE_TIMEOUT + ); + + List, Void, String>> writeOperations = new ArrayList<>(); + List> writeFutures = new ArrayList<>(); + when(runtime.scheduleWriteOperation( + anyString(), + eq(SHARD_PARTITION), + eq(WRITE_TIMEOUT), + any() + )).thenAnswer(args -> { + writeOperations.add(args.getArgument(3)); + CompletableFuture writeFuture = new CompletableFuture<>(); + writeFutures.add(writeFuture); + return writeFuture; + }); + + when(executorService.submit(any(Runnable.class))).thenAnswer(args -> { + Runnable op = args.getArgument(0); + op.run(); + return CompletableFuture.completedFuture(null); + }); + + AtomicInteger taskCallCount = new AtomicInteger(0); + CoordinatorExecutor.TaskRunnable taskRunnable = () -> { + taskCallCount.incrementAndGet(); + return "Hello!"; + }; + + AtomicInteger operationCallCount = new AtomicInteger(0); + CoordinatorExecutor.TaskOperation taskOperation = (result, exception) -> { + operationCallCount.incrementAndGet(); + return null; + }; + + for (int i = 0; i < 2; i++) { + executor.schedule( + TASK_KEY + i, + taskRunnable, + taskOperation + ); + } + + executor.cancelAll(); + + for (int i = 0; i < writeOperations.size(); i++) { + CoordinatorRuntime.CoordinatorWriteOperation, Void, String> writeOperation = writeOperations.get(i); + CompletableFuture writeFuture = writeFutures.get(i); + Throwable ex = assertThrows(RejectedExecutionException.class, () -> writeOperation.generateRecordsAndResult(coordinatorShard)); + writeFuture.completeExceptionally(ex); + } + + assertEquals(2, taskCallCount.get()); + assertEquals(0, operationCallCount.get()); + } }