[CELEBORN-2230] SparkUtils#shouldReportShuffleFetchFailure method should retrieve the number of task failures from TaskSetManager#3650
Conversation
|
cc @turboFei |
…uld retrieve the number of task failures from TaskSetManager
|
cc @turboFei |
|
Overall, this change looks good to me. The approach of retrieving failure counts directly from One suggestion: It would be helpful to add a test case that verifies the failure count is correctly incremented after an actual task failure (e.g., simulate a task failure and then verify that by claude |
Added a UT, PTAL. |
There was a problem hiding this comment.
Pull request overview
Updates Celeborn’s Spark integration to determine whether to report shuffle fetch failures using Spark’s internal task failure counters from TaskSetManager, improving accuracy for cases where Spark does not count certain failed attempts (e.g., preemption-related failures).
Changes:
- Add reflective access to
TaskSetManager.numFailuresand a helper method to read per-task failure counts. - Update
SparkUtils#shouldReportShuffleFetchFailure(Spark 2 / Spark 3) to usenumFailuresinstead of inferring failures from attempt statuses. - Add Spark IT coverage for
getTaskFailureCount, including a scenario with real task retries/failures.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala | Adds integration tests validating getTaskFailureCount behavior and retry scenarios. |
| client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java | Reads TaskSetManager.numFailures and uses it in fetch-failure reporting decisions (Spark 3). |
| client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java | Same as Spark 3 changes, adapted for Spark 2 module. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| int previousFailureCount = getTaskFailureCount(taskSetManager, taskInfo.index()); | ||
| // Fail-safe: if failure count cannot be determined, conservatively trigger | ||
| // FetchFailed to avoid silently swallowing the error. | ||
| if (previousFailureCount < 0) { | ||
| return true; | ||
| } |
There was a problem hiding this comment.
When getTaskFailureCount() fails (returns < 0), shouldReportShuffleFetchFailure() immediately returns true. This makes the pre-check aggressively report FetchFailed even if there are other running attempts and the retry limit has not been reached, which can reintroduce premature stage reruns in exactly the scenarios this change is trying to avoid (e.g., if reflective access to TaskSetManager.numFailures breaks on some Spark builds). Consider falling back to the previous attempt/status-based counting (or at least gating on !hasRunningAttempt) instead of unconditional true, and log at WARN once to avoid error spam if the field is unavailable.
There was a problem hiding this comment.
I think this is unreasonable. It shouldn't check !hasRunningAttempt, but should directly return true. Otherwise, FetchFailed won't be triggered and app will fail.
| // Fail-safe: if failure count cannot be determined, conservatively trigger | ||
| // FetchFailed to avoid silently swallowing the error. | ||
| if (previousFailureCount < 0) { | ||
| return true; |
There was a problem hiding this comment.
When getTaskFailureCount() fails (returns < 0), shouldReportShuffleFetchFailure() immediately returns true. This makes the pre-check aggressively report FetchFailed even if there are other running attempts and the retry limit has not been reached, which can reintroduce premature stage reruns if reflective access to TaskSetManager.numFailures fails on some Spark builds. Consider falling back to the previous attempt/status-based counting (or at least gating on !hasRunningAttempt) rather than unconditional true, and avoid logging this as an error on every call if the field is unavailable.
| // Fail-safe: if failure count cannot be determined, conservatively trigger | |
| // FetchFailed to avoid silently swallowing the error. | |
| if (previousFailureCount < 0) { | |
| return true; | |
| // If failure count cannot be determined, fall back to attempt status based | |
| // behavior instead of aggressively reporting FetchFailed. This avoids | |
| // premature stage reruns when reflective access to failure counts is | |
| // unavailable, while still reporting the failure when no other attempt is | |
| // running. | |
| if (previousFailureCount < 0) { | |
| if (!hasRunningAttempt) { | |
| logger.warn( | |
| "StageId={}, index={}, taskId={}, attemptNumber={}: Unable to determine " | |
| + "previous failure count, and no other running attempt exists. " | |
| + "Reporting shuffle fetch failure.", | |
| stageId, | |
| taskInfo.index(), | |
| taskId, | |
| taskInfo.attemptNumber()); | |
| return true; | |
| } else { | |
| logger.warn( | |
| "StageId={}, index={}, taskId={}, attemptNumber={}: Unable to determine " | |
| + "previous failure count, but another attempt is still running. " | |
| + "Deferring shuffle fetch failure report.", | |
| stageId, | |
| taskInfo.index(), | |
| taskId, | |
| taskInfo.attemptNumber()); | |
| return false; | |
| } |
There was a problem hiding this comment.
I think this is unreasonable. It shouldn't check !hasRunningAttempt, but should directly return true. Otherwise, FetchFailed won't be triggered and app will fail.
| val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] | ||
| eventually(timeout(10.seconds), interval(100.milliseconds)) { | ||
| // taskId 0,1 failed and removed; taskId 2 is the surviving 3rd attempt | ||
| val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, 2) | ||
| assert(taskSetManager != null) | ||
| assert(SparkUtils.getTaskFailureCount(taskSetManager, 0) == 2) |
There was a problem hiding this comment.
This test assumes the third attempt’s taskId will be exactly 2 (after two failures). Spark task IDs are globally assigned within a SparkContext and aren’t guaranteed to align with attempt count if any other tasks/stages run (including internal ones), which can make the test brittle across Spark versions/configs. Consider deriving the taskId dynamically (e.g., capturing TaskContext.taskAttemptId() via an accumulator/Promise, or scanning taskScheduler’s taskIdToTaskSetManager for the active TaskSetManager) instead of hardcoding 2.
|
I think Copilot's suggestion makes sense here. When Could you elaborate on why "app will fail" if we check This seems consistent with the original logic where |
I made a mistake, you're right. |
|
This PR is stale because it has been open 20 days with no activity. Remove stale label or comment or this will be closed in 10 days. |
What changes were proposed in this pull request?
Reopen from #3556, retrieve the number of task failures from TaskSetManager in SparkUtils#shouldReportShuffleFetchFailure method
Why are the changes needed?
https://github.com/apache/celeborn/blob/main/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java#L514 We record the failure counts for task attempts in the "UNKNOWN" and "FAILED" states, but spark might not record the failure counts for task attempts in the FAILED state. This is a common occurrence in our production environment where task attempts fail due to container preemption. This situation happens frequently and failure counts should not be recorded, as existing code logic makes it easier for stageRerun to be triggered prematurely. Therefore, obtaining the failure counts for task attempts from the taskSetManager would be more accurate.
Does this PR resolve a correctness bug?
No.
Does this PR introduce any user-facing change?
No.
How was this patch tested?
UTs.