Skip to content

Commit 9a037b3

Browse files
zhujch1tansongchen
authored andcommitted
Update tests
1 parent eff4ac7 commit 9a037b3

7 files changed

Lines changed: 30 additions & 27 deletions

File tree

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ TaylorDiffSFExt = ["SpecialFunctions"]
2121
ChainRules = "1"
2222
ChainRulesCore = "1"
2323
ChainRulesOverloadGeneration = "0.1"
24-
IrrationalConstants = "0.2"
2524
SpecialFunctions = "2"
2625
SymbolicUtils = "1"
2726
Symbolics = "5"

src/derivative.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,25 @@
22
export derivative
33

44
"""
5-
derivative(f, x::T, order::Int64)
5+
derivative(f, x, order::Int64)
6+
derivative(f, x, l, order::Int64)
7+
8+
Wrapper functions for converting order from a number to a type. Actual APIs are detailed below:
9+
610
derivative(f, x::T, ::Val{N})
711
812
Computes `order`-th derivative of `f` w.r.t. scalar `x`.
913
10-
derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, order::Int64)
1114
derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, ::Val{N})
1215
1316
Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
1417
15-
derivative(f, x::AbstractMatrix{T}, order::Int64)
1618
derivative(f, x::AbstractMatrix{T}, ::Val{N})
17-
derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, order::Int64)
1819
derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, ::Val{N})
1920
20-
Shorthand notations for multiple calculations.
21-
For a M-by-N matrix, calculate the directional derivative for each column.
22-
For a 1-by-N matrix (row vector), calculate the derivative for each scalar.
21+
Batch mode derivative / directional derivative calculations, where each column of `x` represents a scalar or a vector. `f` is expected to accept matrices as input.
22+
- For a M-by-N matrix, calculate the directional derivative for each column.
23+
- For a 1-by-N matrix (row vector), calculate the derivative for each scalar.
2324
"""
2425
function derivative end
2526

@@ -55,7 +56,7 @@ end
5556

5657
@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: TN, N}
5758
size(x)[1] != 1 && @warn "x is not a row vector."
58-
t = make_taylor.(x, one(N), vN)
59+
t = make_taylor.(x, one(T), vN)
5960
return extract_derivative.(f(t), N)
6061
end
6162

test/derivative.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
2+
@testset "Derivative" begin
3+
g(x) = x^3
4+
@test derivative(g, 1.0, 1) 3
5+
6+
h(x) = x.^3
7+
@test derivative(h, [2.0 3.0], 1) [12.0 27.0]
8+
end
9+
10+
@testset "Directional derivative" begin
11+
g(x) = x[1] * x[1] + x[2] * x[2]
12+
@test derivative(g, [1.0, 2.0], [1.0, 0.0], 1) 2.0
13+
14+
h(x) = sum(x, dims=1)
15+
@test derivative(h, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) [2. 2.]
16+
end

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
using TaylorDiff
22
using Test
33

4-
include("scalar.jl")
5-
include("vector.jl")
64
include("primitive.jl")
5+
include("derivative.jl")
76
include("zygote.jl")
87
# include("lux.jl")

test/scalar.jl

Lines changed: 0 additions & 6 deletions
This file was deleted.

test/vector.jl

Lines changed: 0 additions & 6 deletions
This file was deleted.

test/zygote.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ using Zygote, LinearAlgebra
66
for f in (exp, log, sqrt, sin, asin, sinh, asinh)
77
@test gradient(x -> derivative(f, x, 2), some_number)[1]
88
derivative(f, some_number, 3)
9-
derivative_result = vec(derivative(f, some_numbers, 3))
10-
@test Zygote.jacobian(x -> derivative(f, x, 2), some_numbers)[1]
9+
derivative_result = vec(derivative.(f, some_numbers, 3))
10+
@test Zygote.jacobian(x -> derivative.(f, x, 2), some_numbers)[1]
1111
diagm(derivative_result)
1212
end
1313

1414
some_matrix = [0.7 0.1; 0.4 0.2]
1515
f = x -> sum(tanh.(x), dims = 1)
16-
dfdx1(m, x) = derivative(u -> sum(m(u)), x, [1.0, 0.0], 1)
17-
dfdx2(m, x) = derivative(u -> sum(m(u)), x, [0.0, 1.0], 1)
16+
dfdx1(m, x) = derivative(m, x, [1.0, 0.0], 1)
17+
dfdx2(m, x) = derivative(m, x, [0.0, 1.0], 1)
1818
res(m, x) = dfdx1(m, x) .+ 2 * dfdx2(m, x)
1919
grads = Zygote.gradient(some_matrix) do x
2020
sum(res(f, x))

0 commit comments

Comments
 (0)