11
2- export derivative, derivative!
2+ export derivative, derivative!, derivatives, make_seed
33
44"""
5- derivative(f, x, order::Int64)
6- derivative(f, x, l, order::Int64)
7-
8- Wrapper functions for converting order from a number to a type. Actual APIs are detailed below:
9-
10- derivative(f, x::T, ::Val{N})
11-
12- Computes `order`-th derivative of `f` w.r.t. scalar `x`.
13-
14- derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, ::Val{N})
5+ derivative(f, x, l, ::Val{N})
6+ derivative(f!, y, x, l, ::Val{N})
157
168Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
17-
18- derivative(f, x::AbstractMatrix{T}, ::Val{N})
19- derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, ::Val{N})
20-
21- Batch mode derivative / directional derivative calculations, where each column of `x` represents a scalar or a vector. `f` is expected to accept matrices as input.
22- - For a M-by-N matrix, calculate the directional derivative for each column.
23- - For a 1-by-N matrix (row vector), calculate the derivative for each scalar.
249"""
2510function derivative end
2611
@@ -32,54 +17,58 @@ In-place derivative calculation APIs. `result` is expected to be pre-allocated a
3217"""
3318function derivative! end
3419
20+ """
21+ derivatives(f, x, l, ::Val{N})
22+ derivatives(f!, y, x, l, ::Val{N})
23+
24+ Computes all derivatives of `f` at `x` up to order `N - 1`.
25+ """
26+ function derivatives end
27+
28+ # Convenience wrapper for adding unit seed to the input
29+
30+ @inline derivative (f, x, order:: Int64 ) = derivative (f, x, one (eltype (x)), order)
31+
3532# Convenience wrappers for converting orders to value types
3633# and forward work to core APIs
3734
38- @inline derivative (f, x, order:: Int64 ) = derivative (f, x, one (eltype (x)), order)
3935@inline derivative (f, x, l, order:: Int64 ) = derivative (f, x, l, Val {order + 1} ())
36+ @inline derivative (f!, y, x, l, order:: Int64 ) = derivative (f!, y, x, l, Val {order + 1} ())
37+ @inline derivative! (result, f, x, l, order:: Int64 ) = derivative! (
38+ result, f, x, l, Val {order + 1} ())
39+ @inline derivative! (result, f!, y, x, l, order:: Int64 ) = derivative! (
40+ result, f!, y, x, l, Val {order + 1} ())
4041
4142# Core APIs
4243
4344# Added to help Zygote infer types
44- @inline function make_taylor (x:: T , l:: S , :: Val{N} ) where {T <: TN , S <: TN , N}
45+ @inline function make_seed (x:: T , l:: S , :: Val{N} ) where {T <: TN , S <: TN , N}
4546 TaylorScalar {T, N} (x, convert (T, l))
4647end
4748
48- @inline function make_taylor (x:: AbstractArray{T} , l, vN:: Val{N} ) where {T <: TN , N}
49- broadcast (make_taylor, x, l, vN)
50- end
51-
52- # Out-of-place function, out-of-place derivative
53- @inline function derivative (f, x, l, vN:: Val{N} ) where {N}
54- t = make_taylor (x, l, vN)
55- return extract_derivative (f (t), N)
56- end
57-
58- # Below three advanced APIs do not have convenience wrappers
59-
60- # In-place function, out-of-place derivative
61- @inline function derivative (f!, y:: AbstractArray{T} , x, l, vN:: Val{N} ) where {T, N}
62- s = similar (y, TaylorScalar{T, N})
63- t = make_taylor (x, l, vN)
64- f! (s, t)
65- map! (primal, y, s)
66- return extract_derivative (s, N)
67- end
68-
69- # Out-of-place function, in-place derivative
70- @inline function derivative! (result, f, x, l, vN:: Val{N} ) where {N}
71- t = make_taylor (x, l, vN)
72- s = f (t)
73- extract_derivative! (result, s, N)
74- return result
49+ @inline function make_seed (x:: AbstractArray{T} , l, vN:: Val{N} ) where {T <: TN , N}
50+ broadcast (make_seed, x, l, vN)
7551end
7652
77- # In-place function, in-place derivative
78- @inline function derivative! (result, f!, y:: AbstractArray{T} , x, l, vN:: Val{N} ) where {T, N}
79- s = similar (y, TaylorScalar{T, N})
80- t = make_taylor (x, l, vN)
81- f! (s, t)
82- map! (primal, y, s)
83- extract_derivative! (result, s, N)
84- return result
53+ # `derivative` API: computes the `N - 1`-th derivative of `f` at `x`
54+ @inline derivative (f, x, l, vN:: Val{N} ) where {N} = extract_derivative (
55+ derivatives (f, x, l, vN), N)
56+ @inline derivative (f!, y, x, l, vN:: Val{N} ) where {N} = extract_derivative (
57+ derivatives (f!, y, x, l, vN), N)
58+ @inline derivative! (result, f, x, l, vN:: Val{N} ) where {N} = extract_derivative! (
59+ result, derivatives (f, x, l, vN), N)
60+ @inline derivative! (result, f!, y, x, l, vN:: Val{N} ) where {N} = extract_derivative! (
61+ result, derivatives (f!, y, x, l, vN), N)
62+
63+ # `derivatives` API: computes all derivatives of `f` at `x` up to order `N - 1`
64+
65+ # Out-of-place function
66+ @inline derivatives (f, x, l, vN:: Val{N} ) where {N} = f (make_seed (x, l, vN))
67+
68+ # In-place function
69+ @inline function derivatives (f!, y:: AbstractArray{T} , x, l, vN:: Val{N} ) where {T, N}
70+ buffer = similar (y, TaylorScalar{T, N})
71+ f! (buffer, make_seed (x, l, vN))
72+ map! (primal, y, buffer)
73+ return buffer
8574end
0 commit comments