@@ -242,25 +242,29 @@ end
242242 _make_fwd_args(func, arg_list)
243243
244244Function `_make_fwd_args` accepts a function name and an argument
245- list, returns a tuple of argument lists whose elements are: 1. the
246- `arg_list` untouched, 2. a new argument list with the function as its
247- first element and other elements in `arg_list` followed, 3. a new
248- argument list with all varargs removed. E.g.:
245+ list, returns a tuple of argument lists whose elements are:
246+ 1. the`arg_list` untouched, 2. a new argument list with the function
247+ as its first element and other elements in `arg_list` followed, 3. a
248+ new argument for the definition of function `track`, 4. a new argument
249+ list with all kwargs removed, 5 the kwargs name if any otherwise an
250+ empty tuple. E.g.:
249251
250252_make_fwd_args(:f, [:(a::String), :(b::TrackedReal), :(args...)])
251253
252254returns
253255
254256([:(a::String), :(b::TrackedReal), :(args...)],
255257 [:f, :(a::String), :(b::TrackedReal), :(args...)],
256- [:(a::String), :(b::TrackedReal)])
258+ [:(::typeof(f)), :(a::String), :(b::TrackedReal), :(args...)],
259+ [:(a::String), :(b::TrackedReal), :(args...)],
260+ :kwargs)
257261
258262It also deals with varargs and variable keyword arguments, and ensures
259263that at least one of the argument is tracked.
260264
261265"""
262- function _make_fwd_args (func, xs_l )
263- has_tracked_data = any (xs_l ) do arg
266+ function _make_fwd_args (func, args_l )
267+ has_tracked_data = any (args_l ) do arg
264268 isa (arg, Expr) && arg. head == :(:: ) &&
265269 arg. args[end ] in (:(ReverseDiff. TrackedReal), :(TrackedReal),
266270 :(ReverseDiff. TrackedArray), :(TrackedArray),
@@ -271,18 +275,24 @@ function _make_fwd_args(func, xs_l)
271275
272276 has_tracked_data || error (" The rule should have at least one tracked argument." )
273277
274- xs_r = copy (xs_l)
275- if isa (xs_r[1 ], Expr) && xs_r[1 ]. head == :parameters # has kw args
276- insert! (xs_r, 2 , func)
278+ kwargs = :(())
279+ args_r = copy (args_l)
280+ args_track = copy (args_l)
281+ if isa (args_r[1 ], Expr) && args_r[1 ]. head == :parameters # has kw args
282+ insert! (args_r, 2 , func)
283+ insert! (args_track, 2 , :(:: typeof ($ func)))
284+ kwargs = gensym (:kwargs )
285+ args_track[1 ]. args = [:($ (kwargs). .. )]
277286 else
278- insert! (xs_r, 1 , func)
287+ insert! (args_r, 1 , func)
288+ insert! (args_track, 1 , :(:: typeof ($ func)))
279289 end
280290
281- xs_t = filter (copy (xs_l )) do arg
291+ args_fixed = filter (copy (args_l )) do arg
282292 ! (isa (arg, Expr) && arg. head == :parameters )
283293 end
284294
285- return xs_l, xs_r, xs_t
295+ return args_l, args_r, args_track, args_fixed, kwargs
286296end
287297
288298"""
@@ -309,14 +319,14 @@ macro grad_from_chainrules(fcall)
309319 fcall. head == :call || error (" The rule should be in format of a function call." )
310320 @capture (fcall, f_ (xs__)) # extract information into f and xs
311321 f = esc (f)
312- xs_l, xs_r, xs_t = _make_fwd_args (f, xs)
322+ args_l, args_r, args_track, args_fixed, kwargs = _make_fwd_args (f, xs)
313323
314324 return quote
315- $ f ($ (xs_l ... )) = ReverseDiff. track ($ (xs_r ... ))
316- function ReverseDiff. track (:: typeof ( $ f), $ (xs_t ... ); kwargs ... )
317- args = ($ (xs_t ... ),)
325+ $ f ($ (args_l ... )) = ReverseDiff. track ($ (args_r ... ))
326+ function ReverseDiff. track ($ (args_track ... ))
327+ args = ($ (args_fixed ... ),)
318328 tp = ReverseDiff. tape (args... )
319- output_value, back = ChainRulesCore. rrule ($ f, map (ReverseDiff. value, args)... ; kwargs... )
329+ output_value, back = ChainRulesCore. rrule ($ f, map (ReverseDiff. value, args)... ; $ kwargs... )
320330 output = ReverseDiff. track (output_value, tp)
321331 closure (cls_args... ; cls_kwargs... ) = ChainRulesCore. rrule ($ f, map (ReverseDiff. value, cls_args)... ; cls_kwargs... )
322332 ReverseDiff. record! (
@@ -325,7 +335,7 @@ macro grad_from_chainrules(fcall)
325335 $ f,
326336 args,
327337 output,
328- (back, closure, kwargs),
338+ (back, closure, $ kwargs),
329339 )
330340 return output
331341 end
0 commit comments