@@ -246,8 +246,8 @@ list, returns a tuple of argument lists whose elements are:
2462461. the`arg_list` untouched, 2. a new argument list with the function
247247as its first element and other elements in `arg_list` followed, 3. a
248248new argument for the definition of function `track`, 4. a new argument
249- list with all kwargs removed, 5 the kwargs name if any otherwise an
250- empty tuple. E.g.:
249+ list with all kwargs removed, 5, types of the arguments in the 4th
250+ element, 5 the kwargs name if any otherwise an empty tuple. E.g.:
251251
252252_make_fwd_args(:f, [:(a::String), :(b::TrackedReal), :(args...)])
253253
@@ -257,6 +257,7 @@ returns
257257 [:f, :(a::String), :(b::TrackedReal), :(args...)],
258258 [:(::typeof(f)), :(a::String), :(b::TrackedReal), :(args...)],
259259 [:(a::String), :(b::TrackedReal), :(args...)],
260+ [:String, :TrackedReal, :(Vararg{Any})],
260261 :kwargs)
261262
262263It also deals with varargs and variable keyword arguments, and ensures
@@ -292,7 +293,18 @@ function _make_fwd_args(func, args_l)
292293 ! (isa (arg, Expr) && arg. head == :parameters )
293294 end
294295
295- return args_l, args_r, args_track, args_fixed, kwargs
296+ arg_types = map (args_fixed) do arg
297+ if isa (arg, Expr) && arg. head == :(... )
298+ :(Vararg{Any})
299+ elseif isa (arg, Expr) && arg. head == :(:: )
300+ arg. args[end ]
301+ else
302+ :Any
303+ end
304+ end
305+
306+
307+ return args_l, args_r, args_track, args_fixed, arg_types, kwargs
296308end
297309
298310"""
@@ -319,7 +331,7 @@ macro grad_from_chainrules(fcall)
319331 fcall. head == :call || error (" The rule should be in format of a function call." )
320332 @capture (fcall, f_ (xs__)) # extract information into f and xs
321333 f = esc (f)
322- args_l, args_r, args_track, args_fixed, kwargs = _make_fwd_args (f, xs)
334+ args_l, args_r, args_track, args_fixed, arg_types, kwargs = _make_fwd_args (f, xs)
323335
324336 return quote
325337 $ f ($ (args_l... )) = ReverseDiff. track ($ (args_r... ))
@@ -340,36 +352,26 @@ macro grad_from_chainrules(fcall)
340352 return output
341353 end
342354
343- if ! hasmethod (
344- ReverseDiff. special_reverse_exec!,
345- Tuple{ReverseDiff. SpecialInstruction{typeof ($ f)}},
346- )
347- @noinline function ReverseDiff. special_reverse_exec! (instruction:: ReverseDiff.SpecialInstruction{typeof($f)} )
348- output = instruction. output
349- input = instruction. input
350- back = instruction. cache[1 ]
351- back_output = back (ReverseDiff. deriv (output))
352- input_derivs = back_output[2 : end ]
353- @assert input_derivs isa Tuple
354- ReverseDiff. _add_to_deriv! .(input, input_derivs)
355- ReverseDiff. unseed! (output)
356- return nothing
357- end
355+ @noinline function ReverseDiff. special_reverse_exec! (instruction:: ReverseDiff.SpecialInstruction{typeof($f), <:Tuple{$(arg_types...)}} )
356+ output = instruction. output
357+ input = instruction. input
358+ back = instruction. cache[1 ]
359+ back_output = back (ReverseDiff. deriv (output))
360+ input_derivs = back_output[2 : end ]
361+ @assert input_derivs isa Tuple
362+ ReverseDiff. _add_to_deriv! .(input, input_derivs)
363+ ReverseDiff. unseed! (output)
364+ return nothing
358365 end
359366
360- if ! hasmethod (
361- ReverseDiff. special_forward_exec!,
362- Tuple{ReverseDiff. SpecialInstruction{typeof ($ f)}},
363- )
364- @noinline function ReverseDiff. special_forward_exec! (instruction:: ReverseDiff.SpecialInstruction{typeof($f)} )
365- output, input = instruction. output, instruction. input
366- ReverseDiff. pull_value! .(input)
367- pullback = instruction. cache[2 ]
368- kwargs = instruction. cache[3 ]
369- out_value = pullback (input... ; kwargs... )[1 ]
370- ReverseDiff. value! (output, out_value)
371- return nothing
372- end
367+ @noinline function ReverseDiff. special_forward_exec! (instruction:: ReverseDiff.SpecialInstruction{typeof($f), <:Tuple{$(arg_types...)}} )
368+ output, input = instruction. output, instruction. input
369+ ReverseDiff. pull_value! .(input)
370+ pullback = instruction. cache[2 ]
371+ kwargs = instruction. cache[3 ]
372+ out_value = pullback (input... ; kwargs... )[1 ]
373+ ReverseDiff. value! (output, out_value)
374+ return nothing
373375 end
374376 end
375377end
0 commit comments