Skip to content

Commit e0f515c

Browse files
authored
Fix sum of tracked reals in Turing (#173)
* fix sum for tracked reals * fix sum implementation * fix the fix * pull value * apply rule to trackedarrays only
1 parent 16b3596 commit e0f515c

2 files changed

Lines changed: 15 additions & 8 deletions

File tree

src/derivatives/linalg/reductions.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,30 @@
55
# basic sum #
66
#-----------#
77

8-
function Base.sum(x::TrackedArray{V,D}) where {V,D}
8+
function Base.sum(x::TrackedArray{V,D}; dims=:) where {V,D}
99
tp = tape(x)
10-
out = track(sum(value(x)), D, tp)
11-
record!(tp, SpecialInstruction, sum, x, out)
10+
out = track(sum(value(x), dims = dims), D, tp)
11+
record!(tp, SpecialInstruction, sum, (x, dims), out)
1212
return out
1313
end
1414

1515
@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(sum)})
16-
input = instruction.input
16+
input, dims = instruction.input
1717
output = instruction.output
18-
istracked(input) && increment_deriv!(input, deriv(output))
18+
if istracked(input)
19+
if dims === Colon()
20+
increment_deriv!(input, deriv(output))
21+
else
22+
increment_deriv!(input, zero(value(input)) .+ deriv(output))
23+
end
24+
end
1925
unseed!(output)
2026
return nothing
2127
end
2228

2329
@noinline function special_forward_exec!(instruction::SpecialInstruction{typeof(sum)})
24-
input = instruction.input
25-
value!(instruction.output, sum(value(input)))
30+
input, dims = instruction.input
31+
value!(instruction.output, sum(value(input); dims = dims))
2632
return nothing
2733
end
2834

src/tracked.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ unseed!(x::AbstractArray, i) = unseed!(x[i])
222222
# `forward_pass!`/`reverse_pass!`.
223223
capture(t::TrackedReal) = ifelse(hastape(t), t, value(t))
224224
capture(t::TrackedArray) = t
225-
capture(t::AbstractArray) = istracked(t) ? map!(capture, similar(t), t) : copy(t)
225+
capture(t::AbstractArray) = istracked(t) ? map(capture, t) : copy(t)
226226

227227
########################
228228
# Conversion/Promotion #
@@ -415,6 +415,7 @@ Base.float(t::TrackedReal{V}) where {V<:AbstractFloat} = t
415415

416416
Base.one(::Type{TrackedReal{V,D,O}}) where {V,D,O} = TrackedReal{V,D,O}(one(V))
417417
Base.zero(::Type{TrackedReal{V,D,O}}) where {V,D,O} = TrackedReal{V,D,O}(zero(V))
418+
Base.zero(::Type{<:TrackedReal{V,D}}) where {V,D} = TrackedReal{V,D,Nothing}(zero(V))
418419

419420
Base.rand(::Type{TrackedReal{V,D,O}}) where {V,D,O} = TrackedReal{V,D,O}(rand(V))
420421
Base.rand(rng::Random.AbstractRNG, ::Type{TrackedReal{V,D,O}}) where {V,D,O} = TrackedReal{V,D,O}(rand(rng, V))

0 commit comments

Comments
 (0)