Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2887,9 +2887,13 @@ case class TruncTimestamp(

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
// Per-task offset cache so the hot path avoids a transition-array binary search per row.
val cacheClass = classOf[org.apache.spark.sql.catalyst.util.ZoneOffsetCache].getName
val cache = ctx.addMutableState(cacheClass, "zoneOffsetCache",
v => s"$v = new $cacheClass($zid);", forceInline = true)
codeGenHelper(ctx, ev, minLevel = MIN_LEVEL_OF_TIMESTAMP_TRUNC, true) {
(date: String, fmt: String) =>
s"truncTimestamp($date, $fmt, $zid);"
s"truncTimestamp($date, $fmt, $cache);"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,15 +512,29 @@ object DateTimeUtils extends SparkDateTimeUtils {
instantToMicros(truncated.toInstant)
}

/**
* Truncates the timestamp `micros` to `level` in `zoneId`. `level` should be generated using
* `parseTruncLevel()`, between 0 and 9.
*
* Convenience variant that resolves the zone offset from a fresh single-use [[ZoneOffsetCache]];
* intended for one-shot callers such as interpreted evaluation and constant folding. The codegen
* hot path instead calls the [[ZoneOffsetCache]]-based overload directly with a per-task cache so
* the offset is reused across rows. The result is identical to that overload.
*/
def truncTimestamp(micros: Long, level: Int, zoneId: ZoneId): Long =
truncTimestamp(micros, level, new ZoneOffsetCache(zoneId))

/**
* Returns the trunc date time from original date time and trunc level.
* Trunc level should be generated using `parseTruncLevel()`, should be between 0 and 9.
*
* Uses an offset-arithmetic fast path: the zone offset at `micros` is resolved once,
* truncation runs in the shifted-local frame, and the result is shifted back to UTC
* micros. Falls back to [[truncTimestampSlow]] when the offset at the candidate
* truncated instant differs from the offset at `micros` (DST/historical transition
* spans the candidate; SPARK-30766/30857) or on arithmetic overflow.
* Uses an offset-arithmetic fast path: the zone offset at `micros` is resolved once through
* `cache`, which memoizes it over the constant-offset interval around the last lookup, so for
* temporally clustered data the per-row transition-array binary search collapses to two
* comparisons. Truncation then runs in the shifted-local frame, and the result is shifted back
* to UTC micros. Falls back to [[truncTimestampSlow]] when the offset at the candidate truncated
* instant differs from the offset at `micros` (DST/historical transition spans the candidate;
* SPARK-30766/30857) or on arithmetic overflow.
*
* Sub-minute LMT offsets (e.g. America/Los_Angeles -07:52:58 pre-1883, see
* SPARK-33404) and 30/45-minute offsets (Asia/Kolkata +05:30, Asia/Kathmandu +05:45)
Expand All @@ -538,7 +552,7 @@ object DateTimeUtils extends SparkDateTimeUtils {
* - WEEK/MONTH/QUARTER/YEAR: convert local micros to local epoch-day, run
* [[truncDate]] in the local-day frame, multiply back to local micros.
*/
def truncTimestamp(micros: Long, level: Int, zoneId: ZoneId): Long = {
def truncTimestamp(micros: Long, level: Int, cache: ZoneOffsetCache): Long = {
// MICROSECOND / MILLISECOND / SECOND don't need zone information.
level match {
case TRUNC_TO_MICROSECOND => return micros
Expand All @@ -548,10 +562,8 @@ object DateTimeUtils extends SparkDateTimeUtils {
return Math.subtractExact(micros, Math.floorMod(micros, MICROS_PER_SECOND))
case _ =>
}
val rules = zoneId.getRules
val originalSec = Math.floorDiv(micros, MICROS_PER_SECOND)
val originalOffsetSec =
rules.getOffset(Instant.ofEpochSecond(originalSec)).getTotalSeconds.toLong
val originalOffsetSec = cache.offsetSeconds(originalSec)
val offsetMicros = originalOffsetSec * MICROS_PER_SECOND
try {
val local = Math.addExact(micros, offsetMicros)
Expand All @@ -567,17 +579,16 @@ object DateTimeUtils extends SparkDateTimeUtils {
Math.multiplyExact(truncDate(localDays, level).toLong, MICROS_PER_DAY)
}
val candidate = Math.subtractExact(truncatedLocal, offsetMicros)
if (!rules.isFixedOffset) {
if (!cache.isFixedOffset) {
val candidateSec = Math.floorDiv(candidate, MICROS_PER_SECOND)
val candidateOffsetSec =
rules.getOffset(Instant.ofEpochSecond(candidateSec)).getTotalSeconds.toLong
val candidateOffsetSec = cache.offsetSeconds(candidateSec)
if (candidateOffsetSec != originalOffsetSec) {
return truncTimestampSlow(micros, level, zoneId)
return truncTimestampSlow(micros, level, cache.zoneId)
}
}
candidate
} catch {
case _: ArithmeticException => truncTimestampSlow(micros, level, zoneId)
case _: ArithmeticException => truncTimestampSlow(micros, level, cache.zoneId)
}
}

Expand Down Expand Up @@ -1287,3 +1298,60 @@ object DateTimeUtils extends SparkDateTimeUtils {
c
}
}

/**
* Per-task memoization of a zone's UTC offset, used by the [[DateTimeUtils.truncTimestamp]] hot
* path. The session zone is constant for a query and the offset is piecewise-constant between DST
* transitions, so consecutive rows almost always resolve to the same offset. The cache holds the
* half-open epoch-second interval `[lo, hi)` on which the offset is provably constant -- derived
* from the surrounding zone transitions -- so a lookup that falls in the interval reduces to two
* comparisons instead of a transition-array binary search.
*
* Not thread-safe by design: a fresh instance is created per task (codegen mutable state) and used
* single-threaded, mirroring how stateful per-row helpers are scoped in generated code.
*/
class ZoneOffsetCache(val zoneId: ZoneId) {
private val rules = zoneId.getRules
val isFixedOffset: Boolean = rules.isFixedOffset

private var cached = false
private var lo = 0L
private var hi = 0L
private var offsetSec = 0L

/**
* Offset in seconds at `epochSec`, equal to
* `rules.getOffset(Instant.ofEpochSecond(epochSec)).getTotalSeconds`, memoized over the
* constant-offset interval containing the previous lookup.
*/
def offsetSeconds(epochSec: Long): Long = {
if (cached && epochSec >= lo && epochSec < hi) {
offsetSec
} else {
val instant = Instant.ofEpochSecond(epochSec)
val o = rules.getOffset(instant).getTotalSeconds.toLong
if (isFixedOffset) {
lo = Long.MinValue
hi = Long.MaxValue
} else {
val nextT = rules.nextTransition(instant)
if (nextT == null) {
// No transition after `instant`: offset is constant on [epochSec, +inf).
lo = epochSec
hi = Long.MaxValue
} else {
hi = nextT.toEpochSecond
// `hi - 1` lies strictly inside the constant-offset window ending at `hi` (zone
// transitions are always more than a second apart), so its previous transition is
// exactly that window's start. Anchoring on an interior point avoids an off-by-one
// when `epochSec` sits exactly on a transition instant.
val prevT = rules.previousTransition(Instant.ofEpochSecond(hi - 1))
lo = if (prevT == null) Long.MinValue else prevT.toEpochSecond
}
}
offsetSec = o
cached = true
o
}
}
}
Loading