From 3a76d3a17d60de610044808733c40e343a01f5ca Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 31 Jul 2025 14:21:31 +0200 Subject: [PATCH 1/5] feat: make `FromPrimitive` wrappers public --- DifferentiationInterface/CHANGELOG.md | 5 +++++ DifferentiationInterface/docs/src/api.md | 7 ++++++ .../src/DifferentiationInterface.jl | 1 + .../src/misc/from_primitive.jl | 22 +++++++++++-------- .../test/Core/SimpleFiniteDiff/test.jl | 8 ++++++- 5 files changed, 33 insertions(+), 10 deletions(-) diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index 2b6fe4060..be3c2bef3 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Make `AutoForwardFromPrimitive` and `AutoReverseFromPrimitive` public ([#824]) + ## [0.7.3] ### Fixed @@ -62,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54 [0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53 +[#824]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/824 [#823]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823 [#818]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/818 [#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812 diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index 486a5830f..08e58e081 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -132,6 +132,13 @@ MixedMode DenseSparsityDetector ``` +### From primitive + +```@docs +AutoForwardFromPrimitive +AutoReverseFromPrimitive +``` + ## Internals The following is not part of the public API. diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 32e699572..929abff94 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -126,6 +126,7 @@ export AutoSparse ## Public but not exported @public inner, outer +@public AutoForwardFromPrimitive, AutoReverseFromPrimitive include("init.jl") diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 00dd6b445..091166fbf 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -12,17 +12,21 @@ function pick_batchsize(backend::FromPrimitive, N::Integer) end """ - AutoForwardFromPrimitive + AutoForwardFromPrimitive(backend::AbstractADType) -Wrapper which forces a given backend to act as a reverse-mode backend. +Wrapper which forces a given backend to act as a forward-mode backend, using only its native `value_and_pushforward` primitive and re-implementing the rest from scratch. -Used in internal testing. +!!! tip + This can be useful to circumvent high-level operators when they have impractical limitations. + For instance, ForwardDiff.jl's `jacobian` does not support GPU arrays but its `pushforward` does, so `AutoForwardFromPrimitive(AutoForwardDiff())` has a GPU-friendly `jacobian`. """ struct AutoForwardFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace} backend::B end -function AutoForwardFromPrimitive(backend::AbstractADType; inplace=true) +function AutoForwardFromPrimitive( + backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend)) +) return AutoForwardFromPrimitive{inplace,typeof(backend)}(backend) end @@ -133,17 +137,17 @@ function value_and_pushforward!( end """ - AutoReverseFromPrimitive + AutoReverseFromPrimitive(backend::AbstractADType) -Wrapper which forces a given backend to act as a reverse-mode backend. - -Used in internal testing. +Wrapper which forces a given backend to act as a reverse-mode backend, using only its native `value_and_pullback` implementation and rebuilding the rest from scratch. """ struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace} backend::B end -function AutoReverseFromPrimitive(backend::AbstractADType; inplace=true) +function AutoReverseFromPrimitive( + backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend)) +) return AutoReverseFromPrimitive{inplace,typeof(backend)}(backend) end diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index beec9841f..573aae0af 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -141,6 +141,12 @@ end @testset "Weird arrays" begin test_differentiation( - AutoSimpleFiniteDiff(), vcat(static_scenarios(), gpu_scenarios()); logging=LOGGING + [ + AutoSimpleFiniteDiff(), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff()), + AutoReverseFromPrimitive(AutoSimpleFiniteDiff()), + ], + vcat(static_scenarios(), gpu_scenarios()); + logging=LOGGING, ) end; From 2005fc6edfa800ef03707e171be7c93acdc4d3cd Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 31 Jul 2025 17:56:30 +0200 Subject: [PATCH 2/5] Fix --- DifferentiationInterface/docs/src/api.md | 4 ++-- DifferentiationInterface/test/Back/ForwardDiff/test.jl | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index 08e58e081..ac50ab1ab 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -135,8 +135,8 @@ DenseSparsityDetector ### From primitive ```@docs -AutoForwardFromPrimitive -AutoReverseFromPrimitive +DifferentiationInterface.AutoForwardFromPrimitive +DifferentiationInterface.AutoReverseFromPrimitive ``` ## Internals diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 8a99c08f4..d840f0024 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -8,6 +8,7 @@ import DifferentiationInterface as DI import DifferentiationInterfaceTest as DIT using ForwardDiff: ForwardDiff using StaticArrays: StaticArrays, @SVector +using JLArrays: JLArrays using Test using ExplicitImports @@ -75,6 +76,9 @@ end @testset "Weird" begin test_differentiation(AutoForwardDiff(), component_scenarios(); logging=LOGGING) test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING) + test_differentiation( + DI.AutoForwardFromPrimitive(AutoForwardDiff()), gpu_scenarios(); logging=LOGGING + ) @testset "Batch size" begin @test DI.pick_batchsize(AutoForwardDiff(), rand(7)) isa DI.BatchSizeSettings{7} From 6894ad7b6ffc0fc3eab02088f808fd29636e01c1 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 1 Aug 2025 08:49:34 +0200 Subject: [PATCH 3/5] Fix overload --- .../src/first_order/mixed_mode.jl | 7 +++++++ .../src/misc/from_primitive.jl | 19 +++++++++++++++++-- .../src/utils/batchsize.jl | 7 ------- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/DifferentiationInterface/src/first_order/mixed_mode.jl b/DifferentiationInterface/src/first_order/mixed_mode.jl index 839cb941f..5951b456c 100644 --- a/DifferentiationInterface/src/first_order/mixed_mode.jl +++ b/DifferentiationInterface/src/first_order/mixed_mode.jl @@ -41,3 +41,10 @@ Appropriate mode type for `MixedMode` backends. """ struct ForwardAndReverseMode <: ADTypes.AbstractMode end ADTypes.mode(::MixedMode) = ForwardAndReverseMode() + +function threshold_batchsize(backend::MixedMode, B::Integer) + return MixedMode( + threshold_batchsize(forward_backend(backend), B), + threshold_batchsize(reverse_backend(backend), B), + ) +end diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 091166fbf..81f7ea0b6 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -3,14 +3,29 @@ abstract type FromPrimitive{inplace} <: AbstractADType end check_available(backend::FromPrimitive) = check_available(backend.backend) inplace_support(::FromPrimitive{true}) = InPlaceSupported() inplace_support(::FromPrimitive{false}) = InPlaceNotSupported() -function inner_preparation_behavior(backend::FromPrimitive) - return inner_preparation_behavior(backend.backend) + +function pick_batchsize(backend::FromPrimitive, x_or_y::AbstractArray) + return pick_batchsize(backend.backend, x_or_y) end function pick_batchsize(backend::FromPrimitive, N::Integer) return pick_batchsize(backend.backend, N) end +function inner_preparation_behavior(backend::FromPrimitive) + return inner_preparation_behavior(backend.backend) +end + +function overloaded_input(::typeof(pushforward), f, backend::FromPrimitive, x, tx::NTuple) + return overloaded_input(pushforward, f, backend.backend, x, tx) +end + +function overloaded_input( + ::typeof(pushforward), f!, y, backend::FromPrimitive, x, tx::NTuple +) + return overloaded_input(pushforward, f!, y, backend.backend, x, tx) +end + """ AutoForwardFromPrimitive(backend::AbstractADType) diff --git a/DifferentiationInterface/src/utils/batchsize.jl b/DifferentiationInterface/src/utils/batchsize.jl index b3d370d05..054d5c9b9 100644 --- a/DifferentiationInterface/src/utils/batchsize.jl +++ b/DifferentiationInterface/src/utils/batchsize.jl @@ -112,13 +112,6 @@ function threshold_batchsize(backend::SecondOrder, B::Integer) ) end -function threshold_batchsize(backend::MixedMode, B::Integer) - return MixedMode( - threshold_batchsize(forward_backend(backend), B), - threshold_batchsize(reverse_backend(backend), B), - ) -end - """ reasonable_batchsize(N::Integer, Bmax::Integer) From be7d595e1544f0512e0068511b34f830d2482a7d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 1 Aug 2025 14:55:06 +0200 Subject: [PATCH 4/5] Update CHANGELOG.md --- DifferentiationInterface/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index 6def57c51..dc87ecf38 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -70,7 +70,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54 [0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53 -[#825]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/825 +[#826]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/826 [#824]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/824 [#823]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823 [#818]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/818 From c018bd2274898ed0dc91d38552d5af900b0ebbca Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 1 Aug 2025 14:55:11 +0200 Subject: [PATCH 5/5] Update CHANGELOG.md --- DifferentiationInterface/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index dc87ecf38..341caa0ed 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -71,7 +71,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53 [#826]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/826 -[#824]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/824 +[#825]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/825 [#823]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823 [#818]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/818 [#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812