11
2- export derivative
2+ export derivative, derivative!
33
44"""
55 derivative(f, x, order::Int64)
@@ -24,44 +24,62 @@ Batch mode derivative / directional derivative calculations, where each column o
2424"""
2525function derivative end
2626
27+ """
28+ derivative!(result, f, x, l, ::Val{N})
29+ derivative!(result, f!, y, x, l, ::Val{N})
30+
31+ In-place derivative calculation APIs. `result` is expected to be pre-allocated and have the same shape as `y`.
32+ """
33+ function derivative! end
34+
2735# Convenience wrappers for converting orders to value types
2836# and forward work to core APIs
2937
30- @inline function derivative (f, x, order:: Int64 )
31- derivative (f, x, Val {order + 1} ())
32- end
33-
34- @inline function derivative (f, x, l, order:: Int64 )
35- derivative (f, x, l, Val {order + 1} ())
36- end
38+ @inline derivative (f, x, order:: Int64 ) = derivative (f, x, one (eltype (x)), order)
39+ @inline derivative (f, x, l, order:: Int64 ) = derivative (f, x, l, Val {order + 1} ())
3740
3841# Core APIs
3942
4043# Added to help Zygote infer types
41- make_taylor (t0:: T , t1:: S , :: Val{N} ) where {T, S, N} = TaylorScalar {T, N} (t0, convert (T, t1))
44+ @inline function make_taylor (x:: T , l:: S , :: Val{N} ) where {T <: TN , S <: TN , N}
45+ TaylorScalar {T, N} (x, convert (T, l))
46+ end
4247
43- @inline function derivative (f, x:: T , :: Val{N} ) where {T <: TN , N}
44- t = TaylorScalar {T, N} (x, one (x))
45- return extract_derivative (f (t), N)
48+ @inline function make_taylor (x:: AbstractArray{T} , l, vN:: Val{N} ) where {T <: TN , N}
49+ broadcast (make_taylor, x, l, vN)
4650end
4751
48- @inline function derivative (f, x:: AbstractVector{T} , l:: AbstractVector{S} ,
49- vN:: Val{N} ) where {T <: TN , S <: TN , N}
50- t = map ((t0, t1) -> make_taylor (t0, t1, vN), x, l)
51- # equivalent to map(TaylorScalar{T, N}, x, l)
52+ # Out-of-place function, out-of-place derivative
53+ @inline function derivative (f, x, l, vN:: Val{N} ) where {N}
54+ t = make_taylor (x, l, vN)
5255 return extract_derivative (f (t), N)
5356end
5457
55- # shorthand notations for matrices
58+ # Below three advanced APIs do not have convenience wrappers
59+
60+ # In-place function, out-of-place derivative
61+ @inline function derivative (f!, y:: AbstractArray{T} , x, l, vN:: Val{N} ) where {T, N}
62+ s = similar (y, TaylorScalar{T, N})
63+ t = make_taylor (x, l, vN)
64+ f! (s, t)
65+ map! (primal, y, s)
66+ return extract_derivative (s, N)
67+ end
5668
57- @inline function derivative (f, x:: AbstractMatrix{T} , vN:: Val{N} ) where {T <: TN , N}
58- size (x)[1 ] != 1 && @warn " x is not a row vector."
59- t = make_taylor .(x, one (T), vN)
60- return extract_derivative .(f (t), N)
69+ # Out-of-place function, in-place derivative
70+ @inline function derivative! (result, f, x, l, vN:: Val{N} ) where {N}
71+ t = make_taylor (x, l, vN)
72+ s = f (t)
73+ extract_derivative! (result, s, N)
74+ return result
6175end
6276
63- @inline function derivative (f, x:: AbstractMatrix{T} , l:: AbstractVector{S} ,
64- vN:: Val{N} ) where {T <: TN , S <: TN , N}
65- t = make_taylor .(x, l, vN)
66- return extract_derivative .(f (t), N)
77+ # In-place function, in-place derivative
78+ @inline function derivative! (result, f!, y:: AbstractArray{T} , x, l, vN:: Val{N} ) where {T, N}
79+ s = similar (y, TaylorScalar{T, N})
80+ t = make_taylor (x, l, vN)
81+ f! (s, t)
82+ map! (primal, y, s)
83+ extract_derivative! (result, s, N)
84+ return result
6785end
0 commit comments