Skip to content

Commit 4f2da6f

Browse files
author
KDr2
committed
revise code
1 parent cc1e503 commit 4f2da6f

1 file changed

Lines changed: 10 additions & 22 deletions

File tree

src/macros.jl

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -265,21 +265,10 @@ that at least one of the argument is tracked.
265265
266266
"""
267267
function _make_fwd_args(func, args_l)
268-
has_tracked_data = any(args_l) do arg
269-
isa(arg, Expr) && arg.head == :(::) &&
270-
arg.args[end] in (:(ReverseDiff.TrackedReal), :(TrackedReal),
271-
:(ReverseDiff.TrackedArray), :(TrackedArray),
272-
:(ReverseDiff.TrackedVector), :(TrackedVector),
273-
:(ReverseDiff.TrackedMatrix), :(TrackedMatrix),
274-
:(ReverseDiff.TrackedVecOrMat), :(TrackedVecOrMat))
275-
end
276-
277-
has_tracked_data || error("The rule should have at least one tracked argument.")
278-
279268
kwargs = :(())
280269
args_r = copy(args_l)
281270
args_track = copy(args_l)
282-
if isa(args_r[1], Expr) && args_r[1].head == :parameters # has kw args
271+
if Meta.isexpr(args_r[1], :parameters) # has kw args
283272
insert!(args_r, 2, func)
284273
insert!(args_track, 2, :(::typeof($func)))
285274
kwargs = gensym(:kwargs)
@@ -290,25 +279,24 @@ function _make_fwd_args(func, args_l)
290279
end
291280

292281
args_fixed = filter(copy(args_l)) do arg
293-
!(isa(arg, Expr) && arg.head == :parameters)
282+
!Meta.isexpr(arg, :parameters)
294283
end
295284

296285
arg_types = map(args_fixed) do arg
297-
if isa(arg, Expr) && arg.head == :(...)
298-
:(Vararg{Any})
299-
elseif isa(arg, Expr) && arg.head == :(::)
286+
if Meta.isexpr(arg, :(...))
287+
Meta.isexpr(arg.args[1], :(::)) ? :(Vararg{$(arg.args[1].args[end])}) : :(Vararg{Any})
288+
elseif Meta.isexpr(arg, :(::))
300289
arg.args[end]
301290
else
302291
:Any
303292
end
304293
end
305294

306-
307295
return args_l, args_r, args_track, args_fixed, arg_types, kwargs
308296
end
309297

310298
"""
311-
ReverseDiff.@grad_from_chainrules Base.sin(x::TrackedReal)
299+
@grad_from_chainrules f(args...; kwargs...)
312300
313301
The `@grad_from_chainrules` macro provides a way to import
314302
adjoints(rrule) defined in ChainRules to ReverseDiff. One must provide
@@ -319,17 +307,17 @@ to which one wants to take derivatives with respect with
319307
respectively. For example, we can import `rrule` of `f(x::Real,
320308
y::Array)` like below:
321309
322-
323310
```julia
324311
ReverseDiff.@grad_from_chainrules f(x::TrackedReal, y::TrackedArray)
325312
ReverseDiff.@grad_from_chainrules f(x::TrackedReal, y::Array)
326313
ReverseDiff.@grad_from_chainrules f(x::Real, y::TrackedArray)
327314
```
328-
329315
"""
330316
macro grad_from_chainrules(fcall)
331-
fcall.head == :call || error("The rule should be in format of a function call.")
332-
@capture(fcall, f_(xs__)) # extract information into f and xs
317+
Meta.isexpr(fcall, :call) && length(fcall.args) >= 2 ||
318+
error("`@grad_from_chainrules` has to be applied to a function signature")
319+
f = fcall.args[1]
320+
xs = fcall.args[2:end]
333321
f = esc(f)
334322
args_l, args_r, args_track, args_fixed, arg_types, kwargs = _make_fwd_args(f, xs)
335323

0 commit comments

Comments
 (0)