Skip to content

Commit e187137

Browse files
Revert "Fix sum of tracked reals in Turing (#173)" (#177)
This reverts commit e0f515c.
1 parent 029805d commit e187137

2 files changed

Lines changed: 8 additions & 15 deletions

File tree

src/derivatives/linalg/reductions.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,24 @@
55
# basic sum #
66
#-----------#
77

8-
function Base.sum(x::TrackedArray{V,D}; dims=:) where {V,D}
8+
function Base.sum(x::TrackedArray{V,D}) where {V,D}
99
tp = tape(x)
10-
out = track(sum(value(x), dims = dims), D, tp)
11-
record!(tp, SpecialInstruction, sum, (x, dims), out)
10+
out = track(sum(value(x)), D, tp)
11+
record!(tp, SpecialInstruction, sum, x, out)
1212
return out
1313
end
1414

1515
@noinline function special_reverse_exec!(instruction::SpecialInstruction{typeof(sum)})
16-
input, dims = instruction.input
16+
input = instruction.input
1717
output = instruction.output
18-
if istracked(input)
19-
if dims === Colon()
20-
increment_deriv!(input, deriv(output))
21-
else
22-
increment_deriv!(input, zero(value(input)) .+ deriv(output))
23-
end
24-
end
18+
istracked(input) && increment_deriv!(input, deriv(output))
2519
unseed!(output)
2620
return nothing
2721
end
2822

2923
@noinline function special_forward_exec!(instruction::SpecialInstruction{typeof(sum)})
30-
input, dims = instruction.input
31-
value!(instruction.output, sum(value(input); dims = dims))
24+
input = instruction.input
25+
value!(instruction.output, sum(value(input)))
3226
return nothing
3327
end
3428

src/tracked.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ unseed!(x::AbstractArray, i) = unseed!(x[i])
222222
# `forward_pass!`/`reverse_pass!`.
223223
capture(t::TrackedReal) = ifelse(hastape(t), t, value(t))
224224
capture(t::TrackedArray) = t
225-
capture(t::AbstractArray) = istracked(t) ? map(capture, t) : copy(t)
225+
capture(t::AbstractArray) = istracked(t) ? map!(capture, similar(t), t) : copy(t)
226226

227227
########################
228228
# Conversion/Promotion #
@@ -415,7 +415,6 @@ Base.float(t::TrackedReal{V}) where {V<:AbstractFloat} = t
415415

416416
Base.one(::Type{TrackedReal{V,D,O}}) where {V,D,O} = TrackedReal{V,D,O}(one(V))
417417
Base.zero(::Type{TrackedReal{V,D,O}}) where {V,D,O} = TrackedReal{V,D,O}(zero(V))
418-
Base.zero(::Type{<:TrackedReal{V,D}}) where {V,D} = TrackedReal{V,D,Nothing}(zero(V))
419418

420419
Base.rand(::Type{TrackedReal{V,D,O}}) where {V,D,O} = TrackedReal{V,D,O}(rand(V))
421420
Base.rand(rng::Random.AbstractRNG, ::Type{TrackedReal{V,D,O}}) where {V,D,O} = TrackedReal{V,D,O}(rand(rng, V))

0 commit comments

Comments
 (0)