|
2 | 2 | export derivative, derivative!, derivatives, make_seed |
3 | 3 |
|
4 | 4 | """ |
5 | | - derivative(f, x, l, ::Val{N}) |
6 | | - derivative(f!, y, x, l, ::Val{N}) |
| 5 | + derivative(f, x, l, ::Val{P}) |
| 6 | + derivative(f!, y, x, l, ::Val{P}) |
7 | 7 |
|
8 | | -Computes `N`-th directional derivative of `f` w.r.t. vector `x` in direction `l`. |
| 8 | +Computes `P`-th directional derivative of `f` w.r.t. vector `x` in direction `l`. |
9 | 9 | """ |
10 | 10 | function derivative end |
11 | 11 |
|
12 | 12 | """ |
13 | | - derivative!(result, f, x, l, ::Val{N}) |
14 | | - derivative!(result, f!, y, x, l, ::Val{N}) |
| 13 | + derivative!(result, f, x, l, ::Val{P}) |
| 14 | + derivative!(result, f!, y, x, l, ::Val{P}) |
15 | 15 |
|
16 | 16 | In-place derivative calculation APIs. `result` is expected to be pre-allocated and have the same shape as `y`. |
17 | 17 | """ |
18 | 18 | function derivative! end |
19 | 19 |
|
20 | 20 | """ |
21 | | - derivatives(f, x, l, ::Val{N}) |
22 | | - derivatives(f!, y, x, l, ::Val{N}) |
| 21 | + derivatives(f, x, l, ::Val{P}) |
| 22 | + derivatives(f!, y, x, l, ::Val{P}) |
23 | 23 |
|
24 | | -Computes all derivatives of `f` at `x` up to order `N`. |
| 24 | +Computes all derivatives of `f` at `x` up to order `P`. |
25 | 25 | """ |
26 | 26 | function derivatives end |
27 | 27 |
|
28 | 28 | # Convenience wrapper for adding unit seed to the input |
29 | 29 |
|
30 | | -@inline derivative(f, x, order::Int64) = derivative(f, x, one(eltype(x)), order) |
| 30 | +@inline derivative(f, x, p::Int64) = derivative(f, x, one(eltype(x)), p) |
31 | 31 |
|
32 | | -# Convenience wrappers for converting orders to value types |
| 32 | +# Convenience wrappers for converting ps to value types |
33 | 33 | # and forward work to core APIs |
34 | 34 |
|
35 | | -@inline derivative(f, x, l, order::Int64) = derivative(f, x, l, Val{order}()) |
36 | | -@inline derivative(f!, y, x, l, order::Int64) = derivative(f!, y, x, l, Val{order}()) |
37 | | -@inline derivative!(result, f, x, l, order::Int64) = derivative!( |
38 | | - result, f, x, l, Val{order}()) |
39 | | -@inline derivative!(result, f!, y, x, l, order::Int64) = derivative!( |
40 | | - result, f!, y, x, l, Val{order}()) |
| 35 | +@inline derivative(f, x, l, p::Int64) = derivative(f, x, l, Val{p}()) |
| 36 | +@inline derivative(f!, y, x, l, p::Int64) = derivative(f!, y, x, l, Val{p}()) |
| 37 | +@inline derivative!(result, f, x, l, p::Int64) = derivative!( |
| 38 | + result, f, x, l, Val{p}()) |
| 39 | +@inline derivative!(result, f!, y, x, l, p::Int64) = derivative!( |
| 40 | + result, f!, y, x, l, Val{p}()) |
41 | 41 |
|
42 | 42 | # Core APIs |
43 | 43 |
|
44 | 44 | # Added to help Zygote infer types |
45 | | -@inline function make_seed(x::T, l::S, ::Val{N}) where {T <: Real, S <: Real, N} |
46 | | - TaylorScalar{T, N}(x, convert(T, l)) |
| 45 | +@inline function make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P} |
| 46 | + TaylorScalar{P}(x, convert(T, l)) |
47 | 47 | end |
48 | 48 |
|
49 | | -@inline function make_seed(x::AbstractArray{T}, l, vN::Val{N}) where {T <: Real, N} |
50 | | - broadcast(make_seed, x, l, vN) |
| 49 | +@inline function make_seed(x::AbstractArray{T}, l, p::Val{P}) where {T <: Real, P} |
| 50 | + broadcast(make_seed, x, l, p) |
51 | 51 | end |
52 | 52 |
|
53 | | -# `derivative` API: computes the `N - 1`-th derivative of `f` at `x` |
54 | | -@inline derivative(f, x, l, vN::Val{N}) where {N} = extract_derivative( |
55 | | - derivatives(f, x, l, vN), N) |
56 | | -@inline derivative(f!, y, x, l, vN::Val{N}) where {N} = extract_derivative( |
57 | | - derivatives(f!, y, x, l, vN), N) |
58 | | -@inline derivative!(result, f, x, l, vN::Val{N}) where {N} = extract_derivative!( |
59 | | - result, derivatives(f, x, l, vN), N) |
60 | | -@inline derivative!(result, f!, y, x, l, vN::Val{N}) where {N} = extract_derivative!( |
61 | | - result, derivatives(f!, y, x, l, vN), N) |
| 53 | +# `derivative` API: computes the `P - 1`-th derivative of `f` at `x` |
| 54 | +@inline derivative(f, x, l, p::Val{P}) where {P} = extract_derivative( |
| 55 | + derivatives(f, x, l, p), P) |
| 56 | +@inline derivative(f!, y, x, l, p::Val{P}) where {P} = extract_derivative( |
| 57 | + derivatives(f!, y, x, l, p), P) |
| 58 | +@inline derivative!(result, f, x, l, p::Val{P}) where {P} = extract_derivative!( |
| 59 | + result, derivatives(f, x, l, p), P) |
| 60 | +@inline derivative!(result, f!, y, x, l, p::Val{P}) where {P} = extract_derivative!( |
| 61 | + result, derivatives(f!, y, x, l, p), P) |
62 | 62 |
|
63 | | -# `derivatives` API: computes all derivatives of `f` at `x` up to order `N - 1` |
| 63 | +# `derivatives` API: computes all derivatives of `f` at `x` up to p `P - 1` |
64 | 64 |
|
65 | 65 | # Out-of-place function |
66 | | -@inline derivatives(f, x, l, vN::Val{N}) where {N} = f(make_seed(x, l, vN)) |
| 66 | +@inline derivatives(f, x, l, p::Val{P}) where {P} = f(make_seed(x, l, p)) |
67 | 67 |
|
68 | 68 | # In-place function |
69 | | -@inline function derivatives(f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N} |
70 | | - buffer = similar(y, TaylorScalar{T, N}) |
71 | | - f!(buffer, make_seed(x, l, vN)) |
| 69 | +@inline function derivatives(f!, y::AbstractArray{T}, x, l, p::Val{P}) where {T, P} |
| 70 | + buffer = similar(y, TaylorScalar{T, P}) |
| 71 | + f!(buffer, make_seed(x, l, p)) |
72 | 72 | map!(value, y, buffer) |
73 | 73 | return buffer |
74 | 74 | end |
0 commit comments