Skip to content

Commit cac7dac

Browse files
committed
[SPARK-56663][SQL] Restore fast path for date_trunc MINUTE/HOUR/DAY
1 parent f0b447e commit cac7dac

2 files changed

Lines changed: 108 additions & 3 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,43 @@ object DateTimeUtils extends SparkDateTimeUtils {
485485
instantToMicros(truncated.toInstant)
486486
}
487487

488+
/**
489+
* Fast path for truncating to MINUTE/HOUR/DAY using offset arithmetic instead of
490+
* allocating a `ZonedDateTime` per row. The offset is resolved once for `micros`; the
491+
* truncation then runs as `floorMod` in local time. We fall back to [[truncToUnit]] when
492+
* the offset at the candidate truncated instant differs from the offset at `micros`,
493+
* which means the truncation crosses a DST/historical transition and the local-time
494+
* alignment we computed is no longer valid (see SPARK-30766/30857). The check is
495+
* skipped for fixed-offset zones. Sub-minute offsets (e.g. America/Los_Angeles LMT
496+
* -07:52:58 — SPARK-33404) and 30/45-minute offsets (Asia/Kolkata +05:30, Asia/Kathmandu
497+
* +05:45) are handled correctly by this path because the offset is applied as part of
498+
* the arithmetic; no offset-alignment guard is needed.
499+
*/
500+
private def truncToUnitFast(
501+
micros: Long, zoneId: ZoneId, unitMicros: Long, fallbackUnit: ChronoUnit): Long = {
502+
val rules = zoneId.getRules
503+
val originalSec = Math.floorDiv(micros, MICROS_PER_SECOND)
504+
val originalOffsetSec =
505+
rules.getOffset(Instant.ofEpochSecond(originalSec)).getTotalSeconds.toLong
506+
val offsetMicros = originalOffsetSec * MICROS_PER_SECOND
507+
try {
508+
val local = Math.addExact(micros, offsetMicros)
509+
val truncatedLocal = local - Math.floorMod(local, unitMicros)
510+
val candidate = Math.subtractExact(truncatedLocal, offsetMicros)
511+
if (!rules.isFixedOffset) {
512+
val candidateSec = Math.floorDiv(candidate, MICROS_PER_SECOND)
513+
val candidateOffsetSec =
514+
rules.getOffset(Instant.ofEpochSecond(candidateSec)).getTotalSeconds.toLong
515+
if (candidateOffsetSec != originalOffsetSec) {
516+
return truncToUnit(micros, zoneId, fallbackUnit)
517+
}
518+
}
519+
candidate
520+
} catch {
521+
case _: ArithmeticException => truncToUnit(micros, zoneId, fallbackUnit)
522+
}
523+
}
524+
488525
/**
489526
* Returns the trunc date time from original date time and trunc level.
490527
* Trunc level should be generated using `parseTruncLevel()`, should be between 0 and 9.
@@ -499,9 +536,12 @@ object DateTimeUtils extends SparkDateTimeUtils {
499536
micros - Math.floorMod(micros, MICROS_PER_MILLIS)
500537
case TRUNC_TO_SECOND =>
501538
micros - Math.floorMod(micros, MICROS_PER_SECOND)
502-
case TRUNC_TO_MINUTE => truncToUnit(micros, zoneId, ChronoUnit.MINUTES)
503-
case TRUNC_TO_HOUR => truncToUnit(micros, zoneId, ChronoUnit.HOURS)
504-
case TRUNC_TO_DAY => truncToUnit(micros, zoneId, ChronoUnit.DAYS)
539+
case TRUNC_TO_MINUTE =>
540+
truncToUnitFast(micros, zoneId, MICROS_PER_MINUTE, ChronoUnit.MINUTES)
541+
case TRUNC_TO_HOUR =>
542+
truncToUnitFast(micros, zoneId, MICROS_PER_HOUR, ChronoUnit.HOURS)
543+
case TRUNC_TO_DAY =>
544+
truncToUnitFast(micros, zoneId, MICROS_PER_DAY, ChronoUnit.DAYS)
505545
case _ => // Try to truncate date levels
506546
val dDays = microsToDays(micros, zoneId)
507547
daysToMicros(truncDate(dDays, level), zoneId)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,71 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
766766
}
767767
}
768768

769+
test("truncTimestamp with sub-hour zone offsets") {
770+
// Asia/Kolkata (+05:30) and Asia/Kathmandu (+05:45) are not aligned to HOUR in UTC.
771+
// The fast path applies the offset as part of its arithmetic, so HOUR/DAY truncation
772+
// produces the correct local-aligned result without needing the slow path.
773+
val kolkata = getZoneId("Asia/Kolkata")
774+
val ts = DateTimeUtils.stringToTimestamp(
775+
UTF8String.fromString("2024-01-15T09:42:17.123456+05:30"), kolkata).get
776+
val expectedHour = DateTimeUtils.stringToTimestamp(
777+
UTF8String.fromString("2024-01-15T09:00:00+05:30"), kolkata).get
778+
assert(DateTimeUtils.truncTimestamp(ts, DateTimeUtils.TRUNC_TO_HOUR, kolkata) === expectedHour)
779+
val expectedDay = DateTimeUtils.stringToTimestamp(
780+
UTF8String.fromString("2024-01-15T00:00:00+05:30"), kolkata).get
781+
assert(DateTimeUtils.truncTimestamp(ts, DateTimeUtils.TRUNC_TO_DAY, kolkata) === expectedDay)
782+
783+
val kathmandu = getZoneId("Asia/Kathmandu")
784+
val ts2 = DateTimeUtils.stringToTimestamp(
785+
UTF8String.fromString("2024-01-15T09:42:17.123456+05:45"), kathmandu).get
786+
val expectedHour2 = DateTimeUtils.stringToTimestamp(
787+
UTF8String.fromString("2024-01-15T09:00:00+05:45"), kathmandu).get
788+
assert(DateTimeUtils.truncTimestamp(
789+
ts2, DateTimeUtils.TRUNC_TO_HOUR, kathmandu) === expectedHour2)
790+
}
791+
792+
test("truncTimestamp across DST transitions") {
793+
val la = getZoneId("America/Los_Angeles")
794+
// Spring-forward in LA: 2024-03-10 02:00 PDT does not exist; 02:30 local maps to
795+
// 2024-03-10 03:30 PDT in wall-clock terms. Use an instant just after the transition
796+
// so HOUR/DAY truncation candidate falls into the pre-transition offset window.
797+
val postSpring = DateTimeUtils.stringToTimestamp(
798+
UTF8String.fromString("2024-03-10T03:30:00-07:00"), la).get
799+
val expectedHour = DateTimeUtils.stringToTimestamp(
800+
UTF8String.fromString("2024-03-10T03:00:00-07:00"), la).get
801+
assert(DateTimeUtils.truncTimestamp(postSpring, DateTimeUtils.TRUNC_TO_HOUR, la)
802+
=== expectedHour)
803+
val expectedDay = DateTimeUtils.stringToTimestamp(
804+
UTF8String.fromString("2024-03-10T00:00:00-08:00"), la).get
805+
assert(DateTimeUtils.truncTimestamp(postSpring, DateTimeUtils.TRUNC_TO_DAY, la)
806+
=== expectedDay)
807+
808+
// Fall-back in LA: 2024-11-03 01:30 occurs twice. Truncation to HOUR/DAY should
809+
// produce the same wall-clock boundary as the slow path regardless.
810+
val postFall = DateTimeUtils.stringToTimestamp(
811+
UTF8String.fromString("2024-11-03T01:30:00-08:00"), la).get
812+
val expectedHour2 = DateTimeUtils.stringToTimestamp(
813+
UTF8String.fromString("2024-11-03T01:00:00-08:00"), la).get
814+
assert(DateTimeUtils.truncTimestamp(postFall, DateTimeUtils.TRUNC_TO_HOUR, la)
815+
=== expectedHour2)
816+
val expectedDay2 = DateTimeUtils.stringToTimestamp(
817+
UTF8String.fromString("2024-11-03T00:00:00-07:00"), la).get
818+
assert(DateTimeUtils.truncTimestamp(postFall, DateTimeUtils.TRUNC_TO_DAY, la)
819+
=== expectedDay2)
820+
}
821+
822+
test("SPARK-30766/30857: truncTimestamp before the epoch in HOUR/DAY") {
823+
val la = getZoneId("America/Los_Angeles")
824+
val ts1 = DateTimeUtils.stringToTimestamp(
825+
UTF8String.fromString("1960-02-11T00:01:02.123"), la).get
826+
val expectedHour1 = DateTimeUtils.stringToTimestamp(
827+
UTF8String.fromString("1960-02-11T00:00:00"), la).get
828+
assert(DateTimeUtils.truncTimestamp(ts1, DateTimeUtils.TRUNC_TO_HOUR, la) === expectedHour1)
829+
val expectedDay1 = DateTimeUtils.stringToTimestamp(
830+
UTF8String.fromString("1960-02-11T00:00:00"), la).get
831+
assert(DateTimeUtils.truncTimestamp(ts1, DateTimeUtils.TRUNC_TO_DAY, la) === expectedDay1)
832+
}
833+
769834
test("SPARK-51554: time truncation using timeTrunc") {
770835
// 01:02:03.400500600
771836
val input = localTimeToNanos(LocalTime.of(1, 2, 3, 400500600))

0 commit comments

Comments
 (0)