@@ -265,21 +265,10 @@ that at least one of the argument is tracked.
265265
266266"""
267267function _make_fwd_args (func, args_l)
268- has_tracked_data = any (args_l) do arg
269- isa (arg, Expr) && arg. head == :(:: ) &&
270- arg. args[end ] in (:(ReverseDiff. TrackedReal), :(TrackedReal),
271- :(ReverseDiff. TrackedArray), :(TrackedArray),
272- :(ReverseDiff. TrackedVector), :(TrackedVector),
273- :(ReverseDiff. TrackedMatrix), :(TrackedMatrix),
274- :(ReverseDiff. TrackedVecOrMat), :(TrackedVecOrMat))
275- end
276-
277- has_tracked_data || error (" The rule should have at least one tracked argument." )
278-
279268 kwargs = :(())
280269 args_r = copy (args_l)
281270 args_track = copy (args_l)
282- if isa (args_r[1 ], Expr) && args_r[ 1 ] . head == :parameters # has kw args
271+ if Meta . isexpr (args_r[1 ], :parameters ) # has kw args
283272 insert! (args_r, 2 , func)
284273 insert! (args_track, 2 , :(:: typeof ($ func)))
285274 kwargs = gensym (:kwargs )
@@ -290,25 +279,24 @@ function _make_fwd_args(func, args_l)
290279 end
291280
292281 args_fixed = filter (copy (args_l)) do arg
293- ! ( isa ( arg, Expr) && arg . head == :parameters )
282+ ! Meta . isexpr ( arg, :parameters )
294283 end
295284
296285 arg_types = map (args_fixed) do arg
297- if isa (arg, Expr) && arg . head == :(... )
298- :(Vararg{Any})
299- elseif isa (arg, Expr) && arg . head == :(:: )
286+ if Meta . isexpr (arg, :(... ) )
287+ Meta . isexpr (arg . args[ 1 ], :( :: )) ? :(Vararg{ $ (arg . args[ 1 ] . args[ end ])}) : :(Vararg{Any})
288+ elseif Meta . isexpr (arg, :(:: ) )
300289 arg. args[end ]
301290 else
302291 :Any
303292 end
304293 end
305294
306-
307295 return args_l, args_r, args_track, args_fixed, arg_types, kwargs
308296end
309297
310298"""
311- ReverseDiff. @grad_from_chainrules Base.sin(x::TrackedReal )
299+ @grad_from_chainrules f(args...; kwargs... )
312300
313301The `@grad_from_chainrules` macro provides a way to import
314302adjoints(rrule) defined in ChainRules to ReverseDiff. One must provide
@@ -319,17 +307,17 @@ to which one wants to take derivatives with respect with
319307respectively. For example, we can import `rrule` of `f(x::Real,
320308y::Array)` like below:
321309
322-
323310```julia
324311ReverseDiff.@grad_from_chainrules f(x::TrackedReal, y::TrackedArray)
325312ReverseDiff.@grad_from_chainrules f(x::TrackedReal, y::Array)
326313ReverseDiff.@grad_from_chainrules f(x::Real, y::TrackedArray)
327314```
328-
329315"""
330316macro grad_from_chainrules (fcall)
331- fcall. head == :call || error (" The rule should be in format of a function call." )
332- @capture (fcall, f_ (xs__)) # extract information into f and xs
317+ Meta. isexpr (fcall, :call ) && length (fcall. args) >= 2 ||
318+ error (" `@grad_from_chainrules` has to be applied to a function signature" )
319+ f = fcall. args[1 ]
320+ xs = fcall. args[2 : end ]
333321 f = esc (f)
334322 args_l, args_r, args_track, args_fixed, arg_types, kwargs = _make_fwd_args (f, xs)
335323
0 commit comments