diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 91e419bb8..45aacc540 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -96,9 +96,11 @@ import java.lang.reflect.Field; import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND; import static org.apache.flink.agents.api.configuration.AgentConfigOptions.BASE_LOG_DIR; @@ -935,14 +937,19 @@ private void tryResumeProcessActionTasks() throws Exception { int maxParallelism = getRuntimeContext().getTaskInfo().getMaxNumberOfParallelSubtasks(); KeyGroupRange currentSubtaskKeyGroupRange = getCurrentSubtaskKeyGroupRange(maxParallelism); + Set ownedKeys = new LinkedHashSet<>(); for (Object key : keys) { if (!isKeyOwnedByCurrentSubtask(key, maxParallelism, currentSubtaskKeyGroupRange)) { continue; } + if (!ownedKeys.add(key)) { + continue; + } keySegmentQueue.addKeyToLastSegment(key); mailboxExecutor.submit( () -> tryProcessActionTaskForKey(key), "process action task"); } + currentProcessingKeysOpState.update(new ArrayList<>(ownedKeys)); } getKeyedStateBackend() diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java index 729f106eb..6cf872179 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java @@ -217,6 +217,31 @@ void testRestoreOnlyResumesKeysOwnedByCurrentSubtask() throws Exception { assertThat(ownerHarness.getTaskMailbox().size()).isEqualTo(1); assertThat(nonOwnerHarness.getTaskMailbox().size()).isZero(); + + OperatorSubtaskState secondCheckpoint = + AbstractStreamOperatorTestHarness.repackageState( + ownerHarness.snapshot(2L, 2L), nonOwnerHarness.snapshot(2L, 2L)); + OperatorSubtaskState secondRestoreOwnerState = + AbstractStreamOperatorTestHarness.repartitionOperatorState( + secondCheckpoint, + maxParallelism, + newParallelism, + newParallelism, + ownerSubtask); + + try (KeyedOneInputStreamOperatorTestHarness restoredOwnerHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory(TestAgent.getAgentPlan(false), true), + (KeySelector) value -> value, + TypeInformation.of(Long.class), + maxParallelism, + newParallelism, + ownerSubtask)) { + restoredOwnerHarness.initializeState(secondRestoreOwnerState); + restoredOwnerHarness.open(); + + assertThat(restoredOwnerHarness.getTaskMailbox().size()).isEqualTo(1); + } } }