Skip to content

Commit 39d0c76

Browse files
committed
Add more modes
1 parent b7e97c7 commit 39d0c76

4 files changed

Lines changed: 77 additions & 32 deletions

File tree

src/derivative.jl

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
export derivative
2+
export derivative, derivative!
33

44
"""
55
derivative(f, x, order::Int64)
@@ -24,44 +24,62 @@ Batch mode derivative / directional derivative calculations, where each column o
2424
"""
2525
function derivative end
2626

27+
"""
28+
derivative!(result, f, x, l, ::Val{N})
29+
derivative!(result, f!, y, x, l, ::Val{N})
30+
31+
In-place derivative calculation APIs. `result` is expected to be pre-allocated and have the same shape as `y`.
32+
"""
33+
function derivative! end
34+
2735
# Convenience wrappers for converting orders to value types
2836
# and forward work to core APIs
2937

30-
@inline function derivative(f, x, order::Int64)
31-
derivative(f, x, Val{order + 1}())
32-
end
33-
34-
@inline function derivative(f, x, l, order::Int64)
35-
derivative(f, x, l, Val{order + 1}())
36-
end
38+
@inline derivative(f, x, order::Int64) = derivative(f, x, one(eltype(x)), order)
39+
@inline derivative(f, x, l, order::Int64) = derivative(f, x, l, Val{order + 1}())
3740

3841
# Core APIs
3942

4043
# Added to help Zygote infer types
41-
make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, convert(T, t1))
44+
@inline function make_taylor(x::T, l::S, ::Val{N}) where {T <: TN, S <: TN, N}
45+
TaylorScalar{T, N}(x, convert(T, l))
46+
end
4247

43-
@inline function derivative(f, x::T, ::Val{N}) where {T <: TN, N}
44-
t = TaylorScalar{T, N}(x, one(x))
45-
return extract_derivative(f(t), N)
48+
@inline function make_taylor(x::AbstractArray{T}, l, vN::Val{N}) where {T <: TN, N}
49+
broadcast(make_taylor, x, l, vN)
4650
end
4751

48-
@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S},
49-
vN::Val{N}) where {T <: TN, S <: TN, N}
50-
t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l)
51-
# equivalent to map(TaylorScalar{T, N}, x, l)
52+
# Out-of-place function, out-of-place derivative
53+
@inline function derivative(f, x, l, vN::Val{N}) where {N}
54+
t = make_taylor(x, l, vN)
5255
return extract_derivative(f(t), N)
5356
end
5457

55-
# shorthand notations for matrices
58+
# Below three advanced APIs do not have convenience wrappers
59+
60+
# In-place function, out-of-place derivative
61+
@inline function derivative(f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N}
62+
s = similar(y, TaylorScalar{T, N})
63+
t = make_taylor(x, l, vN)
64+
f!(s, t)
65+
map!(primal, y, s)
66+
return extract_derivative(s, N)
67+
end
5668

57-
@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: TN, N}
58-
size(x)[1] != 1 && @warn "x is not a row vector."
59-
t = make_taylor.(x, one(T), vN)
60-
return extract_derivative.(f(t), N)
69+
# Out-of-place function, in-place derivative
70+
@inline function derivative!(result, f, x, l, vN::Val{N}) where {N}
71+
t = make_taylor(x, l, vN)
72+
s = f(t)
73+
extract_derivative!(result, s, N)
74+
return result
6175
end
6276

63-
@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S},
64-
vN::Val{N}) where {T <: TN, S <: TN, N}
65-
t = make_taylor.(x, l, vN)
66-
return extract_derivative.(f(t), N)
77+
# In-place function, in-place derivative
78+
@inline function derivative!(result, f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N}
79+
s = similar(y, TaylorScalar{T, N})
80+
t = make_taylor(x, l, vN)
81+
f!(s, t)
82+
map!(primal, y, s)
83+
extract_derivative!(result, s, N)
84+
return result
6785
end

src/scalar.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ end
6262
map(t -> extract_derivative(t, i), v)
6363
end
6464
@inline extract_derivative(r, i::Integer) = false
65+
@inline function extract_derivative!(result::AbstractArray, v::AbstractArray{T},
66+
i::Integer) where {T <: TaylorScalar}
67+
map!(t -> extract_derivative(t, i), result, v)
68+
end
6569
@inline primal(t::TaylorScalar) = extract_derivative(t, 1)
6670

6771
@inline zero(::Type{TaylorScalar{T, N}}) where {T, N} = TaylorScalar{T, N}(zero(T))

test/derivative.jl

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,39 @@
11

2-
@testset "Derivative" begin
2+
@testset "O-function, O-derivative" begin
33
g(x) = x^3
44
@test derivative(g, 1.0, 1) 3
55

66
h(x) = x .^ 3
77
@test derivative(h, [2.0 3.0], 1) [12.0 27.0]
8+
9+
g1(x) = x[1] * x[1] + x[2] * x[2]
10+
@test derivative(g1, [1.0, 2.0], [1.0, 0.0], 1) 2.0
11+
12+
h1(x) = sum(x, dims = 1)
13+
@test derivative(h1, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) [2.0 2.0]
814
end
915

10-
@testset "Directional derivative" begin
11-
g(x) = x[1] * x[1] + x[2] * x[2]
12-
@test derivative(g, [1.0, 2.0], [1.0, 0.0], 1) 2.0
16+
@testset "I-function, O-derivative" begin
17+
g!(y, x) = begin
18+
y[1] = x * x
19+
y[2] = x + 1
20+
end
21+
x = 2.0
22+
y = [0.0, 0.0]
23+
@test derivative(g!, y, x, 1.0, Val{2}()) [4.0, 1.0]
24+
end
25+
26+
@testset "O-function, I-derivative" begin
27+
g(x) = x .^ 2
28+
@test derivative!(zeros(2), g, [1.0, 2.0], [1.0, 0.0], Val{2}()) [2.0, 0.0]
29+
end
1330

14-
h(x) = sum(x, dims = 1)
15-
@test derivative(h, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) [2.0 2.0]
31+
@testset "I-function, I-derivative" begin
32+
g!(y, x) = begin
33+
y[1] = x[1] * x[1]
34+
y[2] = x[2] * x[2]
35+
end
36+
x = [2.0, 3.0]
37+
y = [0.0, 0.0]
38+
@test derivative!(y, g!, zeros(2), x, [1.0, 0.0], Val{2}()) [4.0, 0.0]
1639
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ using Test
33

44
include("primitive.jl")
55
include("derivative.jl")
6-
include("zygote.jl")
6+
# include("zygote.jl")
77
# include("lux.jl")

0 commit comments

Comments
 (0)