Skip to content

Commit cc1e503

Browse files
author
KDr2
committed
take types of inputs into consideration
1 parent d4c95b1 commit cc1e503

1 file changed

Lines changed: 34 additions & 32 deletions

File tree

src/macros.jl

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@ list, returns a tuple of argument lists whose elements are:
246246
1. the`arg_list` untouched, 2. a new argument list with the function
247247
as its first element and other elements in `arg_list` followed, 3. a
248248
new argument for the definition of function `track`, 4. a new argument
249-
list with all kwargs removed, 5 the kwargs name if any otherwise an
250-
empty tuple. E.g.:
249+
list with all kwargs removed, 5, types of the arguments in the 4th
250+
element, 5 the kwargs name if any otherwise an empty tuple. E.g.:
251251
252252
_make_fwd_args(:f, [:(a::String), :(b::TrackedReal), :(args...)])
253253
@@ -257,6 +257,7 @@ returns
257257
[:f, :(a::String), :(b::TrackedReal), :(args...)],
258258
[:(::typeof(f)), :(a::String), :(b::TrackedReal), :(args...)],
259259
[:(a::String), :(b::TrackedReal), :(args...)],
260+
[:String, :TrackedReal, :(Vararg{Any})],
260261
:kwargs)
261262
262263
It also deals with varargs and variable keyword arguments, and ensures
@@ -292,7 +293,18 @@ function _make_fwd_args(func, args_l)
292293
!(isa(arg, Expr) && arg.head == :parameters)
293294
end
294295

295-
return args_l, args_r, args_track, args_fixed, kwargs
296+
arg_types = map(args_fixed) do arg
297+
if isa(arg, Expr) && arg.head == :(...)
298+
:(Vararg{Any})
299+
elseif isa(arg, Expr) && arg.head == :(::)
300+
arg.args[end]
301+
else
302+
:Any
303+
end
304+
end
305+
306+
307+
return args_l, args_r, args_track, args_fixed, arg_types, kwargs
296308
end
297309

298310
"""
@@ -319,7 +331,7 @@ macro grad_from_chainrules(fcall)
319331
fcall.head == :call || error("The rule should be in format of a function call.")
320332
@capture(fcall, f_(xs__)) # extract information into f and xs
321333
f = esc(f)
322-
args_l, args_r, args_track, args_fixed, kwargs = _make_fwd_args(f, xs)
334+
args_l, args_r, args_track, args_fixed, arg_types, kwargs = _make_fwd_args(f, xs)
323335

324336
return quote
325337
$f($(args_l...)) = ReverseDiff.track($(args_r...))
@@ -340,36 +352,26 @@ macro grad_from_chainrules(fcall)
340352
return output
341353
end
342354

343-
if !hasmethod(
344-
ReverseDiff.special_reverse_exec!,
345-
Tuple{ReverseDiff.SpecialInstruction{typeof($f)}},
346-
)
347-
@noinline function ReverseDiff.special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f)})
348-
output = instruction.output
349-
input = instruction.input
350-
back = instruction.cache[1]
351-
back_output = back(ReverseDiff.deriv(output))
352-
input_derivs = back_output[2:end]
353-
@assert input_derivs isa Tuple
354-
ReverseDiff._add_to_deriv!.(input, input_derivs)
355-
ReverseDiff.unseed!(output)
356-
return nothing
357-
end
355+
@noinline function ReverseDiff.special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f), <:Tuple{$(arg_types...)}})
356+
output = instruction.output
357+
input = instruction.input
358+
back = instruction.cache[1]
359+
back_output = back(ReverseDiff.deriv(output))
360+
input_derivs = back_output[2:end]
361+
@assert input_derivs isa Tuple
362+
ReverseDiff._add_to_deriv!.(input, input_derivs)
363+
ReverseDiff.unseed!(output)
364+
return nothing
358365
end
359366

360-
if !hasmethod(
361-
ReverseDiff.special_forward_exec!,
362-
Tuple{ReverseDiff.SpecialInstruction{typeof($f)}},
363-
)
364-
@noinline function ReverseDiff.special_forward_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f)})
365-
output, input = instruction.output, instruction.input
366-
ReverseDiff.pull_value!.(input)
367-
pullback = instruction.cache[2]
368-
kwargs = instruction.cache[3]
369-
out_value = pullback(input...; kwargs...)[1]
370-
ReverseDiff.value!(output, out_value)
371-
return nothing
372-
end
367+
@noinline function ReverseDiff.special_forward_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f), <:Tuple{$(arg_types...)}})
368+
output, input = instruction.output, instruction.input
369+
ReverseDiff.pull_value!.(input)
370+
pullback = instruction.cache[2]
371+
kwargs = instruction.cache[3]
372+
out_value = pullback(input...; kwargs...)[1]
373+
ReverseDiff.value!(output, out_value)
374+
return nothing
373375
end
374376
end
375377
end

0 commit comments

Comments
 (0)