Skip to content

Commit 0964e46

Browse files
Use dispatch instead of conditional _maybe_reshape for type stability
Address review feedback: replace the type-unstable _maybe_reshape helper with two dispatch-based extract_jacobian! methods: - AbstractMatrix: skip reshape entirely (zero-alloc, hot path for DiffEq) - AbstractArray: reshape unconditionally (type-stable for non-matrix inputs) Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 84b9290 commit 0964e46

1 file changed

Lines changed: 12 additions & 11 deletions

File tree

src/jacobian.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,19 @@ jacobian(f, x::Real) = throw(DimensionMismatch("jacobian(f, x) expects that x is
9292
# result extraction #
9393
#####################
9494

95+
# Specialized method for AbstractMatrix: no reshape needed, avoids ReshapedArray allocation
96+
# that cannot be elided under --check-bounds=yes.
97+
function extract_jacobian!(::Type{T}, result::AbstractMatrix, ydual::AbstractArray, n) where {T}
98+
ydual_reshaped = vec(ydual)
99+
# Use closure to avoid GPU broadcasting with Type
100+
partials_wrap(ydual, nrange) = partials(T, ydual, nrange)
101+
result .= partials_wrap.(ydual_reshaped, transpose(1:n))
102+
return result
103+
end
104+
105+
# General method for non-matrix arrays: reshape unconditionally (type-stable).
95106
function extract_jacobian!(::Type{T}, result::AbstractArray, ydual::AbstractArray, n) where {T}
96-
out_reshaped = _maybe_reshape(result, length(ydual), n)
107+
out_reshaped = reshape(result, length(ydual), n)
97108
ydual_reshaped = vec(ydual)
98109
# Use closure to avoid GPU broadcasting with Type
99110
partials_wrap(ydual, nrange) = partials(T, ydual, nrange)
@@ -117,16 +128,6 @@ function extract_jacobian_chunk!(::Type{T}, result, ydual, index, chunksize) whe
117128
return result
118129
end
119130

120-
# Avoid allocating a ReshapedArray wrapper when `result` already has the target shape.
121-
# reshape() always allocates a wrapper that cannot be elided under --check-bounds=yes.
122-
@inline function _maybe_reshape(result::AbstractArray, m, n)
123-
if size(result) == (m, n)
124-
return result
125-
else
126-
return reshape(result, m, n)
127-
end
128-
end
129-
130131
reshape_jacobian(result, ydual, xdual) = reshape(result, length(ydual), length(xdual))
131132
reshape_jacobian(result::DiffResult, ydual, xdual) = reshape_jacobian(DiffResults.jacobian(result), ydual, xdual)
132133

0 commit comments

Comments
 (0)