Skip to content

Commit 4e2f1bf

Browse files
author
KDr2
committed
keep arguments untouched when define track
And, add test cases for varargs and kwargs
1 parent 7aeee94 commit 4e2f1bf

2 files changed

Lines changed: 70 additions & 28 deletions

File tree

src/macros.jl

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -242,25 +242,29 @@ end
242242
_make_fwd_args(func, arg_list)
243243
244244
Function `_make_fwd_args` accepts a function name and an argument
245-
list, returns a tuple of argument lists whose elements are: 1. the
246-
`arg_list` untouched, 2. a new argument list with the function as its
247-
first element and other elements in `arg_list` followed, 3. a new
248-
argument list with all varargs removed. E.g.:
245+
list, returns a tuple of argument lists whose elements are:
246+
1. the`arg_list` untouched, 2. a new argument list with the function
247+
as its first element and other elements in `arg_list` followed, 3. a
248+
new argument for the definition of function `track`, 4. a new argument
249+
list with all kwargs removed, 5 the kwargs name if any otherwise an
250+
empty tuple. E.g.:
249251
250252
_make_fwd_args(:f, [:(a::String), :(b::TrackedReal), :(args...)])
251253
252254
returns
253255
254256
([:(a::String), :(b::TrackedReal), :(args...)],
255257
[:f, :(a::String), :(b::TrackedReal), :(args...)],
256-
[:(a::String), :(b::TrackedReal)])
258+
[:(::typeof(f)), :(a::String), :(b::TrackedReal), :(args...)],
259+
[:(a::String), :(b::TrackedReal), :(args...)],
260+
:kwargs)
257261
258262
It also deals with varargs and variable keyword arguments, and ensures
259263
that at least one of the argument is tracked.
260264
261265
"""
262-
function _make_fwd_args(func, xs_l)
263-
has_tracked_data = any(xs_l) do arg
266+
function _make_fwd_args(func, args_l)
267+
has_tracked_data = any(args_l) do arg
264268
isa(arg, Expr) && arg.head == :(::) &&
265269
arg.args[end] in (:(ReverseDiff.TrackedReal), :(TrackedReal),
266270
:(ReverseDiff.TrackedArray), :(TrackedArray),
@@ -271,18 +275,24 @@ function _make_fwd_args(func, xs_l)
271275

272276
has_tracked_data || error("The rule should have at least one tracked argument.")
273277

274-
xs_r = copy(xs_l)
275-
if isa(xs_r[1], Expr) && xs_r[1].head == :parameters # has kw args
276-
insert!(xs_r, 2, func)
278+
kwargs = :(())
279+
args_r = copy(args_l)
280+
args_track = copy(args_l)
281+
if isa(args_r[1], Expr) && args_r[1].head == :parameters # has kw args
282+
insert!(args_r, 2, func)
283+
insert!(args_track, 2, :(::typeof($func)))
284+
kwargs = gensym(:kwargs)
285+
args_track[1].args = [:($(kwargs)...)]
277286
else
278-
insert!(xs_r, 1, func)
287+
insert!(args_r, 1, func)
288+
insert!(args_track, 1, :(::typeof($func)))
279289
end
280290

281-
xs_t = filter(copy(xs_l)) do arg
291+
args_fixed = filter(copy(args_l)) do arg
282292
!(isa(arg, Expr) && arg.head == :parameters)
283293
end
284294

285-
return xs_l, xs_r, xs_t
295+
return args_l, args_r, args_track, args_fixed, kwargs
286296
end
287297

288298
"""
@@ -309,14 +319,14 @@ macro grad_from_chainrules(fcall)
309319
fcall.head == :call || error("The rule should be in format of a function call.")
310320
@capture(fcall, f_(xs__)) # extract information into f and xs
311321
f = esc(f)
312-
xs_l, xs_r, xs_t = _make_fwd_args(f, xs)
322+
args_l, args_r, args_track, args_fixed, kwargs = _make_fwd_args(f, xs)
313323

314324
return quote
315-
$f($(xs_l...)) = ReverseDiff.track($(xs_r...))
316-
function ReverseDiff.track(::typeof($f), $(xs_t...); kwargs...)
317-
args = ($(xs_t...),)
325+
$f($(args_l...)) = ReverseDiff.track($(args_r...))
326+
function ReverseDiff.track($(args_track...))
327+
args = ($(args_fixed...),)
318328
tp = ReverseDiff.tape(args...)
319-
output_value, back = ChainRulesCore.rrule($f, map(ReverseDiff.value, args)...; kwargs...)
329+
output_value, back = ChainRulesCore.rrule($f, map(ReverseDiff.value, args)...; $kwargs...)
320330
output = ReverseDiff.track(output_value, tp)
321331
closure(cls_args...; cls_kwargs...) = ChainRulesCore.rrule($f, map(ReverseDiff.value, cls_args)...; cls_kwargs...)
322332
ReverseDiff.record!(
@@ -325,7 +335,7 @@ macro grad_from_chainrules(fcall)
325335
$f,
326336
args,
327337
output,
328-
(back, closure, kwargs),
338+
(back, closure, $kwargs),
329339
)
330340
return output
331341
end

test/ChainRulesTests.jl

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module ChainRulesTest
22

33
using LinearAlgebra
44
using ChainRulesCore
5-
using ChainRules
65
using DiffResults
76
using ReverseDiff
87
using Test
@@ -94,19 +93,52 @@ ReverseDiff.@grad_from_chainrules g(x::ReverseDiff.TrackedArray, y::ReverseDiff.
9493

9594
end
9695

97-
### Functions from ChainRules
96+
### Functions with varargs and kwargs
97+
# Varargs
98+
f_vararg(x, args...) = sum(4x .+ sum(args))
9899

99-
# import rrule from ChainRules
100-
ReverseDiff.@grad_from_chainrules LinearAlgebra.norm1(x::ReverseDiff.TrackedArray)
100+
function ChainRulesCore.rrule(::typeof(f_vararg), x, args...)
101+
r = f_vararg(x, args...)
102+
function back(d)
103+
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
104+
end
105+
return r, back
106+
end
107+
108+
ReverseDiff.@grad_from_chainrules f_vararg(x::ReverseDiff.TrackedArray, args...)
101109

102-
@testset "test imported rrules" begin
110+
@testset "Function with Varargs" begin
103111
inputs = (rand(3, 3), )
112+
104113
results = (similar(inputs[1]),)
114+
f_tape = ReverseDiff.GradientTape(x -> f_vararg(x, 1, 2, 3) + 2, (rand(3, 3),))
115+
ReverseDiff.gradient!(results, f_tape, inputs)
105116

106-
g = (x) -> LinearAlgebra.norm1(x)
107-
g_tape = ReverseDiff.GradientTape(g, (rand(3, 3),))
108-
ReverseDiff.gradient!(results, g_tape, inputs)
109-
@test results[1] == fill(1, size(inputs[1]))
117+
@test results[1] == fill(3, size(inputs[1]))
118+
end
119+
120+
121+
# Vargs and kwargs
122+
f_kw(x, args...; k=1, kwargs...) = sum(4x .+ sum(args) .+ (k + kwargs[:j]))
123+
124+
function ChainRulesCore.rrule(::typeof(f_kw), x, args...; k=1, kwargs...)
125+
r = f_kw(x, args...; k=k, kwargs...)
126+
function back(d)
127+
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
128+
end
129+
return r, back
130+
end
131+
132+
ReverseDiff.@grad_from_chainrules f_kw(x::ReverseDiff.TrackedArray, args...; k=1, kwargs...)
133+
134+
@testset "Function with Varargs and kwargs" begin
135+
inputs = (rand(3, 3), )
136+
137+
results = (similar(inputs[1]),)
138+
f_tape = ReverseDiff.GradientTape(x -> f_kw(x, 1, 2, 3; k=2, j=3) + 2, (rand(3, 3),))
139+
ReverseDiff.gradient!(results, f_tape, inputs)
140+
141+
@test results[1] == fill(3, size(inputs[1]))
110142
end
111143

112144
## Mix @grad and @grad_from_chainrules

0 commit comments

Comments
 (0)