Skip to content

Commit 5b8ffab

Browse files
authored
Merge branch 'master' into ap/gpu_arrays
2 parents 2c97323 + 463e830 commit 5b8ffab

7 files changed

Lines changed: 95 additions & 7 deletions

File tree

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDi
4949
result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...)
5050
return quote
5151
$(Expr(:meta, :inline))
52-
V = StaticArrays.similar_type(S, valtype($y))
52+
V = StaticArrays.similar_type(S, valtype(T, $y))
5353
return V($result)
5454
end
5555
end
@@ -79,7 +79,7 @@ end
7979
result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...)
8080
return quote
8181
$(Expr(:meta, :inline))
82-
V = StaticArrays.similar_type(S, valtype(eltype($ydual)), Size($M, $N))
82+
V = StaticArrays.similar_type(S, valtype(T, eltype($ydual)), Size($M, $N))
8383
return V($result)
8484
end
8585
end
@@ -90,7 +90,7 @@ end
9090
end
9191

9292
function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where T
93-
result = similar(ydual, valtype(eltype(ydual)), length(ydual), length(x))
93+
result = similar(ydual, valtype(T, eltype(ydual)), length(ydual), length(x))
9494
return extract_jacobian!(T, result, ydual, length(x))
9595
end
9696

src/dual.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,17 @@ end
128128
@inline valtype(::Dual{T,V,N}) where {T,V,N} = V
129129
@inline valtype(::Type{Dual{T,V,N}}) where {T,V,N} = V
130130

131+
@inline valtype(::Type{T}, ::V) where {T,V} = valtype(T, V)
132+
@inline valtype(::Type, ::Type{V}) where {V} = V
133+
@inline valtype(::Type{T}, ::Type{Dual{T,V,N}}) where {T,V,N} = V
134+
@inline function valtype(::Type{T}, ::Type{Dual{S,V,N}}) where {T,S,V,N}
135+
if S T
136+
Dual{S,V,N}
137+
else
138+
throw(DualMismatchError(T,S))
139+
end
140+
end
141+
131142
@inline tagtype(::V) where {V} = Nothing
132143
@inline tagtype(::Type{V}) where {V} = Nothing
133144
@inline tagtype(::Dual{T,V,N}) where {T,V,N} = T

src/gradient.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ const GRAD_ERROR = DimensionMismatch("gradient(f, x) expects that f(x) is a real
109109
function vector_mode_gradient(f::F, x, cfg::GradientConfig{T}) where {T, F}
110110
ydual = vector_mode_dual_eval!(f, cfg, x)
111111
ydual isa Real || throw(GRAD_ERROR)
112-
result = similar(x, valtype(ydual))
112+
result = similar(x, valtype(T, ydual))
113113
return extract_gradient!(T, result, ydual)
114114
end
115115

@@ -168,7 +168,7 @@ function chunk_mode_gradient_expr(result_definition::Expr)
168168
end
169169

170170
@eval function chunk_mode_gradient(f::F, x, cfg::GradientConfig{T,V,N}) where {F,T,V,N}
171-
$(chunk_mode_gradient_expr(:(result = similar(x, valtype(ydual)))))
171+
$(chunk_mode_gradient_expr(:(result = similar(x, valtype(T, ydual)))))
172172
end
173173

174174
@eval function chunk_mode_gradient!(result, f::F, x, cfg::GradientConfig{T,V,N}) where {F,T,V,N}

src/jacobian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function vector_mode_jacobian(f::F, x, cfg::JacobianConfig{T}) where {F,T}
128128
N = chunksize(cfg)
129129
ydual = vector_mode_dual_eval!(f, cfg, x)
130130
ydual isa AbstractArray || throw(JACOBIAN_ERROR)
131-
result = similar(ydual, valtype(eltype(ydual)), length(ydual), N)
131+
result = similar(ydual, valtype(T, eltype(ydual)), length(ydual), N)
132132
extract_jacobian!(T, result, ydual, N)
133133
extract_value!(T, result, ydual)
134134
return result
@@ -217,7 +217,7 @@ end
217217
seed!(xdual, x)
218218
end,
219219
:(ydual = f(xdual)),
220-
:(result = similar(ydual, valtype(eltype(ydual)), length(ydual), xlen)),
220+
:(result = similar(ydual, valtype(T, eltype(ydual)), length(ydual), xlen)),
221221
:()))
222222
end
223223

test/DualTest.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,21 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false
101101
@test ForwardDiff.valtype(NESTED_FDNUM) == Dual{TestTag,V,M}
102102
@test ForwardDiff.valtype(typeof(NESTED_FDNUM)) == Dual{TestTag,V,M}
103103

104+
@test ForwardDiff.valtype(TestTag, FDNUM) == V
105+
@test ForwardDiff.valtype(TestTag, typeof(FDNUM)) == V
106+
@test ForwardDiff.valtype(TestTag, NESTED_FDNUM) == Dual{TestTag,V,M}
107+
@test ForwardDiff.valtype(TestTag, typeof(NESTED_FDNUM)) == Dual{TestTag,V,M}
108+
109+
@test ForwardDiff.valtype(OuterTestTag, FDNUM) == Dual{TestTag,V,N}
110+
@test ForwardDiff.valtype(OuterTestTag, typeof(FDNUM)) == Dual{TestTag,V,N}
111+
@test ForwardDiff.valtype(OuterTestTag, NESTED_FDNUM) == Dual{TestTag,Dual{TestTag,V,M},N}
112+
@test ForwardDiff.valtype(OuterTestTag, typeof(NESTED_FDNUM)) == Dual{TestTag,Dual{TestTag,V,M},N}
113+
114+
@test_throws ForwardDiff.DualMismatchError(TestTag, OuterTestTag) ForwardDiff.valtype(TestTag, Dual{OuterTestTag}(PRIMAL, PARTIALS))
115+
@test_throws ForwardDiff.DualMismatchError(TestTag, OuterTestTag) ForwardDiff.valtype(TestTag, typeof(Dual{OuterTestTag}(PRIMAL, PARTIALS)))
116+
@test_throws ForwardDiff.DualMismatchError(TestTag, OuterTestTag) ForwardDiff.valtype(TestTag, Dual{OuterTestTag}(Dual{TestTag}(PRIMAL, M_PARTIALS), NESTED_PARTIALS))
117+
@test_throws ForwardDiff.DualMismatchError(TestTag, OuterTestTag) ForwardDiff.valtype(TestTag, typeof(Dual{OuterTestTag}(Dual{TestTag}(PRIMAL, M_PARTIALS), NESTED_PARTIALS)))
118+
104119
#####################
105120
# Generic Functions #
106121
#####################

test/GradientTest.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ using JLArrays
1313

1414
include(joinpath(dirname(@__FILE__), "utils.jl"))
1515

16+
struct TestTag end
17+
struct OuterTestTag end
18+
ForwardDiff.:(::Type{TestTag}, ::Type{OuterTestTag}) = true
19+
ForwardDiff.:(::Type{OuterTestTag}, ::Type{<:Tag}) = true
20+
1621
##################
1722
# hardcoded test #
1823
##################
@@ -282,4 +287,30 @@ end
282287
@test !ForwardDiff.supports_fast_scalar_indexing(UnitLowerTriangular(view(JLArray(rand(6, 6)), 1:3, 1:3)))
283288
end
284289

290+
# issue #769
291+
@testset "functions with `Dual` output" begin
292+
x = [Dual{OuterTestTag}(Dual{TestTag}(1.3, 2.1), Dual{TestTag}(0.3, -2.4))]
293+
f(x) = sum(ForwardDiff.value, x)
294+
der = ForwardDiff.derivative(ForwardDiff.value, only(x))
295+
296+
# Vector mode
297+
grad = ForwardDiff.gradient(f, x)
298+
@test grad isa Vector{typeof(der)}
299+
@test grad == [der]
300+
grad = ForwardDiff.gradient(f, SVector{1}(x))
301+
@test grad isa SVector{1,typeof(der)}
302+
@test grad == SVector{1}(der)
303+
304+
# Chunk mode
305+
y = repeat(x, 3)
306+
cfg = ForwardDiff.GradientConfig(f, y, ForwardDiff.Chunk{2}())
307+
grad = ForwardDiff.gradient(f, y, cfg)
308+
@test grad isa Vector{typeof(der)}
309+
@test grad == [der, der, der]
310+
cfg = ForwardDiff.GradientConfig(f, SVector{3}(y), ForwardDiff.Chunk{2}())
311+
grad = ForwardDiff.gradient(f, SVector{3}(y), cfg)
312+
@test grad isa SVector{3,typeof(der)}
313+
@test grad == SVector{3}(der, der, der)
314+
end
315+
285316
end # module

test/JacobianTest.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ using JLArrays
1212

1313
include(joinpath(dirname(@__FILE__), "utils.jl"))
1414

15+
struct TestTag end
16+
struct OuterTestTag end
17+
ForwardDiff.:(::Type{TestTag}, ::Type{OuterTestTag}) = true
18+
ForwardDiff.:(::Type{OuterTestTag}, ::Type{<:Tag}) = true
19+
1520
##################
1621
# hardcoded test #
1722
##################
@@ -322,4 +327,30 @@ end
322327
@test Array(jac_jl) jac
323328
end
324329

330+
# issue #769
331+
@testset "functions with `Dual` output" begin
332+
x = [Dual{OuterTestTag}(Dual{TestTag}(1.3, 2.1), Dual{TestTag}(0.3, -2.4))]
333+
f(x) = map(ForwardDiff.value, x)
334+
der = ForwardDiff.derivative(ForwardDiff.value, only(x))
335+
336+
# Vector mode
337+
jac = ForwardDiff.jacobian(f, x)
338+
@test jac isa Matrix{typeof(der)}
339+
@test jac == [der;;]
340+
jac = ForwardDiff.jacobian(f, SVector{1}(x))
341+
@test jac isa SMatrix{1,1,typeof(der)}
342+
@test jac == SMatrix{1,1}(der)
343+
344+
# Chunk mode
345+
y = repeat(x, 3)
346+
cfg = ForwardDiff.JacobianConfig(f, y, ForwardDiff.Chunk{2}())
347+
jac = ForwardDiff.jacobian(f, y, cfg)
348+
@test jac isa Matrix{typeof(der)}
349+
@test jac == Diagonal([der, der, der])
350+
cfg = ForwardDiff.JacobianConfig(f, SVector{3}(y), ForwardDiff.Chunk{2}())
351+
jac = ForwardDiff.jacobian(f, SVector{3}(y), cfg)
352+
@test jac isa SMatrix{3,3,typeof(der)}
353+
@test jac == Diagonal([der, der, der])
354+
end
355+
325356
end # module

0 commit comments

Comments
 (0)