Skip to content

Commit bebfac2

Browse files
zhujch1tansongchen
authored andcommitted
derivative for matrix
1 parent d35f058 commit bebfac2

2 files changed

Lines changed: 13 additions & 1 deletion

File tree

src/chainrules.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
4141
return A * t, gemv_pullback
4242
end
4343

44+
function rrule(::typeof(*), A::AbstractMatrix{S},
45+
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T}
46+
project_A = ProjectTo(A)
47+
project_B = ProjectTo(B)
48+
function gemm_pullback(x̄)
49+
= unthunk(x̄)
50+
NoTangent(), @thunk(project_A(X̄ * transpose(B))), @thunk(project_B(transpose(A) * X̄))
51+
end
52+
return A * B, gemm_pullback
53+
end
54+
4455
@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T}
4556
project_v = ProjectTo(v)
4657
t + v, x̄ -> (x̄, project_v(x̄))

src/derivative.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,6 @@ end
6161

6262
@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S},
6363
vN::Val{N}) where {T <: TN, S <: TN, N}
64-
mapcols(u -> derivative(f, u, l, vN), x)
64+
t = make_taylor.(x, l, vN)
65+
return extract_derivative.(f(t), N)
6566
end

0 commit comments

Comments
 (0)